diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index c155bdf5b..e75a8d724 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -14,6 +14,7 @@ import sbt.librarymanagement.{ Resolver, UpdateReport } import scala.concurrent.duration.Duration import java.io.File +import java.util.concurrent.atomic.AtomicReference import Def.{ dummyState, ScopedKey, Setting } import Keys.{ Streams, @@ -377,6 +378,8 @@ object EvaluateTask { (dummyRoots, roots) :: (Def.dummyStreamsManager, streams) :: (dummyState, state) :: dummies ) + val currentlyRunningEngine: AtomicReference[(State, RunningTaskEngine)] = new AtomicReference() + def runTask[T]( root: Task[T], state: State, @@ -432,11 +435,15 @@ object EvaluateTask { shutdown() } } + currentlyRunningEngine.set((state, runningEngine)) // Register with our cancel handler we're about to start. val strat = config.cancelStrategy val cancelState = strat.onTaskEngineStart(runningEngine) try run() - finally strat.onTaskEngineFinish(cancelState) + finally { + strat.onTaskEngineFinish(cancelState) + currentlyRunningEngine.set(null) + } } private[this] def storeValuesForPrevious( diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 6293af0e8..ce035e158 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -13,12 +13,14 @@ import sjsonnew.JsonFormat import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sjsonnew.support.scalajson.unsafe.Converter import sbt.protocol.Serialization -import sbt.protocol.{ SettingQuery => Q } +import sbt.protocol.{ SettingQuery => Q, ExecStatusEvent } import sbt.internal.protocol._ import sbt.internal.protocol.codec._ import sbt.internal.langserver._ import sbt.internal.util.ObjectEvent import sbt.util.Logger +import scala.util.Try +import scala.util.control.NonFatal private[sbt] final case class LangServerError(code: Long, message: String) extends Throwable(message) @@ -80,6 +82,59 @@ private[sbt] object LanguageServerProtocol { import sbt.protocol.codec.JsonProtocol._ val param = Converter.fromJson[Q](json(r)).get onSettingQuery(Option(r.id), param) + case r: JsonRpcRequestMessage if r.method == "sbt/cancelRequest" => + val param = Converter.fromJson[CancelRequestParams](json(r)).get + + def errorRespond(msg: String) = jsonRpcRespondError( + Some(r.id), + ErrorCodes.RequestCancelled, + msg + ) + + try { + Option(EvaluateTask.currentlyRunningEngine.get) match { + case Some((state, runningEngine)) => + val execId: String = state.currentCommand.map(_.execId).flatten.getOrElse("") + + def checkId(): Boolean = { + if (execId.startsWith("\u2668")) { + ( + Try { param.id.toLong }.toOption, + Try { execId.substring(1).toLong }.toOption + ) match { + case (Some(id), Some(eid)) => id == eid + case _ => false + } + } else execId == param.id + } + + // direct comparison on strings and + // remove hotspring unicode added character for numbers + if (checkId) { + runningEngine.cancelAndShutdown() + + import sbt.protocol.codec.JsonProtocol._ + jsonRpcRespond( + ExecStatusEvent( + "Task cancelled", + Some(name), + Some(execId.toString), + Vector(), + None, + ), + Option(r.id) + ) + } else { + errorRespond("Task ID not matched") + } + + case None => + errorRespond("No tasks under execution") + } + } catch { + case NonFatal(e) => + errorRespond("Cancel request failed") + } } }, { case n: JsonRpcNotificationMessage if n.method == "textDocument/didSave" => diff --git a/protocol/src/main/contraband-scala/sbt/internal/langserver/CancelRequestParams.scala b/protocol/src/main/contraband-scala/sbt/internal/langserver/CancelRequestParams.scala new file mode 100644 index 000000000..56acbf231 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/langserver/CancelRequestParams.scala @@ -0,0 +1,33 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.langserver +/** Id for a cancel request */ +final class CancelRequestParams private ( + val id: String) extends Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: CancelRequestParams => (this.id == x.id) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (17 + "sbt.internal.langserver.CancelRequestParams".##) + id.##) + } + override def toString: String = { + "CancelRequestParams(" + id + ")" + } + private[this] def copy(id: String = id): CancelRequestParams = { + new CancelRequestParams(id) + } + def withId(id: String): CancelRequestParams = { + copy(id = id) + } +} +object CancelRequestParams { + + def apply(id: String): CancelRequestParams = new CancelRequestParams(id) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/CancelRequestParamsFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/CancelRequestParamsFormats.scala new file mode 100644 index 000000000..582fce9b9 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/CancelRequestParamsFormats.scala @@ -0,0 +1,27 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.langserver.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait CancelRequestParamsFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val CancelRequestParamsFormat: JsonFormat[sbt.internal.langserver.CancelRequestParams] = new JsonFormat[sbt.internal.langserver.CancelRequestParams] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.langserver.CancelRequestParams = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val id = unbuilder.readField[String]("id") + unbuilder.endObject() + sbt.internal.langserver.CancelRequestParams(id) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.langserver.CancelRequestParams, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("id", obj.id) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/JsonProtocol.scala index 7898d1d72..512b67fab 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/langserver/codec/JsonProtocol.scala @@ -19,6 +19,7 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.internal.langserver.codec.LogMessageParamsFormats with sbt.internal.langserver.codec.PublishDiagnosticsParamsFormats with sbt.internal.langserver.codec.SbtExecParamsFormats + with sbt.internal.langserver.codec.CancelRequestParamsFormats with sbt.internal.langserver.codec.TextDocumentIdentifierFormats with sbt.internal.langserver.codec.TextDocumentPositionParamsFormats object JsonProtocol extends JsonProtocol \ No newline at end of file diff --git a/protocol/src/main/contraband/lsp.contra b/protocol/src/main/contraband/lsp.contra index a46d27d1d..314c2a163 100644 --- a/protocol/src/main/contraband/lsp.contra +++ b/protocol/src/main/contraband/lsp.contra @@ -131,6 +131,11 @@ type SbtExecParams { commandLine: String! } +## Id for a cancel request +type CancelRequestParams { + id: String! +} + ## Goto definition params model type TextDocumentPositionParams { ## The text document. diff --git a/sbt/src/server-test/events/Main.scala b/sbt/src/server-test/events/Main.scala new file mode 100644 index 000000000..dca34c07f --- /dev/null +++ b/sbt/src/server-test/events/Main.scala @@ -0,0 +1,8 @@ + +object Main extends App { + + while (true) { + Thread.sleep(1000) + } + +} diff --git a/sbt/src/server-test/events/build.sbt b/sbt/src/server-test/events/build.sbt index 5f9436df3..b31ca73aa 100644 --- a/sbt/src/server-test/events/build.sbt +++ b/sbt/src/server-test/events/build.sbt @@ -1,2 +1,4 @@ commands += Command.command("hello") { state => ??? } + +Global / cancelable := true diff --git a/sbt/src/test/scala/testpkg/ServerSpec.scala b/sbt/src/test/scala/testpkg/ServerSpec.scala index ab6c17264..e13532489 100644 --- a/sbt/src/test/scala/testpkg/ServerSpec.scala +++ b/sbt/src/test/scala/testpkg/ServerSpec.scala @@ -47,6 +47,51 @@ class ServerSpec extends AsyncFreeSpec with Matchers { (s contains """"id":11""") && (s contains """"error":""") }) } + + "return error if cancelling non-matched task id" in withTestServer("events") { p => + p.writeLine( + """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + ) + p.writeLine( + """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "55" } }""" + ) + + assert(p.waitForString(20) { s => + (s contains """"error":{"code":-32800""") + }) + } + + "cancel on-going task with numeric id" in withTestServer("events") { p => + p.writeLine( + """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + ) + + Thread.sleep(1000) + + p.writeLine( + """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "12" } }""" + ) + + assert(p.waitForString(30) { s => + s contains """"result":{"status":"Task cancelled"""" + }) + } + + "cancel on-going task with string id" in withTestServer("events") { p => + p.writeLine( + """{ "jsonrpc": "2.0", "id": "foo", "method": "sbt/exec", "params": { "commandLine": "run" } }""" + ) + + Thread.sleep(1000) + + p.writeLine( + """{ "jsonrpc": "2.0", "id": "bar", "method": "sbt/cancelRequest", "params": { "id": "foo" } }""" + ) + + assert(p.waitForString(30) { s => + s contains """"result":{"status":"Task cancelled"""" + }) + } } }