Merge pull request #5549 from adpi2/issue/json-response

Prevent more than one response per json RPC request
This commit is contained in:
eugene yokota 2020-05-12 22:07:11 -04:00 committed by GitHub
commit 4592493617
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 316 additions and 340 deletions

3
.gitignore vendored
View File

@ -8,3 +8,6 @@ npm-debug.log
!sbt/src/server-test/completions/target
.big
.idea
.bloop
.metals
metals.sbt

View File

@ -677,7 +677,8 @@ lazy val protocolProj = (project in file("protocol"))
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQueryFailure.copy$default$*"),
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy"),
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"),
// ignore missing methods in sbt.internal
// ignore missing or incompatible methods in sbt.internal
exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
exclude[DirectMissingMethodProblem]("sbt.internal.*"),
exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"),
)
@ -876,7 +877,7 @@ lazy val mainProj = (project in file("main"))
// New and changed methods on KeyIndex. internal.
exclude[ReversedMissingMethodProblem]("sbt.internal.KeyIndex.*"),
// internal
exclude[IncompatibleMethTypeProblem]("sbt.internal.server.LanguageServerReporter.*"),
exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
// Changed signature or removed private[sbt] methods
exclude[DirectMissingMethodProblem]("sbt.Classpaths.unmanagedLibs0"),
exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"),

View File

@ -11,7 +11,6 @@ package internal
import java.util.concurrent.ConcurrentLinkedQueue
import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
/**
* A command channel represents an IO device such as network socket or human
@ -42,9 +41,6 @@ abstract class CommandChannel {
}
def poll: Option[Exec] = Option(commandQueue.poll)
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit
final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None)
def publishEventMessage(event: EventMessage): Unit
def publishBytes(bytes: Array[Byte]): Unit
def shutdown(): Unit
def name: String

View File

@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference
import sbt.BasicKeys._
import sbt.internal.util._
import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel {
private[this] val askUserThread = new AtomicReference[AskUserThread]
@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel
def publishBytes(bytes: Array[Byte]): Unit = ()
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = ()
def publishEventMessage(event: EventMessage): Unit =
event match {
case e: ConsolePromptEvent =>
if (Terminal.systemInIsAttached) {
askUserThread.synchronized {
askUserThread.get match {
case null => askUserThread.set(makeAskUserThread(e.state))
case t => t.redraw()
}
}
def prompt(event: ConsolePromptEvent): Unit = {
if (Terminal.systemInIsAttached) {
askUserThread.synchronized {
askUserThread.get match {
case null => askUserThread.set(makeAskUserThread(event.state))
case t => t.redraw()
}
case _ => //
}
}
}
def shutdown(): Unit = askUserThread.synchronized {
askUserThread.get match {

View File

@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin
}
def onResponse(msg: JsonRpcResponseMessage): Unit = {
msg.id foreach {
msg.id match {
case execId if pendingExecIds contains execId =>
onReturningReponse(msg)
lock.synchronized {

View File

@ -890,7 +890,7 @@ object BuiltinCommands {
val exchange = StandardMain.exchange
val welcomeState = displayWelcomeBanner(s0)
val s1 = exchange run welcomeState
exchange publishEventMessage ConsolePromptEvent(s0)
exchange prompt ConsolePromptEvent(s0)
val minGCInterval = Project
.extract(s1)
.getOpt(Keys.minForcegcInterval)

View File

@ -176,7 +176,7 @@ object MainLoop {
/** This is the main function State transfer function of the sbt command processing. */
def processCommand(exec: Exec, state: State): State = {
val channelName = exec.source map (_.channelName)
StandardMain.exchange publishEventMessage
StandardMain.exchange notifyStatus
ExecStatusEvent("Processing", channelName, exec.execId, Vector())
try {
def process(): State = {
@ -197,12 +197,7 @@ object MainLoop {
newState.remainingCommands.toVector map (_.commandLine),
exitCode(newState, state),
)
if (doneEvent.execId.isDefined) { // send back a response or error
import sbt.protocol.codec.JsonProtocol._
StandardMain.exchange publishEvent doneEvent
} else { // send back a notification
StandardMain.exchange publishEventMessage doneEvent
}
StandardMain.exchange.respondStatus(doneEvent)
newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop())
newState.remove(sbt.Keys.currentTaskProgress)
}
@ -225,8 +220,7 @@ object MainLoop {
ExitCode(ErrorCodes.UnknownError),
Option(err.getMessage),
)
import sbt.protocol.codec.JsonProtocol._
StandardMain.exchange.publishEvent(errorEvent)
StandardMain.exchange.respondStatus(errorEvent)
throw err
}
}

