Merge pull request #5549 from adpi2/issue/json-response

Prevent more than one response per json RPC request
This commit is contained in:
eugene yokota 2020-05-12 22:07:11 -04:00 committed by GitHub
commit 4592493617
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 316 additions and 340 deletions

3
.gitignore vendored
View File

@ -8,3 +8,6 @@ npm-debug.log
!sbt/src/server-test/completions/target !sbt/src/server-test/completions/target
.big .big
.idea .idea
.bloop
.metals
metals.sbt

View File

@ -677,7 +677,8 @@ lazy val protocolProj = (project in file("protocol"))
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQueryFailure.copy$default$*"), exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQueryFailure.copy$default$*"),
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy"), exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy"),
exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"), exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"),
// ignore missing methods in sbt.internal // ignore missing or incompatible methods in sbt.internal
exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
exclude[DirectMissingMethodProblem]("sbt.internal.*"), exclude[DirectMissingMethodProblem]("sbt.internal.*"),
exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"), exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"),
) )
@ -876,7 +877,7 @@ lazy val mainProj = (project in file("main"))
// New and changed methods on KeyIndex. internal. // New and changed methods on KeyIndex. internal.
exclude[ReversedMissingMethodProblem]("sbt.internal.KeyIndex.*"), exclude[ReversedMissingMethodProblem]("sbt.internal.KeyIndex.*"),
// internal // internal
exclude[IncompatibleMethTypeProblem]("sbt.internal.server.LanguageServerReporter.*"), exclude[IncompatibleMethTypeProblem]("sbt.internal.*"),
// Changed signature or removed private[sbt] methods // Changed signature or removed private[sbt] methods
exclude[DirectMissingMethodProblem]("sbt.Classpaths.unmanagedLibs0"), exclude[DirectMissingMethodProblem]("sbt.Classpaths.unmanagedLibs0"),
exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"), exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"),

View File

@ -11,7 +11,6 @@ package internal
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ConcurrentLinkedQueue
import sbt.protocol.EventMessage import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
/** /**
* A command channel represents an IO device such as network socket or human * A command channel represents an IO device such as network socket or human
@ -42,9 +41,6 @@ abstract class CommandChannel {
} }
def poll: Option[Exec] = Option(commandQueue.poll) def poll: Option[Exec] = Option(commandQueue.poll)
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit
final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None)
def publishEventMessage(event: EventMessage): Unit
def publishBytes(bytes: Array[Byte]): Unit def publishBytes(bytes: Array[Byte]): Unit
def shutdown(): Unit def shutdown(): Unit
def name: String def name: String

View File

@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference
import sbt.BasicKeys._ import sbt.BasicKeys._
import sbt.internal.util._ import sbt.internal.util._
import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel {
private[this] val askUserThread = new AtomicReference[AskUserThread] private[this] val askUserThread = new AtomicReference[AskUserThread]
@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel
def publishBytes(bytes: Array[Byte]): Unit = () def publishBytes(bytes: Array[Byte]): Unit = ()
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = () def prompt(event: ConsolePromptEvent): Unit = {
if (Terminal.systemInIsAttached) {
def publishEventMessage(event: EventMessage): Unit = askUserThread.synchronized {
event match { askUserThread.get match {
case e: ConsolePromptEvent => case null => askUserThread.set(makeAskUserThread(event.state))
if (Terminal.systemInIsAttached) { case t => t.redraw()
askUserThread.synchronized {
askUserThread.get match {
case null => askUserThread.set(makeAskUserThread(e.state))
case t => t.redraw()
}
}
} }
case _ => // }
} }
}
def shutdown(): Unit = askUserThread.synchronized { def shutdown(): Unit = askUserThread.synchronized {
askUserThread.get match { askUserThread.get match {

View File

@ -119,7 +119,7 @@ class NetworkClient(configuration: xsbti.AppConfiguration, arguments: List[Strin
} }
def onResponse(msg: JsonRpcResponseMessage): Unit = { def onResponse(msg: JsonRpcResponseMessage): Unit = {
msg.id foreach { msg.id match {
case execId if pendingExecIds contains execId => case execId if pendingExecIds contains execId =>
onReturningReponse(msg) onReturningReponse(msg)
lock.synchronized { lock.synchronized {

View File

@ -890,7 +890,7 @@ object BuiltinCommands {
val exchange = StandardMain.exchange val exchange = StandardMain.exchange
val welcomeState = displayWelcomeBanner(s0) val welcomeState = displayWelcomeBanner(s0)
val s1 = exchange run welcomeState val s1 = exchange run welcomeState
exchange publishEventMessage ConsolePromptEvent(s0) exchange prompt ConsolePromptEvent(s0)
val minGCInterval = Project val minGCInterval = Project
.extract(s1) .extract(s1)
.getOpt(Keys.minForcegcInterval) .getOpt(Keys.minForcegcInterval)

View File

@ -176,7 +176,7 @@ object MainLoop {
/** This is the main function State transfer function of the sbt command processing. */ /** This is the main function State transfer function of the sbt command processing. */
def processCommand(exec: Exec, state: State): State = { def processCommand(exec: Exec, state: State): State = {
val channelName = exec.source map (_.channelName) val channelName = exec.source map (_.channelName)
StandardMain.exchange publishEventMessage StandardMain.exchange notifyStatus
ExecStatusEvent("Processing", channelName, exec.execId, Vector()) ExecStatusEvent("Processing", channelName, exec.execId, Vector())
try { try {
def process(): State = { def process(): State = {
@ -197,12 +197,7 @@ object MainLoop {
newState.remainingCommands.toVector map (_.commandLine), newState.remainingCommands.toVector map (_.commandLine),
exitCode(newState, state), exitCode(newState, state),
) )
if (doneEvent.execId.isDefined) { // send back a response or error StandardMain.exchange.respondStatus(doneEvent)
import sbt.protocol.codec.JsonProtocol._
StandardMain.exchange publishEvent doneEvent
} else { // send back a notification
StandardMain.exchange publishEventMessage doneEvent
}
newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop())
newState.remove(sbt.Keys.currentTaskProgress) newState.remove(sbt.Keys.currentTaskProgress)
} }
@ -225,8 +220,7 @@ object MainLoop {
ExitCode(ErrorCodes.UnknownError), ExitCode(ErrorCodes.UnknownError),
Option(err.getMessage), Option(err.getMessage),
) )
import sbt.protocol.codec.JsonProtocol._ StandardMain.exchange.respondStatus(errorEvent)
StandardMain.exchange.publishEvent(errorEvent)
throw err throw err
} }
} }

