diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index be1c40c62..c53440d34 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -60,7 +60,7 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel def publishBytes(bytes: Array[Byte]): Unit = () - def publishEventMessage(event: ConsolePromptEvent): Unit = { + def prompt(event: ConsolePromptEvent): Unit = { if (Terminal.systemInIsAttached) { askUserThread.synchronized { askUserThread.get match { diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 19a4ea5eb..23a6e5b36 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -890,7 +890,7 @@ object BuiltinCommands { val exchange = StandardMain.exchange val welcomeState = displayWelcomeBanner(s0) val s1 = exchange run welcomeState - exchange publishEventMessage ConsolePromptEvent(s0) + exchange prompt ConsolePromptEvent(s0) val minGCInterval = Project .extract(s1) .getOpt(Keys.minForcegcInterval) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index c80b873a9..9320ebb25 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -176,7 +176,7 @@ object MainLoop { /** This is the main function State transfer function of the sbt command processing. */ def processCommand(exec: Exec, state: State): State = { val channelName = exec.source map (_.channelName) - StandardMain.exchange publishEventMessage + StandardMain.exchange notifyStatus ExecStatusEvent("Processing", channelName, exec.execId, Vector()) try { def process(): State = { @@ -197,12 +197,7 @@ object MainLoop { newState.remainingCommands.toVector map (_.commandLine), exitCode(newState, state), ) - if (doneEvent.execId.isDefined) { // send back a response or error - import sbt.protocol.codec.JsonProtocol._ - StandardMain.exchange publishEvent doneEvent - } else { // send back a notification - StandardMain.exchange publishEventMessage doneEvent - } + StandardMain.exchange.respondStatus(doneEvent) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.remove(sbt.Keys.currentTaskProgress) } @@ -225,8 +220,7 @@ object MainLoop { ExitCode(ErrorCodes.UnknownError), Option(err.getMessage), ) - import sbt.protocol.codec.JsonProtocol._ - StandardMain.exchange.publishEvent(errorEvent) + StandardMain.exchange.respondStatus(errorEvent) throw err } } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 09ac87532..d62b94665 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -17,10 +17,10 @@ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal } +import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } -import sbt.protocol.{ EventMessage, ExecStatusEvent } +import sbt.protocol.{ ExecStatusEvent, LogEvent } import sbt.util.{ Level, LogExchange, Logger } import sjsonnew.JsonFormat @@ -212,7 +212,7 @@ private[sbt] final class CommandExchange { // broadcast to the source channel only case c: NetworkChannel if c.name == source => c } - } tryTo(_.respondError(err, execId), removeChannel)(channel) + } tryTo(_.respondError(err, execId))(channel) } // This is an interface to directly respond events. @@ -227,7 +227,7 @@ private[sbt] final class CommandExchange { // broadcast to the source channel only case c: NetworkChannel if c.name == source => c } - } tryTo(_.respondResult(event, execId), removeChannel)(channel) + } tryTo(_.respondResult(event, execId))(channel) } // This is an interface to directly notify events. @@ -235,42 +235,36 @@ private[sbt] final class CommandExchange { channels .collect { case c: NetworkChannel => c } .foreach { - tryTo(_.notifyEvent(method, params), removeChannel) + tryTo(_.notifyEvent(method, params)) } } - private def tryTo(f: NetworkChannel => Unit, fallback: NetworkChannel => Unit)( + private def tryTo(f: NetworkChannel => Unit)( channel: NetworkChannel ): Unit = try f(channel) - catch { case _: IOException => fallback(channel) } + catch { case _: IOException => removeChannel(channel) } - def publishEvent[A: JsonFormat](event: A): Unit = { - event match { - case entry: StringEvent => - // Note that language server's LogMessageParams does not hold the execid, - // so this is weaker than the StringMessage. We might want to double-send - // in case we have a better client that can utilize the knowledge. - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.logMessage(entry.level, entry.message), removeChannel) - } + def respondStatus(event: ExecStatusEvent): Unit = { + import sbt.protocol.codec.JsonProtocol._ + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } { + if (event.execId.isEmpty) { + tryTo(_.notifyEvent(event))(channel) + } else { + event.exitCode match { + case None | Some(0) => + tryTo(_.respondResult(event, event.execId))(channel) + case Some(code) => + tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel) + } + } - case entry: ExecStatusEvent => - for { - source <- entry.channelName - channel <- channels.collectFirst { - case c: NetworkChannel if c.name == source => c - } - } tryTo(_.publishEvent(event), removeChannel)(channel) - - case _ => - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.publishEvent(event), removeChannel) - } + tryTo(_.respond(event, event.execId))(channel) } } @@ -278,39 +272,38 @@ private[sbt] final class CommandExchange { * This publishes object events. The type information has been * erased because it went through logging. */ - private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { + private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = { for { source <- event.channelName channel <- channels.collectFirst { case c: NetworkChannel if c.name == source => c } - } tryTo(_.publishObjectEvent(event), removeChannel)(channel) + } tryTo(_.respond(event))(channel) } - // fanout publishEvent - def publishEventMessage(event: EventMessage): Unit = { - event match { - // Special treatment for ConsolePromptEvent since it's hand coded without codec. - case entry: ConsolePromptEvent => - activePrompt.set(Terminal.systemInIsAttached) - channels - .collect { case c: ConsoleChannel => c } - .foreach { _.publishEventMessage(entry) } - case entry: ExecStatusEvent => - for { - source <- entry.channelName - channel <- channels.collectFirst { - case c: NetworkChannel if c.name == source => c - } - } tryTo(_.publishEventMessage(event), removeChannel)(channel) - - case _ => - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.publishEventMessage(event), removeChannel) - } - } + def prompt(event: ConsolePromptEvent): Unit = { + activePrompt.set(Terminal.systemInIsAttached) + channels + .collect { case c: ConsoleChannel => c } + .foreach { _.prompt(event) } } + + def logMessage(event: LogEvent): Unit = { + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.notifyEvent(event)) + } + } + + def notifyStatus(event: ExecStatusEvent): Unit = { + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.notifyEvent(event))(channel) + } + private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) } diff --git a/main/src/main/scala/sbt/internal/RelayAppender.scala b/main/src/main/scala/sbt/internal/RelayAppender.scala index 988d5feb4..97bb2a1c1 100644 --- a/main/src/main/scala/sbt/internal/RelayAppender.scala +++ b/main/src/main/scala/sbt/internal/RelayAppender.scala @@ -17,7 +17,6 @@ import org.apache.logging.log4j.core.config.Property import sbt.util.Level import sbt.internal.util._ import sbt.protocol.LogEvent -import sbt.internal.util.codec._ class RelayAppender(name: String) extends AbstractAppender( @@ -40,15 +39,12 @@ class RelayAppender(name: String) } } def appendLog(level: Level.Value, message: => String): Unit = { - exchange.publishEventMessage(LogEvent(level.toString, message)) + exchange.logMessage(LogEvent(level.toString, message)) } def appendEvent(event: AnyRef): Unit = event match { - case x: StringEvent => { - import JsonProtocol._ - exchange.publishEvent(x: AbstractEntry) - } - case x: ObjectEvent[_] => exchange.publishObjectEvent(x) + case x: StringEvent => exchange.logMessage(LogEvent(x.message, x.level)) + case x: ObjectEvent[_] => exchange.respondObjectEvent(x) case _ => println(s"appendEvent: ${event.getClass}") () diff --git a/main/src/main/scala/sbt/internal/server/Definition.scala b/main/src/main/scala/sbt/internal/server/Definition.scala index 288ddd00b..ad7bdfe90 100644 --- a/main/src/main/scala/sbt/internal/server/Definition.scala +++ b/main/src/main/scala/sbt/internal/server/Definition.scala @@ -43,7 +43,7 @@ private[sbt] object Definition { case c: NetworkChannel if c.name == source.channelName => c } } { - channel.publishEvent(params, Option(execId)) + channel.respond(params, Option(execId)) } } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 3fa007ca6..58fff4af3 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -20,7 +20,7 @@ import sbt.internal.protocol.{ } import sbt.internal.util.codec.JValueFormats import sbt.internal.util.complete.Parser -import sbt.internal.util.{ ObjectEvent, StringEvent } +import sbt.internal.util.ObjectEvent import sbt.protocol._ import sbt.util.Logger import sjsonnew._ @@ -293,7 +293,9 @@ final class NetworkChannel( onGoingRequests -= id jsonRpcRespond(event, id) case _ => - log.debug(s"unmatched json response: ${CompactPrinter(Converter.toJsonUnsafe(event))}") + log.debug( + s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" + ) } } @@ -305,21 +307,11 @@ final class NetworkChannel( } } - def publishEvent[A: JsonFormat](event: A): Unit = - publishEvent(event, None) + def respond[A: JsonFormat](event: A): Unit = respond(event, None) - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = { + def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = { if (isLanguageServerProtocol) { - event match { - case entry: StringEvent => logMessage(entry.level, entry.message) - case entry: ExecStatusEvent => - entry.exitCode match { - case None => respondResult(event, entry.execId) - case Some(0) => respondResult(event, entry.execId) - case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) - } - case _ => respondResult(event, execId) - } + respondResult(event, execId) } else { contentType match { case SbtX1Protocol => @@ -330,7 +322,7 @@ final class NetworkChannel( } } - def publishEventMessage(event: EventMessage): Unit = { + def notifyEvent(event: EventMessage): Unit = { if (isLanguageServerProtocol) { event match { case entry: LogEvent => logMessage(entry.level, entry.message) @@ -351,22 +343,22 @@ final class NetworkChannel( * This publishes object events. The type information has been * erased because it went through logging. */ - private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { + private[sbt] def respond(event: ObjectEvent[_]): Unit = { import sjsonnew.shaded.scalajson.ast.unsafe._ if (isLanguageServerProtocol) onObjectEvent(event) else { import jsonFormat._ val json: JValue = JObject( JField("type", JString(event.contentType)), - (Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ - (event.channelName.toVector map { channelName => + Seq(JField("message", event.json), JField("level", JString(event.level.toString))) ++ + (event.channelName map { channelName => JField("channelName", JString(channelName)) }) ++ - (event.execId.toVector map { execId => + (event.execId map { execId => JField("execId", JString(execId)) - })): _* + }): _* ) - publishEvent(json) + respond(json, event.execId) } } @@ -393,7 +385,7 @@ final class NetworkChannel( authenticate(x) match { case true => initialized = true - publishEventMessage(ChannelAcceptedEvent(name)) + notifyEvent(ChannelAcceptedEvent(name)) case _ => sys.error("invalid token") } case None => sys.error("init command but without token.") diff --git a/server-test/src/test/scala/testpkg/EventsTest.scala b/server-test/src/test/scala/testpkg/EventsTest.scala index d31796f73..5359bb2dd 100644 --- a/server-test/src/test/scala/testpkg/EventsTest.scala +++ b/server-test/src/test/scala/testpkg/EventsTest.scala @@ -22,15 +22,6 @@ object EventsTest extends AbstractServerTest { }) } - test("report task failures in case of exceptions") { _ => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" - ) - assert(svr.waitForString(10.seconds) { s => - (s contains """"id":11""") && (s contains """"error":""") - }) - } - test("return error if cancelling non-matched task id") { _ => svr.sendJsonRpc( """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index ae2e8c2f1..e9a5faa64 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -189,17 +189,21 @@ case class TestServer( var out = sk.getOutputStream var in = sk.getInputStream - def resetConnection() = { - sk = ClientSocket.socket(portfile)._1 - out = sk.getOutputStream - in = sk.getInputStream - } - // initiate handshake sendJsonRpc( """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" ) + def resetConnection() = { + sk = ClientSocket.socket(portfile)._1 + out = sk.getOutputStream + in = sk.getInputStream + + sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" + ) + } + def test(f: TestServer => Future[Assertion]): Future[Assertion] = { f(this) }