mirror of https://github.com/sbt/sbt.git
Merge pull request #5549 from adpi2/issue/json-response
Prevent more than one response per json RPC request
This commit is contained in:
commit
4592493617
|
|
@ -8,3 +8,6 @@ npm-debug.log
|
|||
!sbt/src/server-test/completions/target
|
||||
.big
|
||||
.idea
|
||||
.bloop
|
||||
.metals
|
||||
metals.sbt
|
||||
|
|
|
|||
|
|
@ -677,7 +677,8 @@ lazy val protocolProj = (project in file("protocol"))
|
|||
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQueryFailure.copy$default$*"),
|
||||
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy"),
|
||||
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"),
|
||||
// ignore missing methods in sbt.internal
|
||||
// ignore missing or incompatible methods in sbt.internal
|
||||
exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
|
||||
exclude[DirectMissingMethodProblem]("sbt.internal.*"),
|
||||
exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"),
|
||||
)
|
||||
|
|
@ -876,7 +877,7 @@ lazy val mainProj = (project in file("main"))
|
|||
// New and changed methods on KeyIndex. internal.
|
||||
exclude[ReversedMissingMethodProblem]("sbt.internal.KeyIndex.*"),
|
||||
// internal
|
||||
exclude[IncompatibleMethTypeProblem]("sbt.internal.server.LanguageServerReporter.*"),
|
||||
exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
|
||||
// Changed signature or removed private[sbt] methods
|
||||
exclude[DirectMissingMethodProblem]("sbt.Classpaths.unmanagedLibs0"),
|
||||
exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ package internal
|
|||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
import sbt.protocol.EventMessage
|
||||
import sjsonnew.JsonFormat
|
||||
|
||||
/**
|
||||
* A command channel represents an IO device such as network socket or human
|
||||
|
|
@ -42,9 +41,6 @@ abstract class CommandChannel {
|
|||
}
|
||||
def poll: Option[Exec] = Option(commandQueue.poll)
|
||||
|
||||
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit
|
||||
final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None)
|
||||
def publishEventMessage(event: EventMessage): Unit
|
||||
def publishBytes(bytes: Array[Byte]): Unit
|
||||
def shutdown(): Unit
|
||||
def name: String
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference
|
|||
|
||||
import sbt.BasicKeys._
|
||||
import sbt.internal.util._
|
||||
import sbt.protocol.EventMessage
|
||||
import sjsonnew.JsonFormat
|
||||
|
||||
private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel {
|
||||
private[this] val askUserThread = new AtomicReference[AskUserThread]
|
||||
|
|
@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel
|
|||
|
||||
def publishBytes(bytes: Array[Byte]): Unit = ()
|
||||
|
||||
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = ()
|
||||
|
||||
def publishEventMessage(event: EventMessage): Unit =
|
||||
event match {
|
||||
case e: ConsolePromptEvent =>
|
||||
if (Terminal.systemInIsAttached) {
|
||||
askUserThread.synchronized {
|
||||
askUserThread.get match {
|
||||
case null => askUserThread.set(makeAskUserThread(e.state))
|
||||
case t => t.redraw()
|
||||
}
|
||||
}
|
||||
def prompt(event: ConsolePromptEvent): Unit = {
|
||||
if (Terminal.systemInIsAttached) {
|
||||
askUserThread.synchronized {
|
||||
askUserThread.get match {
|
||||
case null => askUserThread.set(makeAskUserThread(event.state))
|
||||
case t => t.redraw()
|
||||
}
|
||||
case _ => //
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def shutdown(): Unit = askUserThread.synchronized {
|
||||
askUserThread.get match {
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin
|
|||
}
|
||||
|
||||
def onResponse(msg: JsonRpcResponseMessage): Unit = {
|
||||
msg.id foreach {
|
||||
msg.id match {
|
||||
case execId if pendingExecIds contains execId =>
|
||||
onReturningReponse(msg)
|
||||
lock.synchronized {
|
||||
|
|
|
|||
|
|
@ -890,7 +890,7 @@ object BuiltinCommands {
|
|||
val exchange = StandardMain.exchange
|
||||
val welcomeState = displayWelcomeBanner(s0)
|
||||
val s1 = exchange run welcomeState
|
||||
exchange publishEventMessage ConsolePromptEvent(s0)
|
||||
exchange prompt ConsolePromptEvent(s0)
|
||||
val minGCInterval = Project
|
||||
.extract(s1)
|
||||
.getOpt(Keys.minForcegcInterval)
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ object MainLoop {
|
|||
/** This is the main function State transfer function of the sbt command processing. */
|
||||
def processCommand(exec: Exec, state: State): State = {
|
||||
val channelName = exec.source map (_.channelName)
|
||||
StandardMain.exchange publishEventMessage
|
||||
StandardMain.exchange notifyStatus
|
||||
ExecStatusEvent("Processing", channelName, exec.execId, Vector())
|
||||
try {
|
||||
def process(): State = {
|
||||
|
|
@ -197,12 +197,7 @@ object MainLoop {
|
|||
newState.remainingCommands.toVector map (_.commandLine),
|
||||
exitCode(newState, state),
|
||||
)
|
||||
if (doneEvent.execId.isDefined) { // send back a response or error
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
StandardMain.exchange publishEvent doneEvent
|
||||
} else { // send back a notification
|
||||
StandardMain.exchange publishEventMessage doneEvent
|
||||
}
|
||||
StandardMain.exchange.respondStatus(doneEvent)
|
||||
newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop())
|
||||
newState.remove(sbt.Keys.currentTaskProgress)
|
||||
}
|
||||
|
|
@ -225,8 +220,7 @@ object MainLoop {
|
|||
ExitCode(ErrorCodes.UnknownError),
|
||||
Option(err.getMessage),
|
||||
)
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
StandardMain.exchange.publishEvent(errorEvent)
|
||||
StandardMain.exchange.respondStatus(errorEvent)
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,16 +16,13 @@ import java.util.concurrent.atomic._
|
|||
import sbt.BasicKeys._
|
||||
import sbt.nio.Watch.NullLogger
|
||||
import sbt.internal.protocol.JsonRpcResponseError
|
||||
import sbt.internal.langserver.{ LogMessageParams, MessageType }
|
||||
import sbt.internal.server._
|
||||
import sbt.internal.util.codec.JValueFormats
|
||||
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal }
|
||||
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal }
|
||||
import sbt.io.syntax._
|
||||
import sbt.io.{ Hash, IO }
|
||||
import sbt.protocol.{ EventMessage, ExecStatusEvent }
|
||||
import sbt.protocol.{ ExecStatusEvent, LogEvent }
|
||||
import sbt.util.{ Level, LogExchange, Logger }
|
||||
import sjsonnew.JsonFormat
|
||||
import sjsonnew.shaded.scalajson.ast.unsafe._
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
|
@ -51,17 +48,12 @@ private[sbt] final class CommandExchange {
|
|||
private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel]
|
||||
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
|
||||
private[this] val activePrompt = new AtomicBoolean(false)
|
||||
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
|
||||
|
||||
def channels: List[CommandChannel] = channelBuffer.toList
|
||||
private[this] def removeChannels(toDel: List[CommandChannel]): Unit = {
|
||||
toDel match {
|
||||
case Nil => // do nothing
|
||||
case xs =>
|
||||
channelBufferLock.synchronized {
|
||||
channelBuffer --= xs
|
||||
()
|
||||
}
|
||||
private[this] def removeChannel(channel: CommandChannel): Unit = {
|
||||
channelBufferLock.synchronized {
|
||||
channelBuffer -= channel
|
||||
()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -206,19 +198,7 @@ private[sbt] final class CommandExchange {
|
|||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = {
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
channels.foreach {
|
||||
case _: ConsoleChannel =>
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
// broadcast to all network channels
|
||||
c.respondError(code, message, execId, source)
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
respondError(JsonRpcResponseError(code, message), execId, source)
|
||||
}
|
||||
|
||||
private[sbt] def respondError(
|
||||
|
|
@ -226,19 +206,13 @@ private[sbt] final class CommandExchange {
|
|||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = {
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
channels.foreach {
|
||||
case _: ConsoleChannel =>
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
// broadcast to all network channels
|
||||
c.respondError(err, execId, source)
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
for {
|
||||
source <- source.map(_.channelName)
|
||||
channel <- channels.collectFirst {
|
||||
// broadcast to the source channel only
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} tryTo(_.respondError(err, execId))(channel)
|
||||
}
|
||||
|
||||
// This is an interface to directly respond events.
|
||||
|
|
@ -247,146 +221,89 @@ private[sbt] final class CommandExchange {
|
|||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = {
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
channels.foreach {
|
||||
case _: ConsoleChannel =>
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
// broadcast to all network channels
|
||||
c.respondEvent(event, execId, source)
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
for {
|
||||
source <- source.map(_.channelName)
|
||||
channel <- channels.collectFirst {
|
||||
// broadcast to the source channel only
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} tryTo(_.respondResult(event, execId))(channel)
|
||||
}
|
||||
|
||||
// This is an interface to directly notify events.
|
||||
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
channels.foreach {
|
||||
case _: ConsoleChannel =>
|
||||
// c.publishEvent(event)
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
c.notifyEvent(method, params)
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
channels
|
||||
.collect { case c: NetworkChannel => c }
|
||||
.foreach {
|
||||
tryTo(_.notifyEvent(method, params))
|
||||
}
|
||||
}
|
||||
|
||||
private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit =
|
||||
try x
|
||||
catch { case _: IOException => toDel += c }
|
||||
private def tryTo(f: NetworkChannel => Unit)(
|
||||
channel: NetworkChannel
|
||||
): Unit =
|
||||
try f(channel)
|
||||
catch { case _: IOException => removeChannel(channel) }
|
||||
|
||||
def publishEvent[A: JsonFormat](event: A): Unit = {
|
||||
val broadcastStringMessage = true
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
def respondStatus(event: ExecStatusEvent): Unit = {
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
for {
|
||||
source <- event.channelName
|
||||
channel <- channels.collectFirst {
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} {
|
||||
if (event.execId.isEmpty) {
|
||||
tryTo(_.notifyEvent(event))(channel)
|
||||
} else {
|
||||
event.exitCode match {
|
||||
case None | Some(0) =>
|
||||
tryTo(_.respondResult(event, event.execId))(channel)
|
||||
case Some(code) =>
|
||||
tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel)
|
||||
}
|
||||
}
|
||||
|
||||
event match {
|
||||
case entry: StringEvent =>
|
||||
val params = toLogMessageParams(entry)
|
||||
channels collect {
|
||||
case c: ConsoleChannel =>
|
||||
if (broadcastStringMessage || (entry.channelName forall (_ == c.name)))
|
||||
c.publishEvent(event)
|
||||
case c: NetworkChannel =>
|
||||
tryTo(
|
||||
{
|
||||
// Note that language server's LogMessageParams does not hold the execid,
|
||||
// so this is weaker than the StringMessage. We might want to double-send
|
||||
// in case we have a better client that can utilize the knowledge.
|
||||
import sbt.internal.langserver.codec.JsonProtocol._
|
||||
if (broadcastStringMessage || (entry.channelName contains c.name))
|
||||
c.jsonRpcNotify("window/logMessage", params)
|
||||
},
|
||||
c,
|
||||
toDel
|
||||
)
|
||||
}
|
||||
case entry: ExecStatusEvent =>
|
||||
channels collect {
|
||||
case c: ConsoleChannel =>
|
||||
if (entry.channelName forall (_ == c.name)) c.publishEvent(event)
|
||||
case c: NetworkChannel =>
|
||||
if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel)
|
||||
}
|
||||
case _ =>
|
||||
channels foreach {
|
||||
case c: ConsoleChannel => c.publishEvent(event)
|
||||
case c: NetworkChannel =>
|
||||
tryTo(c.publishEvent(event), c, toDel)
|
||||
}
|
||||
tryTo(_.respond(event, event.execId))(channel)
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
}
|
||||
|
||||
private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = {
|
||||
LogMessageParams(MessageType.fromLevelString(event.level), event.message)
|
||||
}
|
||||
|
||||
/**
|
||||
* This publishes object events. The type information has been
|
||||
* erased because it went through logging.
|
||||
*/
|
||||
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = {
|
||||
import jsonFormat._
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
def json: JValue = JObject(
|
||||
JField("type", JString(event.contentType)),
|
||||
Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++
|
||||
(event.channelName.toVector map { channelName =>
|
||||
JField("channelName", JString(channelName))
|
||||
}) ++
|
||||
(event.execId.toVector map { execId =>
|
||||
JField("execId", JString(execId))
|
||||
}): _*
|
||||
)
|
||||
channels collect {
|
||||
case c: ConsoleChannel =>
|
||||
c.publishEvent(json)
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
c.publishObjectEvent(event)
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
removeChannels(toDel.toList)
|
||||
private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = {
|
||||
for {
|
||||
source <- event.channelName
|
||||
channel <- channels.collectFirst {
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} tryTo(_.respond(event))(channel)
|
||||
}
|
||||
|
||||
// fanout publishEvent
|
||||
def publishEventMessage(event: EventMessage): Unit = {
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
|
||||
event match {
|
||||
// Special treatment for ConsolePromptEvent since it's hand coded without codec.
|
||||
case entry: ConsolePromptEvent =>
|
||||
channels collect {
|
||||
case c: ConsoleChannel =>
|
||||
c.publishEventMessage(entry)
|
||||
activePrompt.set(Terminal.systemInIsAttached)
|
||||
}
|
||||
case entry: ExecStatusEvent =>
|
||||
channels collect {
|
||||
case c: ConsoleChannel =>
|
||||
if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event)
|
||||
case c: NetworkChannel =>
|
||||
if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel)
|
||||
}
|
||||
case _ =>
|
||||
channels collect {
|
||||
case c: ConsoleChannel => c.publishEventMessage(event)
|
||||
case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel)
|
||||
}
|
||||
}
|
||||
|
||||
removeChannels(toDel.toList)
|
||||
def prompt(event: ConsolePromptEvent): Unit = {
|
||||
activePrompt.set(Terminal.systemInIsAttached)
|
||||
channels
|
||||
.collect { case c: ConsoleChannel => c }
|
||||
.foreach { _.prompt(event) }
|
||||
}
|
||||
|
||||
def logMessage(event: LogEvent): Unit = {
|
||||
channels
|
||||
.collect { case c: NetworkChannel => c }
|
||||
.foreach {
|
||||
tryTo(_.notifyEvent(event))
|
||||
}
|
||||
}
|
||||
|
||||
def notifyStatus(event: ExecStatusEvent): Unit = {
|
||||
for {
|
||||
source <- event.channelName
|
||||
channel <- channels.collectFirst {
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} tryTo(_.notifyEvent(event))(channel)
|
||||
}
|
||||
|
||||
private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ import org.apache.logging.log4j.core.config.Property
|
|||
import sbt.util.Level
|
||||
import sbt.internal.util._
|
||||
import sbt.protocol.LogEvent
|
||||
import sbt.internal.util.codec._
|
||||
|
||||
class RelayAppender(name: String)
|
||||
extends AbstractAppender(
|
||||
|
|
@ -40,15 +39,12 @@ class RelayAppender(name: String)
|
|||
}
|
||||
}
|
||||
def appendLog(level: Level.Value, message: => String): Unit = {
|
||||
exchange.publishEventMessage(LogEvent(level.toString, message))
|
||||
exchange.logMessage(LogEvent(level.toString, message))
|
||||
}
|
||||
def appendEvent(event: AnyRef): Unit =
|
||||
event match {
|
||||
case x: StringEvent => {
|
||||
import JsonProtocol._
|
||||
exchange.publishEvent(x: AbstractEntry)
|
||||
}
|
||||
case x: ObjectEvent[_] => exchange.publishObjectEvent(x)
|
||||
case x: StringEvent => exchange.logMessage(LogEvent(x.message, x.level))
|
||||
case x: ObjectEvent[_] => exchange.respondObjectEvent(x)
|
||||
case _ =>
|
||||
println(s"appendEvent: ${event.getClass}")
|
||||
()
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ private[sbt] object Definition {
|
|||
def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = {
|
||||
for {
|
||||
channel <- StandardMain.exchange.channels.collectFirst {
|
||||
case c if c.name == source.channelName => c
|
||||
case c: NetworkChannel if c.name == source.channelName => c
|
||||
}
|
||||
} {
|
||||
channel.publishEvent(params, Option(execId))
|
||||
channel.respond(params, Option(execId))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ package internal
|
|||
package server
|
||||
|
||||
import sjsonnew.JsonFormat
|
||||
import sjsonnew.shaded.scalajson.ast.unsafe.JValue
|
||||
import sjsonnew.support.scalajson.unsafe.Converter
|
||||
import sbt.protocol.Serialization
|
||||
import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q }
|
||||
|
|
@ -103,7 +102,7 @@ private[sbt] object LanguageServerProtocol {
|
|||
}
|
||||
|
||||
/** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */
|
||||
private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
|
||||
private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
|
||||
|
||||
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
|
||||
|
||||
|
|
@ -117,10 +116,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
|
|||
|
||||
protected lazy val callbackImpl: ServerCallback = new ServerCallback {
|
||||
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
|
||||
self.jsonRpcRespond(event, execId)
|
||||
self.respondResult(event, execId)
|
||||
|
||||
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
|
||||
self.jsonRpcRespondError(execId, code, message)
|
||||
self.respondError(code, message, execId)
|
||||
|
||||
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit =
|
||||
self.jsonRpcNotify(method, params)
|
||||
|
|
@ -162,28 +161,16 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
|
|||
}
|
||||
|
||||
/** Respond back to Language Server's client. */
|
||||
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
|
||||
val m =
|
||||
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
|
||||
val response =
|
||||
JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None)
|
||||
val bytes = Serialization.serializeResponseMessage(m)
|
||||
val bytes = Serialization.serializeResponseMessage(response)
|
||||
publishBytes(bytes)
|
||||
}
|
||||
|
||||
/** Respond back to Language Server's client. */
|
||||
private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
|
||||
jsonRpcRespondErrorImpl(execId, code, message, None)
|
||||
|
||||
/** Respond back to Language Server's client. */
|
||||
private[sbt] def jsonRpcRespondError[A: JsonFormat](
|
||||
execId: Option[String],
|
||||
code: Long,
|
||||
message: String,
|
||||
data: A,
|
||||
): Unit =
|
||||
jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get))
|
||||
|
||||
private[sbt] def jsonRpcRespondError(
|
||||
execId: Option[String],
|
||||
execId: String,
|
||||
err: JsonRpcResponseError
|
||||
): Unit = {
|
||||
val m = JsonRpcResponseMessage("2.0", execId, None, Option(err))
|
||||
|
|
@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
|
|||
publishBytes(bytes)
|
||||
}
|
||||
|
||||
private[this] def jsonRpcRespondErrorImpl(
|
||||
execId: Option[String],
|
||||
code: Long,
|
||||
message: String,
|
||||
data: Option[JValue],
|
||||
): Unit = {
|
||||
val e = JsonRpcResponseError(code, message, data)
|
||||
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
|
||||
val bytes = Serialization.serializeResponseMessage(m)
|
||||
publishBytes(bytes)
|
||||
}
|
||||
|
||||
/** Notify to Language Server's client. */
|
||||
private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = {
|
||||
val m =
|
||||
|
|
|
|||
|
|
@ -12,19 +12,22 @@ package server
|
|||
import java.net.{ Socket, SocketTimeoutException }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import sjsonnew._
|
||||
import scala.annotation.tailrec
|
||||
import sbt.protocol._
|
||||
import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams }
|
||||
import sbt.internal.util.{ ObjectEvent, StringEvent }
|
||||
import sbt.internal.util.complete.Parser
|
||||
import sbt.internal.util.codec.JValueFormats
|
||||
import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes }
|
||||
import sbt.internal.protocol.{
|
||||
JsonRpcResponseError,
|
||||
JsonRpcNotificationMessage,
|
||||
JsonRpcRequestMessage,
|
||||
JsonRpcNotificationMessage
|
||||
JsonRpcResponseError
|
||||
}
|
||||
import sbt.internal.util.codec.JValueFormats
|
||||
import sbt.internal.util.complete.Parser
|
||||
import sbt.internal.util.ObjectEvent
|
||||
import sbt.protocol._
|
||||
import sbt.util.Logger
|
||||
import sjsonnew._
|
||||
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable
|
||||
import scala.util.Try
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
|
|
@ -53,6 +56,7 @@ final class NetworkChannel(
|
|||
private val VsCode = sbt.protocol.Serialization.VsCode
|
||||
private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8"
|
||||
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
|
||||
private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map()
|
||||
|
||||
def setContentType(ct: String): Unit = synchronized { _contentType = ct }
|
||||
def contentType: String = _contentType
|
||||
|
|
@ -176,6 +180,7 @@ final class NetworkChannel(
|
|||
intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) {
|
||||
case (f, i) => f orElse i.onRequest
|
||||
}
|
||||
|
||||
lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] =
|
||||
intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) {
|
||||
case (f, i) => f orElse i.onNotification
|
||||
|
|
@ -186,25 +191,27 @@ final class NetworkChannel(
|
|||
Serialization.deserializeJsonMessage(chunk) match {
|
||||
case Right(req: JsonRpcRequestMessage) =>
|
||||
try {
|
||||
registerRequest(req)
|
||||
onRequestMessage(req)
|
||||
} catch {
|
||||
case LangServerError(code, message) =>
|
||||
log.debug(s"sending error: $code: $message")
|
||||
jsonRpcRespondError(Option(req.id), code, message)
|
||||
respondError(code, message, Some(req.id))
|
||||
}
|
||||
case Right(ntf: JsonRpcNotificationMessage) =>
|
||||
try {
|
||||
onNotification(ntf)
|
||||
} catch {
|
||||
case LangServerError(code, message) =>
|
||||
log.debug(s"sending error: $code: $message")
|
||||
jsonRpcRespondError(None, code, message) // new id?
|
||||
logMessage("error", s"Error $code while handling notification: $message")
|
||||
}
|
||||
case Right(msg) =>
|
||||
log.debug(s"Unhandled message: $msg")
|
||||
case Left(errorDesc) =>
|
||||
val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
|
||||
jsonRpcRespondError(None, ErrorCodes.ParseError, msg)
|
||||
logMessage(
|
||||
"error",
|
||||
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc"
|
||||
)
|
||||
}
|
||||
} else {
|
||||
contentType match {
|
||||
|
|
@ -213,13 +220,17 @@ final class NetworkChannel(
|
|||
.deserializeCommand(chunk)
|
||||
.fold(
|
||||
errorDesc =>
|
||||
log.error(
|
||||
logMessage(
|
||||
"error",
|
||||
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
|
||||
),
|
||||
onCommand
|
||||
)
|
||||
case _ =>
|
||||
log.error(s"Unknown Content-Type: $contentType")
|
||||
logMessage(
|
||||
"error",
|
||||
s"Unknown Content-Type: $contentType"
|
||||
)
|
||||
}
|
||||
} // if-else
|
||||
}
|
||||
|
|
@ -245,24 +256,48 @@ final class NetworkChannel(
|
|||
}
|
||||
}
|
||||
|
||||
private def registerRequest(request: JsonRpcRequestMessage): Unit = {
|
||||
this.synchronized {
|
||||
pendingRequests += (request.id -> request)
|
||||
()
|
||||
}
|
||||
}
|
||||
|
||||
private[sbt] def respondError(
|
||||
err: JsonRpcResponseError,
|
||||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = jsonRpcRespondError(execId, err)
|
||||
execId: Option[String]
|
||||
): Unit = this.synchronized {
|
||||
execId match {
|
||||
case Some(id) if pendingRequests.contains(id) =>
|
||||
pendingRequests -= id
|
||||
jsonRpcRespondError(id, err)
|
||||
case _ =>
|
||||
logMessage("error", s"Error ${err.code}: ${err.message}")
|
||||
}
|
||||
}
|
||||
|
||||
private[sbt] def respondError(
|
||||
code: Long,
|
||||
message: String,
|
||||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = jsonRpcRespondError(execId, code, message)
|
||||
execId: Option[String]
|
||||
): Unit = {
|
||||
respondError(JsonRpcResponseError(code, message), execId)
|
||||
}
|
||||
|
||||
private[sbt] def respondEvent[A: JsonFormat](
|
||||
private[sbt] def respondResult[A: JsonFormat](
|
||||
event: A,
|
||||
execId: Option[String],
|
||||
source: Option[CommandSource]
|
||||
): Unit = jsonRpcRespond(event, execId)
|
||||
execId: Option[String]
|
||||
): Unit = this.synchronized {
|
||||
execId match {
|
||||
case Some(id) if pendingRequests.contains(id) =>
|
||||
pendingRequests -= id
|
||||
jsonRpcRespond(event, id)
|
||||
case _ =>
|
||||
log.debug(
|
||||
s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
|
||||
if (isLanguageServerProtocol) {
|
||||
|
|
@ -272,19 +307,11 @@ final class NetworkChannel(
|
|||
}
|
||||
}
|
||||
|
||||
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = {
|
||||
def respond[A: JsonFormat](event: A): Unit = respond(event, None)
|
||||
|
||||
def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
|
||||
if (isLanguageServerProtocol) {
|
||||
event match {
|
||||
case entry: StringEvent => logMessage(entry.level, entry.message)
|
||||
case entry: ExecStatusEvent =>
|
||||
entry.exitCode match {
|
||||
case None => jsonRpcRespond(event, entry.execId)
|
||||
case Some(0) => jsonRpcRespond(event, entry.execId)
|
||||
case Some(exitCode) =>
|
||||
jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse(""))
|
||||
}
|
||||
case _ => jsonRpcRespond(event, execId)
|
||||
}
|
||||
respondResult(event, execId)
|
||||
} else {
|
||||
contentType match {
|
||||
case SbtX1Protocol =>
|
||||
|
|
@ -295,7 +322,7 @@ final class NetworkChannel(
|
|||
}
|
||||
}
|
||||
|
||||
def publishEventMessage(event: EventMessage): Unit = {
|
||||
def notifyEvent(event: EventMessage): Unit = {
|
||||
if (isLanguageServerProtocol) {
|
||||
event match {
|
||||
case entry: LogEvent => logMessage(entry.level, entry.message)
|
||||
|
|
@ -316,22 +343,22 @@ final class NetworkChannel(
|
|||
* This publishes object events. The type information has been
|
||||
* erased because it went through logging.
|
||||
*/
|
||||
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = {
|
||||
private[sbt] def respond(event: ObjectEvent[_]): Unit = {
|
||||
import sjsonnew.shaded.scalajson.ast.unsafe._
|
||||
if (isLanguageServerProtocol) onObjectEvent(event)
|
||||
else {
|
||||
import jsonFormat._
|
||||
val json: JValue = JObject(
|
||||
JField("type", JString(event.contentType)),
|
||||
(Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++
|
||||
(event.channelName.toVector map { channelName =>
|
||||
Seq(JField("message", event.json), JField("level", JString(event.level.toString))) ++
|
||||
(event.channelName map { channelName =>
|
||||
JField("channelName", JString(channelName))
|
||||
}) ++
|
||||
(event.execId.toVector map { execId =>
|
||||
(event.execId map { execId =>
|
||||
JField("execId", JString(execId))
|
||||
})): _*
|
||||
}): _*
|
||||
)
|
||||
publishEvent(json)
|
||||
respond(json, event.execId)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -358,7 +385,7 @@ final class NetworkChannel(
|
|||
authenticate(x) match {
|
||||
case true =>
|
||||
initialized = true
|
||||
publishEventMessage(ChannelAcceptedEvent(name))
|
||||
notifyEvent(ChannelAcceptedEvent(name))
|
||||
case _ => sys.error("invalid token")
|
||||
}
|
||||
case None => sys.error("init command but without token.")
|
||||
|
|
@ -383,8 +410,8 @@ final class NetworkChannel(
|
|||
if (initialized) {
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
SettingQuery.handleSettingQueryEither(req, structure) match {
|
||||
case Right(x) => jsonRpcRespond(x, execId)
|
||||
case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s)
|
||||
case Right(x) => respondResult(x, execId)
|
||||
case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId)
|
||||
}
|
||||
} else {
|
||||
log.warn(s"ignoring query $req before initialization")
|
||||
|
|
@ -400,32 +427,31 @@ final class NetworkChannel(
|
|||
Parser
|
||||
.completions(sstate.combinedParser, cp.query, 9)
|
||||
.get
|
||||
.map(c => {
|
||||
.flatMap { c =>
|
||||
if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
|
||||
else None
|
||||
})
|
||||
.flatten
|
||||
.map(c => cp.query + c.toString)
|
||||
}
|
||||
.map(c => cp.query + c)
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
jsonRpcRespond(
|
||||
respondResult(
|
||||
CompletionResponse(
|
||||
items = completionItems.toVector
|
||||
),
|
||||
execId
|
||||
)
|
||||
case _ =>
|
||||
jsonRpcRespondError(
|
||||
execId,
|
||||
respondError(
|
||||
ErrorCodes.UnknownError,
|
||||
"No available sbt state"
|
||||
"No available sbt state",
|
||||
execId
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
jsonRpcRespondError(
|
||||
execId,
|
||||
case NonFatal(_) =>
|
||||
respondError(
|
||||
ErrorCodes.UnknownError,
|
||||
"Completions request failed"
|
||||
"Completions request failed",
|
||||
execId
|
||||
)
|
||||
}
|
||||
} else {
|
||||
|
|
@ -436,10 +462,10 @@ final class NetworkChannel(
|
|||
protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = {
|
||||
if (initialized) {
|
||||
|
||||
def errorRespond(msg: String) = jsonRpcRespondError(
|
||||
execId,
|
||||
def errorRespond(msg: String) = respondError(
|
||||
ErrorCodes.RequestCancelled,
|
||||
msg
|
||||
msg,
|
||||
execId
|
||||
)
|
||||
|
||||
try {
|
||||
|
|
@ -465,11 +491,11 @@ final class NetworkChannel(
|
|||
runningEngine.cancelAndShutdown()
|
||||
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
jsonRpcRespond(
|
||||
respondResult(
|
||||
ExecStatusEvent(
|
||||
"Task cancelled",
|
||||
Some(name),
|
||||
Some(runningExecId.toString),
|
||||
Some(runningExecId),
|
||||
Vector(),
|
||||
None,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ package sbt.internal.protocol
|
|||
*/
|
||||
final class JsonRpcResponseMessage private (
|
||||
jsonrpc: String,
|
||||
val id: Option[String],
|
||||
val id: String,
|
||||
val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue],
|
||||
val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable {
|
||||
|
||||
|
|
@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private (
|
|||
override def toString: String = {
|
||||
s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)"""
|
||||
}
|
||||
private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
|
||||
private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
|
||||
new JsonRpcResponseMessage(jsonrpc, id, result, error)
|
||||
}
|
||||
def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = {
|
||||
copy(jsonrpc = jsonrpc)
|
||||
}
|
||||
def withId(id: Option[String]): JsonRpcResponseMessage = {
|
||||
copy(id = id)
|
||||
}
|
||||
def withId(id: String): JsonRpcResponseMessage = {
|
||||
copy(id = Option(id))
|
||||
copy(id = id)
|
||||
}
|
||||
def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = {
|
||||
copy(result = result)
|
||||
|
|
@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private (
|
|||
}
|
||||
object JsonRpcResponseMessage {
|
||||
|
||||
def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
|
||||
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error))
|
||||
def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
|
||||
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage
|
|||
jsonrpc: String!
|
||||
|
||||
## The request id.
|
||||
id: String
|
||||
id: String!
|
||||
|
||||
## The result of a request. This can be omitted in
|
||||
## the case of an error.
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats {
|
|||
unbuilder.beginObject(js)
|
||||
val jsonrpc = unbuilder.readField[String]("jsonrpc")
|
||||
val id = try {
|
||||
unbuilder.readField[Option[String]]("id")
|
||||
unbuilder.readField[String]("id")
|
||||
} catch {
|
||||
case _: DeserializationException =>
|
||||
unbuilder.readField[Option[Long]]("id") map { _.toString }
|
||||
unbuilder.readField[Long]("id").toString
|
||||
}
|
||||
|
||||
val result = unbuilder.lookupField("result") map {
|
||||
|
|
@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats {
|
|||
}
|
||||
builder.beginObject()
|
||||
builder.addField("jsonrpc", obj.jsonrpc)
|
||||
obj.id foreach { id =>
|
||||
parseId(id) match {
|
||||
case Right(strId) => builder.addField("id", strId)
|
||||
case Left(longId) => builder.addField("id", longId)
|
||||
}
|
||||
parseId(obj.id) match {
|
||||
case Right(strId) => builder.addField("id", strId)
|
||||
case Left(longId) => builder.addField("id", longId)
|
||||
}
|
||||
builder.addField("result", obj.result map parseResult)
|
||||
builder.addField("error", obj.error)
|
||||
|
|
|
|||
|
|
@ -25,12 +25,25 @@ Global / serverHandlers += ServerHandler({ callback =>
|
|||
case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" =>
|
||||
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
|
||||
()
|
||||
case r if r.method == "foo/respondTwice" =>
|
||||
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
|
||||
jsonRpcRespond("concurrent response", Some(r.id))
|
||||
()
|
||||
case r if r.method == "foo/resultAndError" =>
|
||||
appendExec(Exec("fooCustomFail", Some(r.id), Some(CommandSource(callback.name))))
|
||||
jsonRpcRespond("concurrent response", Some(r.id))
|
||||
()
|
||||
},
|
||||
PartialFunction.empty
|
||||
{
|
||||
case r if r.method == "foo/customNotification" =>
|
||||
jsonRpcRespond("notification result", None)
|
||||
()
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
lazy val fooClasspath = taskKey[Unit]("")
|
||||
|
||||
lazy val root = (project in file("."))
|
||||
.settings(
|
||||
name := "response",
|
||||
|
|
@ -55,5 +68,5 @@ lazy val root = (project in file("."))
|
|||
val s = state.value
|
||||
val cp = (Compile / fullClasspath).value
|
||||
s.respondEvent(cp.map(_.data))
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,15 +22,6 @@ object EventsTest extends AbstractServerTest {
|
|||
})
|
||||
}
|
||||
|
||||
test("report task failures in case of exceptions") { _ =>
|
||||
svr.sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }"""
|
||||
)
|
||||
assert(svr.waitForString(10.seconds) { s =>
|
||||
(s contains """"id":11""") && (s contains """"error":""")
|
||||
})
|
||||
}
|
||||
|
||||
test("return error if cancelling non-matched task id") { _ =>
|
||||
svr.sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }"""
|
||||
|
|
|
|||
|
|
@ -64,4 +64,54 @@ object ResponseTest extends AbstractServerTest {
|
|||
(s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""")
|
||||
})
|
||||
}
|
||||
|
||||
test("respond concurrently from a task and the handler") { _ =>
|
||||
svr.sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id": "15", "method": "foo/respondTwice", "params": {} }"""
|
||||
)
|
||||
assert {
|
||||
svr.waitForString(1.seconds) { s =>
|
||||
println(s)
|
||||
s contains "\"id\":\"15\""
|
||||
}
|
||||
}
|
||||
assert {
|
||||
// the second response should never be sent
|
||||
svr.neverReceive(500.milliseconds) { s =>
|
||||
println(s)
|
||||
s contains "\"id\":\"15\""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("concurrent result and error") { _ =>
|
||||
svr.sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id": "16", "method": "foo/resultAndError", "params": {} }"""
|
||||
)
|
||||
assert {
|
||||
svr.waitForString(1.seconds) { s =>
|
||||
println(s)
|
||||
s contains "\"id\":\"16\""
|
||||
}
|
||||
}
|
||||
assert {
|
||||
// the second response (result or error) should never be sent
|
||||
svr.neverReceive(500.milliseconds) { s =>
|
||||
println(s)
|
||||
s contains "\"id\":\"16\""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("response to a notification should not be sent") { _ =>
|
||||
svr.sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "method": "foo/customNotification", "params": {} }"""
|
||||
)
|
||||
assert {
|
||||
svr.neverReceive(500.milliseconds) { s =>
|
||||
println(s)
|
||||
s contains "\"result\":\"notification result\""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,13 +8,14 @@
|
|||
package testpkg
|
||||
|
||||
import java.io.{ File, IOException }
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
import verify._
|
||||
import sbt.RunFromSourceMain
|
||||
import sbt.io.IO
|
||||
import sbt.io.syntax._
|
||||
import sbt.protocol.ClientSocket
|
||||
import scala.annotation.tailrec
|
||||
|
||||
import scala.concurrent._
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.{ Success, Try }
|
||||
|
|
@ -150,6 +151,7 @@ case class TestServer(
|
|||
sbtVersion: String,
|
||||
classpath: Seq[File]
|
||||
) {
|
||||
import scala.concurrent.ExecutionContext.Implicits._
|
||||
import TestServer.hostLog
|
||||
|
||||
val readBuffer = new Array[Byte](40960)
|
||||
|
|
@ -183,15 +185,25 @@ case class TestServer(
|
|||
waitForPortfile(90.seconds)
|
||||
|
||||
// make connection to the socket described in the portfile
|
||||
val (sk, tkn) = ClientSocket.socket(portfile)
|
||||
val out = sk.getOutputStream
|
||||
val in = sk.getInputStream
|
||||
var (sk, _) = ClientSocket.socket(portfile)
|
||||
var out = sk.getOutputStream
|
||||
var in = sk.getInputStream
|
||||
|
||||
// initiate handshake
|
||||
sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
|
||||
)
|
||||
|
||||
def resetConnection() = {
|
||||
sk = ClientSocket.socket(portfile)._1
|
||||
out = sk.getOutputStream
|
||||
in = sk.getInputStream
|
||||
|
||||
sendJsonRpc(
|
||||
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
|
||||
)
|
||||
}
|
||||
|
||||
def test(f: TestServer => Future[Assertion]): Future[Assertion] = {
|
||||
f(this)
|
||||
}
|
||||
|
|
@ -230,7 +242,7 @@ case class TestServer(
|
|||
writeEndLine
|
||||
}
|
||||
|
||||
def readFrame: Option[String] = {
|
||||
def readFrame: Future[Option[String]] = Future {
|
||||
def getContentLength: Int = {
|
||||
readLine map { line =>
|
||||
line.drop(16).toInt
|
||||
|
|
@ -244,14 +256,28 @@ case class TestServer(
|
|||
|
||||
final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = {
|
||||
val deadline = duration.fromNow
|
||||
@tailrec
|
||||
def impl(): Boolean = {
|
||||
if (deadline.isOverdue || !process.isAlive) false
|
||||
else
|
||||
readFrame.fold(false)(f) || {
|
||||
Thread.sleep(100)
|
||||
impl
|
||||
}
|
||||
try {
|
||||
Await.result(readFrame, deadline.timeLeft).fold(false)(f) || impl
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
resetConnection() // create a new connection to invalidate the running readFrame future
|
||||
false
|
||||
}
|
||||
}
|
||||
impl()
|
||||
}
|
||||
|
||||
final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = {
|
||||
val deadline = duration.fromNow
|
||||
def impl(): Boolean = {
|
||||
try {
|
||||
Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) && impl
|
||||
} catch {
|
||||
case _: TimeoutException =>
|
||||
resetConnection() // create a new connection to invalidate the running readFrame future
|
||||
true
|
||||
}
|
||||
}
|
||||
impl()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue