diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala index fd6d45778..1e251d1cc 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -13,8 +13,9 @@ import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import java.util.concurrent.Executors import sbt.State -import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } -import sbt.internal.util.Prompt +import scala.concurrent.duration._ +import sbt.internal.util.JoinThread._ +import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Prompt } private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { private[this] val uiThread = new AtomicReference[(UITask, Thread)] @@ -65,14 +66,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable case null => case (t, thread) => t.close() - Util.ignoreResult(thread.interrupt()) - try thread.join(1000) - catch { case _: InterruptedException => } - - // This join should always work, but if it doesn't log an error because - // it can cause problems if the thread isn't joined - if (thread.isAlive) System.err.println(s"Unable to join thread $thread") - () + thread.joinFor(1.second) } } private[sbt] def stopThread(): Unit = uiThread.synchronized(stopThreadImpl()) diff --git a/main-command/src/main/scala/sbt/internal/util/JoinThread.scala b/main-command/src/main/scala/sbt/internal/util/JoinThread.scala new file mode 100644 index 000000000..e2bee2f07 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/util/JoinThread.scala @@ -0,0 +1,30 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.util + +import scala.annotation.tailrec +import scala.concurrent.duration._ + +object JoinThread { + implicit class ThreadOps(val t: Thread) extends AnyVal { + def joinFor(duration: FiniteDuration): Unit = { + val deadline = duration.fromNow + var exception: Option[InterruptedException] = None + @tailrec def impl(): Unit = { + try { + t.interrupt() + t.join(10) + } catch { case e: InterruptedException => exception = Some(e) } + if (t.isAlive) impl() + } + impl() + if (t.isAlive) System.err.println(s"Unable to join thread $t after $duration") + exception.foreach(throw _) + } + } +} diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index be0c1cea1..e7ec8c644 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -21,6 +21,7 @@ import sbt.internal.LabeledFunctions._ import sbt.internal.io.WatchState import sbt.internal.nio._ import sbt.internal.ui.UITask +import sbt.internal.util.JoinThread._ import sbt.internal.util.complete.DefaultParsers.{ Space, matched } import sbt.internal.util.complete.Parser._ import sbt.internal.util.complete.{ Parser, Parsers } @@ -666,15 +667,13 @@ private[sbt] object Continuous extends DeprecatedContinuous { } def removeThread(thread: Thread): Unit = Util.ignoreResult(threads.remove(thread)) def close(): Unit = if (closed.compareAndSet(false, true)) { - threads.forEach { t => - val deadline = 1.second.fromNow - while (t.isAlive && !deadline.isOverdue) { - t.interrupt() - t.join(10) - } - if (t.isAlive) System.err.println(s"Couldn't join watch thread $t") + var exception: Option[InterruptedException] = None + threads.forEach { thread => + try thread.joinFor(1.second) + catch { case e: InterruptedException => exception = Some(e) } } threads.clear() + exception.foreach(throw _) } } private object WatchExecutor { @@ -692,8 +691,8 @@ private[sbt] object Continuous extends DeprecatedContinuous { executor: WatchExecutor ) extends Future[R] { def cancel(): Unit = { - thread.interrupt() executor.removeThread(thread) + thread.joinFor(1.second) } def result: Try[R] = try queue.take match {