Implementing cancellation requests for LSP server

This commit is contained in:
andrea 2018-09-26 10:41:59 +01:00
parent 144a146f07
commit d7c55a3d82
9 changed files with 185 additions and 2 deletions

View File

@ -14,6 +14,7 @@ import sbt.librarymanagement.{ Resolver, UpdateReport }
import scala.concurrent.duration.Duration
import java.io.File
import java.util.concurrent.atomic.AtomicReference
import Def.{ dummyState, ScopedKey, Setting }
import Keys.{
Streams,
@ -377,6 +378,8 @@ object EvaluateTask {
(dummyRoots, roots) :: (Def.dummyStreamsManager, streams) :: (dummyState, state) :: dummies
)
val currentlyRunningEngine: AtomicReference[(State, RunningTaskEngine)] = new AtomicReference()
def runTask[T](
root: Task[T],
state: State,
@ -432,11 +435,15 @@ object EvaluateTask {
shutdown()
}
}
currentlyRunningEngine.set((state, runningEngine))
// Register with our cancel handler we're about to start.
val strat = config.cancelStrategy
val cancelState = strat.onTaskEngineStart(runningEngine)
try run()
finally strat.onTaskEngineFinish(cancelState)
finally {
strat.onTaskEngineFinish(cancelState)
currentlyRunningEngine.set(null)
}
}
private[this] def storeValuesForPrevious(

View File

@ -13,12 +13,14 @@ import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe.JValue
import sjsonnew.support.scalajson.unsafe.Converter
import sbt.protocol.Serialization
import sbt.protocol.{ SettingQuery => Q }
import sbt.protocol.{ SettingQuery => Q, ExecStatusEvent }
import sbt.internal.protocol._
import sbt.internal.protocol.codec._
import sbt.internal.langserver._
import sbt.internal.util.ObjectEvent
import sbt.util.Logger
import scala.util.Try
import scala.util.control.NonFatal
private[sbt] final case class LangServerError(code: Long, message: String)
extends Throwable(message)
@ -80,6 +82,59 @@ private[sbt] object LanguageServerProtocol {
import sbt.protocol.codec.JsonProtocol._
val param = Converter.fromJson[Q](json(r)).get
onSettingQuery(Option(r.id), param)
case r: JsonRpcRequestMessage if r.method == "sbt/cancelRequest" =>
val param = Converter.fromJson[CancelRequestParams](json(r)).get
def errorRespond(msg: String) = jsonRpcRespondError(
Some(r.id),
ErrorCodes.RequestCancelled,
msg
)
try {
Option(EvaluateTask.currentlyRunningEngine.get) match {
case Some((state, runningEngine)) =>
val execId: String = state.currentCommand.map(_.execId).flatten.getOrElse("")
def checkId(): Boolean = {
if (execId.startsWith("\u2668")) {
(
Try { param.id.toLong }.toOption,
Try { execId.substring(1).toLong }.toOption
) match {
case (Some(id), Some(eid)) => id == eid
case _ => false
}
} else execId == param.id
}
// direct comparison on strings and
// remove hotspring unicode added character for numbers
if (checkId) {
runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._
jsonRpcRespond(
ExecStatusEvent(
"Task cancelled",
Some(name),
Some(execId.toString),
Vector(),
None,
),
Option(r.id)
)
} else {
errorRespond("Task ID not matched")
}
case None =>
errorRespond("No tasks under execution")
}
} catch {
case NonFatal(e) =>
errorRespond("Cancel request failed")
}
}
}, {
case n: JsonRpcNotificationMessage if n.method == "textDocument/didSave" =>

View File

@ -0,0 +1,33 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.langserver
/** Id for a cancel request */
final class CancelRequestParams private (
val id: String) extends Serializable {
override def equals(o: Any): Boolean = o match {
case x: CancelRequestParams => (this.id == x.id)
case _ => false
}
override def hashCode: Int = {
37 * (37 * (17 + "sbt.internal.langserver.CancelRequestParams".##) + id.##)
}
override def toString: String = {
"CancelRequestParams(" + id + ")"
}
private[this] def copy(id: String = id): CancelRequestParams = {
new CancelRequestParams(id)
}
def withId(id: String): CancelRequestParams = {
copy(id = id)
}
}
object CancelRequestParams {
def apply(id: String): CancelRequestParams = new CancelRequestParams(id)
}

View File

@ -0,0 +1,27 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.langserver.codec
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait CancelRequestParamsFormats { self: sjsonnew.BasicJsonProtocol =>
implicit lazy val CancelRequestParamsFormat: JsonFormat[sbt.internal.langserver.CancelRequestParams] = new JsonFormat[sbt.internal.langserver.CancelRequestParams] {
override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.langserver.CancelRequestParams = {
jsOpt match {
case Some(js) =>
unbuilder.beginObject(js)
val id = unbuilder.readField[String]("id")
unbuilder.endObject()
sbt.internal.langserver.CancelRequestParams(id)
case None =>
deserializationError("Expected JsObject but found None")
}
}
override def write[J](obj: sbt.internal.langserver.CancelRequestParams, builder: Builder[J]): Unit = {
builder.beginObject()
builder.addField("id", obj.id)
builder.endObject()
}
}
}

View File

@ -19,6 +19,7 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol
with sbt.internal.langserver.codec.LogMessageParamsFormats
with sbt.internal.langserver.codec.PublishDiagnosticsParamsFormats
with sbt.internal.langserver.codec.SbtExecParamsFormats
with sbt.internal.langserver.codec.CancelRequestParamsFormats
with sbt.internal.langserver.codec.TextDocumentIdentifierFormats
with sbt.internal.langserver.codec.TextDocumentPositionParamsFormats
object JsonProtocol extends JsonProtocol

View File

@ -131,6 +131,11 @@ type SbtExecParams {
commandLine: String!
}
## Id for a cancel request
type CancelRequestParams {
id: String!
}
## Goto definition params model
type TextDocumentPositionParams {
## The text document.

View File

@ -0,0 +1,8 @@
object Main extends App {
while (true) {
Thread.sleep(1000)
}
}

View File

@ -1,2 +1,4 @@
commands += Command.command("hello") { state => ??? }
Global / cancelable := true

View File

@ -47,6 +47,51 @@ class ServerSpec extends AsyncFreeSpec with Matchers {
(s contains """"id":11""") && (s contains """"error":""")
})
}
"return error if cancelling non-matched task id" in withTestServer("events") { p =>
p.writeLine(
"""{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }"""
)
p.writeLine(
"""{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "55" } }"""
)
assert(p.waitForString(20) { s =>
(s contains """"error":{"code":-32800""")
})
}
"cancel on-going task with numeric id" in withTestServer("events") { p =>
p.writeLine(
"""{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }"""
)
Thread.sleep(1000)
p.writeLine(
"""{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "12" } }"""
)
assert(p.waitForString(30) { s =>
s contains """"result":{"status":"Task cancelled""""
})
}
"cancel on-going task with string id" in withTestServer("events") { p =>
p.writeLine(
"""{ "jsonrpc": "2.0", "id": "foo", "method": "sbt/exec", "params": { "commandLine": "run" } }"""
)
Thread.sleep(1000)
p.writeLine(
"""{ "jsonrpc": "2.0", "id": "bar", "method": "sbt/cancelRequest", "params": { "id": "foo" } }"""
)
assert(p.waitForString(30) { s =>
s contains """"result":{"status":"Task cancelled""""
})
}
}
}