View File

@ -16,16 +16,13 @@ import java.util.concurrent.atomic._
import sbt.BasicKeys._ import sbt.BasicKeys._
import sbt.nio.Watch.NullLogger import sbt.nio.Watch.NullLogger
import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.langserver.{ LogMessageParams, MessageType }
import sbt.internal.server._ import sbt.internal.server._
import sbt.internal.util.codec.JValueFormats import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal }
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal }
import sbt.io.syntax._ import sbt.io.syntax._
import sbt.io.{ Hash, IO } import sbt.io.{ Hash, IO }
import sbt.protocol.{ EventMessage, ExecStatusEvent } import sbt.protocol.{ ExecStatusEvent, LogEvent }
import sbt.util.{ Level, LogExchange, Logger } import sbt.util.{ Level, LogExchange, Logger }
import sjsonnew.JsonFormat import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe._
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer import scala.collection.mutable.ListBuffer
@ -51,17 +48,12 @@ private[sbt] final class CommandExchange {
private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel]
private val nextChannelId: AtomicInteger = new AtomicInteger(0) private val nextChannelId: AtomicInteger = new AtomicInteger(0)
private[this] val activePrompt = new AtomicBoolean(false) private[this] val activePrompt = new AtomicBoolean(false)
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
def channels: List[CommandChannel] = channelBuffer.toList def channels: List[CommandChannel] = channelBuffer.toList
private[this] def removeChannels(toDel: List[CommandChannel]): Unit = { private[this] def removeChannel(channel: CommandChannel): Unit = {
toDel match { channelBufferLock.synchronized {
case Nil => // do nothing channelBuffer -= channel
case xs => ()
channelBufferLock.synchronized {
channelBuffer --= xs
()
}
} }
} }
@ -206,19 +198,7 @@ private[sbt] final class CommandExchange {
execId: Option[String], execId: Option[String],
source: Option[CommandSource] source: Option[CommandSource]
): Unit = { ): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty respondError(JsonRpcResponseError(code, message), execId, source)
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondError(code, message, execId, source)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
} }
private[sbt] def respondError( private[sbt] def respondError(
@ -226,19 +206,13 @@ private[sbt] final class CommandExchange {
execId: Option[String], execId: Option[String],
source: Option[CommandSource] source: Option[CommandSource]
): Unit = { ): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty for {
channels.foreach { source <- source.map(_.channelName)
case _: ConsoleChannel => channel <- channels.collectFirst {
case c: NetworkChannel => // broadcast to the source channel only
try { case c: NetworkChannel if c.name == source => c
// broadcast to all network channels }
c.respondError(err, execId, source) } tryTo(_.respondError(err, execId))(channel)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
} }
// This is an interface to directly respond events. // This is an interface to directly respond events.
@ -247,146 +221,89 @@ private[sbt] final class CommandExchange {
execId: Option[String], execId: Option[String],
source: Option[CommandSource] source: Option[CommandSource]
): Unit = { ): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty for {
channels.foreach { source <- source.map(_.channelName)
case _: ConsoleChannel => channel <- channels.collectFirst {
case c: NetworkChannel => // broadcast to the source channel only
try { case c: NetworkChannel if c.name == source => c
// broadcast to all network channels }
c.respondEvent(event, execId, source) } tryTo(_.respondResult(event, execId))(channel)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
} }
// This is an interface to directly notify events. // This is an interface to directly notify events.
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty channels
channels.foreach { .collect { case c: NetworkChannel => c }
case _: ConsoleChannel => .foreach {
// c.publishEvent(event) tryTo(_.notifyEvent(method, params))
case c: NetworkChannel => }
try {
c.notifyEvent(method, params)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
} }
private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit = private def tryTo(f: NetworkChannel => Unit)(
try x channel: NetworkChannel
catch { case _: IOException => toDel += c } ): Unit =
try f(channel)
catch { case _: IOException => removeChannel(channel) }
def publishEvent[A: JsonFormat](event: A): Unit = { def respondStatus(event: ExecStatusEvent): Unit = {
val broadcastStringMessage = true import sbt.protocol.codec.JsonProtocol._
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} {
if (event.execId.isEmpty) {
tryTo(_.notifyEvent(event))(channel)
} else {
event.exitCode match {
case None | Some(0) =>
tryTo(_.respondResult(event, event.execId))(channel)
case Some(code) =>
tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel)
}
}
event match { tryTo(_.respond(event, event.execId))(channel)
case entry: StringEvent =>
val params = toLogMessageParams(entry)
channels collect {
case c: ConsoleChannel =>
if (broadcastStringMessage || (entry.channelName forall (_ == c.name)))
c.publishEvent(event)
case c: NetworkChannel =>
tryTo(
{
// Note that language server's LogMessageParams does not hold the execid,
// so this is weaker than the StringMessage. We might want to double-send
// in case we have a better client that can utilize the knowledge.
import sbt.internal.langserver.codec.JsonProtocol._
if (broadcastStringMessage || (entry.channelName contains c.name))
c.jsonRpcNotify("window/logMessage", params)
},
c,
toDel
)
}
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEvent(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel)
}
case _ =>
channels foreach {
case c: ConsoleChannel => c.publishEvent(event)
case c: NetworkChannel =>
tryTo(c.publishEvent(event), c, toDel)
}
} }
removeChannels(toDel.toList)
}
private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = {
LogMessageParams(MessageType.fromLevelString(event.level), event.message)
} }
/** /**
* This publishes object events. The type information has been * This publishes object events. The type information has been
* erased because it went through logging. * erased because it went through logging.
*/ */
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = {
import jsonFormat._ for {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty source <- event.channelName
def json: JValue = JObject( channel <- channels.collectFirst {
JField("type", JString(event.contentType)), case c: NetworkChannel if c.name == source => c
Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ }
(event.channelName.toVector map { channelName => } tryTo(_.respond(event))(channel)
JField("channelName", JString(channelName))
}) ++
(event.execId.toVector map { execId =>
JField("execId", JString(execId))
}): _*
)
channels collect {
case c: ConsoleChannel =>
c.publishEvent(json)
case c: NetworkChannel =>
try {
c.publishObjectEvent(event)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
} }
// fanout publishEvent def prompt(event: ConsolePromptEvent): Unit = {
def publishEventMessage(event: EventMessage): Unit = { activePrompt.set(Terminal.systemInIsAttached)
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty channels
.collect { case c: ConsoleChannel => c }
event match { .foreach { _.prompt(event) }
// Special treatment for ConsolePromptEvent since it's hand coded without codec.
case entry: ConsolePromptEvent =>
channels collect {
case c: ConsoleChannel =>
c.publishEventMessage(entry)
activePrompt.set(Terminal.systemInIsAttached)
}
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel)
}
case _ =>
channels collect {
case c: ConsoleChannel => c.publishEventMessage(event)
case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel)
}
}
removeChannels(toDel.toList)
} }
def logMessage(event: LogEvent): Unit = {
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.notifyEvent(event))
}
}
def notifyStatus(event: ExecStatusEvent): Unit = {
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.notifyEvent(event))(channel)
}
private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false)
} }