View File

@ -16,16 +16,13 @@ import java.util.concurrent.atomic._
import sbt.BasicKeys._
import sbt.nio.Watch.NullLogger
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.langserver.{ LogMessageParams, MessageType }
import sbt.internal.server._
import sbt.internal.util.codec.JValueFormats
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal }
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal }
import sbt.io.syntax._
import sbt.io.{ Hash, IO }
import sbt.protocol.{ EventMessage, ExecStatusEvent }
import sbt.protocol.{ ExecStatusEvent, LogEvent }
import sbt.util.{ Level, LogExchange, Logger }
import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe._
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
@ -51,17 +48,12 @@ private[sbt] final class CommandExchange {
private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel]
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
private[this] val activePrompt = new AtomicBoolean(false)
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
def channels: List[CommandChannel] = channelBuffer.toList
private[this] def removeChannels(toDel: List[CommandChannel]): Unit = {
toDel match {
case Nil => // do nothing
case xs =>
channelBufferLock.synchronized {
channelBuffer --= xs
()
}
private[this] def removeChannel(channel: CommandChannel): Unit = {
channelBufferLock.synchronized {
channelBuffer -= channel
()
}
}
@ -206,19 +198,7 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondError(code, message, execId, source)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
respondError(JsonRpcResponseError(code, message), execId, source)
}
private[sbt] def respondError(
@ -226,19 +206,13 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondError(err, execId, source)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
for {
source <- source.map(_.channelName)
channel <- channels.collectFirst {
// broadcast to the source channel only
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respondError(err, execId))(channel)
}
// This is an interface to directly respond events.
@ -247,146 +221,89 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondEvent(event, execId, source)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
for {
source <- source.map(_.channelName)
channel <- channels.collectFirst {
// broadcast to the source channel only
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respondResult(event, execId))(channel)
}
// This is an interface to directly notify events.
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
// c.publishEvent(event)
case c: NetworkChannel =>
try {
c.notifyEvent(method, params)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.notifyEvent(method, params))
}
}
private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit =
try x
catch { case _: IOException => toDel += c }
private def tryTo(f: NetworkChannel => Unit)(
channel: NetworkChannel
): Unit =
try f(channel)
catch { case _: IOException => removeChannel(channel) }
def publishEvent[A: JsonFormat](event: A): Unit = {
val broadcastStringMessage = true
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
def respondStatus(event: ExecStatusEvent): Unit = {
import sbt.protocol.codec.JsonProtocol._
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} {
if (event.execId.isEmpty) {
tryTo(_.notifyEvent(event))(channel)
} else {
event.exitCode match {
case None | Some(0) =>
tryTo(_.respondResult(event, event.execId))(channel)
case Some(code) =>
tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel)
}
}
event match {
case entry: StringEvent =>
val params = toLogMessageParams(entry)
channels collect {
case c: ConsoleChannel =>
if (broadcastStringMessage || (entry.channelName forall (_ == c.name)))
c.publishEvent(event)
case c: NetworkChannel =>
tryTo(
{
// Note that language server's LogMessageParams does not hold the execid,
// so this is weaker than the StringMessage. We might want to double-send
// in case we have a better client that can utilize the knowledge.
import sbt.internal.langserver.codec.JsonProtocol._
if (broadcastStringMessage || (entry.channelName contains c.name))
c.jsonRpcNotify("window/logMessage", params)
},
c,
toDel
)
}
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEvent(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel)
}
case _ =>
channels foreach {
case c: ConsoleChannel => c.publishEvent(event)
case c: NetworkChannel =>
tryTo(c.publishEvent(event), c, toDel)
}
tryTo(_.respond(event, event.execId))(channel)
}
removeChannels(toDel.toList)
}
private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = {
LogMessageParams(MessageType.fromLevelString(event.level), event.message)
}
/**
* This publishes object events. The type information has been
* erased because it went through logging.
*/
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = {
import jsonFormat._
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
def json: JValue = JObject(
JField("type", JString(event.contentType)),
Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++
(event.channelName.toVector map { channelName =>
JField("channelName", JString(channelName))
}) ++
(event.execId.toVector map { execId =>
JField("execId", JString(execId))
}): _*
)
channels collect {
case c: ConsoleChannel =>
c.publishEvent(json)
case c: NetworkChannel =>
try {
c.publishObjectEvent(event)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = {
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respond(event))(channel)
}
// fanout publishEvent
def publishEventMessage(event: EventMessage): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
event match {
// Special treatment for ConsolePromptEvent since it's hand coded without codec.
case entry: ConsolePromptEvent =>
channels collect {
case c: ConsoleChannel =>
c.publishEventMessage(entry)
activePrompt.set(Terminal.systemInIsAttached)
}
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel)
}
case _ =>
channels collect {
case c: ConsoleChannel => c.publishEventMessage(event)
case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel)
}
}
removeChannels(toDel.toList)
def prompt(event: ConsolePromptEvent): Unit = {
activePrompt.set(Terminal.systemInIsAttached)
channels
.collect { case c: ConsoleChannel => c }
.foreach { _.prompt(event) }
}
def logMessage(event: LogEvent): Unit = {
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.notifyEvent(event))
}
}
def notifyStatus(event: ExecStatusEvent): Unit = {
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.notifyEvent(event))(channel)
}
private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false)
}

