From 781584d137ebf5115d067ec2bde93d5fc28587ef Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Mon, 11 May 2020 15:50:41 +0200 Subject: [PATCH 1/7] id is mandatory in json rpc responses --- .../sbt/internal/client/NetworkClient.scala | 2 +- .../scala/sbt/internal/CommandExchange.scala | 6 +- .../server/LanguageServerProtocol.scala | 20 +-- .../sbt/internal/server/NetworkChannel.scala | 115 +++++++++++------- .../protocol/JsonRpcResponseMessage.scala | 13 +- protocol/src/main/contraband/jsonrpc.contra | 2 +- .../codec/JsonRpcResponseMessageFormats.scala | 12 +- 7 files changed, 97 insertions(+), 73 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index c01d99340..6a53d2837 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin } def onResponse(msg: JsonRpcResponseMessage): Unit = { - msg.id foreach { + msg.id match { case execId if pendingExecIds contains execId => onReturningReponse(msg) lock.synchronized { diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 85aa09f6a..cb029d8d2 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -212,7 +212,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondError(code, message, execId, source) + c.respondError(code, message, execId) } catch { case _: IOException => toDel += c @@ -232,7 +232,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondError(err, execId, source) + c.respondError(err, execId) } catch { case _: IOException => toDel += c @@ -253,7 +253,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondEvent(event, execId, source) + c.respondEvent(event, execId) } catch { case _: IOException => toDel += c diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 86d89e0db..cee5d91fc 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -103,7 +103,7 @@ private[sbt] object LanguageServerProtocol { } /** Implements Language Server Protocol . */ -private[sbt] trait LanguageServerProtocol extends CommandChannel { self => +private[sbt] trait LanguageServerProtocol { self: NetworkChannel => lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {} @@ -117,10 +117,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => protected lazy val callbackImpl: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = - self.jsonRpcRespond(event, execId) + self.respondEvent(event, execId) def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = - self.jsonRpcRespondError(execId, code, message) + self.respondError(code, message, execId) def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = self.jsonRpcNotify(method, params) @@ -162,20 +162,20 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = { - val m = + private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = { + val response = JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None) - val bytes = Serialization.serializeResponseMessage(m) + val bytes = Serialization.serializeResponseMessage(response) publishBytes(bytes) } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = + private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit = jsonRpcRespondErrorImpl(execId, code, message, None) /** Respond back to Language Server's client. */ private[sbt] def jsonRpcRespondError[A: JsonFormat]( - execId: Option[String], + execId: String, code: Long, message: String, data: A, @@ -183,7 +183,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) private[sbt] def jsonRpcRespondError( - execId: Option[String], + execId: String, err: JsonRpcResponseError ): Unit = { val m = JsonRpcResponseMessage("2.0", execId, None, Option(err)) @@ -192,7 +192,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => } private[this] def jsonRpcRespondErrorImpl( - execId: Option[String], + execId: String, code: Long, message: String, data: Option[JValue], diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 66a5e01ac..fc423e57e 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -13,18 +13,21 @@ import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean import sjsonnew._ + import scala.annotation.tailrec import sbt.protocol._ -import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams } +import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } import sbt.internal.util.{ ObjectEvent, StringEvent } import sbt.internal.util.complete.Parser import sbt.internal.util.codec.JValueFormats import sbt.internal.protocol.{ - JsonRpcResponseError, + JsonRpcNotificationMessage, JsonRpcRequestMessage, - JsonRpcNotificationMessage + JsonRpcResponseError } import sbt.util.Logger +import sjsonnew.support.scalajson.unsafe.Converter + import scala.util.Try import scala.util.control.NonFatal @@ -54,6 +57,8 @@ final class NetworkChannel( private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} + private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty + def setContentType(ct: String): Unit = synchronized { _contentType = ct } def contentType: String = _contentType @@ -176,6 +181,7 @@ final class NetworkChannel( intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) { case (f, i) => f orElse i.onRequest } + lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] = intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) { case (f, i) => f orElse i.onNotification @@ -186,25 +192,27 @@ final class NetworkChannel( Serialization.deserializeJsonMessage(chunk) match { case Right(req: JsonRpcRequestMessage) => try { + registerRequest(req) onRequestMessage(req) } catch { case LangServerError(code, message) => log.debug(s"sending error: $code: $message") - jsonRpcRespondError(Option(req.id), code, message) + jsonRpcRespondError(req.id, code, message) } case Right(ntf: JsonRpcNotificationMessage) => try { onNotification(ntf) } catch { case LangServerError(code, message) => - log.debug(s"sending error: $code: $message") - jsonRpcRespondError(None, code, message) // new id? + logMessage("error", s"Error $code while handling notification: $message") } case Right(msg) => log.debug(s"Unhandled message: $msg") case Left(errorDesc) => - val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc - jsonRpcRespondError(None, ErrorCodes.ParseError, msg) + logMessage( + "error", + s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc" + ) } } else { contentType match { @@ -213,13 +221,17 @@ final class NetworkChannel( .deserializeCommand(chunk) .fold( errorDesc => - log.error( + logMessage( + "error", s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc ), onCommand ) case _ => - log.error(s"Unknown Content-Type: $contentType") + logMessage( + "error", + s"Unknown Content-Type: $contentType" + ) } } // if-else } @@ -245,24 +257,43 @@ final class NetworkChannel( } } + private def registerRequest(request: JsonRpcRequestMessage): Unit = { + onGoingRequests += request + } + private[sbt] def respondError( err: JsonRpcResponseError, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespondError(execId, err) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespondError(id, err) + case None => logMessage("error", s"Error ${err.code}: ${err.message}") + } + + } private[sbt] def respondError( code: Long, message: String, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespondError(execId, code, message) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespondError(id, code, message) + case None => logMessage("error", s"Error $code: $message") + } + } private[sbt] def respondEvent[A: JsonFormat]( event: A, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespond(event, execId) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespond(event, id) + case None => + val json = Converter.toJsonUnsafe(event) + log.debug(s"unmatched json response: $json") + } + } private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { if (isLanguageServerProtocol) { @@ -278,12 +309,11 @@ final class NetworkChannel( case entry: StringEvent => logMessage(entry.level, entry.message) case entry: ExecStatusEvent => entry.exitCode match { - case None => jsonRpcRespond(event, entry.execId) - case Some(0) => jsonRpcRespond(event, entry.execId) - case Some(exitCode) => - jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse("")) + case None => respondEvent(event, entry.execId) + case Some(0) => respondEvent(event, entry.execId) + case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) } - case _ => jsonRpcRespond(event, execId) + case _ => respondEvent(event, execId) } } else { contentType match { @@ -383,8 +413,8 @@ final class NetworkChannel( if (initialized) { import sbt.protocol.codec.JsonProtocol._ SettingQuery.handleSettingQueryEither(req, structure) match { - case Right(x) => jsonRpcRespond(x, execId) - case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s) + case Right(x) => respondEvent(x, execId) + case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId) } } else { log.warn(s"ignoring query $req before initialization") @@ -400,32 +430,31 @@ final class NetworkChannel( Parser .completions(sstate.combinedParser, cp.query, 9) .get - .map(c => { + .flatMap { c => if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) else None - }) - .flatten - .map(c => cp.query + c.toString) + } + .map(c => cp.query + c) import sbt.protocol.codec.JsonProtocol._ - jsonRpcRespond( + respondEvent( CompletionResponse( items = completionItems.toVector ), execId ) case _ => - jsonRpcRespondError( - execId, + respondError( ErrorCodes.UnknownError, - "No available sbt state" + "No available sbt state", + execId ) } } catch { - case NonFatal(e) => - jsonRpcRespondError( - execId, + case NonFatal(_) => + respondError( ErrorCodes.UnknownError, - "Completions request failed" + "Completions request failed", + execId ) } } else { @@ -436,10 +465,10 @@ final class NetworkChannel( protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = { if (initialized) { - def errorRespond(msg: String) = jsonRpcRespondError( - execId, + def errorRespond(msg: String) = respondError( ErrorCodes.RequestCancelled, - msg + msg, + execId ) try { @@ -465,11 +494,11 @@ final class NetworkChannel( runningEngine.cancelAndShutdown() import sbt.protocol.codec.JsonProtocol._ - jsonRpcRespond( + respondEvent( ExecStatusEvent( "Task cancelled", Some(name), - Some(runningExecId.toString), + Some(runningExecId), Vector(), None, ), diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala index 4add26b48..4cab85ffb 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala @@ -12,7 +12,7 @@ package sbt.internal.protocol */ final class JsonRpcResponseMessage private ( jsonrpc: String, - val id: Option[String], + val id: String, val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable { @@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private ( override def toString: String = { s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)""" } - private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { + private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { new JsonRpcResponseMessage(jsonrpc, id, result, error) } def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = { copy(jsonrpc = jsonrpc) } - def withId(id: Option[String]): JsonRpcResponseMessage = { - copy(id = id) - } def withId(id: String): JsonRpcResponseMessage = { - copy(id = Option(id)) + copy(id = id) } def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = { copy(result = result) @@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private ( } object JsonRpcResponseMessage { - def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) - def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error)) + def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) + def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error)) } diff --git a/protocol/src/main/contraband/jsonrpc.contra b/protocol/src/main/contraband/jsonrpc.contra index 54a134b40..983f6e747 100644 --- a/protocol/src/main/contraband/jsonrpc.contra +++ b/protocol/src/main/contraband/jsonrpc.contra @@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage jsonrpc: String! ## The request id. - id: String + id: String! ## The result of a request. This can be omitted in ## the case of an error. diff --git a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala index 3ea29d13b..c93b31e22 100644 --- a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala +++ b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala @@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats { unbuilder.beginObject(js) val jsonrpc = unbuilder.readField[String]("jsonrpc") val id = try { - unbuilder.readField[Option[String]]("id") + unbuilder.readField[String]("id") } catch { case _: DeserializationException => - unbuilder.readField[Option[Long]]("id") map { _.toString } + unbuilder.readField[Long]("id").toString } val result = unbuilder.lookupField("result") map { @@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats { } builder.beginObject() builder.addField("jsonrpc", obj.jsonrpc) - obj.id foreach { id => - parseId(id) match { - case Right(strId) => builder.addField("id", strId) - case Left(longId) => builder.addField("id", longId) - } + parseId(obj.id) match { + case Right(strId) => builder.addField("id", strId) + case Left(longId) => builder.addField("id", longId) } builder.addField("result", obj.result map parseResult) builder.addField("error", obj.error) From e040eebd21be8f199ad0fe27cae506ba6f5255f8 Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 09:00:44 +0200 Subject: [PATCH 2/7] "add failing json rpc response tests" --- .../src/server-test/response/build.sbt | 17 ++++++- .../src/test/scala/testpkg/ResponseTest.scala | 50 +++++++++++++++++++ .../src/test/scala/testpkg/TestServer.scala | 46 ++++++++++++----- 3 files changed, 99 insertions(+), 14 deletions(-) diff --git a/server-test/src/server-test/response/build.sbt b/server-test/src/server-test/response/build.sbt index 5d431a084..ee95370c3 100644 --- a/server-test/src/server-test/response/build.sbt +++ b/server-test/src/server-test/response/build.sbt @@ -25,12 +25,25 @@ Global / serverHandlers += ServerHandler({ callback => case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" => appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name)))) () + case r if r.method == "foo/respondTwice" => + appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name)))) + jsonRpcRespond("concurrent response", Some(r.id)) + () + case r if r.method == "foo/resultAndError" => + appendExec(Exec("fooCustomFail", Some(r.id), Some(CommandSource(callback.name)))) + jsonRpcRespond("concurrent response", Some(r.id)) + () }, - PartialFunction.empty + { + case r if r.method == "foo/customNotification" => + jsonRpcRespond("notification result", None) + () + } ) }) lazy val fooClasspath = taskKey[Unit]("") + lazy val root = (project in file(".")) .settings( name := "response", @@ -55,5 +68,5 @@ lazy val root = (project in file(".")) val s = state.value val cp = (Compile / fullClasspath).value s.respondEvent(cp.map(_.data)) - }, + } ) diff --git a/server-test/src/test/scala/testpkg/ResponseTest.scala b/server-test/src/test/scala/testpkg/ResponseTest.scala index 8f8ee3459..75cec4250 100644 --- a/server-test/src/test/scala/testpkg/ResponseTest.scala +++ b/server-test/src/test/scala/testpkg/ResponseTest.scala @@ -64,4 +64,54 @@ object ResponseTest extends AbstractServerTest { (s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""") }) } + + test("respond concurrently from a task and the handler") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "15", "method": "foo/respondTwice", "params": {} }""" + ) + assert { + svr.waitForString(1.seconds) { s => + println(s) + s contains "\"id\":\"15\"" + } + } + assert { + // the second response should never be sent + svr.neverReceive(500.milliseconds) { s => + println(s) + s contains "\"id\":\"15\"" + } + } + } + + test("concurrent result and error") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "16", "method": "foo/resultAndError", "params": {} }""" + ) + assert { + svr.waitForString(1.seconds) { s => + println(s) + s contains "\"id\":\"16\"" + } + } + assert { + // the second response (result or error) should never be sent + svr.neverReceive(500.milliseconds) { s => + println(s) + s contains "\"id\":\"16\"" + } + } + } + + test("response to a notification should not be sent") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "method": "foo/customNotification", "params": {} }""" + ) + assert { + svr.neverReceive(500.milliseconds) { s => + println(s) + s contains "\"result\":\"notification result\"" + } + } + } } diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index ea8625026..ae2e8c2f1 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -8,13 +8,14 @@ package testpkg import java.io.{ File, IOException } +import java.util.concurrent.TimeoutException import verify._ import sbt.RunFromSourceMain import sbt.io.IO import sbt.io.syntax._ import sbt.protocol.ClientSocket -import scala.annotation.tailrec + import scala.concurrent._ import scala.concurrent.duration._ import scala.util.{ Success, Try } @@ -150,6 +151,7 @@ case class TestServer( sbtVersion: String, classpath: Seq[File] ) { + import scala.concurrent.ExecutionContext.Implicits._ import TestServer.hostLog val readBuffer = new Array[Byte](40960) @@ -183,9 +185,15 @@ case class TestServer( waitForPortfile(90.seconds) // make connection to the socket described in the portfile - val (sk, tkn) = ClientSocket.socket(portfile) - val out = sk.getOutputStream - val in = sk.getInputStream + var (sk, _) = ClientSocket.socket(portfile) + var out = sk.getOutputStream + var in = sk.getInputStream + + def resetConnection() = { + sk = ClientSocket.socket(portfile)._1 + out = sk.getOutputStream + in = sk.getInputStream + } // initiate handshake sendJsonRpc( @@ -230,7 +238,7 @@ case class TestServer( writeEndLine } - def readFrame: Option[String] = { + def readFrame: Future[Option[String]] = Future { def getContentLength: Int = { readLine map { line => line.drop(16).toInt @@ -244,14 +252,28 @@ case class TestServer( final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = { val deadline = duration.fromNow - @tailrec def impl(): Boolean = { - if (deadline.isOverdue || !process.isAlive) false - else - readFrame.fold(false)(f) || { - Thread.sleep(100) - impl - } + try { + Await.result(readFrame, deadline.timeLeft).fold(false)(f) || impl + } catch { + case _: TimeoutException => + resetConnection() // create a new connection to invalidate the running readFrame future + false + } + } + impl() + } + + final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = { + val deadline = duration.fromNow + def impl(): Boolean = { + try { + Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) && impl + } catch { + case _: TimeoutException => + resetConnection() // create a new connection to invalidate the running readFrame future + true + } } impl() } From df293fbfd552259169860ad6a3e71505cdf6d7ab Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 10:38:47 +0200 Subject: [PATCH 3/7] prevent multiple response to a single request --- .../server/LanguageServerProtocol.scala | 25 --------- .../sbt/internal/server/NetworkChannel.scala | 54 ++++++++++--------- 2 files changed, 29 insertions(+), 50 deletions(-) diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index cee5d91fc..f5a4ff19a 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -10,7 +10,6 @@ package internal package server import sjsonnew.JsonFormat -import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sjsonnew.support.scalajson.unsafe.Converter import sbt.protocol.Serialization import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q } @@ -170,18 +169,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit = - jsonRpcRespondErrorImpl(execId, code, message, None) - - /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError[A: JsonFormat]( - execId: String, - code: Long, - message: String, - data: A, - ): Unit = - jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) - private[sbt] def jsonRpcRespondError( execId: String, err: JsonRpcResponseError @@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => publishBytes(bytes) } - private[this] def jsonRpcRespondErrorImpl( - execId: String, - code: Long, - message: String, - data: Option[JValue], - ): Unit = { - val e = JsonRpcResponseError(code, message, data) - val m = JsonRpcResponseMessage("2.0", execId, None, Option(e)) - val bytes = Serialization.serializeResponseMessage(m) - publishBytes(bytes) - } - /** Notify to Language Server's client. */ private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = { val m = diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index fc423e57e..8e4d093d9 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -12,22 +12,22 @@ package server import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean -import sjsonnew._ - -import scala.annotation.tailrec -import sbt.protocol._ import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } -import sbt.internal.util.{ ObjectEvent, StringEvent } -import sbt.internal.util.complete.Parser -import sbt.internal.util.codec.JValueFormats import sbt.internal.protocol.{ JsonRpcNotificationMessage, JsonRpcRequestMessage, JsonRpcResponseError } +import sbt.internal.util.codec.JValueFormats +import sbt.internal.util.complete.Parser +import sbt.internal.util.{ ObjectEvent, StringEvent } +import sbt.protocol._ import sbt.util.Logger -import sjsonnew.support.scalajson.unsafe.Converter +import sjsonnew._ +import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } +import scala.annotation.tailrec +import scala.collection.mutable import scala.util.Try import scala.util.control.NonFatal @@ -56,8 +56,7 @@ final class NetworkChannel( private val VsCode = sbt.protocol.Serialization.VsCode private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} - - private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty + private val onGoingRequests: mutable.Set[String] = mutable.Set() def setContentType(ct: String): Unit = synchronized { _contentType = ct } def contentType: String = _contentType @@ -197,7 +196,7 @@ final class NetworkChannel( } catch { case LangServerError(code, message) => log.debug(s"sending error: $code: $message") - jsonRpcRespondError(req.id, code, message) + respondError(code, message, Some(req.id)) } case Right(ntf: JsonRpcNotificationMessage) => try { @@ -258,18 +257,24 @@ final class NetworkChannel( } private def registerRequest(request: JsonRpcRequestMessage): Unit = { - onGoingRequests += request + this.synchronized { + onGoingRequests += request.id + () + } } private[sbt] def respondError( err: JsonRpcResponseError, execId: Option[String] - ): Unit = { + ): Unit = this.synchronized { + println(s"respond error for $execId") execId match { - case Some(id) => jsonRpcRespondError(id, err) - case None => logMessage("error", s"Error ${err.code}: ${err.message}") + case Some(id) if onGoingRequests.contains(id) => + onGoingRequests -= id + jsonRpcRespondError(id, err) + case _ => + logMessage("error", s"Error ${err.code}: ${err.message}") } - } private[sbt] def respondError( @@ -277,21 +282,20 @@ final class NetworkChannel( message: String, execId: Option[String] ): Unit = { - execId match { - case Some(id) => jsonRpcRespondError(id, code, message) - case None => logMessage("error", s"Error $code: $message") - } + respondError(JsonRpcResponseError(code, message), execId) } private[sbt] def respondEvent[A: JsonFormat]( event: A, execId: Option[String] - ): Unit = { + ): Unit = this.synchronized { + println(s"respond result for $execId") execId match { - case Some(id) => jsonRpcRespond(event, id) - case None => - val json = Converter.toJsonUnsafe(event) - log.debug(s"unmatched json response: $json") + case Some(id) if onGoingRequests.contains(id) => + onGoingRequests -= id + jsonRpcRespond(event, id) + case _ => + log.debug(s"unmatched json response: ${CompactPrinter(Converter.toJsonUnsafe(event))}") } } From 255a0a6ea69aa97cb730b52963209b553da901a4 Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 14:33:19 +0200 Subject: [PATCH 4/7] send response to the source channel only --- .../scala/sbt/internal/CommandChannel.scala | 4 - .../scala/sbt/internal/ConsoleChannel.scala | 23 +- .../scala/sbt/internal/CommandExchange.scala | 226 ++++++------------ .../sbt/internal/server/Definition.scala | 2 +- .../server/LanguageServerProtocol.scala | 2 +- .../sbt/internal/server/NetworkChannel.scala | 19 +- 6 files changed, 95 insertions(+), 181 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index fb40245df..dc9f97712 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -11,7 +11,6 @@ package internal import java.util.concurrent.ConcurrentLinkedQueue import sbt.protocol.EventMessage -import sjsonnew.JsonFormat /** * A command channel represents an IO device such as network socket or human @@ -42,9 +41,6 @@ abstract class CommandChannel { } def poll: Option[Exec] = Option(commandQueue.poll) - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit - final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None) - def publishEventMessage(event: EventMessage): Unit def publishBytes(bytes: Array[Byte]): Unit def shutdown(): Unit def name: String diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index be1eadab6..be1c40c62 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference import sbt.BasicKeys._ import sbt.internal.util._ -import sbt.protocol.EventMessage -import sjsonnew.JsonFormat private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { private[this] val askUserThread = new AtomicReference[AskUserThread] @@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel def publishBytes(bytes: Array[Byte]): Unit = () - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = () - - def publishEventMessage(event: EventMessage): Unit = - event match { - case e: ConsolePromptEvent => - if (Terminal.systemInIsAttached) { - askUserThread.synchronized { - askUserThread.get match { - case null => askUserThread.set(makeAskUserThread(e.state)) - case t => t.redraw() - } - } + def publishEventMessage(event: ConsolePromptEvent): Unit = { + if (Terminal.systemInIsAttached) { + askUserThread.synchronized { + askUserThread.get match { + case null => askUserThread.set(makeAskUserThread(event.state)) + case t => t.redraw() } - case _ => // + } } + } def shutdown(): Unit = askUserThread.synchronized { askUserThread.get match { diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index cb029d8d2..09ac87532 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -16,16 +16,13 @@ import java.util.concurrent.atomic._ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError -import sbt.internal.langserver.{ LogMessageParams, MessageType } import sbt.internal.server._ -import sbt.internal.util.codec.JValueFormats import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } import sbt.protocol.{ EventMessage, ExecStatusEvent } import sbt.util.{ Level, LogExchange, Logger } import sjsonnew.JsonFormat -import sjsonnew.shaded.scalajson.ast.unsafe._ import scala.annotation.tailrec import scala.collection.mutable.ListBuffer @@ -51,17 +48,12 @@ private[sbt] final class CommandExchange { private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val nextChannelId: AtomicInteger = new AtomicInteger(0) private[this] val activePrompt = new AtomicBoolean(false) - private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} def channels: List[CommandChannel] = channelBuffer.toList - private[this] def removeChannels(toDel: List[CommandChannel]): Unit = { - toDel match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } + private[this] def removeChannel(channel: CommandChannel): Unit = { + channelBufferLock.synchronized { + channelBuffer -= channel + () } } @@ -206,19 +198,7 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondError(code, message, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + respondError(JsonRpcResponseError(code, message), execId, source) } private[sbt] def respondError( @@ -226,19 +206,13 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondError(err, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- source.map(_.channelName) + channel <- channels.collectFirst { + // broadcast to the source channel only + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.respondError(err, execId), removeChannel)(channel) } // This is an interface to directly respond events. @@ -247,86 +221,57 @@ private[sbt] final class CommandExchange { execId: Option[String], source: Option[CommandSource] ): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - case c: NetworkChannel => - try { - // broadcast to all network channels - c.respondEvent(event, execId) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- source.map(_.channelName) + channel <- channels.collectFirst { + // broadcast to the source channel only + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.respondResult(event, execId), removeChannel)(channel) } // This is an interface to directly notify events. private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - channels.foreach { - case _: ConsoleChannel => - // c.publishEvent(event) - case c: NetworkChannel => - try { - c.notifyEvent(method, params) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.notifyEvent(method, params), removeChannel) + } } - private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit = - try x - catch { case _: IOException => toDel += c } + private def tryTo(f: NetworkChannel => Unit, fallback: NetworkChannel => Unit)( + channel: NetworkChannel + ): Unit = + try f(channel) + catch { case _: IOException => fallback(channel) } def publishEvent[A: JsonFormat](event: A): Unit = { - val broadcastStringMessage = true - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - event match { case entry: StringEvent => - val params = toLogMessageParams(entry) - channels collect { - case c: ConsoleChannel => - if (broadcastStringMessage || (entry.channelName forall (_ == c.name))) - c.publishEvent(event) - case c: NetworkChannel => - tryTo( - { - // Note that language server's LogMessageParams does not hold the execid, - // so this is weaker than the StringMessage. We might want to double-send - // in case we have a better client that can utilize the knowledge. - import sbt.internal.langserver.codec.JsonProtocol._ - if (broadcastStringMessage || (entry.channelName contains c.name)) - c.jsonRpcNotify("window/logMessage", params) - }, - c, - toDel - ) - } - case entry: ExecStatusEvent => - channels collect { - case c: ConsoleChannel => - if (entry.channelName forall (_ == c.name)) c.publishEvent(event) - case c: NetworkChannel => - if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel) - } - case _ => - channels foreach { - case c: ConsoleChannel => c.publishEvent(event) - case c: NetworkChannel => - tryTo(c.publishEvent(event), c, toDel) - } - } - removeChannels(toDel.toList) - } + // Note that language server's LogMessageParams does not hold the execid, + // so this is weaker than the StringMessage. We might want to double-send + // in case we have a better client that can utilize the knowledge. + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.logMessage(entry.level, entry.message), removeChannel) + } - private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = { - LogMessageParams(MessageType.fromLevelString(event.level), event.message) + case entry: ExecStatusEvent => + for { + source <- entry.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishEvent(event), removeChannel)(channel) + + case _ => + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.publishEvent(event), removeChannel) + } + } } /** @@ -334,59 +279,38 @@ private[sbt] final class CommandExchange { * erased because it went through logging. */ private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { - import jsonFormat._ - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - def json: JValue = JObject( - JField("type", JString(event.contentType)), - Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ - (event.channelName.toVector map { channelName => - JField("channelName", JString(channelName)) - }) ++ - (event.execId.toVector map { execId => - JField("execId", JString(execId)) - }): _* - ) - channels collect { - case c: ConsoleChannel => - c.publishEvent(json) - case c: NetworkChannel => - try { - c.publishObjectEvent(event) - } catch { - case _: IOException => - toDel += c - } - } - removeChannels(toDel.toList) + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishObjectEvent(event), removeChannel)(channel) } // fanout publishEvent def publishEventMessage(event: EventMessage): Unit = { - val toDel: ListBuffer[CommandChannel] = ListBuffer.empty - event match { // Special treatment for ConsolePromptEvent since it's hand coded without codec. case entry: ConsolePromptEvent => - channels collect { - case c: ConsoleChannel => - c.publishEventMessage(entry) - activePrompt.set(Terminal.systemInIsAttached) - } + activePrompt.set(Terminal.systemInIsAttached) + channels + .collect { case c: ConsoleChannel => c } + .foreach { _.publishEventMessage(entry) } case entry: ExecStatusEvent => - channels collect { - case c: ConsoleChannel => - if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event) - case c: NetworkChannel => - if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel) - } - case _ => - channels collect { - case c: ConsoleChannel => c.publishEventMessage(event) - case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel) - } - } + for { + source <- entry.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.publishEventMessage(event), removeChannel)(channel) - removeChannels(toDel.toList) + case _ => + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.publishEventMessage(event), removeChannel) + } + } } private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) } diff --git a/main/src/main/scala/sbt/internal/server/Definition.scala b/main/src/main/scala/sbt/internal/server/Definition.scala index 381aa8219..288ddd00b 100644 --- a/main/src/main/scala/sbt/internal/server/Definition.scala +++ b/main/src/main/scala/sbt/internal/server/Definition.scala @@ -40,7 +40,7 @@ private[sbt] object Definition { def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = { for { channel <- StandardMain.exchange.channels.collectFirst { - case c if c.name == source.channelName => c + case c: NetworkChannel if c.name == source.channelName => c } } { channel.publishEvent(params, Option(execId)) diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index f5a4ff19a..8e68dc5a0 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -116,7 +116,7 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => protected lazy val callbackImpl: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = - self.respondEvent(event, execId) + self.respondResult(event, execId) def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = self.respondError(code, message, execId) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 8e4d093d9..3fa007ca6 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -267,7 +267,6 @@ final class NetworkChannel( err: JsonRpcResponseError, execId: Option[String] ): Unit = this.synchronized { - println(s"respond error for $execId") execId match { case Some(id) if onGoingRequests.contains(id) => onGoingRequests -= id @@ -285,11 +284,10 @@ final class NetworkChannel( respondError(JsonRpcResponseError(code, message), execId) } - private[sbt] def respondEvent[A: JsonFormat]( + private[sbt] def respondResult[A: JsonFormat]( event: A, execId: Option[String] ): Unit = this.synchronized { - println(s"respond result for $execId") execId match { case Some(id) if onGoingRequests.contains(id) => onGoingRequests -= id @@ -307,17 +305,20 @@ final class NetworkChannel( } } + def publishEvent[A: JsonFormat](event: A): Unit = + publishEvent(event, None) + def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = { if (isLanguageServerProtocol) { event match { case entry: StringEvent => logMessage(entry.level, entry.message) case entry: ExecStatusEvent => entry.exitCode match { - case None => respondEvent(event, entry.execId) - case Some(0) => respondEvent(event, entry.execId) + case None => respondResult(event, entry.execId) + case Some(0) => respondResult(event, entry.execId) case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) } - case _ => respondEvent(event, execId) + case _ => respondResult(event, execId) } } else { contentType match { @@ -417,7 +418,7 @@ final class NetworkChannel( if (initialized) { import sbt.protocol.codec.JsonProtocol._ SettingQuery.handleSettingQueryEither(req, structure) match { - case Right(x) => respondEvent(x, execId) + case Right(x) => respondResult(x, execId) case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId) } } else { @@ -440,7 +441,7 @@ final class NetworkChannel( } .map(c => cp.query + c) import sbt.protocol.codec.JsonProtocol._ - respondEvent( + respondResult( CompletionResponse( items = completionItems.toVector ), @@ -498,7 +499,7 @@ final class NetworkChannel( runningEngine.cancelAndShutdown() import sbt.protocol.codec.JsonProtocol._ - respondEvent( + respondResult( ExecStatusEvent( "Task cancelled", Some(name), From 8df754eeb1280c2ea4b2fd21eafc6a7ac30444e9 Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 16:26:33 +0200 Subject: [PATCH 5/7] rename publish to either respond or notify --- .../scala/sbt/internal/ConsoleChannel.scala | 2 +- main/src/main/scala/sbt/Main.scala | 2 +- main/src/main/scala/sbt/MainLoop.scala | 12 +- .../scala/sbt/internal/CommandExchange.scala | 109 ++++++++---------- .../scala/sbt/internal/RelayAppender.scala | 10 +- .../sbt/internal/server/Definition.scala | 2 +- .../sbt/internal/server/NetworkChannel.scala | 38 +++--- .../src/test/scala/testpkg/EventsTest.scala | 9 -- .../src/test/scala/testpkg/TestServer.scala | 16 ++- 9 files changed, 85 insertions(+), 115 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index be1c40c62..c53440d34 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -60,7 +60,7 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel def publishBytes(bytes: Array[Byte]): Unit = () - def publishEventMessage(event: ConsolePromptEvent): Unit = { + def prompt(event: ConsolePromptEvent): Unit = { if (Terminal.systemInIsAttached) { askUserThread.synchronized { askUserThread.get match { diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 19a4ea5eb..23a6e5b36 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -890,7 +890,7 @@ object BuiltinCommands { val exchange = StandardMain.exchange val welcomeState = displayWelcomeBanner(s0) val s1 = exchange run welcomeState - exchange publishEventMessage ConsolePromptEvent(s0) + exchange prompt ConsolePromptEvent(s0) val minGCInterval = Project .extract(s1) .getOpt(Keys.minForcegcInterval) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index c80b873a9..9320ebb25 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -176,7 +176,7 @@ object MainLoop { /** This is the main function State transfer function of the sbt command processing. */ def processCommand(exec: Exec, state: State): State = { val channelName = exec.source map (_.channelName) - StandardMain.exchange publishEventMessage + StandardMain.exchange notifyStatus ExecStatusEvent("Processing", channelName, exec.execId, Vector()) try { def process(): State = { @@ -197,12 +197,7 @@ object MainLoop { newState.remainingCommands.toVector map (_.commandLine), exitCode(newState, state), ) - if (doneEvent.execId.isDefined) { // send back a response or error - import sbt.protocol.codec.JsonProtocol._ - StandardMain.exchange publishEvent doneEvent - } else { // send back a notification - StandardMain.exchange publishEventMessage doneEvent - } + StandardMain.exchange.respondStatus(doneEvent) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.remove(sbt.Keys.currentTaskProgress) } @@ -225,8 +220,7 @@ object MainLoop { ExitCode(ErrorCodes.UnknownError), Option(err.getMessage), ) - import sbt.protocol.codec.JsonProtocol._ - StandardMain.exchange.publishEvent(errorEvent) + StandardMain.exchange.respondStatus(errorEvent) throw err } } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 09ac87532..d62b94665 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -17,10 +17,10 @@ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal } +import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } -import sbt.protocol.{ EventMessage, ExecStatusEvent } +import sbt.protocol.{ ExecStatusEvent, LogEvent } import sbt.util.{ Level, LogExchange, Logger } import sjsonnew.JsonFormat @@ -212,7 +212,7 @@ private[sbt] final class CommandExchange { // broadcast to the source channel only case c: NetworkChannel if c.name == source => c } - } tryTo(_.respondError(err, execId), removeChannel)(channel) + } tryTo(_.respondError(err, execId))(channel) } // This is an interface to directly respond events. @@ -227,7 +227,7 @@ private[sbt] final class CommandExchange { // broadcast to the source channel only case c: NetworkChannel if c.name == source => c } - } tryTo(_.respondResult(event, execId), removeChannel)(channel) + } tryTo(_.respondResult(event, execId))(channel) } // This is an interface to directly notify events. @@ -235,42 +235,36 @@ private[sbt] final class CommandExchange { channels .collect { case c: NetworkChannel => c } .foreach { - tryTo(_.notifyEvent(method, params), removeChannel) + tryTo(_.notifyEvent(method, params)) } } - private def tryTo(f: NetworkChannel => Unit, fallback: NetworkChannel => Unit)( + private def tryTo(f: NetworkChannel => Unit)( channel: NetworkChannel ): Unit = try f(channel) - catch { case _: IOException => fallback(channel) } + catch { case _: IOException => removeChannel(channel) } - def publishEvent[A: JsonFormat](event: A): Unit = { - event match { - case entry: StringEvent => - // Note that language server's LogMessageParams does not hold the execid, - // so this is weaker than the StringMessage. We might want to double-send - // in case we have a better client that can utilize the knowledge. - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.logMessage(entry.level, entry.message), removeChannel) - } + def respondStatus(event: ExecStatusEvent): Unit = { + import sbt.protocol.codec.JsonProtocol._ + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } { + if (event.execId.isEmpty) { + tryTo(_.notifyEvent(event))(channel) + } else { + event.exitCode match { + case None | Some(0) => + tryTo(_.respondResult(event, event.execId))(channel) + case Some(code) => + tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel) + } + } - case entry: ExecStatusEvent => - for { - source <- entry.channelName - channel <- channels.collectFirst { - case c: NetworkChannel if c.name == source => c - } - } tryTo(_.publishEvent(event), removeChannel)(channel) - - case _ => - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.publishEvent(event), removeChannel) - } + tryTo(_.respond(event, event.execId))(channel) } } @@ -278,39 +272,38 @@ private[sbt] final class CommandExchange { * This publishes object events. The type information has been * erased because it went through logging. */ - private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { + private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = { for { source <- event.channelName channel <- channels.collectFirst { case c: NetworkChannel if c.name == source => c } - } tryTo(_.publishObjectEvent(event), removeChannel)(channel) + } tryTo(_.respond(event))(channel) } - // fanout publishEvent - def publishEventMessage(event: EventMessage): Unit = { - event match { - // Special treatment for ConsolePromptEvent since it's hand coded without codec. - case entry: ConsolePromptEvent => - activePrompt.set(Terminal.systemInIsAttached) - channels - .collect { case c: ConsoleChannel => c } - .foreach { _.publishEventMessage(entry) } - case entry: ExecStatusEvent => - for { - source <- entry.channelName - channel <- channels.collectFirst { - case c: NetworkChannel if c.name == source => c - } - } tryTo(_.publishEventMessage(event), removeChannel)(channel) - - case _ => - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.publishEventMessage(event), removeChannel) - } - } + def prompt(event: ConsolePromptEvent): Unit = { + activePrompt.set(Terminal.systemInIsAttached) + channels + .collect { case c: ConsoleChannel => c } + .foreach { _.prompt(event) } } + + def logMessage(event: LogEvent): Unit = { + channels + .collect { case c: NetworkChannel => c } + .foreach { + tryTo(_.notifyEvent(event)) + } + } + + def notifyStatus(event: ExecStatusEvent): Unit = { + for { + source <- event.channelName + channel <- channels.collectFirst { + case c: NetworkChannel if c.name == source => c + } + } tryTo(_.notifyEvent(event))(channel) + } + private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) } diff --git a/main/src/main/scala/sbt/internal/RelayAppender.scala b/main/src/main/scala/sbt/internal/RelayAppender.scala index 988d5feb4..97bb2a1c1 100644 --- a/main/src/main/scala/sbt/internal/RelayAppender.scala +++ b/main/src/main/scala/sbt/internal/RelayAppender.scala @@ -17,7 +17,6 @@ import org.apache.logging.log4j.core.config.Property import sbt.util.Level import sbt.internal.util._ import sbt.protocol.LogEvent -import sbt.internal.util.codec._ class RelayAppender(name: String) extends AbstractAppender( @@ -40,15 +39,12 @@ class RelayAppender(name: String) } } def appendLog(level: Level.Value, message: => String): Unit = { - exchange.publishEventMessage(LogEvent(level.toString, message)) + exchange.logMessage(LogEvent(level.toString, message)) } def appendEvent(event: AnyRef): Unit = event match { - case x: StringEvent => { - import JsonProtocol._ - exchange.publishEvent(x: AbstractEntry) - } - case x: ObjectEvent[_] => exchange.publishObjectEvent(x) + case x: StringEvent => exchange.logMessage(LogEvent(x.message, x.level)) + case x: ObjectEvent[_] => exchange.respondObjectEvent(x) case _ => println(s"appendEvent: ${event.getClass}") () diff --git a/main/src/main/scala/sbt/internal/server/Definition.scala b/main/src/main/scala/sbt/internal/server/Definition.scala index 288ddd00b..ad7bdfe90 100644 --- a/main/src/main/scala/sbt/internal/server/Definition.scala +++ b/main/src/main/scala/sbt/internal/server/Definition.scala @@ -43,7 +43,7 @@ private[sbt] object Definition { case c: NetworkChannel if c.name == source.channelName => c } } { - channel.publishEvent(params, Option(execId)) + channel.respond(params, Option(execId)) } } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 3fa007ca6..58fff4af3 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -20,7 +20,7 @@ import sbt.internal.protocol.{ } import sbt.internal.util.codec.JValueFormats import sbt.internal.util.complete.Parser -import sbt.internal.util.{ ObjectEvent, StringEvent } +import sbt.internal.util.ObjectEvent import sbt.protocol._ import sbt.util.Logger import sjsonnew._ @@ -293,7 +293,9 @@ final class NetworkChannel( onGoingRequests -= id jsonRpcRespond(event, id) case _ => - log.debug(s"unmatched json response: ${CompactPrinter(Converter.toJsonUnsafe(event))}") + log.debug( + s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" + ) } } @@ -305,21 +307,11 @@ final class NetworkChannel( } } - def publishEvent[A: JsonFormat](event: A): Unit = - publishEvent(event, None) + def respond[A: JsonFormat](event: A): Unit = respond(event, None) - def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = { + def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = { if (isLanguageServerProtocol) { - event match { - case entry: StringEvent => logMessage(entry.level, entry.message) - case entry: ExecStatusEvent => - entry.exitCode match { - case None => respondResult(event, entry.execId) - case Some(0) => respondResult(event, entry.execId) - case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) - } - case _ => respondResult(event, execId) - } + respondResult(event, execId) } else { contentType match { case SbtX1Protocol => @@ -330,7 +322,7 @@ final class NetworkChannel( } } - def publishEventMessage(event: EventMessage): Unit = { + def notifyEvent(event: EventMessage): Unit = { if (isLanguageServerProtocol) { event match { case entry: LogEvent => logMessage(entry.level, entry.message) @@ -351,22 +343,22 @@ final class NetworkChannel( * This publishes object events. The type information has been * erased because it went through logging. */ - private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { + private[sbt] def respond(event: ObjectEvent[_]): Unit = { import sjsonnew.shaded.scalajson.ast.unsafe._ if (isLanguageServerProtocol) onObjectEvent(event) else { import jsonFormat._ val json: JValue = JObject( JField("type", JString(event.contentType)), - (Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ - (event.channelName.toVector map { channelName => + Seq(JField("message", event.json), JField("level", JString(event.level.toString))) ++ + (event.channelName map { channelName => JField("channelName", JString(channelName)) }) ++ - (event.execId.toVector map { execId => + (event.execId map { execId => JField("execId", JString(execId)) - })): _* + }): _* ) - publishEvent(json) + respond(json, event.execId) } } @@ -393,7 +385,7 @@ final class NetworkChannel( authenticate(x) match { case true => initialized = true - publishEventMessage(ChannelAcceptedEvent(name)) + notifyEvent(ChannelAcceptedEvent(name)) case _ => sys.error("invalid token") } case None => sys.error("init command but without token.") diff --git a/server-test/src/test/scala/testpkg/EventsTest.scala b/server-test/src/test/scala/testpkg/EventsTest.scala index d31796f73..5359bb2dd 100644 --- a/server-test/src/test/scala/testpkg/EventsTest.scala +++ b/server-test/src/test/scala/testpkg/EventsTest.scala @@ -22,15 +22,6 @@ object EventsTest extends AbstractServerTest { }) } - test("report task failures in case of exceptions") { _ => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" - ) - assert(svr.waitForString(10.seconds) { s => - (s contains """"id":11""") && (s contains """"error":""") - }) - } - test("return error if cancelling non-matched task id") { _ => svr.sendJsonRpc( """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index ae2e8c2f1..e9a5faa64 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -189,17 +189,21 @@ case class TestServer( var out = sk.getOutputStream var in = sk.getInputStream - def resetConnection() = { - sk = ClientSocket.socket(portfile)._1 - out = sk.getOutputStream - in = sk.getInputStream - } - // initiate handshake sendJsonRpc( """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" ) + def resetConnection() = { + sk = ClientSocket.socket(portfile)._1 + out = sk.getOutputStream + in = sk.getInputStream + + sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" + ) + } + def test(f: TestServer => Future[Assertion]): Future[Assertion] = { f(this) } From 42e4c5a7c03e32d1314eaffd74e7d9abcfcd820f Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 17:34:34 +0200 Subject: [PATCH 6/7] evict some mima errors in sbt.internal --- .gitignore | 3 +++ build.sbt | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index e9b7d861f..5350c8ef6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ npm-debug.log !sbt/src/server-test/completions/target .big .idea +.bloop +.metals +metals.sbt diff --git a/build.sbt b/build.sbt index ed583c659..05b926cbe 100644 --- a/build.sbt +++ b/build.sbt @@ -677,7 +677,8 @@ lazy val protocolProj = (project in file("protocol")) exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQueryFailure.copy$default$*"), exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy"), exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"), - // ignore missing methods in sbt.internal + // ignore missing or incompatible methods in sbt.internal + exclude[IncompatibleMethTypeProblem]("sbt.internal.*"), exclude[DirectMissingMethodProblem]("sbt.internal.*"), exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"), ) @@ -876,7 +877,7 @@ lazy val mainProj = (project in file("main")) // New and changed methods on KeyIndex. internal. exclude[ReversedMissingMethodProblem]("sbt.internal.KeyIndex.*"), // internal - exclude[IncompatibleMethTypeProblem]("sbt.internal.server.LanguageServerReporter.*"), + exclude[IncompatibleMethTypeProblem]("sbt.internal.*"), // Changed signature or removed private[sbt] methods exclude[DirectMissingMethodProblem]("sbt.Classpaths.unmanagedLibs0"), exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"), From c221f578120931b3b249fd14b40f82f0e64c984d Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 17:40:16 +0200 Subject: [PATCH 7/7] code review --- .../scala/sbt/internal/server/NetworkChannel.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 58fff4af3..fe44bffcf 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -56,7 +56,7 @@ final class NetworkChannel( private val VsCode = sbt.protocol.Serialization.VsCode private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} - private val onGoingRequests: mutable.Set[String] = mutable.Set() + private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map() def setContentType(ct: String): Unit = synchronized { _contentType = ct } def contentType: String = _contentType @@ -258,7 +258,7 @@ final class NetworkChannel( private def registerRequest(request: JsonRpcRequestMessage): Unit = { this.synchronized { - onGoingRequests += request.id + pendingRequests += (request.id -> request) () } } @@ -268,8 +268,8 @@ final class NetworkChannel( execId: Option[String] ): Unit = this.synchronized { execId match { - case Some(id) if onGoingRequests.contains(id) => - onGoingRequests -= id + case Some(id) if pendingRequests.contains(id) => + pendingRequests -= id jsonRpcRespondError(id, err) case _ => logMessage("error", s"Error ${err.code}: ${err.message}") @@ -289,8 +289,8 @@ final class NetworkChannel( execId: Option[String] ): Unit = this.synchronized { execId match { - case Some(id) if onGoingRequests.contains(id) => - onGoingRequests -= id + case Some(id) if pendingRequests.contains(id) => + pendingRequests -= id jsonRpcRespond(event, id) case _ => log.debug(