View File

@ -17,7 +17,6 @@ import org.apache.logging.log4j.core.config.Property
import sbt.util.Level import sbt.util.Level
import sbt.internal.util._ import sbt.internal.util._
import sbt.protocol.LogEvent import sbt.protocol.LogEvent
import sbt.internal.util.codec._
class RelayAppender(name: String) class RelayAppender(name: String)
extends AbstractAppender( extends AbstractAppender(
@ -40,15 +39,12 @@ class RelayAppender(name: String)
} }
} }
def appendLog(level: Level.Value, message: => String): Unit = { def appendLog(level: Level.Value, message: => String): Unit = {
exchange.publishEventMessage(LogEvent(level.toString, message)) exchange.logMessage(LogEvent(level.toString, message))
} }
def appendEvent(event: AnyRef): Unit = def appendEvent(event: AnyRef): Unit =
event match { event match {
case x: StringEvent => { case x: StringEvent => exchange.logMessage(LogEvent(x.message, x.level))
import JsonProtocol._ case x: ObjectEvent[_] => exchange.respondObjectEvent(x)
exchange.publishEvent(x: AbstractEntry)
}
case x: ObjectEvent[_] => exchange.publishObjectEvent(x)
case _ => case _ =>
println(s"appendEvent: ${event.getClass}") println(s"appendEvent: ${event.getClass}")
() ()

View File

@ -40,10 +40,10 @@ private[sbt] object Definition {
def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = { def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = {
for { for {
channel <- StandardMain.exchange.channels.collectFirst { channel <- StandardMain.exchange.channels.collectFirst {
case c if c.name == source.channelName => c case c: NetworkChannel if c.name == source.channelName => c
} }
} { } {
channel.publishEvent(params, Option(execId)) channel.respond(params, Option(execId))
} }
} }

View File

@ -10,7 +10,6 @@ package internal
package server package server
import sjsonnew.JsonFormat import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe.JValue
import sjsonnew.support.scalajson.unsafe.Converter import sjsonnew.support.scalajson.unsafe.Converter
import sbt.protocol.Serialization import sbt.protocol.Serialization
import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q } import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q }
@ -103,7 +102,7 @@ private[sbt] object LanguageServerProtocol {
} }
/** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */ /** Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>. */
private[sbt] trait LanguageServerProtocol extends CommandChannel { self => private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {} lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
@ -117,10 +116,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
protected lazy val callbackImpl: ServerCallback = new ServerCallback { protected lazy val callbackImpl: ServerCallback = new ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
self.jsonRpcRespond(event, execId) self.respondResult(event, execId)
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit = def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
self.jsonRpcRespondError(execId, code, message) self.respondError(code, message, execId)
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit =
self.jsonRpcNotify(method, params) self.jsonRpcNotify(method, params)
@ -162,28 +161,16 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
} }
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = { private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
val m = val response =
JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None) JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None)
val bytes = Serialization.serializeResponseMessage(m) val bytes = Serialization.serializeResponseMessage(response)
publishBytes(bytes) publishBytes(bytes)
} }
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
jsonRpcRespondErrorImpl(execId, code, message, None)
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError[A: JsonFormat](
execId: Option[String],
code: Long,
message: String,
data: A,
): Unit =
jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get))
private[sbt] def jsonRpcRespondError( private[sbt] def jsonRpcRespondError(
execId: Option[String], execId: String,
err: JsonRpcResponseError err: JsonRpcResponseError
): Unit = { ): Unit = {
val m = JsonRpcResponseMessage("2.0", execId, None, Option(err)) val m = JsonRpcResponseMessage("2.0", execId, None, Option(err))
@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
publishBytes(bytes) publishBytes(bytes)
} }
private[this] def jsonRpcRespondErrorImpl(
execId: Option[String],
code: Long,
message: String,
data: Option[JValue],
): Unit = {
val e = JsonRpcResponseError(code, message, data)
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
val bytes = Serialization.serializeResponseMessage(m)
publishBytes(bytes)
}
/** Notify to Language Server's client. */ /** Notify to Language Server's client. */
private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = { private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = {
val m = val m =

View File

@ -12,19 +12,22 @@ package server
import java.net.{ Socket, SocketTimeoutException } import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import sjsonnew._ import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes }
import scala.annotation.tailrec
import sbt.protocol._
import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams }
import sbt.internal.util.{ ObjectEvent, StringEvent }
import sbt.internal.util.complete.Parser
import sbt.internal.util.codec.JValueFormats
import sbt.internal.protocol.{ import sbt.internal.protocol.{
JsonRpcResponseError, JsonRpcNotificationMessage,
JsonRpcRequestMessage, JsonRpcRequestMessage,
JsonRpcNotificationMessage JsonRpcResponseError
} }
import sbt.internal.util.codec.JValueFormats
import sbt.internal.util.complete.Parser
import sbt.internal.util.ObjectEvent
import sbt.protocol._
import sbt.util.Logger import sbt.util.Logger
import sjsonnew._
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Try import scala.util.Try
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -53,6 +56,7 @@ final class NetworkChannel(
private val VsCode = sbt.protocol.Serialization.VsCode private val VsCode = sbt.protocol.Serialization.VsCode
private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8"
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map()
def setContentType(ct: String): Unit = synchronized { _contentType = ct } def setContentType(ct: String): Unit = synchronized { _contentType = ct }
def contentType: String = _contentType def contentType: String = _contentType
@ -176,6 +180,7 @@ final class NetworkChannel(
intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) { intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) {
case (f, i) => f orElse i.onRequest case (f, i) => f orElse i.onRequest
} }
lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] = lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] =
intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) { intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) {
case (f, i) => f orElse i.onNotification case (f, i) => f orElse i.onNotification
@ -186,25 +191,27 @@ final class NetworkChannel(
Serialization.deserializeJsonMessage(chunk) match { Serialization.deserializeJsonMessage(chunk) match {
case Right(req: JsonRpcRequestMessage) => case Right(req: JsonRpcRequestMessage) =>
try { try {
registerRequest(req)
onRequestMessage(req) onRequestMessage(req)
} catch { } catch {
case LangServerError(code, message) => case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message") log.debug(s"sending error: $code: $message")
jsonRpcRespondError(Option(req.id), code, message) respondError(code, message, Some(req.id))
} }
case Right(ntf: JsonRpcNotificationMessage) => case Right(ntf: JsonRpcNotificationMessage) =>
try { try {
onNotification(ntf) onNotification(ntf)
} catch { } catch {
case LangServerError(code, message) => case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message") logMessage("error", s"Error $code while handling notification: $message")
jsonRpcRespondError(None, code, message) // new id?
} }
case Right(msg) => case Right(msg) =>
log.debug(s"Unhandled message: $msg") log.debug(s"Unhandled message: $msg")
case Left(errorDesc) => case Left(errorDesc) =>
val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc logMessage(
jsonRpcRespondError(None, ErrorCodes.ParseError, msg) "error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): $errorDesc"
)
} }
} else { } else {
contentType match { contentType match {
@ -213,13 +220,17 @@ final class NetworkChannel(
.deserializeCommand(chunk) .deserializeCommand(chunk)
.fold( .fold(
errorDesc => errorDesc =>
log.error( logMessage(
"error",
s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
), ),
onCommand onCommand
) )
case _ => case _ =>
log.error(s"Unknown Content-Type: $contentType") logMessage(
"error",
s"Unknown Content-Type: $contentType"
)
} }
} // if-else } // if-else
} }
@ -245,24 +256,48 @@ final class NetworkChannel(
} }
} }
private def registerRequest(request: JsonRpcRequestMessage): Unit = {
this.synchronized {
pendingRequests += (request.id -> request)
()
}
}
private[sbt] def respondError( private[sbt] def respondError(
err: JsonRpcResponseError, err: JsonRpcResponseError,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = this.synchronized {
): Unit = jsonRpcRespondError(execId, err) execId match {
case Some(id) if pendingRequests.contains(id) =>
pendingRequests -= id
jsonRpcRespondError(id, err)
case _ =>
logMessage("error", s"Error ${err.code}: ${err.message}")
}
}
private[sbt] def respondError( private[sbt] def respondError(
code: Long, code: Long,
message: String, message: String,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = {
): Unit = jsonRpcRespondError(execId, code, message) respondError(JsonRpcResponseError(code, message), execId)
}
private[sbt] def respondEvent[A: JsonFormat]( private[sbt] def respondResult[A: JsonFormat](
event: A, event: A,
execId: Option[String], execId: Option[String]
source: Option[CommandSource] ): Unit = this.synchronized {
): Unit = jsonRpcRespond(event, execId) execId match {
case Some(id) if pendingRequests.contains(id) =>
pendingRequests -= id
jsonRpcRespond(event, id)
case _ =>
log.debug(
s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}"
)
}
}
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
if (isLanguageServerProtocol) { if (isLanguageServerProtocol) {
@ -272,19 +307,11 @@ final class NetworkChannel(
} }
} }
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = { def respond[A: JsonFormat](event: A): Unit = respond(event, None)
def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
if (isLanguageServerProtocol) { if (isLanguageServerProtocol) {
event match { respondResult(event, execId)
case entry: StringEvent => logMessage(entry.level, entry.message)
case entry: ExecStatusEvent =>
entry.exitCode match {
case None => jsonRpcRespond(event, entry.execId)
case Some(0) => jsonRpcRespond(event, entry.execId)
case Some(exitCode) =>
jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse(""))
}
case _ => jsonRpcRespond(event, execId)
}
} else { } else {
contentType match { contentType match {
case SbtX1Protocol => case SbtX1Protocol =>
@ -295,7 +322,7 @@ final class NetworkChannel(
} }
} }
def publishEventMessage(event: EventMessage): Unit = { def notifyEvent(event: EventMessage): Unit = {
if (isLanguageServerProtocol) { if (isLanguageServerProtocol) {
event match { event match {
case entry: LogEvent => logMessage(entry.level, entry.message) case entry: LogEvent => logMessage(entry.level, entry.message)
@ -316,22 +343,22 @@ final class NetworkChannel(
* This publishes object events. The type information has been * This publishes object events. The type information has been
* erased because it went through logging. * erased because it went through logging.
*/ */
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = { private[sbt] def respond(event: ObjectEvent[_]): Unit = {
import sjsonnew.shaded.scalajson.ast.unsafe._ import sjsonnew.shaded.scalajson.ast.unsafe._
if (isLanguageServerProtocol) onObjectEvent(event) if (isLanguageServerProtocol) onObjectEvent(event)
else { else {
import jsonFormat._ import jsonFormat._
val json: JValue = JObject( val json: JValue = JObject(
JField("type", JString(event.contentType)), JField("type", JString(event.contentType)),
(Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++ Seq(JField("message", event.json), JField("level", JString(event.level.toString))) ++
(event.channelName.toVector map { channelName => (event.channelName map { channelName =>
JField("channelName", JString(channelName)) JField("channelName", JString(channelName))
}) ++ }) ++
(event.execId.toVector map { execId => (event.execId map { execId =>
JField("execId", JString(execId)) JField("execId", JString(execId))
})): _* }): _*
) )
publishEvent(json) respond(json, event.execId)
} }
} }
@ -358,7 +385,7 @@ final class NetworkChannel(
authenticate(x) match { authenticate(x) match {
case true => case true =>
initialized = true initialized = true
publishEventMessage(ChannelAcceptedEvent(name)) notifyEvent(ChannelAcceptedEvent(name))
case _ => sys.error("invalid token") case _ => sys.error("invalid token")
} }
case None => sys.error("init command but without token.") case None => sys.error("init command but without token.")
@ -383,8 +410,8 @@ final class NetworkChannel(
if (initialized) { if (initialized) {
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
SettingQuery.handleSettingQueryEither(req, structure) match { SettingQuery.handleSettingQueryEither(req, structure) match {
case Right(x) => jsonRpcRespond(x, execId) case Right(x) => respondResult(x, execId)
case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s) case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId)
} }
} else { } else {
log.warn(s"ignoring query $req before initialization") log.warn(s"ignoring query $req before initialization")
@ -400,32 +427,31 @@ final class NetworkChannel(
Parser Parser
.completions(sstate.combinedParser, cp.query, 9) .completions(sstate.combinedParser, cp.query, 9)
.get .get
.map(c => { .flatMap { c =>
if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
else None else None
}) }
.flatten .map(c => cp.query + c)
.map(c => cp.query + c.toString)
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond( respondResult(
CompletionResponse( CompletionResponse(
items = completionItems.toVector items = completionItems.toVector
), ),
execId execId
) )
case _ => case _ =>
jsonRpcRespondError( respondError(
execId,
ErrorCodes.UnknownError, ErrorCodes.UnknownError,
"No available sbt state" "No available sbt state",
execId
) )
} }
} catch { } catch {
case NonFatal(e) => case NonFatal(_) =>
jsonRpcRespondError( respondError(
execId,
ErrorCodes.UnknownError, ErrorCodes.UnknownError,
"Completions request failed" "Completions request failed",
execId
) )
} }
} else { } else {
@ -436,10 +462,10 @@ final class NetworkChannel(
protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = { protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = {
if (initialized) { if (initialized) {
def errorRespond(msg: String) = jsonRpcRespondError( def errorRespond(msg: String) = respondError(
execId,
ErrorCodes.RequestCancelled, ErrorCodes.RequestCancelled,
msg msg,
execId
) )
try { try {
@ -465,11 +491,11 @@ final class NetworkChannel(
runningEngine.cancelAndShutdown() runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond( respondResult(
ExecStatusEvent( ExecStatusEvent(
"Task cancelled", "Task cancelled",
Some(name), Some(name),
Some(runningExecId.toString), Some(runningExecId),
Vector(), Vector(),
None, None,
), ),

View File

@ -12,7 +12,7 @@ package sbt.internal.protocol
*/ */
final class JsonRpcResponseMessage private ( final class JsonRpcResponseMessage private (
jsonrpc: String, jsonrpc: String,
val id: Option[String], val id: String,
val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], val result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue],
val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable { val error: Option[sbt.internal.protocol.JsonRpcResponseError]) extends sbt.internal.protocol.JsonRpcMessage(jsonrpc) with Serializable {
@ -28,17 +28,14 @@ final class JsonRpcResponseMessage private (
override def toString: String = { override def toString: String = {
s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)""" s"""JsonRpcResponseMessage($jsonrpc, $id, ${sbt.protocol.Serialization.compactPrintJsonOpt(result)}, $error)"""
} }
private[this] def copy(jsonrpc: String = jsonrpc, id: Option[String] = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = { private[this] def copy(jsonrpc: String = jsonrpc, id: String = id, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = result, error: Option[sbt.internal.protocol.JsonRpcResponseError] = error): JsonRpcResponseMessage = {
new JsonRpcResponseMessage(jsonrpc, id, result, error) new JsonRpcResponseMessage(jsonrpc, id, result, error)
} }
def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = { def withJsonrpc(jsonrpc: String): JsonRpcResponseMessage = {
copy(jsonrpc = jsonrpc) copy(jsonrpc = jsonrpc)
} }
def withId(id: Option[String]): JsonRpcResponseMessage = {
copy(id = id)
}
def withId(id: String): JsonRpcResponseMessage = { def withId(id: String): JsonRpcResponseMessage = {
copy(id = Option(id)) copy(id = id)
} }
def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = { def withResult(result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseMessage = {
copy(result = result) copy(result = result)
@ -55,6 +52,6 @@ final class JsonRpcResponseMessage private (
} }
object JsonRpcResponseMessage { object JsonRpcResponseMessage {
def apply(jsonrpc: String, id: Option[String], result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error) def apply(jsonrpc: String, id: String, result: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue], error: Option[sbt.internal.protocol.JsonRpcResponseError]): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, result, error)
def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, Option(id), Option(result), Option(error)) def apply(jsonrpc: String, id: String, result: sjsonnew.shaded.scalajson.ast.unsafe.JValue, error: sbt.internal.protocol.JsonRpcResponseError): JsonRpcResponseMessage = new JsonRpcResponseMessage(jsonrpc, id, Option(result), Option(error))
} }

View File

@ -31,7 +31,7 @@ type JsonRpcResponseMessage implements JsonRpcMessage
jsonrpc: String! jsonrpc: String!
## The request id. ## The request id.
id: String id: String!
## The result of a request. This can be omitted in ## The result of a request. This can be omitted in
## the case of an error. ## the case of an error.

View File

@ -32,10 +32,10 @@ trait JsonRpcResponseMessageFormats {
unbuilder.beginObject(js) unbuilder.beginObject(js)
val jsonrpc = unbuilder.readField[String]("jsonrpc") val jsonrpc = unbuilder.readField[String]("jsonrpc")
val id = try { val id = try {
unbuilder.readField[Option[String]]("id") unbuilder.readField[String]("id")
} catch { } catch {
case _: DeserializationException => case _: DeserializationException =>
unbuilder.readField[Option[Long]]("id") map { _.toString } unbuilder.readField[Long]("id").toString
} }
val result = unbuilder.lookupField("result") map { val result = unbuilder.lookupField("result") map {
@ -77,11 +77,9 @@ trait JsonRpcResponseMessageFormats {
} }
builder.beginObject() builder.beginObject()
builder.addField("jsonrpc", obj.jsonrpc) builder.addField("jsonrpc", obj.jsonrpc)
obj.id foreach { id => parseId(obj.id) match {
parseId(id) match { case Right(strId) => builder.addField("id", strId)
case Right(strId) => builder.addField("id", strId) case Left(longId) => builder.addField("id", longId)
case Left(longId) => builder.addField("id", longId)
}
} }
builder.addField("result", obj.result map parseResult) builder.addField("result", obj.result map parseResult)
builder.addField("error", obj.error) builder.addField("error", obj.error)

View File

@ -25,12 +25,25 @@ Global / serverHandlers += ServerHandler({ callback =>
case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" => case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" =>
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name)))) appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
() ()
case r if r.method == "foo/respondTwice" =>
appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name))))
jsonRpcRespond("concurrent response", Some(r.id))
()
case r if r.method == "foo/resultAndError" =>
appendExec(Exec("fooCustomFail", Some(r.id), Some(CommandSource(callback.name))))
jsonRpcRespond("concurrent response", Some(r.id))
()
}, },
PartialFunction.empty {
case r if r.method == "foo/customNotification" =>
jsonRpcRespond("notification result", None)
()
}
) )
}) })
lazy val fooClasspath = taskKey[Unit]("") lazy val fooClasspath = taskKey[Unit]("")
lazy val root = (project in file(".")) lazy val root = (project in file("."))
.settings( .settings(
name := "response", name := "response",
@ -55,5 +68,5 @@ lazy val root = (project in file("."))
val s = state.value val s = state.value
val cp = (Compile / fullClasspath).value val cp = (Compile / fullClasspath).value
s.respondEvent(cp.map(_.data)) s.respondEvent(cp.map(_.data))
}, }
) )

View File

@ -22,15 +22,6 @@ object EventsTest extends AbstractServerTest {
}) })
} }
test("report task failures in case of exceptions") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }"""
)
assert(svr.waitForString(10.seconds) { s =>
(s contains """"id":11""") && (s contains """"error":""")
})
}
test("return error if cancelling non-matched task id") { _ => test("return error if cancelling non-matched task id") { _ =>
svr.sendJsonRpc( svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }"""

View File

@ -64,4 +64,54 @@ object ResponseTest extends AbstractServerTest {
(s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""") (s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""")
}) })
} }
test("respond concurrently from a task and the handler") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": "15", "method": "foo/respondTwice", "params": {} }"""
)
assert {
svr.waitForString(1.seconds) { s =>
println(s)
s contains "\"id\":\"15\""
}
}
assert {
// the second response should never be sent
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"id\":\"15\""
}
}
}
test("concurrent result and error") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": "16", "method": "foo/resultAndError", "params": {} }"""
)
assert {
svr.waitForString(1.seconds) { s =>
println(s)
s contains "\"id\":\"16\""
}
}
assert {
// the second response (result or error) should never be sent
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"id\":\"16\""
}
}
}
test("response to a notification should not be sent") { _ =>
svr.sendJsonRpc(
"""{ "jsonrpc": "2.0", "method": "foo/customNotification", "params": {} }"""
)
assert {
svr.neverReceive(500.milliseconds) { s =>
println(s)
s contains "\"result\":\"notification result\""
}
}
}
} }

View File

@ -8,13 +8,14 @@
package testpkg package testpkg
import java.io.{ File, IOException } import java.io.{ File, IOException }
import java.util.concurrent.TimeoutException
import verify._ import verify._
import sbt.RunFromSourceMain import sbt.RunFromSourceMain
import sbt.io.IO import sbt.io.IO
import sbt.io.syntax._ import sbt.io.syntax._
import sbt.protocol.ClientSocket import sbt.protocol.ClientSocket
import scala.annotation.tailrec
import scala.concurrent._ import scala.concurrent._
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.util.{ Success, Try } import scala.util.{ Success, Try }
@ -150,6 +151,7 @@ case class TestServer(
sbtVersion: String, sbtVersion: String,
classpath: Seq[File] classpath: Seq[File]
) { ) {
import scala.concurrent.ExecutionContext.Implicits._
import TestServer.hostLog import TestServer.hostLog
val readBuffer = new Array[Byte](40960) val readBuffer = new Array[Byte](40960)
@ -183,15 +185,25 @@ case class TestServer(
waitForPortfile(90.seconds) waitForPortfile(90.seconds)
// make connection to the socket described in the portfile // make connection to the socket described in the portfile
val (sk, tkn) = ClientSocket.socket(portfile) var (sk, _) = ClientSocket.socket(portfile)
val out = sk.getOutputStream var out = sk.getOutputStream
val in = sk.getInputStream var in = sk.getInputStream
// initiate handshake // initiate handshake
sendJsonRpc( sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
) )
def resetConnection() = {
sk = ClientSocket.socket(portfile)._1
out = sk.getOutputStream
in = sk.getInputStream
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
)
}
def test(f: TestServer => Future[Assertion]): Future[Assertion] = { def test(f: TestServer => Future[Assertion]): Future[Assertion] = {
f(this) f(this)
} }
@ -230,7 +242,7 @@ case class TestServer(
writeEndLine writeEndLine
} }
def readFrame: Option[String] = { def readFrame: Future[Option[String]] = Future {
def getContentLength: Int = { def getContentLength: Int = {
readLine map { line => readLine map { line =>
line.drop(16).toInt line.drop(16).toInt
@ -244,14 +256,28 @@ case class TestServer(
final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = { final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow val deadline = duration.fromNow
@tailrec
def impl(): Boolean = { def impl(): Boolean = {
if (deadline.isOverdue || !process.isAlive) false try {
else Await.result(readFrame, deadline.timeLeft).fold(false)(f) || impl
readFrame.fold(false)(f) || { } catch {
Thread.sleep(100) case _: TimeoutException =>
impl resetConnection() // create a new connection to invalidate the running readFrame future
} false
}
}
impl()
}
final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow
def impl(): Boolean = {
try {
Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) && impl
} catch {
case _: TimeoutException =>
resetConnection() // create a new connection to invalidate the running readFrame future
true
}
} }
impl() impl()
} }