prevent multiple response to a single request

This commit is contained in:
Adrien Piquerez 2020-05-12 10:38:47 +02:00
parent e040eebd21
commit df293fbfd5
2 changed files with 29 additions and 50 deletions

View File

@ -10,7 +10,6 @@ package internal
package server package server
import sjsonnew.JsonFormat import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe.JValue
import sjsonnew.support.scalajson.unsafe.Converter import sjsonnew.support.scalajson.unsafe.Converter
import sbt.protocol.Serialization import sbt.protocol.Serialization
import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q } import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q }
@ -170,18 +169,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
} }
/** Respond back to Language Server's client. */ /** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit =
jsonRpcRespondErrorImpl(execId, code, message, None)
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespondError[A: JsonFormat](
execId: String,
code: Long,
message: String,
data: A,
): Unit =
jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get))
private[sbt] def jsonRpcRespondError( private[sbt] def jsonRpcRespondError(
execId: String, execId: String,
err: JsonRpcResponseError err: JsonRpcResponseError
@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
publishBytes(bytes) publishBytes(bytes)
} }
private[this] def jsonRpcRespondErrorImpl(
execId: String,
code: Long,
message: String,
data: Option[JValue],
): Unit = {
val e = JsonRpcResponseError(code, message, data)
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
val bytes = Serialization.serializeResponseMessage(m)
publishBytes(bytes)
}
/** Notify to Language Server's client. */ /** Notify to Language Server's client. */
private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = { private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = {
val m = val m =

View File

@ -12,22 +12,22 @@ package server
import java.net.{ Socket, SocketTimeoutException } import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import sjsonnew._
import scala.annotation.tailrec
import sbt.protocol._
import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes }
import sbt.internal.util.{ ObjectEvent, StringEvent }
import sbt.internal.util.complete.Parser
import sbt.internal.util.codec.JValueFormats
import sbt.internal.protocol.{ import sbt.internal.protocol.{
JsonRpcNotificationMessage, JsonRpcNotificationMessage,
JsonRpcRequestMessage, JsonRpcRequestMessage,
JsonRpcResponseError JsonRpcResponseError
} }
import sbt.internal.util.codec.JValueFormats
import sbt.internal.util.complete.Parser
import sbt.internal.util.{ ObjectEvent, StringEvent }
import sbt.protocol._
import sbt.util.Logger import sbt.util.Logger
import sjsonnew.support.scalajson.unsafe.Converter import sjsonnew._
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Try import scala.util.Try
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -56,8 +56,7 @@ final class NetworkChannel(
private val VsCode = sbt.protocol.Serialization.VsCode private val VsCode = sbt.protocol.Serialization.VsCode
private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8"
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
private val onGoingRequests: mutable.Set[String] = mutable.Set()
private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty
def setContentType(ct: String): Unit = synchronized { _contentType = ct } def setContentType(ct: String): Unit = synchronized { _contentType = ct }
def contentType: String = _contentType def contentType: String = _contentType
@ -197,7 +196,7 @@ final class NetworkChannel(
} catch { } catch {
case LangServerError(code, message) => case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message") log.debug(s"sending error: $code: $message")
jsonRpcRespondError(req.id, code, message) respondError(code, message, Some(req.id))
} }
case Right(ntf: JsonRpcNotificationMessage) => case Right(ntf: JsonRpcNotificationMessage) =>
try { try {
@ -258,18 +257,24 @@ final class NetworkChannel(
} }
private def registerRequest(request: JsonRpcRequestMessage): Unit = { private def registerRequest(request: JsonRpcRequestMessage): Unit = {
onGoingRequests += request this.synchronized {
onGoingRequests += request.id
()
}
} }
private[sbt] def respondError( private[sbt] def respondError(
err: JsonRpcResponseError, err: JsonRpcResponseError,
execId: Option[String] execId: Option[String]
): Unit = { ): Unit = this.synchronized {
println(s"respond error for $execId")
execId match { execId match {
case Some(id) => jsonRpcRespondError(id, err) case Some(id) if onGoingRequests.contains(id) =>
case None => logMessage("error", s"Error ${err.code}: ${err.message}") onGoingRequests -= id
jsonRpcRespondError(id, err)
case _ =>
logMessage("error", s"Error ${err.code}: ${err.message}")
} }
} }
private[sbt] def respondError( private[sbt] def respondError(
@ -277,21 +282,20 @@ final class NetworkChannel(
message: String, message: String,
execId: Option[String] execId: Option[String]
): Unit = { ): Unit = {
execId match { respondError(JsonRpcResponseError(code, message), execId)
case Some(id) => jsonRpcRespondError(id, code, message)
case None => logMessage("error", s"Error $code: $message")
}
} }
private[sbt] def respondEvent[A: JsonFormat]( private[sbt] def respondEvent[A: JsonFormat](
event: A, event: A,
execId: Option[String] execId: Option[String]
): Unit = { ): Unit = this.synchronized {
println(s"respond result for $execId")
execId match { execId match {
case Some(id) => jsonRpcRespond(event, id) case Some(id) if onGoingRequests.contains(id) =>
case None => onGoingRequests -= id
val json = Converter.toJsonUnsafe(event) jsonRpcRespond(event, id)
log.debug(s"unmatched json response: $json") case _ =>
log.debug(s"unmatched json response: ${CompactPrinter(Converter.toJsonUnsafe(event))}")
} }
} }