mirror of https://github.com/sbt/sbt.git
246 lines
9.2 KiB
Scala
246 lines
9.2 KiB
Scala
/* sbt -- Simple Build Tool
|
|
* Copyright 2008 Mark Harrah
|
|
*
|
|
* Partially based on exit trapping in Nailgun by Pete Kirkham,
|
|
* copyright 2004, Martian Software, Inc
|
|
* licensed under Apache 2.0 License.
|
|
*/
|
|
package sbt
|
|
|
|
import scala.collection.Set
|
|
import scala.reflect.Manifest
|
|
|
|
/** This provides functionality to catch System.exit calls to prevent the JVM from terminating.
|
|
* This is useful for executing user code that may call System.exit, but actually exiting is
|
|
* undesirable. This file handles the call to exit by disposing all top-level windows and interrupting
|
|
* all user started threads. It does not stop the threads and does not call shutdown hooks. It is
|
|
* therefore inappropriate to use this with code that requires shutdown hooks or creates threads that
|
|
* do not terminate. This category of code should only be called by forking the JVM. */
|
|
object TrapExit
|
|
{
|
|
/** Executes the given thunk in a context where System.exit(code) throws
|
|
* a custom SecurityException, which is then caught and the exit code returned.
|
|
* Otherwise, 0 is returned. No other exceptions are handled by this method.*/
|
|
def apply(execute: => Unit, log: Logger): Int =
|
|
{
|
|
log.debug("Starting sandboxed run...")
|
|
|
|
/** Take a snapshot of the threads that existed before execution in order to determine
|
|
* the threads that were created by 'execute'.*/
|
|
val originalThreads = allThreads
|
|
val code = new ExitCode
|
|
def executeMain =
|
|
try { execute }
|
|
catch
|
|
{
|
|
case e: TrapExitSecurityException => throw e
|
|
case x =>
|
|
code.set(1) //exceptions in the main thread cause the exit code to be 1
|
|
throw x
|
|
}
|
|
val customThreadGroup = new ExitThreadGroup(new ExitHandler(Thread.getDefaultUncaughtExceptionHandler, originalThreads, code, log))
|
|
val executionThread = new Thread(customThreadGroup, "run-main") { override def run() { executeMain } }
|
|
|
|
val originalSecurityManager = System.getSecurityManager
|
|
try
|
|
{
|
|
val newSecurityManager = new TrapExitSecurityManager(originalSecurityManager, customThreadGroup)
|
|
System.setSecurityManager(newSecurityManager)
|
|
|
|
executionThread.start()
|
|
|
|
log.debug("Waiting for threads to exit or System.exit to be called.")
|
|
waitForExit(originalThreads, log)
|
|
log.debug("Interrupting remaining threads (should be all daemons).")
|
|
interruptAll(originalThreads) // should only be daemon threads left now
|
|
log.debug("Sandboxed run complete..")
|
|
code.value.getOrElse(0)
|
|
}
|
|
finally { System.setSecurityManager(originalSecurityManager) }
|
|
}
|
|
// wait for all non-daemon threads to terminate
|
|
private def waitForExit(originalThreads: Set[Thread], log: Logger)
|
|
{
|
|
var daemonsOnly = true
|
|
processThreads(originalThreads, thread =>
|
|
if(!thread.isDaemon)
|
|
{
|
|
daemonsOnly = false
|
|
waitOnThread(thread, log)
|
|
}
|
|
)
|
|
if(!daemonsOnly)
|
|
waitForExit(originalThreads, log)
|
|
}
|
|
/** Waits for the given thread to exit. */
|
|
private def waitOnThread(thread: Thread, log: Logger)
|
|
{
|
|
log.debug("Waiting for thread " + thread.getName + " to exit")
|
|
thread.join
|
|
log.debug("\tThread " + thread.getName + " exited.")
|
|
}
|
|
/** Returns the exit code of the System.exit that caused the given Exception, or rethrows the exception
|
|
* if its cause was not calling System.exit.*/
|
|
private def exitCode(e: Throwable) =
|
|
withCause[TrapExitSecurityException, Int](e)
|
|
{exited => exited.exitCode}
|
|
{other => throw other}
|
|
/** Recurses into the causes of the given exception looking for a cause of type CauseType. If one is found, `withType` is called with that cause.
|
|
* If not, `notType` is called with the root cause.*/
|
|
private def withCause[CauseType <: Throwable, T](e: Throwable)(withType: CauseType => T)(notType: Throwable => T)(implicit mf: Manifest[CauseType]): T =
|
|
{
|
|
val clazz = mf.erasure
|
|
if(clazz.isInstance(e))
|
|
withType(e.asInstanceOf[CauseType])
|
|
else
|
|
{
|
|
val cause = e.getCause
|
|
if(cause == null)
|
|
notType(e)
|
|
else
|
|
withCause(cause)(withType)(notType)(mf)
|
|
}
|
|
}
|
|
|
|
/** Returns all threads that are not in the 'system' thread group and are not the AWT implementation
|
|
* thread (AWT-XAWT, AWT-Windows, ...)*/
|
|
private def allThreads: Set[Thread] =
|
|
{
|
|
import collection.JavaConversions._
|
|
Thread.getAllStackTraces.keySet.filter(thread => !isSystemThread(thread))
|
|
}
|
|
/** Returns true if the given thread is in the 'system' thread group and is an AWT thread other than
|
|
* AWT-EventQueue or AWT-Shutdown.*/
|
|
private def isSystemThread(t: Thread) =
|
|
{
|
|
val name = t.getName
|
|
if(name.startsWith("AWT-"))
|
|
!(name.startsWith("AWT-EventQueue") || name.startsWith("AWT-Shutdown"))
|
|
else
|
|
{
|
|
val group = t.getThreadGroup
|
|
(group != null) && (group.getName == "system")
|
|
}
|
|
}
|
|
/** Calls the provided function for each thread in the system as provided by the
|
|
* allThreads function except those in ignoreThreads.*/
|
|
private def processThreads(ignoreThreads: Set[Thread], process: Thread => Unit)
|
|
{
|
|
allThreads.filter(thread => !ignoreThreads.contains(thread)).foreach(process)
|
|
}
|
|
/** Handles System.exit by disposing all frames and calling interrupt on all user threads */
|
|
private def stopAll(originalThreads: Set[Thread])
|
|
{
|
|
disposeAllFrames()
|
|
interruptAll(originalThreads)
|
|
}
|
|
private def disposeAllFrames()
|
|
{
|
|
val allFrames = java.awt.Frame.getFrames
|
|
if(allFrames.length > 0)
|
|
{
|
|
allFrames.foreach(_.dispose) // dispose all top-level windows, which will cause the AWT-EventQueue-* threads to exit
|
|
Thread.sleep(2000) // AWT Thread doesn't exit immediately, so wait to interrupt it
|
|
}
|
|
}
|
|
// interrupt all threads that appear to have been started by the user
|
|
private def interruptAll(originalThreads: Set[Thread]): Unit =
|
|
processThreads(originalThreads, safeInterrupt)
|
|
// interrupts the given thread, but first replaces the exception handler so that the InterruptedException is not printed
|
|
private def safeInterrupt(thread: Thread)
|
|
{
|
|
if(!thread.getName.startsWith("AWT-"))
|
|
{
|
|
thread.setUncaughtExceptionHandler(new TrapInterrupt(thread.getUncaughtExceptionHandler))
|
|
thread.interrupt
|
|
}
|
|
}
|
|
// an uncaught exception handler that swallows InterruptedExceptions and otherwise defers to originalHandler
|
|
private final class TrapInterrupt(originalHandler: Thread.UncaughtExceptionHandler) extends Thread.UncaughtExceptionHandler
|
|
{
|
|
def uncaughtException(thread: Thread, e: Throwable)
|
|
{
|
|
withCause[InterruptedException, Unit](e)
|
|
{interrupted => ()}
|
|
{other => originalHandler.uncaughtException(thread, e) }
|
|
thread.setUncaughtExceptionHandler(originalHandler)
|
|
}
|
|
}
|
|
/** An uncaught exception handler that delegates to the original uncaught exception handler except when
|
|
* the cause was a call to System.exit (which generated a SecurityException)*/
|
|
private final class ExitHandler(originalHandler: Thread.UncaughtExceptionHandler, originalThreads: Set[Thread], codeHolder: ExitCode, log: Logger) extends Thread.UncaughtExceptionHandler
|
|
{
|
|
def uncaughtException(t: Thread, e: Throwable)
|
|
{
|
|
try
|
|
{
|
|
codeHolder.set(exitCode(e)) // will rethrow e if it was not because of a call to System.exit
|
|
stopAll(originalThreads)
|
|
}
|
|
catch
|
|
{
|
|
case _ =>
|
|
log.trace(e)
|
|
originalHandler.uncaughtException(t, e)
|
|
}
|
|
}
|
|
}
|
|
private final class ExitThreadGroup(handler: Thread.UncaughtExceptionHandler) extends ThreadGroup("trap.exit")
|
|
{
|
|
override def uncaughtException(t: Thread, e: Throwable) = handler.uncaughtException(t, e)
|
|
}
|
|
}
|
|
private final class ExitCode extends NotNull
|
|
{
|
|
private var code: Option[Int] = None
|
|
def set(c: Int): Unit = synchronized { code = code orElse Some(c) }
|
|
def value: Option[Int] = synchronized { code }
|
|
}
|
|
/////// These two classes are based on similar classes in Nailgun
|
|
/** A custom SecurityManager to disallow System.exit. */
|
|
private final class TrapExitSecurityManager(delegateManager: SecurityManager, group: ThreadGroup) extends SecurityManager
|
|
{
|
|
import java.security.Permission
|
|
override def checkExit(status: Int)
|
|
{
|
|
val stack = Thread.currentThread.getStackTrace
|
|
if(stack == null || stack.exists(isRealExit))
|
|
throw new TrapExitSecurityException(status)
|
|
}
|
|
/** This ensures that only actual calls to exit are trapped and not just calls to check if exit is allowed.*/
|
|
private def isRealExit(element: StackTraceElement): Boolean =
|
|
element.getClassName == "java.lang.Runtime" && element.getMethodName == "exit"
|
|
override def checkPermission(perm: Permission)
|
|
{
|
|
if(delegateManager != null)
|
|
delegateManager.checkPermission(perm)
|
|
}
|
|
override def checkPermission(perm: Permission, context: AnyRef)
|
|
{
|
|
if(delegateManager != null)
|
|
delegateManager.checkPermission(perm, context)
|
|
}
|
|
override def getThreadGroup = group
|
|
}
|
|
/** A custom SecurityException that tries not to be caught.*/
|
|
private final class TrapExitSecurityException(val exitCode: Int) extends SecurityException
|
|
{
|
|
private var accessAllowed = false
|
|
def allowAccess
|
|
{
|
|
accessAllowed = true
|
|
}
|
|
override def printStackTrace = ifAccessAllowed(super.printStackTrace)
|
|
override def toString = ifAccessAllowed(super.toString)
|
|
override def getCause = ifAccessAllowed(super.getCause)
|
|
override def getMessage = ifAccessAllowed(super.getMessage)
|
|
override def fillInStackTrace = ifAccessAllowed(super.fillInStackTrace)
|
|
override def getLocalizedMessage = ifAccessAllowed(super.getLocalizedMessage)
|
|
private def ifAccessAllowed[T](f: => T): T =
|
|
{
|
|
if(accessAllowed)
|
|
f
|
|
else
|
|
throw this
|
|
}
|
|
} |