id is mandatory in json rpc responses

This commit is contained in:
Adrien Piquerez 2020-05-11 15:50:41 +02:00
parent e84e414328
commit 781584d137
7 changed files with 97 additions and 73 deletions

View File

@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin
} }
def onResponse(msg: JsonRpcResponseMessage): Unit = { def onResponse(msg: JsonRpcResponseMessage): Unit = {
msg.id foreach { msg.id match {
case execId if pendingExecIds contains execId => case execId if pendingExecIds contains execId =>
onReturningReponse(msg) onReturningReponse(msg)
lock.synchronized { lock.synchronized {

View File

@ -212,7 +212,7 @@ private[sbt] final class CommandExchange {
case c: NetworkChannel => case c: NetworkChannel =>
try { try {
// broadcast to all network channels // broadcast to all network channels
c.respondError(code, message, execId, source) c.respondError(code, message, execId)
} catch { } catch {
case _: IOException => case _: IOException =>
toDel += c toDel += c
@ -232,7 +232,7 @@ private[sbt] final class CommandExchange {
case c: NetworkChannel => case c: NetworkChannel =>
try { try {
// broadcast to all network channels // broadcast to all network channels
c.respondError(err, execId, source) c.respondError(err, execId)
} catch { } catch {
case _: IOException => case _: IOException =>
toDel += c toDel += c
@ -253,7 +253,7 @@ private[sbt] final class CommandExchange {
case c: NetworkChannel => case c: NetworkChannel =>
try { try {
// broadcast to all network channels // broadcast to all network channels
c.respondEvent(event, execId, source) c.respondEvent(event, execId)
} catch { } catch {
case _: IOException => case _: IOException =>
toDel += c toDel += c

View File

@ -103,7 +103,7 @@ private[sbt] object LanguageServerProtocol {
} }
/** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */ /** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */
private[sbt] trait LanguageServerProtocol extends CommandChannel { self => private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {} lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
@ -117,10 +117,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
protected lazy val callbackImpl: ServerCallback = new ServerCallback { protected lazy val callbackImpl: ServerCallback = new ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
self.jsonRpcRespond(event, execId) self.respondEvent(event, execId)
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
self.jsonRpcRespondError(execId, code, message) self.respondError(code, message, execId)
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit =
self.jsonRpcNotify(method, params) self.jsonRpcNotify(method, params)
@ -162,20 +162,20 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
} }
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = { private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
val m = val response =
JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None) JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None)
val bytes = Serialization.serializeResponseMessage(m) val bytes = Serialization.serializeResponseMessage(response)
publishBytes(bytes) publishBytes(bytes)
} }
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit =
jsonRpcRespondErrorImpl(execId, code, message, None) jsonRpcRespondErrorImpl(execId, code, message, None)
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError[A: JsonFormat]( private[sbt] def jsonRpcRespondError[A: JsonFormat](
execId: Option[String], execId: String,
code: Long, code: Long,
message: String, message: String,
data: A, data: A,
@ -183,7 +183,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get))
private[sbt] def jsonRpcRespondError( private[sbt] def jsonRpcRespondError(
execId: Option[String], execId: String,
err: JsonRpcResponseError err: JsonRpcResponseError
): Unit = { ): Unit = {
val m = JsonRpcResponseMessage("2.0", execId, None, Option(err)) val m = JsonRpcResponseMessage("2.0", execId, None, Option(err))
@ -192,7 +192,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
} }
private[this] def jsonRpcRespondErrorImpl( private[this] def jsonRpcRespondErrorImpl(
execId: Option[String], execId: String,
code: Long, code: Long,
message: String, message: String,
data: Option[JValue], data: Option[JValue],

View File

@ -13,18 +13,21 @@ import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import sjsonnew._ import sjsonnew._
import scala.annotation.tailrec import scala.annotation.tailrec
import sbt.protocol._ import sbt.protocol._
import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams } import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes }
import sbt.internal.util.{ ObjectEvent, StringEvent } import sbt.internal.util.{ ObjectEvent, StringEvent }
import sbt.internal.util.complete.Parser import sbt.internal.util.complete.Parser
import sbt.internal.util.codec.JValueFormats import sbt.internal.util.codec.JValueFormats
import sbt.internal.protocol.{ import sbt.internal.protocol.{
JsonRpcResponseError, JsonRpcNotificationMessage,
JsonRpcRequestMessage, JsonRpcRequestMessage,
JsonRpcNotificationMessage JsonRpcResponseError
} }
import sbt.util.Logger import sbt.util.Logger
import sjsonnew.support.scalajson.unsafe.Converter
import scala.util.Try import scala.util.Try
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -54,6 +57,8 @@ final class NetworkChannel(
private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8"
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty
def setContentType(ct: String): Unit = synchronized { _contentType = ct } def setContentType(ct: String): Unit = synchronized { _contentType = ct }
def contentType: String = _contentType def contentType: String = _contentType
@ -176,6 +181,7 @@ final class NetworkChannel(
intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) { intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) {
case (f, i) => f orElse i.onRequest case (f, i) => f orElse i.onRequest
} }
lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] = lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] =
intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) { intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) {
case (f, i) => f orElse i.onNotification case (f, i) => f orElse i.onNotification
@ -186,25 +192,27 @@ final class NetworkChannel(
Serialization.deserializeJsonMessage(chunk) match { Serialization.deserializeJsonMessage(chunk) match {
case Right(req: JsonRpcRequestMessage) => case Right(req: JsonRpcRequestMessage) =>
try { try {
registerRequest(req)
onRequestMessage(req) onRequestMessage(req)
} catch { } catch {
case LangServerError(code, message) => case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message") log.debug(s"sending error: $code: $message")
jsonRpcRespondError(Option(req.id), code, message) jsonRpcRespondError(req.id, code, message)
} }
case Right(ntf: JsonRpcNotificationMessage) => case Right(ntf: JsonRpcNotificationMessage) =>
try { try {
onNotification(ntf) onNotification(ntf)
} catch { } catch {
case LangServerError(code, message) => case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message") logMessage("error", s"Error $code while handling notification: $message")
jsonRpcRespondError(None, code, message) // new id?
} }
case Right(msg) => case Right(msg) =>
log.debug(s"Unhandled message: $msg") log.debug(s"Unhandled message: $msg")
case Left(errorDesc) => case Left(errorDesc) =>
val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc logMessage(
jsonRpcRespondError(None, ErrorCodes.ParseError, msg) "error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc"
)
} }
} else { } else {
contentType match { contentType match {
@ -213,13 +221,17 @@ final class NetworkChannel(
.deserializeCommand(chunk) .deserializeCommand(chunk)
.fold( .fold(
errorDesc => errorDesc =>
log.error( logMessage(
"error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
), ),
onCommand onCommand
) )
case _ => case _ =>
log.error(s"Unknown Content-Type: $contentType") logMessage(
"error",
s"Unknown Content-Type: $contentType"
)
} }
} // if-else } // if-else
} }
@ -245,24 +257,43 @@ final class NetworkChannel(
} }
} }
private def registerRequest(request: JsonRpcRequestMessage): Unit = {
onGoingRequests += request
}
private[sbt] def respondError( private[sbt] def respondError(
err: JsonRpcResponseError, err: JsonRpcResponseError,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = {
): Unit = jsonRpcRespondError(execId, err) execId match {
case Some(id) => jsonRpcRespondError(id, err)
case None => logMessage("error", s"Error ${err.code}: ${err.message}")
}
}
private[sbt] def respondError( private[sbt] def respondError(
code: Long, code: Long,
message: String, message: String,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = {
): Unit = jsonRpcRespondError(execId, code, message) execId match {
case Some(id) => jsonRpcRespondError(id, code, message)
case None => logMessage("error", s"Error $code: $message")
}
}
private[sbt] def respondEvent[A: JsonFormat]( private[sbt] def respondEvent[A: JsonFormat](
event: A, event: A,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = {
): Unit = jsonRpcRespond(event, execId) execId match {
case Some(id) => jsonRpcRespond(event, id)
case None =>
val json = Converter.toJsonUnsafe(event)
log.debug(s"unmatched json response: $json")
}
}
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
if (isLanguageServerProtocol) { if (isLanguageServerProtocol) {
@ -278,12 +309,11 @@ final class NetworkChannel(
case entry: StringEvent => logMessage(entry.level, entry.message) case entry: StringEvent => logMessage(entry.level, entry.message)
case entry: ExecStatusEvent => case entry: ExecStatusEvent =>
entry.exitCode match { entry.exitCode match {
case None => jsonRpcRespond(event, entry.execId) case None => respondEvent(event, entry.execId)
case Some(0) => jsonRpcRespond(event, entry.execId) case Some(0) => respondEvent(event, entry.execId)
case Some(exitCode) => case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId)
jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse(""))
} }
case _ => jsonRpcRespond(event, execId) case _ => respondEvent(event, execId)
} }
} else { } else {
contentType match { contentType match {
@ -383,8 +413,8 @@ final class NetworkChannel(
if (initialized) { if (initialized) {
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
SettingQuery.handleSettingQueryEither(req, structure) match { SettingQuery.handleSettingQueryEither(req, structure) match {
case Right(x) => jsonRpcRespond(x, execId) case Right(x) => respondEvent(x, execId)
case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s) case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId)
} }
} else { } else {
log.warn(s"ignoring query $req before initialization") log.warn(s"ignoring query $req before initialization")
@ -400,32 +430,31 @@ final class NetworkChannel(
Parser Parser
.completions(sstate.combinedParser, cp.query, 9) .completions(sstate.combinedParser, cp.query, 9)
.get .get
.map(c => { .flatMap { c =>
if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
else None else None
}) }
.flatten .map(c => cp.query + c)
.map(c => cp.query + c.toString)
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond( respondEvent(
CompletionResponse( CompletionResponse(
items = completionItems.toVector items = completionItems.toVector
), ),
execId execId
) )
case _ => case _ =>
jsonRpcRespondError( respondError(
execId,
ErrorCodes.UnknownError, ErrorCodes.UnknownError,
"No available sbt state" "No available sbt state",
execId
) )
} }
} catch { } catch {
case NonFatal(e) => case NonFatal(_) =>
jsonRpcRespondError( respondError(
execId,
ErrorCodes.UnknownError, ErrorCodes.UnknownError,
"Completions request failed" "Completions request failed",
execId
) )
} }
} else { } else {
@ -436,10 +465,10 @@ final class NetworkChannel(
protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = { protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = {
if (initialized) { if (initialized) {
def errorRespond(msg: String) = jsonRpcRespondError( def errorRespond(msg: String) = respondError(
execId,
ErrorCodes.RequestCancelled, ErrorCodes.RequestCancelled,
msg msg,
execId
) )
try { try {
@ -465,11 +494,11 @@ final class NetworkChannel(
runningEngine.cancelAndShutdown() runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond( respondEvent(
ExecStatusEvent( ExecStatusEvent(
"Task cancelled", "Task cancelled",
Some(name), Some(name),
Some(runningExecId.toString), Some(runningExecId),
Vector(), Vector(),
None, None,
), ),

View File

@ -12,7 +12,7 @@ package sbt.internal.protocol
*/ */
final class JsonRpcResponseMessage private ( final class JsonRpcResponseMessage private (
jsonrpc: String, jsonrpc: String,
val id: Option[String], val id: String,
val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue],
val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable { val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable {
@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private (
override def toString: String = { override def toString: String = {
s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)""" s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)"""
} }
private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
new JsonRpcResponseMessage(jsonrpc, id, result, error) new JsonRpcResponseMessage(jsonrpc, id, result, error)
} }
def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = { def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = {
copy(jsonrpc = jsonrpc) copy(jsonrpc = jsonrpc)
} }
def withId(id: Option[String]): JsonRpcResponseMessage = {
copy(id = id)
}
def withId(id: String): JsonRpcResponseMessage = { def withId(id: String): JsonRpcResponseMessage = {
copy(id = Option(id)) copy(id = id)
} }
def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = { def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = {
copy(result = result) copy(result = result)
@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private (
} }
object JsonRpcResponseMessage { object JsonRpcResponseMessage {
def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error)) def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error))
} }

View File

@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage
jsonrpc: String! jsonrpc: String!
## The request id. ## The request id.
id: String id: String!
## The result of a request. This can be omitted in ## The result of a request. This can be omitted in
## the case of an error. ## the case of an error.

View File

@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats {
unbuilder.beginObject(js) unbuilder.beginObject(js)
val jsonrpc = unbuilder.readField[String]("jsonrpc") val jsonrpc = unbuilder.readField[String]("jsonrpc")
val id = try { val id = try {
unbuilder.readField[Option[String]]("id") unbuilder.readField[String]("id")
} catch { } catch {
case _: DeserializationException => case _: DeserializationException =>
unbuilder.readField[Option[Long]]("id") map { _.toString } unbuilder.readField[Long]("id").toString
} }
val result = unbuilder.lookupField("result") map { val result = unbuilder.lookupField("result") map {
@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats {
} }
builder.beginObject() builder.beginObject()
builder.addField("jsonrpc", obj.jsonrpc) builder.addField("jsonrpc", obj.jsonrpc)
obj.id foreach { id => parseId(obj.id) match {
parseId(id) match { case Right(strId) => builder.addField("id", strId)
case Right(strId) => builder.addField("id", strId) case Left(longId) => builder.addField("id", longId)
case Left(longId) => builder.addField("id", longId)
}
} }
builder.addField("result", obj.result map parseResult) builder.addField("result", obj.result map parseResult)
builder.addField("error", obj.error) builder.addField("error", obj.error)