Use more aggressive strategy to join ui threads

There can be race conditions where we try to interrupt and join a ui
thread before it becomes interruptible by blockign on a queue. To
workaround this, we can add the JoinThread class which adds an
extension method Thread.joinFor that takes a FiniteDuration parameter.
This variant of join will repeatedly interrupt and attempt to join the
thread for up to 10 milliseconds before retrying until the limit is
reached. If the limit is reached, we print a noisy error to the console.

I'm not 100% sure if we are leaking threads in the latest sbt version
but this gives me more piece of mind that either we are always
successfully joining the threads or we will be alerted if the joining
fails.
This commit is contained in:
Ethan Atkins 2020-08-19 09:34:04 -07:00
parent b8dac52338
commit 329baf4b0b
3 changed files with 41 additions and 18 deletions

View File

@ -13,8 +13,9 @@ import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.Executors
import sbt.State
import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util }
import sbt.internal.util.Prompt
import scala.concurrent.duration._
import sbt.internal.util.JoinThread._
import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Prompt }
private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable {
private[this] val uiThread = new AtomicReference[(UITask, Thread)]
@ -65,14 +66,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable
case null =>
case (t, thread) =>
t.close()
Util.ignoreResult(thread.interrupt())
try thread.join(1000)
catch { case _: InterruptedException => }
// This join should always work, but if it doesn't log an error because
// it can cause problems if the thread isn't joined
if (thread.isAlive) System.err.println(s"Unable to join thread $thread")
()
thread.joinFor(1.second)
}
}
private[sbt] def stopThread(): Unit = uiThread.synchronized(stopThreadImpl())

View File

@ -0,0 +1,30 @@
/*
* sbt
* Copyright 2011 - 2018, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.util
import scala.annotation.tailrec
import scala.concurrent.duration._
object JoinThread {
implicit class ThreadOps(val t: Thread) extends AnyVal {
def joinFor(duration: FiniteDuration): Unit = {
val deadline = duration.fromNow
var exception: Option[InterruptedException] = None
@tailrec def impl(): Unit = {
try {
t.interrupt()
t.join(10)
} catch { case e: InterruptedException => exception = Some(e) }
if (t.isAlive) impl()
}
impl()
if (t.isAlive) System.err.println(s"Unable to join thread $t after $duration")
exception.foreach(throw _)
}
}
}

View File

@ -21,6 +21,7 @@ import sbt.internal.LabeledFunctions._
import sbt.internal.io.WatchState
import sbt.internal.nio._
import sbt.internal.ui.UITask
import sbt.internal.util.JoinThread._
import sbt.internal.util.complete.DefaultParsers.{ Space, matched }
import sbt.internal.util.complete.Parser._
import sbt.internal.util.complete.{ Parser, Parsers }
@ -666,15 +667,13 @@ private[sbt] object Continuous extends DeprecatedContinuous {
}
def removeThread(thread: Thread): Unit = Util.ignoreResult(threads.remove(thread))
def close(): Unit = if (closed.compareAndSet(false, true)) {
threads.forEach { t =>
val deadline = 1.second.fromNow
while (t.isAlive && !deadline.isOverdue) {
t.interrupt()
t.join(10)
}
if (t.isAlive) System.err.println(s"Couldn't join watch thread $t")
var exception: Option[InterruptedException] = None
threads.forEach { thread =>
try thread.joinFor(1.second)
catch { case e: InterruptedException => exception = Some(e) }
}
threads.clear()
exception.foreach(throw _)
}
}
private object WatchExecutor {
@ -692,8 +691,8 @@ private[sbt] object Continuous extends DeprecatedContinuous {
executor: WatchExecutor
) extends Future[R] {
def cancel(): Unit = {
thread.interrupt()
executor.removeThread(thread)
thread.joinFor(1.second)
}
def result: Try[R] =
try queue.take match {