Add ctrl+c support thin client

When the user presses ctrl+c, we want to cancel any running tasks that
were initiated by that client. This is a bit tricky because we may not
be sure what is running if the client is in interactive mode. To work
around this, we send a cancellation request with the special id
__CancelAll. When the NetworkChannel receives this request, it cancels
the active task if was initiated by the client that sent the
cancellation request. The result it returns to the client indicates if
there were any tasks to be cancelled. If there were and the client was
in interactive mode, we do not exit. Otherwise we exit.
This commit is contained in:
Ethan Atkins 2020-06-24 16:38:18 -07:00
parent d0842711e4
commit 43e4fa85e3
3 changed files with 65 additions and 19 deletions

View File

@ -21,7 +21,7 @@ import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }
import sbt.internal.client.NetworkClient.Arguments
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
import sbt.internal.protocol._
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util }
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Signals, Terminal, Util }
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol._
@ -36,7 +36,9 @@ import scala.concurrent.duration._
import scala.util.control.NonFatal
import scala.util.{ Failure, Properties, Success }
import Serialization.{
CancelAll,
attach,
cancelRequest,
systemIn,
systemOut,
terminalCapabilities,
@ -82,6 +84,7 @@ class NetworkClient(
private val lock: AnyRef = new AnyRef {}
private val running = new AtomicBoolean(true)
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]]
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
private val attached = new AtomicBoolean(false)
private val attachUUID = new AtomicReference[String](null)
@ -247,6 +250,10 @@ class NetworkClient(
}
q.offer(exitCode)
}
pendingCancellations.remove(msg.id) match {
case null =>
case q => q.offer(msg.toString.contains("Task cancelled"))
}
msg.id match {
case execId =>
if (attachUUID.get == msg.id) {
@ -390,24 +397,52 @@ class NetworkClient(
()
}
def run(): Int = {
interactiveThread.set(Thread.currentThread)
val cleaned = arguments.commandArguments
val userCommands = cleaned.takeWhile(_ != "exit")
val interactive = cleaned.isEmpty
val exit = cleaned.nonEmpty && userCommands.isEmpty
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
if (interactive) {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
if (exitClean.get) 0 else 1
} else if (exit) {
0
} else {
batchMode.set(true)
batchExecute(userCommands.toList)
}
private[this] val contHandler: () => Unit = () => {
if (Terminal.console.getLastLine.nonEmpty)
printStream.print(ConsoleAppender.DeleteLine + Terminal.console.getLastLine.get)
}
private[this] def withSignalHandler[R](handler: () => Unit, sig: String)(f: => R): R = {
val registration = Signals.register(handler, sig)
try f
finally registration.remove()
}
private[this] val cancelled = new AtomicBoolean(false)
def run(): Int =
withSignalHandler(contHandler, Signals.CONT) {
interactiveThread.set(Thread.currentThread)
val cleaned = arguments.commandArguments
val userCommands = cleaned.takeWhile(_ != "exit")
val interactive = cleaned.isEmpty
val exit = cleaned.nonEmpty && userCommands.isEmpty
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
val handler: () => Unit = () => {
def exitAbruptly() = {
exitClean.set(false)
close()
}
if (cancelled.compareAndSet(false, true)) {
val cancelledTasks = {
val queue = sendCancelAllCommand()
Option(queue.poll(1, TimeUnit.SECONDS)).getOrElse(true)
}
if ((!interactive && pendingResults.isEmpty) || !cancelledTasks) exitAbruptly()
else cancelled.set(false)
} else exitAbruptly() // handles double ctrl+c to force a shutdown
}
withSignalHandler(handler, Signals.INT) {
if (interactive) {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
if (exitClean.get) 0 else 1
} else if (exit) {
0
} else {
batchMode.set(true)
batchExecute(userCommands.toList)
}
}
}
def batchExecute(userCommands: List[String]): Int = {
val cmd = userCommands mkString " "
@ -440,6 +475,13 @@ class NetworkClient(
queue
}
def sendCancelAllCommand(): LinkedBlockingQueue[Boolean] = {
val queue = new LinkedBlockingQueue[Boolean]
val execId = sendJson(cancelRequest, s"""{"id":"$CancelAll"}""")
pendingCancellations.put(execId, queue)
queue
}
def sendCommand(command: CommandMessage): Unit = {
try {
val s = Serialization.serializeCommandAsJsonMessage(command)
@ -538,7 +580,9 @@ class NetworkClient(
s"Total time: $totalString, completed $nowString"
}
}
object NetworkClient {
private[sbt] val CancelAll = "__CancelAll"
private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
val appender = ConsoleAppender("thin", ConsoleOut.printStreamOut(printStream))
new ConsoleInterface {

View File

@ -451,7 +451,8 @@ final class NetworkChannel(
// direct comparison on strings and
// remove hotspring unicode added character for numbers
if (checkId) {
if (checkId || (crp.id == Serialization.CancelAll &&
StandardMain.exchange.currentExec.exists(_.source.exists(_.channelName == name)))) {
runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._

View File

@ -33,6 +33,7 @@ object Serialization {
val attach = "sbt/attach"
val attachResponse = "sbt/attachResponse"
val cancelRequest = "sbt/cancelRequest"
val CancelAll = "__CancelAll"
@deprecated("unused", since = "1.4.0")
def serializeEvent[A: JsonFormat](event: A): Array[Byte] = {