From 255a0a6ea69aa97cb730b52963209b553da901a4 Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 14:33:19 +0200 Subject: [PATCH] send response to the source channel only --- .../scala/sbt/internal/CommandChannel.scala | 4 - .../scala/sbt/internal/ConsoleChannel.scala | 23 +- .../scala/sbt/internal/CommandExchange.scala | 226 ++++++------------ .../sbt/internal/server/Definition.scala | 2 +- .../server/LanguageServerProtocol.scala | 2 +- .../sbt/internal/server/NetworkChannel.scala | 19 +- 6 files changed, 95 insertions(+), 181 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index fb40245df..dc9f97712 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -11,7 +11,6 @@ package internal import java.util.concurrent.ConcurrentLinkedQueue import sbt.protocol.EventMessage -import sjsonnew.JsonFormat /** * A command channel represents an IO device such as network socket or human @@ -42,9 +41,6 @@ abstract class CommandChannel { } def poll: Option[Exec] = Option(commandQueue.poll) - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit - final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None) - def publishEventMessage(event: EventMessage): Unit def publishBytes(bytes: Array[Byte]): Unit def shutdown(): Unit def name: String diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index be1eadab6..be1c40c62 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference import sbt.BasicKeys._ import sbt.internal.util._ -import sbt.protocol.EventMessage -import sjsonnew.JsonFormat private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { private[this] val askUserThread = new AtomicReference[AskUserThread] @@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel def publishBytes(bytes: Array[Byte]): Unit = () - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = () - - def publishEventMessage(event: EventMessage): Unit = - event match { - case e: ConsolePromptEvent => - if (Terminal.systemInIsAttached) { - askUserThread.synchronized { - askUserThread.get match { - case null => askUserThread.set(makeAskUserThread(e.state)) - case t => t.redraw() - } - } + def publishEventMessage(event: ConsolePromptEvent): Unit = { + if (Terminal.systemInIsAttached) { + askUserThread.synchronized { + askUserThread.get match { + case null => askUserThread.set(makeAskUserThread(event.state)) + case t => t.redraw() } - case _ => // + } } + } def shutdown(): Unit = askUserThread.synchronized { askUserThread.get match { diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index cb029d8d2..09ac87532 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -16,16 +16,13 @@ import java.util.concurrent.atomic._ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError -import sbt.internal.langserver.{ LogMessageParams, MessageType } import sbt.internal.server._ -import sbt.internal.util.codec.JValueFormats import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } import sbt.protocol.{ EventMessage, ExecStatusEvent } import sbt.util.{ Level, LogExchange, Logger } import sjsonnew.JsonFormat -import sjsonnew.shaded.scalajson.ast.unsafe._ import scala.annotation.tailrec import scala.collection.mutable.ListBuffer @@ -51,17 +48,12 @@ private[sbt] final class CommandExchange { private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val nextChannelId: AtomicInteger = new AtomicInteger(0) private[this] val activePrompt = new AtomicBoolean(false) - private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} def channels: List[CommandChannel] = channelBuffer.toList - private[this] def removeChannels(toDel: List[CommandChannel]): Unit = { - toDel match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } + private[this] def removeChannel(channel: CommandChannel): Unit = { + channelBufferLock.synchronized { + channelBuffer -= channel + () } } @@ -206,19 +198,7 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondError(code, message, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + respondError(JsonRpcResponseError(code, message), execId, source) } private[sbt] def respondError( @@ -226,19 +206,13 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondError(err, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- source.map(_.channelName) + channel <- channels.collectFirst { + // broadcast to the source channel only + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.respondError(err, execId), removeChannel)(channel) } // This is an interface to directly respond events. @@ -247,86 +221,57 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondEvent(event, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- source.map(_.channelName) + channel <- channels.collectFirst { + // broadcast to the source channel only + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.respondResult(event, execId), removeChannel)(channel) } // This is an interface to directly notify events. private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - // c.publishEvent(event) - case c: NetworkChannel => - try { - c.notifyEvent(method, params) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.notifyEvent(method, params), removeChannel) + } } - private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit = - try x - catch { case _: IOException => toDel += c } + private def tryTo(f: NetworkChannel => Unit, fallback: NetworkChannel => Unit)( + channel: NetworkChannel + ): Unit = + try f(channel) + catch { case _: IOException => fallback(channel) } def publishEvent[A: JsonFormat](event: A): Unit = { - val broadcastStringMessage = true - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - event match { case entry: StringEvent => - val params = toLogMessageParams(entry) - channels collect { - case c: ConsoleChannel => - if (broadcastStringMessage || (entry.channelName forall (_ == c.name))) - c.publishEvent(event) - case c: NetworkChannel => - tryTo( - { - // Note that language server's LogMessageParams does not hold the execid, - // so this is weaker than the StringMessage. We might want to double-send - // in case we have a better client that can utilize the knowledge. - import sbt.internal.langserver.codec.JsonProtocol._ - if (broadcastStringMessage || (entry.channelName contains c.name)) - c.jsonRpcNotify("window/logMessage", params) - }, - c, - toDel - ) - } - case entry: ExecStatusEvent => - channels collect { - case c: ConsoleChannel => - if (entry.channelName forall (_ == c.name)) c.publishEvent(event) - case c: NetworkChannel => - if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel) - } - case _ => - channels foreach { - case c: ConsoleChannel => c.publishEvent(event) - case c: NetworkChannel => - tryTo(c.publishEvent(event), c, toDel) - } - } - removeChannels(toDel.toList) - } + // Note that language server's LogMessageParams does not hold the execid, + // so this is weaker than the StringMessage. We might want to double-send + // in case we have a better client that can utilize the knowledge. + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.logMessage(entry.level, entry.message), removeChannel) + } - private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = { - LogMessageParams(MessageType.fromLevelString(event.level), event.message) + case entry: ExecStatusEvent => + for { + source <- entry.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishEvent(event), removeChannel)(channel) + + case _ => + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.publishEvent(event), removeChannel) + } + } } /** @@ -334,59 +279,38 @@ private[sbt] final class CommandExchange { * erased because it went through logging. */ private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { - import jsonFormat._ - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - def json: JValue = JObject( - JField("type", JString(event.contentType)), - Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ - (event.channelName.toVector map { channelName => - JField("channelName", JString(channelName)) - }) ++ - (event.execId.toVector map { execId => - JField("execId", JString(execId)) - }): _* - ) - channels collect { - case c: ConsoleChannel => - c.publishEvent(json) - case c: NetworkChannel => - try { - c.publishObjectEvent(event) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishObjectEvent(event), removeChannel)(channel) } // fanout publishEvent def publishEventMessage(event: EventMessage): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - event match { // Special treatment for ConsolePromptEvent since it's hand coded without codec. case entry: ConsolePromptEvent => - channels collect { - case c: ConsoleChannel => - c.publishEventMessage(entry) - activePrompt.set(Terminal.systemInIsAttached) - } + activePrompt.set(Terminal.systemInIsAttached) + channels + .collect { case c: ConsoleChannel => c } + .foreach { _.publishEventMessage(entry) } case entry: ExecStatusEvent => - channels collect { - case c: ConsoleChannel => - if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event) - case c: NetworkChannel => - if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel) - } - case _ => - channels collect { - case c: ConsoleChannel => c.publishEventMessage(event) - case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel) - } - } + for { + source <- entry.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishEventMessage(event), removeChannel)(channel) - removeChannels(toDel.toList) + case _ => + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.publishEventMessage(event), removeChannel) + } + } } private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) } diff --git a/main/src/main/scala/sbt/internal/server/Definition.scala b/main/src/main/scala/sbt/internal/server/Definition.scala index 381aa8219..288ddd00b 100644 --- a/main/src/main/scala/sbt/internal/server/Definition.scala +++ b/main/src/main/scala/sbt/internal/server/Definition.scala @@ -40,7 +40,7 @@ private[sbt] object Definition { def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = { for { channel <- StandardMain.exchange.channels.collectFirst { - case c if c.name == source.channelName => c + case c: NetworkChannel if c.name == source.channelName => c } } { channel.publishEvent(params, Option(execId)) diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index f5a4ff19a..8e68dc5a0 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -116,7 +116,7 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => protected lazy val callbackImpl: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = - self.respondEvent(event, execId) + self.respondResult(event, execId) def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = self.respondError(code, message, execId) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 8e4d093d9..3fa007ca6 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -267,7 +267,6 @@ final class NetworkChannel( err: JsonRpcResponseError, execId: Option[String] ): Unit = this.synchronized { - println(s"respond error for $execId") execId match { case Some(id) if onGoingRequests.contains(id) => onGoingRequests -= id @@ -285,11 +284,10 @@ final class NetworkChannel( respondError(JsonRpcResponseError(code, message), execId) } - private[sbt] def respondEvent[A: JsonFormat]( + private[sbt] def respondResult[A: JsonFormat]( event: A, execId: Option[String] ): Unit = this.synchronized { - println(s"respond result for $execId") execId match { case Some(id) if onGoingRequests.contains(id) => onGoingRequests -= id @@ -307,17 +305,20 @@ final class NetworkChannel( } } + def publishEvent[A: JsonFormat](event: A): Unit = + publishEvent(event, None) + def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = { if (isLanguageServerProtocol) { event match { case entry: StringEvent => logMessage(entry.level, entry.message) case entry: ExecStatusEvent => entry.exitCode match { - case None => respondEvent(event, entry.execId) - case Some(0) => respondEvent(event, entry.execId) + case None => respondResult(event, entry.execId) + case Some(0) => respondResult(event, entry.execId) case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) } - case _ => respondEvent(event, execId) + case _ => respondResult(event, execId) } } else { contentType match { @@ -417,7 +418,7 @@ final class NetworkChannel( if (initialized) { import sbt.protocol.codec.JsonProtocol._ SettingQuery.handleSettingQueryEither(req, structure) match { - case Right(x) => respondEvent(x, execId) + case Right(x) => respondResult(x, execId) case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId) } } else { @@ -440,7 +441,7 @@ final class NetworkChannel( } .map(c => cp.query + c) import sbt.protocol.codec.JsonProtocol._ - respondEvent( + respondResult( CompletionResponse( items = completionItems.toVector ), @@ -498,7 +499,7 @@ final class NetworkChannel( runningEngine.cancelAndShutdown() import sbt.protocol.codec.JsonProtocol._ - respondEvent( + respondResult( ExecStatusEvent( "Task cancelled", Some(name),