From 02b19752eb0ba95a4945bb9cfb69acd848e62ff2 Mon Sep 17 00:00:00 2001 From: andrea Date: Thu, 11 Oct 2018 14:03:41 +0100 Subject: [PATCH] refactoring of server cancellation request --- .../sbt/internal/server/ServerHandler.scala | 2 + main/src/main/scala/sbt/EvaluateTask.scala | 5 +- .../server/LanguageServerProtocol.scala | 63 +++---------------- .../sbt/internal/server/NetworkChannel.scala | 61 +++++++++++++++++- 4 files changed, 73 insertions(+), 58 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala b/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala index befc7fabf..2edf6c357 100644 --- a/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala +++ b/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala @@ -13,6 +13,7 @@ import sjsonnew.JsonFormat import sbt.internal.protocol._ import sbt.util.Logger import sbt.protocol.{ SettingQuery => Q, CompletionParams => CP } +import sbt.internal.langserver.{ CancelRequestParams => CRP } /** * ServerHandler allows plugins to extend sbt server. @@ -71,4 +72,5 @@ trait ServerCallback { private[sbt] def setInitialized(value: Boolean): Unit private[sbt] def onSettingQuery(execId: Option[String], req: Q): Unit private[sbt] def onCompletionRequest(execId: Option[String], cp: CP): Unit + private[sbt] def onCancellationRequest(execId: Option[String], crp: CRP): Unit } diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index d0a630e1c..1f861067a 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -384,7 +384,8 @@ object EvaluateTask { ) val lastEvaluatedState: AtomicReference[SafeState] = new AtomicReference() - val currentlyRunningEngine: AtomicReference[(State, RunningTaskEngine)] = new AtomicReference() + val currentlyRunningEngine: AtomicReference[(SafeState, RunningTaskEngine)] = + new AtomicReference() /** * The main method for the task engine. @@ -445,7 +446,7 @@ object EvaluateTask { shutdown() } } - currentlyRunningEngine.set((state, runningEngine)) + currentlyRunningEngine.set((SafeState(state), runningEngine)) // Register with our cancel handler we're about to start. val strat = config.cancelStrategy val cancelState = strat.onTaskEngineStart(runningEngine) diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 80948027e..a23435173 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -13,14 +13,13 @@ import sjsonnew.JsonFormat import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sjsonnew.support.scalajson.unsafe.Converter import sbt.protocol.Serialization -import sbt.protocol.{ SettingQuery => Q, ExecStatusEvent, CompletionParams => CP } +import sbt.protocol.{ SettingQuery => Q, CompletionParams => CP } +import sbt.internal.langserver.{ CancelRequestParams => CRP } import sbt.internal.protocol._ import sbt.internal.protocol.codec._ import sbt.internal.langserver._ import sbt.internal.util.ObjectEvent import sbt.util.Logger -import scala.util.Try -import scala.util.control.NonFatal private[sbt] final case class LangServerError(code: Long, message: String) extends Throwable(message) @@ -83,58 +82,9 @@ private[sbt] object LanguageServerProtocol { val param = Converter.fromJson[Q](json(r)).get onSettingQuery(Option(r.id), param) case r: JsonRpcRequestMessage if r.method == "sbt/cancelRequest" => - val param = Converter.fromJson[CancelRequestParams](json(r)).get - - def errorRespond(msg: String) = jsonRpcRespondError( - Some(r.id), - ErrorCodes.RequestCancelled, - msg - ) - - try { - Option(EvaluateTask.currentlyRunningEngine.get) match { - case Some((state, runningEngine)) => - val execId: String = state.currentCommand.map(_.execId).flatten.getOrElse("") - - def checkId(): Boolean = { - if (execId.startsWith("\u2668")) { - ( - Try { param.id.toLong }.toOption, - Try { execId.substring(1).toLong }.toOption - ) match { - case (Some(id), Some(eid)) => id == eid - case _ => false - } - } else execId == param.id - } - - // direct comparison on strings and - // remove hotspring unicode added character for numbers - if (checkId) { - runningEngine.cancelAndShutdown() - - import sbt.protocol.codec.JsonProtocol._ - jsonRpcRespond( - ExecStatusEvent( - "Task cancelled", - Some(name), - Some(execId.toString), - Vector(), - None, - ), - Option(r.id) - ) - } else { - errorRespond("Task ID not matched") - } - - case None => - errorRespond("No tasks under execution") - } - } catch { - case NonFatal(e) => - errorRespond("Cancel request failed") - } + import sbt.protocol.codec.JsonProtocol._ + val param = Converter.fromJson[CRP](json(r)).get + onCancellationRequest(Option(r.id), param) case r: JsonRpcRequestMessage if r.method == "sbt/completion" => import sbt.protocol.codec.JsonProtocol._ val param = Converter.fromJson[CP](json(r)).get @@ -160,6 +110,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => protected def log: Logger protected def onSettingQuery(execId: Option[String], req: Q): Unit protected def onCompletionRequest(execId: Option[String], cp: CP): Unit + protected def onCancellationRequest(execId: Option[String], crp: CRP): Unit protected lazy val callbackImpl: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = @@ -181,6 +132,8 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => self.onSettingQuery(execId, req) private[sbt] def onCompletionRequest(execId: Option[String], cp: CP): Unit = self.onCompletionRequest(execId, cp) + private[sbt] def onCancellationRequest(execId: Option[String], crp: CancelRequestParams): Unit = + self.onCancellationRequest(execId, crp) } /** diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 2cd81c050..34aac577d 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -15,12 +15,13 @@ import java.util.concurrent.atomic.AtomicBoolean import sjsonnew._ import scala.annotation.tailrec import sbt.protocol._ -import sbt.internal.langserver.ErrorCodes +import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams } import sbt.internal.util.{ ObjectEvent, StringEvent } import sbt.internal.util.complete.Parser import sbt.internal.util.codec.JValueFormats import sbt.internal.protocol.{ JsonRpcRequestMessage, JsonRpcNotificationMessage } import sbt.util.Logger +import scala.util.Try import scala.util.control.NonFatal final class NetworkChannel( @@ -408,6 +409,64 @@ final class NetworkChannel( } } + protected def onCancellationRequest(execId: Option[String], crp: CancelRequestParams) = { + if (initialized) { + + def errorRespond(msg: String) = jsonRpcRespondError( + execId, + ErrorCodes.RequestCancelled, + msg + ) + + try { + Option(EvaluateTask.currentlyRunningEngine.get) match { + case Some((state, runningEngine)) => + val runningExecId = state.currentExecId.getOrElse("") + + def checkId(): Boolean = { + if (runningExecId.startsWith("\u2668")) { + ( + Try { crp.id.toLong }.toOption, + Try { runningExecId.substring(1).toLong }.toOption + ) match { + case (Some(id), Some(eid)) => id == eid + case _ => false + } + } else runningExecId == crp.id + } + + // direct comparison on strings and + // remove hotspring unicode added character for numbers + if (checkId) { + runningEngine.cancelAndShutdown() + + import sbt.protocol.codec.JsonProtocol._ + jsonRpcRespond( + ExecStatusEvent( + "Task cancelled", + Some(name), + Some(runningExecId.toString), + Vector(), + None, + ), + execId + ) + } else { + errorRespond("Task ID not matched") + } + + case None => + errorRespond("No tasks under execution") + } + } catch { + case NonFatal(e) => + errorRespond("Cancel request failed") + } + } else { + log.warn(s"ignoring cancellation request $crp before initialization") + } + } + def shutdown(): Unit = { log.info("Shutting down client connection") running.set(false)