From 5fd3c1d2e5ba8aa9d4c669b1fb2c6c1755611b98 Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Tue, 18 Oct 2011 22:43:25 -0400 Subject: [PATCH] task execution interruptible using ctrl+c. fixes #228,#229 - interrupts task execution only - no further tasks scheduled - existing tasks interrupted - a task must terminate any other started threads when interrupted - set cancelable to true to enable - currently, 'run' properly terminates if the application properly terminates when interrupted - 'console' does not, 'test' depends on the test framework - also bundled: set connectInput to true to connect standard input to forked run --- main/Aggregation.scala | 15 ++++---- main/Defaults.scala | 7 ++-- main/EvaluateTask.scala | 62 ++++++++++++++++++++++++++-------- main/GlobalPlugin.scala | 3 +- main/Keys.scala | 2 ++ main/Project.scala | 6 ++-- run/Fork.scala | 17 ++++++---- run/Run.scala | 8 ++++- run/TrapExit.scala | 11 +++++- util/collection/Signal.scala | 43 +++++++++++++++++++++++ util/process/Process.scala | 2 ++ util/process/ProcessImpl.scala | 8 ++--- 12 files changed, 145 insertions(+), 39 deletions(-) create mode 100644 util/collection/Signal.scala diff --git a/main/Aggregation.scala b/main/Aggregation.scala index 21268684b..7cc1aeacf 100644 --- a/main/Aggregation.scala +++ b/main/Aggregation.scala @@ -76,11 +76,16 @@ final object Aggregation { import EvaluateTask._ import std.TaskExtra._ + val extracted = Project extract s val toRun = ts map { case KeyValue(k,t) => t.map(v => KeyValue(k,v)) } join; - val workers = maxWorkers(extracted, structure) + val config = extractedConfig(extracted, structure) + val start = System.currentTimeMillis - val (newS, result) = withStreams(structure){ str => runTask(toRun, s,str, structure.index.triggers, maxWorkers = workers)(nodeView(s, str, extra.tasks, extra.values)) } + val (newS, result) = withStreams(structure){ str => + val transform = nodeView(s, str, extra.tasks, extra.values) + runTask(toRun, s,str, structure.index.triggers, config)(transform) + } val stop = System.currentTimeMillis val log = newS.log @@ -90,11 +95,7 @@ final object Aggregation newS } - def maxWorkers(extracted: Extracted, structure: Load.BuildStructure): Int = - (Keys.parallelExecution in extracted.currentRef get structure.data) match { - case Some(true) | None => EvaluateTask.SystemProcessors - case Some(false) => 1 - } + def printSuccess(start: Long, stop: Long, extracted: Extracted, success: Boolean, log: Logger) { import extracted._ diff --git a/main/Defaults.scala b/main/Defaults.scala index 3d9d00e8e..a2c925b30 100644 --- a/main/Defaults.scala +++ b/main/Defaults.scala @@ -55,6 +55,8 @@ object Defaults extends BuildCommon sbtResolver in GlobalScope <<= sbtVersion { sbtV => if(sbtV endsWith "-SNAPSHOT") Classpaths.typesafeSnapshots else Classpaths.typesafeResolver }, pollInterval :== 500, logBuffered :== false, + connectInput :== false, + cancelable :== false, autoScalaLibrary :== true, onLoad <<= onLoad ?? idFun[State], onUnload <<= onUnload ?? idFun[State], @@ -426,9 +428,10 @@ object Defaults extends BuildCommon def runnerTask = runner <<= runnerInit def runnerInit: Initialize[Task[ScalaRun]] = - (taskTemporaryDirectory, scalaInstance, baseDirectory, javaOptions, outputStrategy, fork, javaHome, trapExit) map { (tmp, si, base, options, strategy, forkRun, javaHomeDir, trap) => + (taskTemporaryDirectory, scalaInstance, baseDirectory, javaOptions, outputStrategy, fork, javaHome, trapExit, connectInput) map { + (tmp, si, base, options, strategy, forkRun, javaHomeDir, trap, connectIn) => if(forkRun) { - new ForkRun( ForkOptions(scalaJars = si.jars, javaHome = javaHomeDir, outputStrategy = strategy, + new ForkRun( ForkOptions(scalaJars = si.jars, javaHome = javaHomeDir, connectInput = connectIn, outputStrategy = strategy, runJVMOptions = options, workingDirectory = Some(base)) ) } else new Run(si, trap, tmp) diff --git a/main/EvaluateTask.scala b/main/EvaluateTask.scala index b909598cb..d97314369 100644 --- a/main/EvaluateTask.scala +++ b/main/EvaluateTask.scala @@ -8,8 +8,10 @@ package sbt import Keys.{globalLogging, streams, Streams, TaskStreams} import Keys.{dummyState, dummyStreamsManager, streamsManager, taskDefinitionKey, transformState} import Scope.{GlobalScope, ThisScope} + import Types.const import scala.Console.{RED, RESET} +final case class EvaluateConfig(cancelable: Boolean, checkCycles: Boolean = false, maxWorkers: Int = EvaluateTask.SystemProcessors) object EvaluateTask { import Load.BuildStructure @@ -19,6 +21,23 @@ object EvaluateTask import Keys.state val SystemProcessors = Runtime.getRuntime.availableProcessors + def defaultConfig = EvaluateConfig(false) + def extractedConfig(extracted: Extracted, structure: BuildStructure): EvaluateConfig = + { + val workers = maxWorkers(extracted, structure) + val canCancel = cancelable(extracted, structure) + EvaluateConfig(cancelable = canCancel, maxWorkers = workers) + } + + def maxWorkers(extracted: Extracted, structure: Load.BuildStructure): Int = + if(getBoolean(Keys.parallelExecution, true, extracted, structure)) + EvaluateTask.SystemProcessors + else + 1 + def cancelable(extracted: Extracted, structure: Load.BuildStructure): Boolean = + getBoolean(Keys.cancelable, false, extracted, structure) + def getBoolean(key: SettingKey[Boolean], default: Boolean, extracted: Extracted, structure: Load.BuildStructure): Boolean = + (key in extracted.currentRef get structure.data) getOrElse default def injectSettings: Seq[Setting[_]] = Seq( (state in GlobalScope) ::= dummyState, @@ -29,19 +48,19 @@ object EvaluateTask { val root = ProjectRef(pluginDef.root, Load.getRootProject(pluginDef.units)(pluginDef.root)) val pluginKey = Keys.fullClasspath in Configurations.Runtime - val evaluated = apply(pluginDef, ScopedKey(pluginKey.scope, pluginKey.key), state, root) + val evaluated = apply(pluginDef, ScopedKey(pluginKey.scope, pluginKey.key), state, root, defaultConfig) val (newS, result) = evaluated getOrElse error("Plugin classpath does not exist for plugin definition at " + pluginDef.root) - Project.runUnloadHooks(newS) // discard state + Project.runUnloadHooks(newS) // discard states processResult(result, log) } @deprecated("This method does not apply state changes requested during task execution. Use 'apply' instead, which does.", "0.11.1") def evaluateTask[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, checkCycles: Boolean = false, maxWorkers: Int = SystemProcessors): Option[Result[T]] = - apply(structure, taskKey, state, ref, checkCycles, maxWorkers).map(_._2) - def apply[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, checkCycles: Boolean = false, maxWorkers: Int = SystemProcessors): Option[(State, Result[T])] = + apply(structure, taskKey, state, ref, EvaluateConfig(false, checkCycles, maxWorkers)).map(_._2) + def apply[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, config: EvaluateConfig = defaultConfig): Option[(State, Result[T])] = withStreams(structure) { str => for( (task, toNode) <- getTask(structure, taskKey, state, str, ref) ) yield - runTask(task, state, str, structure.index.triggers, checkCycles, maxWorkers)(toNode) + runTask(task, state, str, structure.index.triggers, config)(toNode) } def logIncResult(result: Result[_], streams: Streams) = result match { case Inc(i) => logIncomplete(i, streams); case _ => () } def logIncomplete(result: Incomplete, streams: Streams) @@ -78,18 +97,31 @@ object EvaluateTask def nodeView[HL <: HList](state: State, streams: Streams, extraDummies: KList[Task, HL] = KNil, extraValues: HL = HNil): Execute.NodeView[Task] = Transform(dummyStreamsManager :^: KCons(dummyState, extraDummies), streams :+: HCons(state, extraValues)) - def runTask[T](root: Task[T], state: State, streams: Streams, triggers: Triggers[Task], checkCycles: Boolean = false, maxWorkers: Int = SystemProcessors)(implicit taskToNode: Execute.NodeView[Task]): (State, Result[T]) = + def runTask[T](root: Task[T], state: State, streams: Streams, triggers: Triggers[Task], config: EvaluateConfig = defaultConfig)(implicit taskToNode: Execute.NodeView[Task]): (State, Result[T]) = { - val (service, shutdown) = CompletionService[Task[_], Completed](maxWorkers) + val log = state.log + log.debug("Running task... Cancelable: " + config.cancelable + ", max worker threads: " + config.maxWorkers + ", check cycles: " + config.checkCycles) + val (service, shutdown) = CompletionService[Task[_], Completed](config.maxWorkers) - val x = new Execute[Task](checkCycles, triggers)(taskToNode) - val (newState, result) = - try applyResults(x.runKeep(root)(service), state, root) - catch { case inc: Incomplete => (state, Inc(inc)) } - finally shutdown() - val replaced = transformInc(result) - logIncResult(replaced, streams) - (newState, replaced) + def run() = { + val x = new Execute[Task](config.checkCycles, triggers)(taskToNode) + val (newState, result) = + try applyResults(x.runKeep(root)(service), state, root) + catch { case inc: Incomplete => (state, Inc(inc)) } + finally shutdown() + val replaced = transformInc(result) + logIncResult(replaced, streams) + (newState, replaced) + } + val cancel = () => { + println("") + log.warn("Canceling execution...") + shutdown() + } + if(config.cancelable) + Signals.withHandler(cancel) { run } + else + run() } def applyResults[T](results: RMap[Task, Result], state: State, root: Task[T]): (State, Result[T]) = diff --git a/main/GlobalPlugin.scala b/main/GlobalPlugin.scala index a317c73ef..ae107c852 100644 --- a/main/GlobalPlugin.scala +++ b/main/GlobalPlugin.scala @@ -53,7 +53,8 @@ object GlobalPlugin import EvaluateTask._ withStreams(structure) { str => val nv = nodeView(state, str) - val (newS, result) = runTask(t, state, str, structure.index.triggers)(nv) + val config = EvaluateTask.defaultConfig + val (newS, result) = runTask(t, state, str, structure.index.triggers, config)(nv) (newS, processResult(result, newS.log)) } } diff --git a/main/Keys.scala b/main/Keys.scala index f04c55782..f5f1a1deb 100644 --- a/main/Keys.scala +++ b/main/Keys.scala @@ -175,6 +175,7 @@ object Keys val fork = SettingKey[Boolean]("fork", "If true, forks a new JVM when running. If false, runs in the same JVM as the build.") val outputStrategy = SettingKey[Option[sbt.OutputStrategy]]("output-strategy", "Selects how to log output when running a main class.") + val connectInput = SettingKey[Boolean]("connect-input", "If true, connects standard input when running a main class forked.") val javaHome = SettingKey[Option[File]]("java-home", "Selects the Java installation used for compiling and forking. If None, uses the Java installation running the build.") val javaOptions = SettingKey[Seq[String]]("java-options", "Options passed to a new JVM when forking.") @@ -297,6 +298,7 @@ object Keys // special val sessionVars = AttributeKey[SessionVar.Map]("session-vars", "Bindings that exist for the duration of the session.") val parallelExecution = SettingKey[Boolean]("parallel-execution", "Enables (true) or disables (false) parallel execution of tasks.") + val cancelable = SettingKey[Boolean]("cancelable", "Enables (true) or disables (false) the ability to interrupt task execution with CTRL+C.") val settings = TaskKey[Settings[Scope]]("settings", "Provides access to the project data for the build.") val streams = TaskKey[TaskStreams]("streams", "Provides streams for logging and persisting data.") val isDummyTask = AttributeKey[Boolean]("is-dummy-task", "Internal: used to identify dummy tasks. sbt injects values for these tasks at the start of task execution.") diff --git a/main/Project.scala b/main/Project.scala index 7724dd473..694b81958 100644 --- a/main/Project.scala +++ b/main/Project.scala @@ -76,7 +76,8 @@ final case class Extracted(structure: BuildStructure, session: SessionSettings, { import EvaluateTask._ val rkey = resolve(key.scopedKey) - val value: Option[(State, Result[T])] = apply(structure, key.task.scopedKey, state, currentRef) + val config = extractedConfig(this, structure) + val value: Option[(State, Result[T])] = apply(structure, key.task.scopedKey, state, currentRef, config) val (newS, result) = getOrError(rkey.scope, rkey.key, value) (newS, processResult(result, newS.log)) } @@ -342,7 +343,8 @@ object Project extends Init[Scope] with ProjectExtra def runTask[T](taskKey: ScopedKey[Task[T]], state: State, checkCycles: Boolean = false, maxWorkers: Int = EvaluateTask.SystemProcessors): Option[(State, Result[T])] = { val extracted = Project.extract(state) - EvaluateTask(extracted.structure, taskKey, state, extracted.currentRef, checkCycles, maxWorkers) + val config = EvaluateConfig(true, checkCycles, maxWorkers) + EvaluateTask(extracted.structure, taskKey, state, extracted.currentRef, config) } // this is here instead of Scoped so that it is considered without need for import (because of Project.Initialize) implicit def richInitializeTask[T](init: Initialize[Task[T]]): Scoped.RichInitializeTask[T] = new Scoped.RichInitializeTask(init) diff --git a/run/Fork.scala b/run/Fork.scala index 32708d179..52216bcb0 100644 --- a/run/Fork.scala +++ b/run/Fork.scala @@ -9,6 +9,7 @@ trait ForkJava { def javaHome: Option[File] def outputStrategy: Option[OutputStrategy] + def connectInput: Boolean } trait ForkScala extends ForkJava { @@ -19,7 +20,7 @@ trait ForkScalaRun extends ForkScala def workingDirectory: Option[File] def runJVMOptions: Seq[String] } -final case class ForkOptions(javaHome: Option[File] = None, outputStrategy: Option[OutputStrategy] = None, scalaJars: Iterable[File] = Nil, workingDirectory: Option[File] = None, runJVMOptions: Seq[String] = Nil) extends ForkScalaRun +final case class ForkOptions(javaHome: Option[File] = None, outputStrategy: Option[OutputStrategy] = None, scalaJars: Iterable[File] = Nil, workingDirectory: Option[File] = None, runJVMOptions: Seq[String] = Nil, connectInput: Boolean = false) extends ForkScalaRun sealed abstract class OutputStrategy extends NotNull case object StdoutOutput extends OutputStrategy @@ -55,6 +56,8 @@ object Fork def apply(javaHome: Option[File], options: Seq[String], workingDirectory: Option[File], outputStrategy: OutputStrategy): Int = apply(javaHome, options, workingDirectory, Map.empty, outputStrategy) def apply(javaHome: Option[File], options: Seq[String], workingDirectory: Option[File], env: Map[String, String], outputStrategy: OutputStrategy): Int = + fork(javaHome, options, workingDirectory, env, false, outputStrategy).exitValue + def fork(javaHome: Option[File], options: Seq[String], workingDirectory: Option[File], env: Map[String, String], connectInput: Boolean, outputStrategy: OutputStrategy): Process = { val executable = javaCommand(javaHome, commandName).getAbsolutePath val command = (executable :: options.toList).toArray @@ -64,10 +67,10 @@ object Fork for( (key, value) <- env ) environment.put(key, value) outputStrategy match { - case StdoutOutput => Process(builder) ! - case BufferedOutput(logger) => Process(builder) ! logger - case LoggedOutput(logger) => Process(builder).run(logger).exitValue() - case CustomOutput(output) => (Process(builder) #> output).run.exitValue() + case StdoutOutput => Process(builder).run(connectInput) + case BufferedOutput(logger) => Process(builder).runBuffered(logger, connectInput) + case LoggedOutput(logger) => Process(builder).run(logger, connectInput) + case CustomOutput(output) => (Process(builder) #> output).run(connectInput) } } } @@ -79,12 +82,14 @@ object Fork def apply(javaHome: Option[File], jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], workingDirectory: Option[File], log: Logger): Int = apply(javaHome, jvmOptions, scalaJars, arguments, workingDirectory, BufferedOutput(log)) def apply(javaHome: Option[File], jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], workingDirectory: Option[File], outputStrategy: OutputStrategy): Int = + fork(javaHome, jvmOptions, scalaJars, arguments, workingDirectory, false, outputStrategy).exitValue() + def fork(javaHome: Option[File], jvmOptions: Seq[String], scalaJars: Iterable[File], arguments: Seq[String], workingDirectory: Option[File], connectInput: Boolean, outputStrategy: OutputStrategy): Process = { if(scalaJars.isEmpty) error("Scala jars not specified") val scalaClasspathString = "-Xbootclasspath/a:" + scalaJars.map(_.getAbsolutePath).mkString(File.pathSeparator) val mainClass = if(mainClassName.isEmpty) Nil else mainClassName :: Nil val options = jvmOptions ++ (scalaClasspathString :: mainClass ::: arguments.toList) - Fork.java(javaHome, options, workingDirectory, Map.empty, outputStrategy) + Fork.java.fork(javaHome, options, workingDirectory, Map.empty, connectInput, outputStrategy) } } } diff --git a/run/Run.scala b/run/Run.scala index c9f16b6b1..033f88886 100644 --- a/run/Run.scala +++ b/run/Run.scala @@ -19,7 +19,13 @@ class ForkRun(config: ForkScalaRun) extends ScalaRun { val scalaOptions = classpathOption(classpath) ::: mainClass :: options.toList val strategy = config.outputStrategy getOrElse LoggedOutput(log) - val exitCode = Fork.scala(config.javaHome, config.runJVMOptions, config.scalaJars, scalaOptions, config.workingDirectory, strategy) + val process = Fork.scala.fork(config.javaHome, config.runJVMOptions, config.scalaJars, scalaOptions, config.workingDirectory, config.connectInput, strategy) + def cancel() = { + log.warn("Run canceled.") + process.destroy() + 1 + } + val exitCode = try process.exitValue() catch { case e: InterruptedException => cancel() } processExitCode(exitCode, "runner") } private def classpathOption(classpath: Seq[File]) = "-cp" :: Path.makeString(classpath) :: Nil diff --git a/run/TrapExit.scala b/run/TrapExit.scala index bcd79d68a..7c8b65e7d 100644 --- a/run/TrapExit.scala +++ b/run/TrapExit.scala @@ -56,8 +56,17 @@ object TrapExit log.debug("Sandboxed run complete..") code.value.getOrElse(0) } - finally { System.setSecurityManager(originalSecurityManager) } + catch { case e: InterruptedException => cancel(executionThread, allThreads, log) } + finally System.setSecurityManager(originalSecurityManager) } + private[this] def cancel(executionThread: Thread, originalThreads: Set[Thread], log: Logger): Int = + { + log.warn("Run canceled.") + executionThread.interrupt() + stopAll(originalThreads) + 1 + } + // wait for all non-daemon threads to terminate private def waitForExit(originalThreads: Set[Thread], log: Logger) { diff --git a/util/collection/Signal.scala b/util/collection/Signal.scala new file mode 100644 index 000000000..09756249d --- /dev/null +++ b/util/collection/Signal.scala @@ -0,0 +1,43 @@ +package sbt + +object Signals +{ + def withHandler[T](handler: () => Unit)(action: () => T): T = + { + val result = + try + { + val signals = new Signals0 + signals.withHandler(handler)(action) + } + catch { case e: LinkageError => Right(action()) } + + result match { + case Left(e) => throw e + case Right(v) => v + } + } +} + +// Must only be referenced using a +// try { } catch { case e: LinkageError => ... } +// block to +private final class Signals0 +{ + // returns a LinkageError in `action` as Left(t) in order to avoid it being + // incorrectly swallowed as missing Signal/SignalHandler + def withHandler[T](handler: () => Unit)(action: () => T): Either[Throwable, T] = + { + import sun.misc.{Signal,SignalHandler} + val intSignal = new Signal("INT") + val newHandler = new SignalHandler { + def handle(sig: Signal) { handler() } + } + + val oldHandler = Signal.handle(intSignal, newHandler) + + try Right(action()) + catch { case e: LinkageError => Left(e) } + finally Signal.handle(intSignal, oldHandler) + } +} \ No newline at end of file diff --git a/util/process/Process.scala b/util/process/Process.scala index 2dd70484c..5a1f46b4a 100644 --- a/util/process/Process.scala +++ b/util/process/Process.scala @@ -166,6 +166,8 @@ trait ProcessBuilder extends SourcePartialBuilder with SinkPartialBuilder * The newly started process reads from standard input of the current process if `connectInput` is true.*/ def run(log: ProcessLogger, connectInput: Boolean): Process + def runBuffered(log: ProcessLogger, connectInput: Boolean): Process + /** Constructs a command that runs this command first and then `other` if this command succeeds.*/ def #&& (other: ProcessBuilder): ProcessBuilder /** Constructs a command that runs this command first and then `other` if this command does not succeed.*/ diff --git a/util/process/ProcessImpl.scala b/util/process/ProcessImpl.scala index c20b23f20..69191d054 100644 --- a/util/process/ProcessImpl.scala +++ b/util/process/ProcessImpl.scala @@ -159,10 +159,10 @@ private abstract class AbstractProcessBuilder extends ProcessBuilder with SinkPa def ! = run(false).exitValue() def !< = run(true).exitValue() - def !(log: ProcessLogger) = runBuffered(log, false) - def !<(log: ProcessLogger) = runBuffered(log, true) - private[this] def runBuffered(log: ProcessLogger, connectInput: Boolean) = - log.buffer { run(log, connectInput).exitValue() } + def !(log: ProcessLogger) = runBuffered(log, false).exitValue() + def !<(log: ProcessLogger) = runBuffered(log, true).exitValue() + def runBuffered(log: ProcessLogger, connectInput: Boolean) = + log.buffer { run(log, connectInput) } def !(io: ProcessIO) = run(io).exitValue() def canPipeTo = false