Add ctrl+c support thin client

When the user presses ctrl+c, we want to cancel any running tasks that
were initiated by that client. This is a bit tricky because we may not
be sure what is running if the client is in interactive mode. To work
around this, we send a cancellation request with the special id
__CancelAll. When the NetworkChannel receives this request, it cancels
the active task if was initiated by the client that sent the
cancellation request. The result it returns to the client indicates if
there were any tasks to be cancelled. If there were and the client was
in interactive mode, we do not exit. Otherwise we exit.
This commit is contained in:
Ethan Atkins 2020-06-24 16:38:18 -07:00
parent d0842711e4
commit 43e4fa85e3
3 changed files with 65 additions and 19 deletions

View File

@ -21,7 +21,7 @@ import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }
import sbt.internal.client.NetworkClient.Arguments import sbt.internal.client.NetworkClient.Arguments
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams } import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
import sbt.internal.protocol._ import sbt.internal.protocol._
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util } import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Signals, Terminal, Util }
import sbt.io.IO import sbt.io.IO
import sbt.io.syntax._ import sbt.io.syntax._
import sbt.protocol._ import sbt.protocol._
@ -36,7 +36,9 @@ import scala.concurrent.duration._
import scala.util.control.NonFatal import scala.util.control.NonFatal
import scala.util.{ Failure, Properties, Success } import scala.util.{ Failure, Properties, Success }
import Serialization.{ import Serialization.{
CancelAll,
attach, attach,
cancelRequest,
systemIn, systemIn,
systemOut, systemOut,
terminalCapabilities, terminalCapabilities,
@ -82,6 +84,7 @@ class NetworkClient(
private val lock: AnyRef = new AnyRef {} private val lock: AnyRef = new AnyRef {}
private val running = new AtomicBoolean(true) private val running = new AtomicBoolean(true)
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)] private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]]
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit] private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
private val attached = new AtomicBoolean(false) private val attached = new AtomicBoolean(false)
private val attachUUID = new AtomicReference[String](null) private val attachUUID = new AtomicReference[String](null)
@ -247,6 +250,10 @@ class NetworkClient(
} }
q.offer(exitCode) q.offer(exitCode)
} }
pendingCancellations.remove(msg.id) match {
case null =>
case q => q.offer(msg.toString.contains("Task cancelled"))
}
msg.id match { msg.id match {
case execId => case execId =>
if (attachUUID.get == msg.id) { if (attachUUID.get == msg.id) {
@ -390,24 +397,52 @@ class NetworkClient(
() ()
} }
def run(): Int = { private[this] val contHandler: () => Unit = () => {
interactiveThread.set(Thread.currentThread) if (Terminal.console.getLastLine.nonEmpty)
val cleaned = arguments.commandArguments printStream.print(ConsoleAppender.DeleteLine + Terminal.console.getLastLine.get)
val userCommands = cleaned.takeWhile(_ != "exit")
val interactive = cleaned.isEmpty
val exit = cleaned.nonEmpty && userCommands.isEmpty
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
if (interactive) {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
if (exitClean.get) 0 else 1
} else if (exit) {
0
} else {
batchMode.set(true)
batchExecute(userCommands.toList)
}
} }
private[this] def withSignalHandler[R](handler: () => Unit, sig: String)(f: => R): R = {
val registration = Signals.register(handler, sig)
try f
finally registration.remove()
}
private[this] val cancelled = new AtomicBoolean(false)
def run(): Int =
withSignalHandler(contHandler, Signals.CONT) {
interactiveThread.set(Thread.currentThread)
val cleaned = arguments.commandArguments
val userCommands = cleaned.takeWhile(_ != "exit")
val interactive = cleaned.isEmpty
val exit = cleaned.nonEmpty && userCommands.isEmpty
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
val handler: () => Unit = () => {
def exitAbruptly() = {
exitClean.set(false)
close()
}
if (cancelled.compareAndSet(false, true)) {
val cancelledTasks = {
val queue = sendCancelAllCommand()
Option(queue.poll(1, TimeUnit.SECONDS)).getOrElse(true)
}
if ((!interactive && pendingResults.isEmpty) || !cancelledTasks) exitAbruptly()
else cancelled.set(false)
} else exitAbruptly() // handles double ctrl+c to force a shutdown
}
withSignalHandler(handler, Signals.INT) {
if (interactive) {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
if (exitClean.get) 0 else 1
} else if (exit) {
0
} else {
batchMode.set(true)
batchExecute(userCommands.toList)
}
}
}
def batchExecute(userCommands: List[String]): Int = { def batchExecute(userCommands: List[String]): Int = {
val cmd = userCommands mkString " " val cmd = userCommands mkString " "
@ -440,6 +475,13 @@ class NetworkClient(
queue queue
} }
def sendCancelAllCommand(): LinkedBlockingQueue[Boolean] = {
val queue = new LinkedBlockingQueue[Boolean]
val execId = sendJson(cancelRequest, s"""{"id":"$CancelAll"}""")
pendingCancellations.put(execId, queue)
queue
}
def sendCommand(command: CommandMessage): Unit = { def sendCommand(command: CommandMessage): Unit = {
try { try {
val s = Serialization.serializeCommandAsJsonMessage(command) val s = Serialization.serializeCommandAsJsonMessage(command)
@ -538,7 +580,9 @@ class NetworkClient(
s"Total time: $totalString, completed $nowString" s"Total time: $totalString, completed $nowString"
} }
} }
object NetworkClient { object NetworkClient {
private[sbt] val CancelAll = "__CancelAll"
private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = { private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
val appender = ConsoleAppender("thin", ConsoleOut.printStreamOut(printStream)) val appender = ConsoleAppender("thin", ConsoleOut.printStreamOut(printStream))
new ConsoleInterface { new ConsoleInterface {

View File

@ -451,7 +451,8 @@ final class NetworkChannel(
// direct comparison on strings and // direct comparison on strings and
// remove hotspring unicode added character for numbers // remove hotspring unicode added character for numbers
if (checkId) { if (checkId || (crp.id == Serialization.CancelAll &&
StandardMain.exchange.currentExec.exists(_.source.exists(_.channelName == name)))) {
runningEngine.cancelAndShutdown() runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._ import sbt.protocol.codec.JsonProtocol._

View File

@ -33,6 +33,7 @@ object Serialization {
val attach = "sbt/attach" val attach = "sbt/attach"
val attachResponse = "sbt/attachResponse" val attachResponse = "sbt/attachResponse"
val cancelRequest = "sbt/cancelRequest" val cancelRequest = "sbt/cancelRequest"
val CancelAll = "__CancelAll"
@deprecated("unused", since = "1.4.0") @deprecated("unused", since = "1.4.0")
def serializeEvent[A: JsonFormat](event: A): Array[Byte] = { def serializeEvent[A: JsonFormat](event: A): Array[Byte] = {