From 329baf4b0bac2d6f1d3089da3a988faaa9e8b064 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 19 Aug 2020 09:34:04 -0700 Subject: [PATCH] Use more aggressive strategy to join ui threads There can be race conditions where we try to interrupt and join a ui thread before it becomes interruptible by blockign on a queue. To workaround this, we can add the JoinThread class which adds an extension method Thread.joinFor that takes a FiniteDuration parameter. This variant of join will repeatedly interrupt and attempt to join the thread for up to 10 milliseconds before retrying until the limit is reached. If the limit is reached, we print a noisy error to the console. I'm not 100% sure if we are leaking threads in the latest sbt version but this gives me more piece of mind that either we are always successfully joining the threads or we will be alerted if the joining fails. --- .../scala/sbt/internal/ui/UserThread.scala | 14 +++------ .../scala/sbt/internal/util/JoinThread.scala | 30 +++++++++++++++++++ .../main/scala/sbt/internal/Continuous.scala | 15 +++++----- 3 files changed, 41 insertions(+), 18 deletions(-) create mode 100644 main-command/src/main/scala/sbt/internal/util/JoinThread.scala diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala index fd6d45778..1e251d1cc 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -13,8 +13,9 @@ import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import java.util.concurrent.Executors import sbt.State -import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } -import sbt.internal.util.Prompt +import scala.concurrent.duration._ +import sbt.internal.util.JoinThread._ +import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Prompt } private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { private[this] val uiThread = new AtomicReference[(UITask, Thread)] @@ -65,14 +66,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable case null => case (t, thread) => t.close() - Util.ignoreResult(thread.interrupt()) - try thread.join(1000) - catch { case _: InterruptedException => } - - // This join should always work, but if it doesn't log an error because - // it can cause problems if the thread isn't joined - if (thread.isAlive) System.err.println(s"Unable to join thread $thread") - () + thread.joinFor(1.second) } } private[sbt] def stopThread(): Unit = uiThread.synchronized(stopThreadImpl()) diff --git a/main-command/src/main/scala/sbt/internal/util/JoinThread.scala b/main-command/src/main/scala/sbt/internal/util/JoinThread.scala new file mode 100644 index 000000000..e2bee2f07 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/util/JoinThread.scala @@ -0,0 +1,30 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.util + +import scala.annotation.tailrec +import scala.concurrent.duration._ + +object JoinThread { + implicit class ThreadOps(val t: Thread) extends AnyVal { + def joinFor(duration: FiniteDuration): Unit = { + val deadline = duration.fromNow + var exception: Option[InterruptedException] = None + @tailrec def impl(): Unit = { + try { + t.interrupt() + t.join(10) + } catch { case e: InterruptedException => exception = Some(e) } + if (t.isAlive) impl() + } + impl() + if (t.isAlive) System.err.println(s"Unable to join thread $t after $duration") + exception.foreach(throw _) + } + } +} diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index be0c1cea1..e7ec8c644 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -21,6 +21,7 @@ import sbt.internal.LabeledFunctions._ import sbt.internal.io.WatchState import sbt.internal.nio._ import sbt.internal.ui.UITask +import sbt.internal.util.JoinThread._ import sbt.internal.util.complete.DefaultParsers.{ Space, matched } import sbt.internal.util.complete.Parser._ import sbt.internal.util.complete.{ Parser, Parsers } @@ -666,15 +667,13 @@ private[sbt] object Continuous extends DeprecatedContinuous { } def removeThread(thread: Thread): Unit = Util.ignoreResult(threads.remove(thread)) def close(): Unit = if (closed.compareAndSet(false, true)) { - threads.forEach { t => - val deadline = 1.second.fromNow - while (t.isAlive && !deadline.isOverdue) { - t.interrupt() - t.join(10) - } - if (t.isAlive) System.err.println(s"Couldn't join watch thread $t") + var exception: Option[InterruptedException] = None + threads.forEach { thread => + try thread.joinFor(1.second) + catch { case e: InterruptedException => exception = Some(e) } } threads.clear() + exception.foreach(throw _) } } private object WatchExecutor { @@ -692,8 +691,8 @@ private[sbt] object Continuous extends DeprecatedContinuous { executor: WatchExecutor ) extends Future[R] { def cancel(): Unit = { - thread.interrupt() executor.removeThread(thread) + thread.joinFor(1.second) } def result: Try[R] = try queue.take match {