From 60426facba697b24d0d3d490f48525cdefc79daf Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Wed, 2 Oct 2013 09:10:38 -0400 Subject: [PATCH] TrapExit support for multiple, concurrent managed applications. Fixes #831. --- main/src/main/scala/sbt/Main.scala | 13 +- run/src/main/scala/sbt/TrapExit.scala | 502 +++++++++++------- .../scala/sbt/TrapExitSecurityException.scala | 26 + sbt/src/sbt-test/run/concurrent/build.sbt | 40 ++ .../sbt-test/run/concurrent/changes/B.scala | 7 + .../sbt-test/run/concurrent/changes/C.scala | 24 + sbt/src/sbt-test/run/concurrent/test | 16 + util/log/src/main/scala/sbt/Logger.scala | 6 +- 8 files changed, 437 insertions(+), 197 deletions(-) create mode 100644 run/src/main/scala/sbt/TrapExitSecurityException.scala create mode 100644 sbt/src/sbt-test/run/concurrent/build.sbt create mode 100644 sbt/src/sbt-test/run/concurrent/changes/B.scala create mode 100644 sbt/src/sbt-test/run/concurrent/changes/C.scala create mode 100644 sbt/src/sbt-test/run/concurrent/test diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index d1bd323d1..af6293202 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -24,7 +24,7 @@ final class xMain extends xsbti.AppMain { import BuiltinCommands.{initialize, defaults} import CommandStrings.{BootCommand, DefaultsCommand, InitCommand} - MainLoop.runLogged( initialState(configuration, + runManaged( initialState(configuration, Seq(initialize, defaults), DefaultsCommand :: InitCommand :: BootCommand :: Nil) ) @@ -33,7 +33,7 @@ final class xMain extends xsbti.AppMain final class ScriptMain extends xsbti.AppMain { def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = - MainLoop.runLogged( initialState(configuration, + runManaged( initialState(configuration, BuiltinCommands.ScriptCommands, Script.Name :: Nil) ) @@ -41,7 +41,7 @@ final class ScriptMain extends xsbti.AppMain final class ConsoleMain extends xsbti.AppMain { def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = - MainLoop.runLogged( initialState(configuration, + runManaged( initialState(configuration, BuiltinCommands.ConsoleCommands, IvyConsole.Name :: Nil) ) @@ -49,6 +49,13 @@ final class ConsoleMain extends xsbti.AppMain object StandardMain { + def runManaged(s: State): xsbti.MainResult = + { + val previous = TrapExit.installManager() + try MainLoop.runLogged(s) + finally TrapExit.uninstallManager(previous) + } + /** The common interface to standard output, used for all built-in ConsoleLoggers. */ val console = ConsoleOut.systemOutOverwrite(ConsoleOut.overwriteContaining("Resolving ")) diff --git a/run/src/main/scala/sbt/TrapExit.scala b/run/src/main/scala/sbt/TrapExit.scala index e5a133338..4b0de4f63 100644 --- a/run/src/main/scala/sbt/TrapExit.scala +++ b/run/src/main/scala/sbt/TrapExit.scala @@ -9,91 +9,96 @@ package sbt import scala.collection.Set import scala.reflect.Manifest +import scala.collection.concurrent.TrieMap -/** This provides functionality to catch System.exit calls to prevent the JVM from terminating. -* This is useful for executing user code that may call System.exit, but actually exiting is -* undesirable. This file handles the call to exit by disposing all top-level windows and interrupting -* all user started threads. It does not stop the threads and does not call shutdown hooks. It is +import java.lang.ref.WeakReference +import Thread.currentThread +import java.security.Permission +import java.util.concurrent.{ConcurrentHashMap => CMap} +import java.lang.Integer.{toHexString => hex} +import java.lang.Long.{toHexString => hexL} + +import TrapExit._ + +/** Provides an approximation to isolated execution within a single JVM. +* System.exit calls are trapped to prevent the JVM from terminating. This is useful for executing +* user code that may call System.exit, but actually exiting is undesirable. +* +* Exit is simulated by disposing all top-level windows and interrupting user-started threads. +* Threads are not stopped and shutdown hooks are not called. It is * therefore inappropriate to use this with code that requires shutdown hooks or creates threads that -* do not terminate. This category of code should only be called by forking the JVM. */ +* do not terminate. This category of code should only be called by forking a new JVM. */ object TrapExit { - /** Executes the given thunk in a context where System.exit(code) throws - * a custom SecurityException, which is then caught and the exit code returned. - * Otherwise, 0 is returned. No other exceptions are handled by this method.*/ + /** Run `execute` in a managed context, using `log` for debugging messages. + * `installManager` must be called before calling this method. */ def apply(execute: => Unit, log: Logger): Int = - { - log.debug("Starting sandboxed run...") - - /** Take a snapshot of the threads that existed before execution in order to determine - * the threads that were created by 'execute'.*/ - val originalThreads = allThreads - val code = new ExitCode - def executeMain() = - try { execute } - catch - { - case e: TrapExitSecurityException => throw e - case x: Throwable => - code.set(1) //exceptions in the main thread cause the exit code to be 1 - throw x - } - val customThreadGroup = new ExitThreadGroup(new ExitHandler(originalThreads, code, log)) - val executionThread = new Thread(customThreadGroup, "run-main") { override def run() { executeMain } } - - val originalSecurityManager = System.getSecurityManager - try - { - val newSecurityManager = new TrapExitSecurityManager(originalSecurityManager, customThreadGroup) - System.setSecurityManager(newSecurityManager) - - executionThread.start() - - log.debug("Waiting for threads to exit or System.exit to be called.") - waitForExit(originalThreads, log) - log.debug("Interrupting remaining threads (should be all daemons).") - interruptAll(originalThreads) // should only be daemon threads left now - log.debug("Sandboxed run complete..") - code.value.getOrElse(0) + System.getSecurityManager match { + case m: TrapExit => m.runManaged(Logger.f0(execute), log) + case _ => runUnmanaged(execute, log) } - catch { case e: InterruptedException => cancel(executionThread, allThreads, log) } - finally System.setSecurityManager(originalSecurityManager) - } - private[this] def cancel(executionThread: Thread, originalThreads: Set[Thread], log: Logger): Int = + + /** Installs the SecurityManager that implements the isolation and returns the previously installed SecurityManager, which may be null. + * This method must be called before using `apply`. */ + def installManager(): SecurityManager = + System.getSecurityManager match { + case m: TrapExit => m + case m => System.setSecurityManager(new TrapExit(m)); m + } + + /** Uninstalls the isolation SecurityManager and restores the old security manager. */ + def uninstallManager(previous: SecurityManager): Unit = + System.setSecurityManager(previous) + + private[this] def runUnmanaged(execute: => Unit, log: Logger): Int = { - log.warn("Run canceled.") - executionThread.interrupt() - stopAll(originalThreads) - 1 + log.warn("Managed execution not possible: security manager not installed.") + try { execute; 0 } + catch { case e: Exception => + log.error("Error during execution: " + e.toString) + log.trace(e) + 1 + } } - - // wait for all non-daemon threads to terminate - private def waitForExit(originalThreads: Set[Thread], log: Logger) - { - var daemonsOnly = true - processThreads(originalThreads, thread => - if(!thread.isDaemon) - { - daemonsOnly = false - waitOnThread(thread, log) - } - ) - if(!daemonsOnly) - waitForExit(originalThreads, log) - } - /** Waits for the given thread to exit. */ + + private type ThreadID = String + + /** `true` if the thread `t` is in the TERMINATED state.x*/ + private def isDone(t: Thread): Boolean = t.getState == Thread.State.TERMINATED + + /** Computes an identifier for a Thread that has a high probability of being unique within a single JVM execution. */ + private def computeID(t: Thread): ThreadID = + // can't use t.getId because when getAccess first sees a Thread, it hasn't been initialized yet + s"${hex(System.identityHashCode(t))}:${t.getName}" + + /** Waits for the given `thread` to terminate. However, if the thread state is NEW, this method returns immediately. */ private def waitOnThread(thread: Thread, log: Logger) { - log.debug("Waiting for thread " + thread.getName + " to exit") + log.debug("Waiting for thread " + thread.getName + " to terminate.") thread.join log.debug("\tThread " + thread.getName + " exited.") } - /** Returns the exit code of the System.exit that caused the given Exception, or rethrows the exception - * if its cause was not calling System.exit.*/ - private def exitCode(e: Throwable) = - withCause[TrapExitSecurityException, Int](e) - {exited => exited.exitCode} - {other => throw other} + + // interrupts the given thread, but first replaces the exception handler so that the InterruptedException is not printed + private def safeInterrupt(thread: Thread, log: Logger) + { + val name = thread.getName + log.debug("Interrupting thread " + thread.getName) + thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler)) + thread.interrupt + log.debug("\tInterrupted " + thread.getName) + } + // an uncaught exception handler that swallows InterruptedExceptions and otherwise defers to originalHandler + private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler + { + def uncaughtException(thread: Thread, e: Throwable) + { + withCause[InterruptedException, Unit](e) + {interrupted => ()} + {other => originalHandler.uncaughtException(thread, e) } + thread.setUncaughtExceptionHandler(originalHandler) + } + } /** Recurses into the causes of the given exception looking for a cause of type CauseType. If one is found, `withType` is called with that cause. * If not, `notType` is called with the root cause.*/ private def withCause[CauseType <: Throwable, T](e: Throwable)(withType: CauseType => T)(notType: Throwable => T)(implicit mf: Manifest[CauseType]): T = @@ -110,146 +115,257 @@ object TrapExit withCause(cause)(withType)(notType)(mf) } } - - /** Returns all threads that are not in the 'system' thread group and are not the AWT implementation - * thread (AWT-XAWT, AWT-Windows, ...)*/ - private def allThreads: Set[Thread] = + +} + +/** Simulates isolation via a SecurityManager. +* Multiple applications are supported by tracking Thread constructions via `checkAccess`. +* The Thread that constructed each Thread is used to map a new Thread to an application. +* This association of Threads with an application allows properly waiting for +* non-daemon threads to terminate or to interrupt the correct threads when terminating.*/ +private final class TrapExit(delegateManager: SecurityManager) extends SecurityManager +{ + /** Tracks the number of running applications in order to short-cut SecurityManager checks when no applications are active.*/ + private[this] val running = new java.util.concurrent.atomic.AtomicInteger + + /** Maps a thread to its originating application. The thread is represented by a unique identifier to avoid leaks. */ + private[this] val threadToApp = new CMap[ThreadID, App] + + /** Executes `f` in a managed context. */ + def runManaged(f: xsbti.F0[Unit], xlog: xsbti.Logger): Int = { - import collection.JavaConversions._ - Thread.getAllStackTraces.keySet.filter(thread => !isSystemThread(thread)) + val _ = running.incrementAndGet() + try runManaged0(f, xlog) + finally running.decrementAndGet() } - /** Returns true if the given thread is in the 'system' thread group and is an AWT thread other than - * AWT-EventQueue or AWT-Shutdown.*/ + private[this] def runManaged0(f: xsbti.F0[Unit], xlog: xsbti.Logger): Int = + { + val log: Logger = xlog + val app = newApp(f, log) + val executionThread = new Thread(app, "run-main") + try { + executionThread.start() // thread actually evaluating `f` + finish(app, log) + } + catch { case e: InterruptedException => // here, the thread that started the run has been interrupted, not the main thread of the executing code + cancel(executionThread, app, log) + } + finally app.cleanUp() + } + /** Interrupt all threads and indicate failure in the exit code. */ + private[this] def cancel(executionThread: Thread, app: App, log: Logger): Int = + { + log.warn("Run canceled.") + executionThread.interrupt() + stopAllThreads(app) + 1 + } + + /** Wait for all non-daemon threads for `app` to exit, for an exception to be thrown in the main thread, + * or for `System.exit` to be called in a thread started by `app`. */ + private[this] def finish(app: App, log: Logger): Int = + { + log.debug("Waiting for threads to exit or System.exit to be called.") + waitForExit(app) + log.debug("Interrupting remaining threads (should be all daemons).") + stopAllThreads(app) // should only be daemon threads left now + log.debug("Sandboxed run complete..") + app.exitCode.value.getOrElse(0) + } + + // wait for all non-daemon threads to terminate + private[this] def waitForExit(app: App) + { + var daemonsOnly = true + app.processThreads { thread => + // check isAlive because calling `join` on a thread that hasn't started returns immediately + // and App will only remove threads that have terminated, which will make this method loop continuously + // if a thread is created but not started + if(thread.isAlive && !thread.isDaemon) + { + daemonsOnly = false + waitOnThread(thread, app.log) + } + } + // processThreads takes a snapshot of the threads at a given moment, so if there were only daemons, the application should shut down + if(!daemonsOnly) + waitForExit(app) + } + + /** Represents an isolated application as simulated by [[TrapExit]]. + * `starterThread` is the user thread that called TrapExit and not the main application thread. + * `execute` is the application code to evalute. + * `log` is used for debug logging. */ + private final class App(val starterThread: ThreadID, val execute: xsbti.F0[Unit], val log: Logger) extends Runnable + { + val exitCode = new ExitCode + /** Tracks threads created by this application. To avoid leaks, keys are a unique identifier and values are held via WeakReference. + * A TrieMap supports the necessary concurrent updates and snapshots.*/ + private[this] val threads = new TrieMap[ThreadID, WeakReference[Thread]] + + def run() { + try execute() + catch + { + case x: Throwable => + exitCode.set(1) //exceptions in the main thread cause the exit code to be 1 + throw x + } + } + + /** Records a new thread both in the global [[TrapExit]] manager and for this [[App]]. + * Its uncaught exception handler is configured to log exceptions through `log`. */ + def register(t: Thread, threadID: ThreadID): Unit = if(!isDone(t)) { + threadToApp.put(threadID, this) + threads.put(threadID, new WeakReference(t)) + t.setUncaughtExceptionHandler(new LoggingExceptionHandler(log)) + } + /** Removes a thread from this [[App]] and the global [[TrapExit]] manager. */ + private[this] def unregister(id: ThreadID): Unit = { + threadToApp.remove(id) + threads.remove(id) + } + /** Final cleanup for this application after it has terminated. */ + def cleanUp(): Unit = { + val snap = threads.readOnlySnapshot + threads.clear() + for( (id, _) <- snap) + unregister(id) + threadToApp.remove(starterThread) + } + + // only want to operate on unterminated threads + // want to drop terminated threads, including those that have been gc'd + /** Evaluates `f` on each `Thread` started by this [[App]] at single instant shortly after this method is called. */ + def processThreads(f: Thread => Unit) { + for((id, tref) <- threads.readOnlySnapshot) { + val t = tref.get + if( (t eq null) || isDone(t)) + unregister(id) + else + f(t) + if(isDone(t)) + unregister(id) + } + } + } + /** Constructs a new application for `f` that will use `log` for debug logging.*/ + private[this] def newApp(f: xsbti.F0[Unit], log: Logger): App = + { + val threadID = computeID(currentThread) + val a = new App(threadID, f, log) + threadToApp.put(threadID, a) + a + } + + private[this] def stopAllThreads(app: App) + { + disposeAllFrames(app) + interruptAllThreads(app) + } + + private[this] def interruptAllThreads(app: App): Unit = + app processThreads { t => if(!isSystemThread(t)) safeInterrupt(t, app.log) else println(s"Not interrupting system thread $t") } + + /** Records a thread if it is not already associated with an application. */ + private[this] def recordThread(t: Thread, threadID: ThreadID) + { + val callerID = computeID(Thread.currentThread) + val app = threadToApp.get(callerID) + if(app ne null) + app.register(t, threadID) + } + + private[this] def getApp(t: Thread): Option[App] = + Option(threadToApp.get(computeID(t))) + + /** Handles a valid call to `System.exit` by setting the exit code and + * interrupting remaining threads for the application associated with `t`, if one exists. */ + private[this] def exitApp(t: Thread, status: Int): Unit = getApp(t) match { + case None => System.err.println(s"Could not exit($status): no application associated with $t") + case Some(a) => + a.exitCode.set(status) + stopAllThreads(a) + } + + /** SecurityManager hook to trap calls to `System.exit` to avoid shutting down the whole JVM.*/ + override def checkExit(status: Int): Unit = if(active) { + val t = currentThread + val stack = t.getStackTrace + if(stack == null || stack.exists(isRealExit)) { + exitApp(t, status) + throw new TrapExitSecurityException(status) + } + } + /** This ensures that only actual calls to exit are trapped and not just calls to check if exit is allowed.*/ + private def isRealExit(element: StackTraceElement): Boolean = + element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit" + + override def checkPermission(perm: Permission) + { + if(delegateManager ne null) + delegateManager.checkPermission(perm) + } + override def checkPermission(perm: Permission, context: AnyRef) + { + if(delegateManager ne null) + delegateManager.checkPermission(perm, context) + } + + /** SecurityManager hook that is abused to record every created Thread and associate it with a managed application. */ + override def checkAccess(t: Thread) { + if(active) { + val id = computeID(t) + if(threadToApp.get(id) eq null) + recordThread(t, id) + } + if(delegateManager ne null) + delegateManager.checkAccess(t) + } + /** `true` if there is at least one application currently being managed. */ + private[this] def active = running.get > 0 + + private def disposeAllFrames(app: App) // TODO: allow multiple graphical applications + { + val allFrames = java.awt.Frame.getFrames + if(allFrames.length > 0) + { + app.log.debug(s"Disposing ${allFrames.length} top-level windows...") + allFrames.foreach(_.dispose) // dispose all top-level windows, which will cause the AWT-EventQueue-* threads to exit + val waitSeconds = 2 + app.log.debug(s"Waiting $waitSeconds s to let AWT thread exit.") + Thread.sleep(waitSeconds * 1000) // AWT Thread doesn't exit immediately, so wait to interrupt it + } + } + /** Returns true if the given thread is in the 'system' thread group and is an AWT thread other than AWT-EventQueue.*/ private def isSystemThread(t: Thread) = { val name = t.getName if(name.startsWith("AWT-")) - !(name.startsWith("AWT-EventQueue") || name.startsWith("AWT-Shutdown")) + !name.startsWith("AWT-EventQueue") else { val group = t.getThreadGroup (group != null) && (group.getName == "system") } } - /** Calls the provided function for each thread in the system as provided by the - * allThreads function except those in ignoreThreads.*/ - private def processThreads(ignoreThreads: Set[Thread], process: Thread => Unit) - { - allThreads.filter(thread => !ignoreThreads.contains(thread)).foreach(process) - } - /** Handles System.exit by disposing all frames and calling interrupt on all user threads */ - private def stopAll(originalThreads: Set[Thread]) - { - disposeAllFrames() - interruptAll(originalThreads) - } - private def disposeAllFrames() - { - val allFrames = java.awt.Frame.getFrames - if(allFrames.length > 0) - { - allFrames.foreach(_.dispose) // dispose all top-level windows, which will cause the AWT-EventQueue-* threads to exit - Thread.sleep(2000) // AWT Thread doesn't exit immediately, so wait to interrupt it - } - } - // interrupt all threads that appear to have been started by the user - private def interruptAll(originalThreads: Set[Thread]): Unit = - processThreads(originalThreads, safeInterrupt) - // interrupts the given thread, but first replaces the exception handler so that the InterruptedException is not printed - private def safeInterrupt(thread: Thread) - { - if(!thread.getName.startsWith("AWT-")) - { - thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler)) - thread.interrupt - } - } - // an uncaught exception handler that swallows InterruptedExceptions and otherwise defers to originalHandler - private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler - { - def uncaughtException(thread: Thread, e: Throwable) - { - withCause[InterruptedException, Unit](e) - {interrupted => ()} - {other => originalHandler.uncaughtException(thread, e) } - thread.setUncaughtExceptionHandler(originalHandler) - } - } - /** An uncaught exception handler that delegates to the original uncaught exception handler except when - * the cause was a call to System.exit (which generated a SecurityException)*/ - private final class ExitHandler(originalThreads: Set[Thread], codeHolder: ExitCode, log: Logger) extends Thread.UncaughtExceptionHandler - { - def uncaughtException(t: Thread, e: Throwable) - { - try - { - codeHolder.set(exitCode(e)) // will rethrow e if it was not because of a call to System.exit - stopAll(originalThreads) - } - catch - { - case _ => - log.error("(" + t.getName + ") " + e.toString) - log.trace(e) - } - } - } - private final class ExitThreadGroup(handler: Thread.UncaughtExceptionHandler) extends ThreadGroup("trap.exit") - { - override def uncaughtException(t: Thread, e: Throwable) = handler.uncaughtException(t, e) - } } -private final class ExitCode extends NotNull + +/** A thread-safe, write-once, optional cell for tracking an application's exit code.*/ +private final class ExitCode { private var code: Option[Int] = None def set(c: Int): Unit = synchronized { code = code orElse Some(c) } def value: Option[Int] = synchronized { code } } -/////// These two classes are based on similar classes in Nailgun -/** A custom SecurityManager to disallow System.exit. */ -private final class TrapExitSecurityManager(delegateManager: SecurityManager, group: ThreadGroup) extends SecurityManager + +/** The default uncaught exception handler for managed executions. +* It logs the thread and the exception. */ +private final class LoggingExceptionHandler(log: Logger) extends Thread.UncaughtExceptionHandler { - import java.security.Permission - override def checkExit(status: Int) + def uncaughtException(t: Thread, e: Throwable) { - val stack = Thread.currentThread.getStackTrace - if(stack == null || stack.exists(isRealExit)) - throw new TrapExitSecurityException(status) + log.error("(" + t.getName + ") " + e.toString) + log.trace(e) } - /** This ensures that only actual calls to exit are trapped and not just calls to check if exit is allowed.*/ - private def isRealExit(element: StackTraceElement): Boolean = - element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit" - override def checkPermission(perm: Permission) - { - if(delegateManager != null) - delegateManager.checkPermission(perm) - } - override def checkPermission(perm: Permission, context: AnyRef) - { - if(delegateManager != null) - delegateManager.checkPermission(perm, context) - } - override def getThreadGroup = group } -/** A custom SecurityException that tries not to be caught.*/ -private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException -{ - private var accessAllowed = false - def allowAccess() - { - accessAllowed = true - } - override def printStackTrace = ifAccessAllowed(super.printStackTrace) - override def toString = ifAccessAllowed(super.toString) - override def getCause = ifAccessAllowed(super.getCause) - override def getMessage = ifAccessAllowed(super.getMessage) - override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace) - override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage) - private def ifAccessAllowed[T](f: => T): T = - { - if(accessAllowed) - f - else - throw this - } -} \ No newline at end of file diff --git a/run/src/main/scala/sbt/TrapExitSecurityException.scala b/run/src/main/scala/sbt/TrapExitSecurityException.scala new file mode 100644 index 000000000..ad16e9988 --- /dev/null +++ b/run/src/main/scala/sbt/TrapExitSecurityException.scala @@ -0,0 +1,26 @@ +package sbt + +/** A custom SecurityException that tries not to be caught. Closely based on a similar class in Nailgun. +* The main goal of this exception is that once thrown, it propagates all of the way up the call stack, +* terminating the thread's execution. */ +private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException +{ + private var accessAllowed = false + def allowAccess() + { + accessAllowed = true + } + override def printStackTrace = ifAccessAllowed(super.printStackTrace) + override def toString = ifAccessAllowed(super.toString) + override def getCause = ifAccessAllowed(super.getCause) + override def getMessage = ifAccessAllowed(super.getMessage) + override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace) + override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage) + private def ifAccessAllowed[T](f: => T): T = + { + if(accessAllowed) + f + else + throw this + } +} diff --git a/sbt/src/sbt-test/run/concurrent/build.sbt b/sbt/src/sbt-test/run/concurrent/build.sbt new file mode 100644 index 000000000..113eb15fd --- /dev/null +++ b/sbt/src/sbt-test/run/concurrent/build.sbt @@ -0,0 +1,40 @@ +lazy val runTest = taskKey[Unit]("Run the test applications.") + +def runTestTask(pre: Def.Initialize[Task[Unit]]) = + runTest := { + val _ = pre.value + val r = (runner in (Compile, run)).value + val cp = (fullClasspath in Compile).value + val main = (mainClass in Compile).value getOrElse error("No main class found") + val args = baseDirectory.value.getAbsolutePath :: Nil + r.run(main, cp.files, args, streams.value.log) foreach error + } + +lazy val b = project.settings( + runTestTask( waitForCStart ), + runTest := { + val _ = runTest.value + val cFinished = (baseDirectory in c).value / "finished" + assert( !cFinished.exists, "C finished before B") + IO.touch(baseDirectory.value / "finished") + } +) + +lazy val c = project.settings( runTestTask( Def.task() ) ) + +// need at least 2 concurrently executing tasks to proceed +concurrentRestrictions in Global := Seq( + Tags.limitAll(math.max(EvaluateTask.SystemProcessors, 2) ) +) + +def waitForCStart = + Def.task { + waitFor( (baseDirectory in c).value / "started" ) + } + +def waitFor(f: File) { + if(!f.exists) { + Thread.sleep(300) + waitFor(f) + } +} diff --git a/sbt/src/sbt-test/run/concurrent/changes/B.scala b/sbt/src/sbt-test/run/concurrent/changes/B.scala new file mode 100644 index 000000000..8527ab04f --- /dev/null +++ b/sbt/src/sbt-test/run/concurrent/changes/B.scala @@ -0,0 +1,7 @@ +import java.io.File + +object B { + def main(args: Array[String]) { + Thread.sleep(1000) + } +} \ No newline at end of file diff --git a/sbt/src/sbt-test/run/concurrent/changes/C.scala b/sbt/src/sbt-test/run/concurrent/changes/C.scala new file mode 100644 index 000000000..54af861f6 --- /dev/null +++ b/sbt/src/sbt-test/run/concurrent/changes/C.scala @@ -0,0 +1,24 @@ +import java.io.File + +object C { + def main(args: Array[String]) { + val base = new File(args(0)) + create(new File(base, "started")) + val bFin = new File(base, "../b/finished") + waitFor(bFin) + create(new File(base, "finished")) + } + + def create(f: File) { + val fabs = f.getAbsoluteFile + fabs.getParentFile.mkdirs + fabs.createNewFile + } + + def waitFor(f: File) { + if(!f.exists) { + Thread.sleep(300) + waitFor(f) + } + } +} \ No newline at end of file diff --git a/sbt/src/sbt-test/run/concurrent/test b/sbt/src/sbt-test/run/concurrent/test new file mode 100644 index 000000000..88d32319e --- /dev/null +++ b/sbt/src/sbt-test/run/concurrent/test @@ -0,0 +1,16 @@ +# Tests that TrapExit properly handles multiple `run` executions + +# The goal is to start run B, then run C, exit run B, and finally, exit run C. +# By interleaving them, we test that TrapExit properly tracks the threads +# created by each and doesn't interrupt/wait for threads from C when B is done. + +# b/run waits for c/started before starting the B application +# c/run has no dependencies +# c/run creates c/started on startup and then waits for b/finished to be created +# this allows b/run to start the B application, which then exits +# the b/run task creates b/finished when the B application finishes and verifies c/finished doesn't exist +# c/run then exits, creating c/finished + +$ copy-file changes/B.scala b/B.scala +$ copy-file changes/C.scala c/C.scala +> runTest \ No newline at end of file diff --git a/util/log/src/main/scala/sbt/Logger.scala b/util/log/src/main/scala/sbt/Logger.scala index 29d965e76..ce8201e9c 100644 --- a/util/log/src/main/scala/sbt/Logger.scala +++ b/util/log/src/main/scala/sbt/Logger.scala @@ -57,7 +57,11 @@ object Logger implicit def absLog2PLog(log: AbstractLogger): ProcessLogger = new BufferedLogger(log) with ProcessLogger implicit def log2PLog(log: Logger): ProcessLogger = absLog2PLog(new FullLogger(log)) - implicit def xlog2Log(lg: xLogger): Logger = new Logger { + implicit def xlog2Log(lg: xLogger): Logger = lg match { + case l: Logger => l + case _ => wrapXLogger(lg) + } + private[this] def wrapXLogger(lg: xLogger): Logger = new Logger { override def debug(msg: F0[String]): Unit = lg.debug(msg) override def warn(msg: F0[String]): Unit = lg.warn(msg) override def info(msg: F0[String]): Unit = lg.info(msg)