From df293fbfd552259169860ad6a3e71505cdf6d7ab Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Tue, 12 May 2020 10:38:47 +0200 Subject: [PATCH] prevent multiple response to a single request --- .../server/LanguageServerProtocol.scala | 25 --------- .../sbt/internal/server/NetworkChannel.scala | 54 ++++++++++--------- 2 files changed, 29 insertions(+), 50 deletions(-) diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index cee5d91fc..f5a4ff19a 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -10,7 +10,6 @@ package internal package server import sjsonnew.JsonFormat -import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sjsonnew.support.scalajson.unsafe.Converter import sbt.protocol.Serialization import sbt.protocol.{ CompletionParams => CP, SettingQuery => Q } @@ -170,18 +169,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => } /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError(execId: String, code: Long, message: String): Unit = - jsonRpcRespondErrorImpl(execId, code, message, None) - - /** Respond back to Language Server's client. */ - private[sbt] def jsonRpcRespondError[A: JsonFormat]( - execId: String, - code: Long, - message: String, - data: A, - ): Unit = - jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) - private[sbt] def jsonRpcRespondError( execId: String, err: JsonRpcResponseError @@ -191,18 +178,6 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel => publishBytes(bytes) } - private[this] def jsonRpcRespondErrorImpl( - execId: String, - code: Long, - message: String, - data: Option[JValue], - ): Unit = { - val e = JsonRpcResponseError(code, message, data) - val m = JsonRpcResponseMessage("2.0", execId, None, Option(e)) - val bytes = Serialization.serializeResponseMessage(m) - publishBytes(bytes) - } - /** Notify to Language Server's client. */ private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = { val m = diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index fc423e57e..8e4d093d9 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -12,22 +12,22 @@ package server import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean -import sjsonnew._ - -import scala.annotation.tailrec -import sbt.protocol._ import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } -import sbt.internal.util.{ ObjectEvent, StringEvent } -import sbt.internal.util.complete.Parser -import sbt.internal.util.codec.JValueFormats import sbt.internal.protocol.{ JsonRpcNotificationMessage, JsonRpcRequestMessage, JsonRpcResponseError } +import sbt.internal.util.codec.JValueFormats +import sbt.internal.util.complete.Parser +import sbt.internal.util.{ ObjectEvent, StringEvent } +import sbt.protocol._ import sbt.util.Logger -import sjsonnew.support.scalajson.unsafe.Converter +import sjsonnew._ +import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } +import scala.annotation.tailrec +import scala.collection.mutable import scala.util.Try import scala.util.control.NonFatal @@ -56,8 +56,7 @@ final class NetworkChannel( private val VsCode = sbt.protocol.Serialization.VsCode private val VsCodeOld = "application/vscode-jsonrpc; charset=utf8" private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} - - private var onGoingRequests: Set[JsonRpcRequestMessage] = Set.empty + private val onGoingRequests: mutable.Set[String] = mutable.Set() def setContentType(ct: String): Unit = synchronized { _contentType = ct } def contentType: String = _contentType @@ -197,7 +196,7 @@ final class NetworkChannel( } catch { case LangServerError(code, message) => log.debug(s"sending error: $code: $message") - jsonRpcRespondError(req.id, code, message) + respondError(code, message, Some(req.id)) } case Right(ntf: JsonRpcNotificationMessage) => try { @@ -258,18 +257,24 @@ final class NetworkChannel( } private def registerRequest(request: JsonRpcRequestMessage): Unit = { - onGoingRequests += request + this.synchronized { + onGoingRequests += request.id + () + } } private[sbt] def respondError( err: JsonRpcResponseError, execId: Option[String] - ): Unit = { + ): Unit = this.synchronized { + println(s"respond error for $execId") execId match { - case Some(id) => jsonRpcRespondError(id, err) - case None => logMessage("error", s"Error ${err.code}: ${err.message}") + case Some(id) if onGoingRequests.contains(id) => + onGoingRequests -= id + jsonRpcRespondError(id, err) + case _ => + logMessage("error", s"Error ${err.code}: ${err.message}") } - } private[sbt] def respondError( @@ -277,21 +282,20 @@ final class NetworkChannel( message: String, execId: Option[String] ): Unit = { - execId match { - case Some(id) => jsonRpcRespondError(id, code, message) - case None => logMessage("error", s"Error $code: $message") - } + respondError(JsonRpcResponseError(code, message), execId) } private[sbt] def respondEvent[A: JsonFormat]( event: A, execId: Option[String] - ): Unit = { + ): Unit = this.synchronized { + println(s"respond result for $execId") execId match { - case Some(id) => jsonRpcRespond(event, id) - case None => - val json = Converter.toJsonUnsafe(event) - log.debug(s"unmatched json response: $json") + case Some(id) if onGoingRequests.contains(id) => + onGoingRequests -= id + jsonRpcRespond(event, id) + case _ => + log.debug(s"unmatched json response: ${CompactPrinter(Converter.toJsonUnsafe(event))}") } }