mirror of https://github.com/sbt/sbt.git
Add ctrl+c support thin client
When the user presses ctrl+c, we want to cancel any running tasks that were initiated by that client. This is a bit tricky because we may not be sure what is running if the client is in interactive mode. To work around this, we send a cancellation request with the special id __CancelAll. When the NetworkChannel receives this request, it cancels the active task if was initiated by the client that sent the cancellation request. The result it returns to the client indicates if there were any tasks to be cancelled. If there were and the client was in interactive mode, we do not exit. Otherwise we exit.
This commit is contained in:
parent
d0842711e4
commit
43e4fa85e3
|
|
@ -21,7 +21,7 @@ import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }
|
|||
import sbt.internal.client.NetworkClient.Arguments
|
||||
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
|
||||
import sbt.internal.protocol._
|
||||
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util }
|
||||
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Signals, Terminal, Util }
|
||||
import sbt.io.IO
|
||||
import sbt.io.syntax._
|
||||
import sbt.protocol._
|
||||
|
|
@ -36,7 +36,9 @@ import scala.concurrent.duration._
|
|||
import scala.util.control.NonFatal
|
||||
import scala.util.{ Failure, Properties, Success }
|
||||
import Serialization.{
|
||||
CancelAll,
|
||||
attach,
|
||||
cancelRequest,
|
||||
systemIn,
|
||||
systemOut,
|
||||
terminalCapabilities,
|
||||
|
|
@ -82,6 +84,7 @@ class NetworkClient(
|
|||
private val lock: AnyRef = new AnyRef {}
|
||||
private val running = new AtomicBoolean(true)
|
||||
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
|
||||
private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]]
|
||||
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
|
||||
private val attached = new AtomicBoolean(false)
|
||||
private val attachUUID = new AtomicReference[String](null)
|
||||
|
|
@ -247,6 +250,10 @@ class NetworkClient(
|
|||
}
|
||||
q.offer(exitCode)
|
||||
}
|
||||
pendingCancellations.remove(msg.id) match {
|
||||
case null =>
|
||||
case q => q.offer(msg.toString.contains("Task cancelled"))
|
||||
}
|
||||
msg.id match {
|
||||
case execId =>
|
||||
if (attachUUID.get == msg.id) {
|
||||
|
|
@ -390,24 +397,52 @@ class NetworkClient(
|
|||
()
|
||||
}
|
||||
|
||||
def run(): Int = {
|
||||
interactiveThread.set(Thread.currentThread)
|
||||
val cleaned = arguments.commandArguments
|
||||
val userCommands = cleaned.takeWhile(_ != "exit")
|
||||
val interactive = cleaned.isEmpty
|
||||
val exit = cleaned.nonEmpty && userCommands.isEmpty
|
||||
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
|
||||
if (interactive) {
|
||||
try this.synchronized(this.wait)
|
||||
catch { case _: InterruptedException => }
|
||||
if (exitClean.get) 0 else 1
|
||||
} else if (exit) {
|
||||
0
|
||||
} else {
|
||||
batchMode.set(true)
|
||||
batchExecute(userCommands.toList)
|
||||
}
|
||||
private[this] val contHandler: () => Unit = () => {
|
||||
if (Terminal.console.getLastLine.nonEmpty)
|
||||
printStream.print(ConsoleAppender.DeleteLine + Terminal.console.getLastLine.get)
|
||||
}
|
||||
private[this] def withSignalHandler[R](handler: () => Unit, sig: String)(f: => R): R = {
|
||||
val registration = Signals.register(handler, sig)
|
||||
try f
|
||||
finally registration.remove()
|
||||
}
|
||||
private[this] val cancelled = new AtomicBoolean(false)
|
||||
|
||||
def run(): Int =
|
||||
withSignalHandler(contHandler, Signals.CONT) {
|
||||
interactiveThread.set(Thread.currentThread)
|
||||
val cleaned = arguments.commandArguments
|
||||
val userCommands = cleaned.takeWhile(_ != "exit")
|
||||
val interactive = cleaned.isEmpty
|
||||
val exit = cleaned.nonEmpty && userCommands.isEmpty
|
||||
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
|
||||
val handler: () => Unit = () => {
|
||||
def exitAbruptly() = {
|
||||
exitClean.set(false)
|
||||
close()
|
||||
}
|
||||
if (cancelled.compareAndSet(false, true)) {
|
||||
val cancelledTasks = {
|
||||
val queue = sendCancelAllCommand()
|
||||
Option(queue.poll(1, TimeUnit.SECONDS)).getOrElse(true)
|
||||
}
|
||||
if ((!interactive && pendingResults.isEmpty) || !cancelledTasks) exitAbruptly()
|
||||
else cancelled.set(false)
|
||||
} else exitAbruptly() // handles double ctrl+c to force a shutdown
|
||||
}
|
||||
withSignalHandler(handler, Signals.INT) {
|
||||
if (interactive) {
|
||||
try this.synchronized(this.wait)
|
||||
catch { case _: InterruptedException => }
|
||||
if (exitClean.get) 0 else 1
|
||||
} else if (exit) {
|
||||
0
|
||||
} else {
|
||||
batchMode.set(true)
|
||||
batchExecute(userCommands.toList)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def batchExecute(userCommands: List[String]): Int = {
|
||||
val cmd = userCommands mkString " "
|
||||
|
|
@ -440,6 +475,13 @@ class NetworkClient(
|
|||
queue
|
||||
}
|
||||
|
||||
def sendCancelAllCommand(): LinkedBlockingQueue[Boolean] = {
|
||||
val queue = new LinkedBlockingQueue[Boolean]
|
||||
val execId = sendJson(cancelRequest, s"""{"id":"$CancelAll"}""")
|
||||
pendingCancellations.put(execId, queue)
|
||||
queue
|
||||
}
|
||||
|
||||
def sendCommand(command: CommandMessage): Unit = {
|
||||
try {
|
||||
val s = Serialization.serializeCommandAsJsonMessage(command)
|
||||
|
|
@ -538,7 +580,9 @@ class NetworkClient(
|
|||
s"Total time: $totalString, completed $nowString"
|
||||
}
|
||||
}
|
||||
|
||||
object NetworkClient {
|
||||
private[sbt] val CancelAll = "__CancelAll"
|
||||
private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
|
||||
val appender = ConsoleAppender("thin", ConsoleOut.printStreamOut(printStream))
|
||||
new ConsoleInterface {
|
||||
|
|
|
|||
|
|
@ -451,7 +451,8 @@ final class NetworkChannel(
|
|||
|
||||
// direct comparison on strings and
|
||||
// remove hotspring unicode added character for numbers
|
||||
if (checkId) {
|
||||
if (checkId || (crp.id == Serialization.CancelAll &&
|
||||
StandardMain.exchange.currentExec.exists(_.source.exists(_.channelName == name)))) {
|
||||
runningEngine.cancelAndShutdown()
|
||||
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ object Serialization {
|
|||
val attach = "sbt/attach"
|
||||
val attachResponse = "sbt/attachResponse"
|
||||
val cancelRequest = "sbt/cancelRequest"
|
||||
val CancelAll = "__CancelAll"
|
||||
|
||||
@deprecated("unused", since = "1.4.0")
|
||||
def serializeEvent[A: JsonFormat](event: A): Array[Byte] = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue