diff --git a/build.sbt b/build.sbt index d7cdc2cc4..cfce92d0b 100644 --- a/build.sbt +++ b/build.sbt @@ -666,6 +666,7 @@ lazy val mainProj = (project in file("main")) exclude[DirectMissingMethodProblem]("sbt.Defaults.allTestGroupsTask"), exclude[DirectMissingMethodProblem]("sbt.Plugins.topologicalSort"), exclude[IncompatibleMethTypeProblem]("sbt.Defaults.allTestGroupsTask"), + exclude[DirectMissingMethodProblem]("sbt.StandardMain.shutdownHook") ) ) .configure( diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 32b314ed4..3d350fc64 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -98,33 +98,36 @@ final class xMain extends xsbti.AppMain { } } private[sbt] object xMainImpl { - private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = { - import BasicCommandStrings.{ DashClient, DashDashClient, runEarly } - import BasicCommands.early - import BuiltinCommands.defaults - import sbt.internal.CommandStrings.{ BootCommand, DefaultsCommand, InitCommand } - import sbt.internal.client.NetworkClient + private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = + try { + import BasicCommandStrings.{ DashClient, DashDashClient, runEarly } + import BasicCommands.early + import BuiltinCommands.defaults + import sbt.internal.CommandStrings.{ BootCommand, DefaultsCommand, InitCommand } + import sbt.internal.client.NetworkClient - // if we detect -Dsbt.client=true or -client, run thin client. - val clientModByEnv = java.lang.Boolean.getBoolean("sbt.client") - val userCommands = configuration.arguments.map(_.trim) - if (clientModByEnv || (userCommands.exists { cmd => + // if we detect -Dsbt.client=true or -client, run thin client. + val clientModByEnv = java.lang.Boolean.getBoolean("sbt.client") + val userCommands = configuration.arguments.map(_.trim) + if (clientModByEnv || (userCommands.exists { cmd => + (cmd == DashClient) || (cmd == DashDashClient) + })) { + val args = userCommands.toList filterNot { cmd => (cmd == DashClient) || (cmd == DashDashClient) - })) { - val args = userCommands.toList filterNot { cmd => - (cmd == DashClient) || (cmd == DashDashClient) + } + NetworkClient.run(configuration, args) + Exit(0) + } else { + val state = StandardMain.initialState( + configuration, + Seq(defaults, early), + runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil + ) + StandardMain.runManaged(state) } - NetworkClient.run(configuration, args) - Exit(0) - } else { - val state = StandardMain.initialState( - configuration, - Seq(defaults, early), - runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil - ) - StandardMain.runManaged(state) + } finally { + ShutdownHooks.close() } - } } final class ScriptMain extends xsbti.AppMain { @@ -155,30 +158,21 @@ object StandardMain { import scalacache.caffeine._ private[sbt] lazy val cache: scalacache.Cache[Any] = CaffeineCache[Any] - private[this] val closeRunnable: Runnable = () => { + private[this] val closeRunnable = () => { cache.close()(scalacache.modes.sync.mode) cache.close()(scalacache.modes.scalaFuture.mode(ExecutionContext.global)) exchange.shutdown() } - private[sbt] val shutdownHook = new Thread(closeRunnable) def runManaged(s: State): xsbti.MainResult = { val previous = TrapExit.installManager() try { try { - val hooked = try { - Runtime.getRuntime.addShutdownHook(shutdownHook) - true - } catch { - case _: IllegalArgumentException => false - } + val hook = ShutdownHooks.add(closeRunnable) try { MainLoop.runLogged(s) } finally { - closeRunnable.run() - if (hooked) { - Runtime.getRuntime.removeShutdownHook(shutdownHook) - } + hook.close() () } } finally DefaultBackgroundJobService.backgroundJobService.shutdown() diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index d2ea976be..e4d4aac28 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -11,6 +11,7 @@ import java.io.PrintWriter import java.util.Properties import jline.TerminalFactory +import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.util.{ ErrorHandling, GlobalLogBacking } import sbt.io.{ IO, Using } @@ -27,13 +28,12 @@ object MainLoop { // We've disabled jline shutdown hooks to prevent classloader leaks, and have been careful to always restore // the jline terminal in finally blocks, but hitting ctrl+c prevents finally blocks from being executed, in that // case the only way to restore the terminal is in a shutdown hook. - val shutdownHook = new Thread(() => TerminalFactory.get().restore()) + val shutdownHook = ShutdownHooks.add(() => TerminalFactory.get().restore()) try { - Runtime.getRuntime.addShutdownHook(shutdownHook) runLoggedLoop(state, state.globalLogging.backing) } finally { - Runtime.getRuntime.removeShutdownHook(shutdownHook) + shutdownHook.close() () } } diff --git a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala index a2c00de52..6ca12c6a7 100644 --- a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala +++ b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala @@ -49,12 +49,10 @@ private[sbt] class LayeredClassLoader( private[internal] object NativeLibs { private[this] val nativeLibs = new jutil.HashSet[File].asScala - Runtime.getRuntime.addShutdownHook(new Thread("sbt.internal.native-library-deletion") { - override def run(): Unit = { - nativeLibs.foreach(IO.delete) - IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet) - nativeLibs.clear() - } + ShutdownHooks.add(() => { + nativeLibs.foreach(IO.delete) + IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet) + nativeLibs.clear() }) def addNativeLib(lib: String): Unit = { nativeLibs.add(new File(lib)) diff --git a/main/src/main/scala/sbt/internal/ShutdownHooks.scala b/main/src/main/scala/sbt/internal/ShutdownHooks.scala new file mode 100644 index 000000000..eb29674d1 --- /dev/null +++ b/main/src/main/scala/sbt/internal/ShutdownHooks.scala @@ -0,0 +1,45 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger } + +import scala.util.control.NonFatal + +private[sbt] object ShutdownHooks extends AutoCloseable { + private[this] val idGenerator = new AtomicInteger(0) + private[this] val hooks = new ConcurrentHashMap[Int, () => Any] + private[this] val ranHooks = new AtomicBoolean(false) + private[this] val thread = new Thread("shutdown-hooks-run-all") { + override def run(): Unit = runAll() + } + private[this] val runtime = Runtime.getRuntime + runtime.addShutdownHook(thread) + private[sbt] def add[R](task: () => R): AutoCloseable = { + val id = idGenerator.getAndIncrement() + hooks.put( + id, + () => + try task() + catch { + case NonFatal(e) => + System.err.println(s"Caught exception running shutdown hook: $e") + e.printStackTrace(System.err) + } + ) + () => Option(hooks.remove(id)).foreach(_.apply()) + } + private def runAll(): Unit = if (ranHooks.compareAndSet(false, true)) { + hooks.forEachValue(Runtime.getRuntime.availableProcessors, _.apply()) + } + override def close(): Unit = { + runtime.removeShutdownHook(thread) + runAll() + } +} diff --git a/main/src/main/scala/sbt/internal/TaskTimings.scala b/main/src/main/scala/sbt/internal/TaskTimings.scala index 6b4976194..a4b0b3184 100644 --- a/main/src/main/scala/sbt/internal/TaskTimings.scala +++ b/main/src/main/scala/sbt/internal/TaskTimings.scala @@ -38,9 +38,7 @@ private[sbt] final class TaskTimings(reportOnShutdown: Boolean) if (reportOnShutdown) { start = System.nanoTime - Runtime.getRuntime.addShutdownHook(new Thread { - override def run() = report() - }) + ShutdownHooks.add(() => report()) } override def initial(): Unit = { diff --git a/main/src/main/scala/sbt/internal/TaskTraceEvent.scala b/main/src/main/scala/sbt/internal/TaskTraceEvent.scala index 63c4522ea..de8b4eede 100644 --- a/main/src/main/scala/sbt/internal/TaskTraceEvent.scala +++ b/main/src/main/scala/sbt/internal/TaskTraceEvent.scala @@ -36,9 +36,7 @@ private[sbt] final class TaskTraceEvent override def stop(): Unit = () start = System.nanoTime - Runtime.getRuntime.addShutdownHook(new Thread { - override def run() = report() - }) + ShutdownHooks.add(() => report()) private[this] def report() = { if (timings.asScala.nonEmpty) {