From 781584d137ebf5115d067ec2bde93d5fc28587ef Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Mon, 11 May 2020 15:50:41 +0200 Subject: [PATCH] id is mandatory in json rpc responses --- .../sbt/internal/client/NetworkClient.scala | 2 +- .../scala/sbt/internal/CommandExchange.scala | 6 +- .../server/LanguageServerProtocol.scala | 20 +-- .../sbt/internal/server/NetworkChannel.scala | 115 +++++++++++------- .../protocol/JsonRpcResponseMessage.scala | 13 +- protocol/src/main/contraband/jsonrpc.contra | 2 +- .../codec/JsonRpcResponseMessageFormats.scala | 12 +- 7 files changed, 97 insertions(+), 73 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index c01d99340..6a53d2837 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin } def onResponse(msg: JsonRpcResponseMessage): Unit = { - msg.id foreach { + msg.id match { case execId if pendingExecIds contains execId => onReturningReponse(msg) lock.synchronized { diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 85aa09f6a..cb029d8d2 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -212,7 +212,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondError(code, message, execId, source) + c.respondError(code, message, execId) } catch { case _: IOException => toDel += c @@ -232,7 +232,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondError(err, execId, source) + c.respondError(err, execId) } catch { case _: IOException => toDel += c @@ -253,7 +253,7 @@ private[sbt] final class CommandExchange { case c: NetworkChannel => try { // broadcast to all network channels - c.respondEvent(event, execId, source) + c.respondEvent(event, execId) } catch { case _: IOException => toDel += c diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 86d89e0db..cee5d91fc 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -103,7 +103,7 @@ private[sbt] object LanguageServerProtocol { } /** Implements Language Server Protocol . */ -private[sbt] trait LanguageServerProtocol extends CommandChannel { self => +private[sbt] trait LanguageServerProtocol { self: NetworkChannel => lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {} @@ -117,10 +117,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => protected lazy val callbackImpl: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = - self.jsonRpcRespond(event, execId) + self.respondEvent(event, execId) def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = - self.jsonRpcRespondError(execId, code, message) + self.respondError(code, message, execId) def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = self.jsonRpcNotify(method, params) @@ -162,20 +162,20 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = { - val m = + private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = { + val response = JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None) - val bytes = Serialization.serializeResponseMessage(m) + val bytes = Serialization.serializeResponseMessage(response) publishBytes(bytes) } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = + private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit = jsonRpcRespondErrorImpl(execId, code, message, None) /** Respond back to Language Server's client. */ private[sbt] def jsonRpcRespondError[A: JsonFormat]( - execId: Option[String], + execId: String, code: Long, message: String, data: A, @@ -183,7 +183,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) private[sbt] def jsonRpcRespondError( - execId: Option[String], + execId: String, err: JsonRpcResponseError ): Unit = { val m = JsonRpcResponseMessage("2.0", execId, None, Option(err)) @@ -192,7 +192,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => } private[this] def jsonRpcRespondErrorImpl( - execId: Option[String], + execId: String, code: Long, message: String, data: Option[JValue], diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 66a5e01ac..fc423e57e 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -13,18 +13,21 @@ import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean import sjsonnew._ + import scala.annotation.tailrec import sbt.protocol._ -import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams } +import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } import sbt.internal.util.{ ObjectEvent, StringEvent } import sbt.internal.util.complete.Parser import sbt.internal.util.codec.JValueFormats import sbt.internal.protocol.{ - JsonRpcResponseError, + JsonRpcNotificationMessage, JsonRpcRequestMessage, - JsonRpcNotificationMessage + JsonRpcResponseError } import sbt.util.Logger +import sjsonnew.support.scalajson.unsafe.Converter + import scala.util.Try import scala.util.control.NonFatal @@ -54,6 +57,8 @@ final class NetworkChannel( private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} + private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty + def setContentType(ct: String): Unit = synchronized { _contentType = ct } def contentType: String = _contentType @@ -176,6 +181,7 @@ final class NetworkChannel( intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) { case (f, i) => f orElse i.onRequest } + lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] = intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) { case (f, i) => f orElse i.onNotification @@ -186,25 +192,27 @@ final class NetworkChannel( Serialization.deserializeJsonMessage(chunk) match { case Right(req: JsonRpcRequestMessage) => try { + registerRequest(req) onRequestMessage(req) } catch { case LangServerError(code, message) => log.debug(s"sending error: $code: $message") - jsonRpcRespondError(Option(req.id), code, message) + jsonRpcRespondError(req.id, code, message) } case Right(ntf: JsonRpcNotificationMessage) => try { onNotification(ntf) } catch { case LangServerError(code, message) => - log.debug(s"sending error: $code: $message") - jsonRpcRespondError(None, code, message) // new id? + logMessage("error", s"Error $code while handling notification: $message") } case Right(msg) => log.debug(s"Unhandled message: $msg") case Left(errorDesc) => - val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc - jsonRpcRespondError(None, ErrorCodes.ParseError, msg) + logMessage( + "error", + s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc" + ) } } else { contentType match { @@ -213,13 +221,17 @@ final class NetworkChannel( .deserializeCommand(chunk) .fold( errorDesc => - log.error( + logMessage( + "error", s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc ), onCommand ) case _ => - log.error(s"Unknown Content-Type: $contentType") + logMessage( + "error", + s"Unknown Content-Type: $contentType" + ) } } // if-else } @@ -245,24 +257,43 @@ final class NetworkChannel( } } + private def registerRequest(request: JsonRpcRequestMessage): Unit = { + onGoingRequests += request + } + private[sbt] def respondError( err: JsonRpcResponseError, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespondError(execId, err) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespondError(id, err) + case None => logMessage("error", s"Error ${err.code}: ${err.message}") + } + + } private[sbt] def respondError( code: Long, message: String, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespondError(execId, code, message) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespondError(id, code, message) + case None => logMessage("error", s"Error $code: $message") + } + } private[sbt] def respondEvent[A: JsonFormat]( event: A, - execId: Option[String], - source: Option[CommandSource] - ): Unit = jsonRpcRespond(event, execId) + execId: Option[String] + ): Unit = { + execId match { + case Some(id) => jsonRpcRespond(event, id) + case None => + val json = Converter.toJsonUnsafe(event) + log.debug(s"unmatched json response: $json") + } + } private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { if (isLanguageServerProtocol) { @@ -278,12 +309,11 @@ final class NetworkChannel( case entry: StringEvent => logMessage(entry.level, entry.message) case entry: ExecStatusEvent => entry.exitCode match { - case None => jsonRpcRespond(event, entry.execId) - case Some(0) => jsonRpcRespond(event, entry.execId) - case Some(exitCode) => - jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse("")) + case None => respondEvent(event, entry.execId) + case Some(0) => respondEvent(event, entry.execId) + case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId) } - case _ => jsonRpcRespond(event, execId) + case _ => respondEvent(event, execId) } } else { contentType match { @@ -383,8 +413,8 @@ final class NetworkChannel( if (initialized) { import sbt.protocol.codec.JsonProtocol._ SettingQuery.handleSettingQueryEither(req, structure) match { - case Right(x) => jsonRpcRespond(x, execId) - case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s) + case Right(x) => respondEvent(x, execId) + case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId) } } else { log.warn(s"ignoring query $req before initialization") @@ -400,32 +430,31 @@ final class NetworkChannel( Parser .completions(sstate.combinedParser, cp.query, 9) .get - .map(c => { + .flatMap { c => if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) else None - }) - .flatten - .map(c => cp.query + c.toString) + } + .map(c => cp.query + c) import sbt.protocol.codec.JsonProtocol._ - jsonRpcRespond( + respondEvent( CompletionResponse( items = completionItems.toVector ), execId ) case _ => - jsonRpcRespondError( - execId, + respondError( ErrorCodes.UnknownError, - "No available sbt state" + "No available sbt state", + execId ) } } catch { - case NonFatal(e) => - jsonRpcRespondError( - execId, + case NonFatal(_) => + respondError( ErrorCodes.UnknownError, - "Completions request failed" + "Completions request failed", + execId ) } } else { @@ -436,10 +465,10 @@ final class NetworkChannel( protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = { if (initialized) { - def errorRespond(msg: String) = jsonRpcRespondError( - execId, + def errorRespond(msg: String) = respondError( ErrorCodes.RequestCancelled, - msg + msg, + execId ) try { @@ -465,11 +494,11 @@ final class NetworkChannel( runningEngine.cancelAndShutdown() import sbt.protocol.codec.JsonProtocol._ - jsonRpcRespond( + respondEvent( ExecStatusEvent( "Task cancelled", Some(name), - Some(runningExecId.toString), + Some(runningExecId), Vector(), None, ), diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala index 4add26b48..4cab85ffb 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseMessage.scala @@ -12,7 +12,7 @@ package sbt.internal.protocol */ final class JsonRpcResponseMessage private ( jsonrpc: String, - val id: Option[String], + val id: String, val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable { @@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private ( override def toString: String = { s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)""" } - private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { + private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { new JsonRpcResponseMessage(jsonrpc, id, result, error) } def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = { copy(jsonrpc = jsonrpc) } - def withId(id: Option[String]): JsonRpcResponseMessage = { - copy(id = id) - } def withId(id: String): JsonRpcResponseMessage = { - copy(id = Option(id)) + copy(id = id) } def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = { copy(result = result) @@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private ( } object JsonRpcResponseMessage { - def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) - def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error)) + def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) + def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error)) } diff --git a/protocol/src/main/contraband/jsonrpc.contra b/protocol/src/main/contraband/jsonrpc.contra index 54a134b40..983f6e747 100644 --- a/protocol/src/main/contraband/jsonrpc.contra +++ b/protocol/src/main/contraband/jsonrpc.contra @@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage jsonrpc: String! ## The request id. - id: String + id: String! ## The result of a request. This can be omitted in ## the case of an error. diff --git a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala index 3ea29d13b..c93b31e22 100644 --- a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala +++ b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala @@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats { unbuilder.beginObject(js) val jsonrpc = unbuilder.readField[String]("jsonrpc") val id = try { - unbuilder.readField[Option[String]]("id") + unbuilder.readField[String]("id") } catch { case _: DeserializationException => - unbuilder.readField[Option[Long]]("id") map { _.toString } + unbuilder.readField[Long]("id").toString } val result = unbuilder.lookupField("result") map { @@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats { } builder.beginObject() builder.addField("jsonrpc", obj.jsonrpc) - obj.id foreach { id => - parseId(id) match { - case Right(strId) => builder.addField("id", strId) - case Left(longId) => builder.addField("id", longId) - } + parseId(obj.id) match { + case Right(strId) => builder.addField("id", strId) + case Left(longId) => builder.addField("id", longId) } builder.addField("result", obj.result map parseResult) builder.addField("error", obj.error)