View File

@ -17,7 +17,6 @@ import org.apache.logging.log4j.core.config.Property
import sbt.util.Level
import sbt.internal.util._
import sbt.protocol.LogEvent
import sbt.internal.util.codec._
class RelayAppender(name: String)
extends AbstractAppender(
@ -40,15 +39,12 @@ class RelayAppender(name: String)
}
}
def appendLog(level: Level.Value, message: => String): Unit = {
exchange.publishEventMessage(LogEvent(level.toString, message))
exchange.logMessage(LogEvent(level.toString, message))
}
def appendEvent(event: AnyRef): Unit =
event match {
case x: StringEvent => {
import JsonProtocol._
exchange.publishEvent(x: AbstractEntry)
}
case x: ObjectEvent[_] => exchange.publishObjectEvent(x)
case x: StringEvent => exchange.logMessage(LogEvent(x.message, x.level))
case x: ObjectEvent[_] => exchange.respondObjectEvent(x)
case _ =>
println(s"appendEvent: ${event.getClass}")
()

View File

@ -40,10 +40,10 @@ private[sbt] object Definition {
def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = {
for {
channel <- StandardMain.exchange.channels.collectFirst {
case c if c.name == source.channelName => c
case c: NetworkChannel if c.name == source.channelName => c
}
} {
channel.publishEvent(params, Option(execId))
channel.respond(params, Option(execId))
}
}

View File

@ -10,7 +10,6 @@ package internal
package server
import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe.JValue
import sjsonnew.support.scalajson.unsafe.Converter
import sbt.protocol.Serialization
import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q }
@ -103,7 +102,7 @@ private[sbt] object LanguageServerProtocol {
}
/** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */
private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
@ -117,10 +116,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
protected lazy val callbackImpl: ServerCallback = new ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
self.jsonRpcRespond(event, execId)
self.respondResult(event, execId)
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
self.jsonRpcRespondError(execId, code, message)
self.respondError(code, message, execId)
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit =
self.jsonRpcNotify(method, params)
@ -162,28 +161,16 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
}
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
val m =
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
val response =
JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None)
val bytes = Serialization.serializeResponseMessage(m)
val bytes = Serialization.serializeResponseMessage(response)
publishBytes(bytes)
}
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
jsonRpcRespondErrorImpl(execId, code, message, None)
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError[A: JsonFormat](
execId: Option[String],
code: Long,
message: String,
data: A,
): Unit =
jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get))
private[sbt] def jsonRpcRespondError(
execId: Option[String],
execId: String,
err: JsonRpcResponseError
): Unit = {
val m = JsonRpcResponseMessage("2.0", execId, None, Option(err))
@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
publishBytes(bytes)
}
private[this] def jsonRpcRespondErrorImpl(
execId: Option[String],
code: Long,
message: String,
data: Option[JValue],
): Unit = {
val e = JsonRpcResponseError(code, message, data)
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
val bytes = Serialization.serializeResponseMessage(m)
publishBytes(bytes)
}
/** Notify to Language Server's client. */
private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = {
val m =

View File

@ -12,19 +12,22 @@ package server
import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean
import sjsonnew._
import scala.annotation.tailrec
import sbt.protocol._
import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams }
import sbt.internal.util.{ ObjectEvent, StringEvent }
import sbt.internal.util.complete.Parser
import sbt.internal.util.codec.JValueFormats
import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes }
import sbt.internal.protocol.{
JsonRpcResponseError,
JsonRpcNotificationMessage,
JsonRpcRequestMessage,
JsonRpcNotificationMessage
JsonRpcResponseError
}
import sbt.internal.util.codec.JValueFormats
import sbt.internal.util.complete.Parser
import sbt.internal.util.ObjectEvent
import sbt.protocol._
import sbt.util.Logger
import sjsonnew._
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Try
import scala.util.control.NonFatal
@ -53,6 +56,7 @@ final class NetworkChannel(
private val VsCode = sbt.protocol.Serialization.VsCode
private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8"
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map()
def setContentType(ct: String): Unit = synchronized { _contentType = ct }
def contentType: String = _contentType
@ -176,6 +180,7 @@ final class NetworkChannel(
intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) {
case (f, i) => f orElse i.onRequest
}
lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] =
intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) {
case (f, i) => f orElse i.onNotification
@ -186,25 +191,27 @@ final class NetworkChannel(
Serialization.deserializeJsonMessage(chunk) match {
case Right(req: JsonRpcRequestMessage) =>
try {
registerRequest(req)
onRequestMessage(req)
} catch {
case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message")
jsonRpcRespondError(Option(req.id), code, message)
respondError(code, message, Some(req.id))
}
case Right(ntf: JsonRpcNotificationMessage) =>
try {
onNotification(ntf)
} catch {
case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message")
jsonRpcRespondError(None, code, message) // new id?
logMessage("error", s"Error $code while handling notification: $message")
}
case Right(msg) =>
log.debug(s"Unhandled message: $msg")
case Left(errorDesc) =>
val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
jsonRpcRespondError(None, ErrorCodes.ParseError, msg)
logMessage(
"error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc"
)
}
} else {
contentType match {
@ -213,13 +220,17 @@ final class NetworkChannel(
.deserializeCommand(chunk)
.fold(
errorDesc =>
log.error(
logMessage(
"error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
),
onCommand
)
case _ =>
log.error(s"Unknown Content-Type: $contentType")
logMessage(
"error",
s"Unknown Content-Type: $contentType"
)
}
} // if-else
}
@ -245,24 +256,48 @@ final class NetworkChannel(
}
}
private def registerRequest(request: JsonRpcRequestMessage): Unit = {
this.synchronized {
pendingRequests += (request.id -> request)
()
}
}
private[sbt] def respondError(
err: JsonRpcResponseError,
execId: Option[String],
source: Option[CommandSource]
): Unit = jsonRpcRespondError(execId, err)
execId: Option[String]
): Unit = this.synchronized {
execId match {
case Some(id) if pendingRequests.contains(id) =>
pendingRequests -= id
jsonRpcRespondError(id, err)
case _ =>
logMessage("error", s"Error ${err.code}: ${err.message}")
}
}
private[sbt] def respondError(
code: Long,
message: String,
execId: Option[String],
source: Option[CommandSource]
): Unit = jsonRpcRespondError(execId, code, message)
execId: Option[String]
): Unit = {
respondError(JsonRpcResponseError(code, message), execId)
}
private[sbt] def respondEvent[A: JsonFormat](
private[sbt] def respondResult[A: JsonFormat](
event: A,
execId: Option[String],
source: Option[CommandSource]
): Unit = jsonRpcRespond(event, execId)
execId: Option[String]
): Unit = this.synchronized {
execId match {
case Some(id) if pendingRequests.contains(id) =>
pendingRequests -= id
jsonRpcRespond(event, id)
case _ =>
log.debug(
s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}"
)
}
}
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
if (isLanguageServerProtocol) {
@ -272,19 +307,11 @@ final class NetworkChannel(
}
}
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = {
def respond[A: JsonFormat](event: A): Unit = respond(event, None)
def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
if (isLanguageServerProtocol) {
event match {
case entry: StringEvent => logMessage(entry.level, entry.message)
case entry: ExecStatusEvent =>
entry.exitCode match {
case None => jsonRpcRespond(event, entry.execId)
case Some(0) => jsonRpcRespond(event, entry.execId)
case Some(exitCode) =>
jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse(""))
}
case _ => jsonRpcRespond(event, execId)
}
respondResult(event, execId)
} else {
contentType match {
case SbtX1Protocol =>
@ -295,7 +322,7 @@ final class NetworkChannel(
}
}
def publishEventMessage(event: EventMessage): Unit = {
def notifyEvent(event: EventMessage): Unit = {
if (isLanguageServerProtocol) {
event match {
case entry: LogEvent => logMessage(entry.level, entry.message)
@ -316,22 +343,22 @@ final class NetworkChannel(
* This publishes object events. The type information has been
* erased because it went through logging.
*/
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = {
private[sbt] def respond(event: ObjectEvent[_]): Unit = {
import sjsonnew.shaded.scalajson.ast.unsafe._
if (isLanguageServerProtocol) onObjectEvent(event)
else {
import jsonFormat._
val json: JValue = JObject(
JField("type", JString(event.contentType)),
(Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++
(event.channelName.toVector map { channelName =>
Seq(JField("message", event.json), JField("level", JString(event.level.toString))) ++
(event.channelName map { channelName =>
JField("channelName", JString(channelName))
}) ++
(event.execId.toVector map { execId =>
(event.execId map { execId =>
JField("execId", JString(execId))
})): _*
}): _*
)
publishEvent(json)
respond(json, event.execId)
}
}
@ -358,7 +385,7 @@ final class NetworkChannel(
authenticate(x) match {
case true =>
initialized = true
publishEventMessage(ChannelAcceptedEvent(name))
notifyEvent(ChannelAcceptedEvent(name))
case _ => sys.error("invalid token")
}
case None => sys.error("init command but without token.")
@ -383,8 +410,8 @@ final class NetworkChannel(
if (initialized) {
import sbt.protocol.codec.JsonProtocol._
SettingQuery.handleSettingQueryEither(req, structure) match {
case Right(x) => jsonRpcRespond(x, execId)
case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s)
case Right(x) => respondResult(x, execId)
case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId)
}
} else {
log.warn(s"ignoring query $req before initialization")
@ -400,32 +427,31 @@ final class NetworkChannel(
Parser
.completions(sstate.combinedParser, cp.query, 9)
.get
.map(c => {
.flatMap { c =>
if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
else None
})
.flatten
.map(c => cp.query + c.toString)
}
.map(c => cp.query + c)
import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond(
respondResult(
CompletionResponse(
items = completionItems.toVector
),
execId
)
case _ =>
jsonRpcRespondError(
execId,
respondError(
ErrorCodes.UnknownError,
"No available sbt state"
"No available sbt state",
execId
)
}
} catch {
case NonFatal(e) =>
jsonRpcRespondError(
execId,
case NonFatal(_) =>
respondError(
ErrorCodes.UnknownError,
"Completions request failed"
"Completions request failed",
execId
)
}
} else {
@ -436,10 +462,10 @@ final class NetworkChannel(
protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = {
if (initialized) {
def errorRespond(msg: String) = jsonRpcRespondError(
execId,
def errorRespond(msg: String) = respondError(
ErrorCodes.RequestCancelled,
msg
msg,
execId
)
try {
@ -465,11 +491,11 @@ final class NetworkChannel(
runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond(
respondResult(
ExecStatusEvent(
"Task cancelled",
Some(name),
Some(runningExecId.toString),
Some(runningExecId),
Vector(),
None,
),

View File

@ -12,7 +12,7 @@ package sbt.internal.protocol
*/
final class JsonRpcResponseMessage private (
jsonrpc: String,
val id: Option[String],
val id: String,
val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue],
val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable {
@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private (
override def toString: String = {
s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)"""
}
private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
new JsonRpcResponseMessage(jsonrpc, id, result, error)
}
def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = {
copy(jsonrpc = jsonrpc)
}
def withId(id: Option[String]): JsonRpcResponseMessage = {
copy(id = id)
}
def withId(id: String): JsonRpcResponseMessage = {
copy(id = Option(id))
copy(id = id)
}
def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = {
copy(result = result)
@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private (
}
object JsonRpcResponseMessage {
def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error))
def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error))
}

View File

@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage
jsonrpc: String!
## The request id.
id: String
id: String!
## The result of a request. This can be omitted in
## the case of an error.

View File

@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats {
unbuilder.beginObject(js)
val jsonrpc = unbuilder.readField[String]("jsonrpc")
val id = try {
unbuilder.readField[Option[String]]("id")
unbuilder.readField[String]("id")
} catch {
case _: DeserializationException =>
unbuilder.readField[Option[Long]]("id") map { _.toString }
unbuilder.readField[Long]("id").toString
}
val result = unbuilder.lookupField("result") map {
@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats {
}
builder.beginObject()
builder.addField("jsonrpc", obj.jsonrpc)
obj.id foreach { id =>
parseId(id) match {
case Right(strId) => builder.addField("id", strId)
case Left(longId) => builder.addField("id", longId)
}
parseId(obj.id) match {
case Right(strId) => builder.addField("id", strId)
case Left(longId) => builder.addField("id", longId)
}
builder.addField("result", obj.result map parseResult)
builder.addField("error", obj.error)

View File

@ -25,12 +25,25 @@ Global / serverHandlers += ServerHandler({ callback =>
case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" =>
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
()
case r if r.method == "foo/respondTwice" =>
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
jsonRpcRespond("concurrent response", Some(r.id))
()
case r if r.method == "foo/resultAndError" =>
appendExec(Exec("fooCustomFail", Some(r.id), Some(CommandSource(callback.name))))
jsonRpcRespond("concurrent response", Some(r.id))
()
},
PartialFunction.empty
{
case r if r.method == "foo/customNotification" =>
jsonRpcRespond("notification result", None)
()
}
)
})
lazy val fooClasspath = taskKey[Unit]("")
lazy val root = (project in file("."))
.settings(
name := "response",
@ -55,5 +68,5 @@ lazy val root = (project in file("."))
val s = state.value
val cp = (Compile / fullClasspath).value
s.respondEvent(cp.map(_.data))
},
}
)

View File

@ -22,15 +22,6 @@ object EventsTest extends AbstractServerTest {
})
}
test("report task failures in case of exceptions") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }"""
)
assert(svr.waitForString(10.seconds) { s =>
(s contains """"id":11""") && (s contains """"error":""")
})
}
test("return error if cancelling non-matched task id") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }"""

View File

@ -64,4 +64,54 @@ object ResponseTest extends AbstractServerTest {
(s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""")
})
}
test("respond concurrently from a task and the handler") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": "15", "method": "foo/respondTwice", "params": {} }"""
)
assert {
svr.waitForString(1.seconds) { s =>
println(s)
s contains "\"id\":\"15\""
}
}
assert {
// the second response should never be sent
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"id\":\"15\""
}
}
}
test("concurrent result and error") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": "16", "method": "foo/resultAndError", "params": {} }"""
)
assert {
svr.waitForString(1.seconds) { s =>
println(s)
s contains "\"id\":\"16\""
}
}
assert {
// the second response (result or error) should never be sent
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"id\":\"16\""
}
}
}
test("response to a notification should not be sent") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "method": "foo/customNotification", "params": {} }"""
)
assert {
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"result\":\"notification result\""
}
}
}
}

View File

@ -8,13 +8,14 @@
package testpkg
import java.io.{ File, IOException }
import java.util.concurrent.TimeoutException
import verify._
import sbt.RunFromSourceMain
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol.ClientSocket
import scala.annotation.tailrec
import scala.concurrent._
import scala.concurrent.duration._
import scala.util.{ Success, Try }
@ -150,6 +151,7 @@ case class TestServer(
sbtVersion: String,
classpath: Seq[File]
) {
import scala.concurrent.ExecutionContext.Implicits._
import TestServer.hostLog
val readBuffer = new Array[Byte](40960)
@ -183,15 +185,25 @@ case class TestServer(
waitForPortfile(90.seconds)
// make connection to the socket described in the portfile
val (sk, tkn) = ClientSocket.socket(portfile)
val out = sk.getOutputStream
val in = sk.getInputStream
var (sk, _) = ClientSocket.socket(portfile)
var out = sk.getOutputStream
var in = sk.getInputStream
// initiate handshake
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
)
def resetConnection() = {
sk = ClientSocket.socket(portfile)._1
out = sk.getOutputStream
in = sk.getInputStream
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
)
}
def test(f: TestServer => Future[Assertion]): Future[Assertion] = {
f(this)
}
@ -230,7 +242,7 @@ case class TestServer(
writeEndLine
}
def readFrame: Option[String] = {
def readFrame: Future[Option[String]] = Future {
def getContentLength: Int = {
readLine map { line =>
line.drop(16).toInt
@ -244,14 +256,28 @@ case class TestServer(
final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow
@tailrec
def impl(): Boolean = {
if (deadline.isOverdue || !process.isAlive) false
else
readFrame.fold(false)(f) || {
Thread.sleep(100)
impl
}
try {
Await.result(readFrame, deadline.timeLeft).fold(false)(f) || impl
} catch {
case _: TimeoutException =>
resetConnection() // create a new connection to invalidate the running readFrame future
false
}
}
impl()
}
final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow
def impl(): Boolean = {
try {
Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) && impl
} catch {
case _: TimeoutException =>
resetConnection() // create a new connection to invalidate the running readFrame future
true
}
}
impl()
}