Close line reader when interrupted

There are cases where if the ui state is changing rapidly, that an
AskUserThread can be created and cancelled in a short time windows. This
could cause problems if the AskUserThread is interrupted during
`LineReader.createReader` which I think can shell out to run some
commands so it is relatively slow. If the thread was interrupted during
the call to `LineReader.createReader` and the interruption was not
handled, then the thread would go into `LineReader.readLine`, which
wouldn't exit until the user pressed enter. This ultimately caused the
ui to break until enter because this zombie line reader would be holding
the lock on the terminal input stream.
This commit is contained in:
Ethan Atkins 2020-07-25 12:20:45 -07:00
parent e4cd6a38fc
commit 6dd69a54ae
8 changed files with 175 additions and 113 deletions

View File

@ -10,9 +10,9 @@ package sbt.internal.util
import java.io.{ EOFException, InputStream, OutputStream, PrintWriter }
import java.nio.charset.Charset
import java.util.{ Arrays, EnumSet }
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import org.jline.utils.InfoCmp.Capability
import org.jline.utils.{ NonBlocking, OSUtils }
import org.jline.utils.{ ClosedException, NonBlockingReader, OSUtils }
import org.jline.terminal.{ Attributes, Size, Terminal => JTerminal }
import org.jline.terminal.Terminal.SignalHandler
import org.jline.terminal.impl.AbstractTerminal
@ -20,6 +20,7 @@ import org.jline.terminal.impl.jansi.JansiSupportImpl
import org.jline.terminal.impl.jansi.win.JansiWinSysTerminal
import scala.collection.JavaConverters._
import scala.util.Try
import java.util.concurrent.LinkedBlockingQueue
private[util] object JLine3 {
private val capabilityMap = Capability
@ -77,6 +78,8 @@ private[util] object JLine3 {
new AbstractTerminal(term.name, "ansi", Charset.forName("UTF-8"), SignalHandler.SIG_DFL) {
val closed = new AtomicBoolean(false)
setOnClose { () =>
doClose()
reader.close()
if (closed.compareAndSet(false, true)) {
// This is necessary to shutdown the non blocking input reader
// so that it doesn't keep blocking
@ -89,8 +92,22 @@ private[util] object JLine3 {
parseInfoCmp()
override val input: InputStream = new InputStream {
override def read: Int = {
val res = try term.inputStream.read
catch { case _: InterruptedException => -2 }
val res = term.inputStream match {
case w: Terminal.WriteableInputStream =>
val result = new LinkedBlockingQueue[Integer]
try {
w.read(result)
result.poll match {
case null => throw new ClosedException
case i => i.toInt
}
} catch {
case _: InterruptedException =>
w.cancel()
throw new ClosedException
}
case _ => throw new ClosedException
}
if (res == 4 && term.prompt.render().endsWith(term.prompt.mkPrompt()))
throw new EOFException
res
@ -110,8 +127,41 @@ private[util] object JLine3 {
override def flush(): Unit = term.withPrintStream(_.flush())
}
override val reader =
NonBlocking.nonBlocking(term.name, input, Charset.defaultCharset())
override val reader = new NonBlockingReader {
val buffer = new LinkedBlockingQueue[Integer]
val thread = new AtomicReference[Thread]
private def fillBuffer(): Unit = thread.synchronized {
thread.set(Thread.currentThread)
buffer.put(
try input.read()
catch { case _: InterruptedException => -3 }
)
}
override def close(): Unit = thread.get match {
case null =>
case t => t.interrupt()
}
override def read(timeout: Long, peek: Boolean) = {
if (buffer.isEmpty && !peek) fillBuffer()
(if (peek) buffer.peek else buffer.take) match {
case null => -2
case i => if (i == -3) throw new ClosedException else i
}
}
override def peek(timeout: Long): Int = buffer.peek() match {
case null => -1
case i => i.toInt
}
override def readBuffered(buf: Array[Char]): Int = {
if (buffer.isEmpty) fillBuffer()
buffer.take match {
case i if i == -1 => -1
case i =>
buf(0) = i.toChar
1
}
}
}
override val writer: PrintWriter = new PrintWriter(output, true)
/*
* For now assume that the terminal capabilities for client and server

View File

@ -158,7 +158,7 @@ private[sbt] object ProgressState {
if (!pe.skipIfActive.getOrElse(false) || (!isRunning && !isBatch)) {
terminal.withPrintStream { ps =>
val commandFromThisTerminal = pe.channelName.fold(true)(_ == terminal.name)
val info = if ((isRunning || isBatch || noPrompt) && commandFromThisTerminal) {
val info = if (commandFromThisTerminal) {
pe.items.map { item =>
val elapsed = item.elapsedMicros / 1000000L
s" | => ${item.name} ${elapsed}s"

View File

@ -34,5 +34,6 @@ private[sbt] object Prompt {
private[sbt] case object Running extends NoPrompt
private[sbt] case object Batch extends NoPrompt
private[sbt] case object Watch extends NoPrompt
private[sbt] case object Pending extends NoPrompt
private[sbt] case object NoPrompt extends NoPrompt
}

View File

@ -11,7 +11,7 @@ import java.io.{ InputStream, InterruptedIOException, IOException, OutputStream,
import java.nio.channels.ClosedChannelException
import java.util.{ Arrays, Locale }
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.{ ArrayBlockingQueue, Executors, LinkedBlockingQueue, TimeUnit }
import java.util.concurrent.{ Executors, LinkedBlockingQueue, TimeUnit }
import jline.DefaultTerminal2
import jline.console.ConsoleReader
@ -141,7 +141,7 @@ trait Terminal extends AutoCloseable {
private[sbt] def withRawOutput[R](f: => R): R
private[sbt] def restore(): Unit = {}
private[sbt] val progressState = new ProgressState(1)
private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Running)
private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Pending)
private[sbt] final def prompt: Prompt = promptHolder.get
private[sbt] final def setPrompt(newPrompt: Prompt): Unit =
if (prompt != Prompt.NoPrompt) promptHolder.set(newPrompt)
@ -396,50 +396,31 @@ object Terminal {
private[sbt] class WriteableInputStream(in: InputStream, name: String)
extends InputStream
with AutoCloseable {
final def write(bytes: Int*): Unit = waiting.synchronized {
waiting.poll match {
case null =>
bytes.foreach(b => buffer.put(b))
case w =>
if (bytes.length > 1) bytes.tail.foreach(b => buffer.put(b))
bytes.headOption.foreach(b => w.put(b))
}
final def write(bytes: Int*): Unit = readThread.synchronized {
bytes.foreach(b => buffer.put(b))
}
private[this] val executor =
Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-input-reader"))
private[this] val buffer = new LinkedBlockingQueue[Integer]
private[this] val closed = new AtomicBoolean(false)
private[this] val readQueue = new LinkedBlockingQueue[Unit]
private[this] val waiting = new ArrayBlockingQueue[LinkedBlockingQueue[Integer]](1)
private[this] val readThread = new AtomicReference[Thread]
/*
* Starts a loop that waits for consumers of the InputStream to call read.
* When read is called, we enqueue a `LinkedBlockingQueue[Int]` to which
* the runnable can return a byte from stdin. If the read caller is interrupted,
* they remove the result from the waiting set and any byte read will be
* enqueued in the buffer. It is done this way so that we only read from
* System.in when a caller actually asks for bytes. If we constantly poll
* from System.in, then when the user calls reboot from the console, the
* first character they type after reboot is swallowed by the previous
* sbt main program. If the user calls reboot from a remote client, we
* can't avoid losing the first byte inputted in the console. A more
* robust fix would be to override System.in at the launcher level instead
* of at the sbt level. At the moment, the use case of a user calling
* reboot from a network client and the adding input at the server console
* seems pathological enough that it isn't worth putting more effort into
* fixing.
*
* Starts a loop that fills a buffer with bytes from stdin. We only read from
* the underlying stream when the buffer is empty and there is an active reader.
* If the reader detaches without consuming any bytes, we just buffer the
* next byte that we read from the stream. One known issue with this approach
* is that if a remote client triggers a reboot, we cannot necessarily stop this
* loop from consuming the next byte from standard in even if sbt has fully
* rebooted and the byte will never be consumed. We try to fix this in withStreams
* by setting the terminal to raw mode, which the input stream makes it non blocking,
* but this approach only works on posix platforms.
*/
private[this] val runnable: Runnable = () => {
@tailrec def impl(): Unit = {
val _ = readQueue.take
val b = in.read
// The downstream consumer may have been interrupted. Buffer the result
// when that hapens.
waiting.poll match {
case null => buffer.put(b)
case q => q.put(b)
}
buffer.put(b)
if (b != -1 && !Thread.interrupted()) impl()
else closed.set(true)
}
@ -447,30 +428,28 @@ object Terminal {
catch { case _: InterruptedException => closed.set(true) }
}
executor.submit(runnable)
override def read(): Int =
if (closed.get) -1
else
synchronized {
def read(result: LinkedBlockingQueue[Integer]): Unit =
if (!closed.get)
readThread.synchronized {
readThread.set(Thread.currentThread)
try buffer.poll match {
case null =>
val result = new LinkedBlockingQueue[Integer]
waiting.synchronized(waiting.put(result))
readQueue.put(())
try result.take.toInt
catch {
case e: InterruptedException =>
waiting.remove(result)
-1
}
result.put(buffer.take)
case b if b == -1 => throw new ClosedChannelException
case b => b.toInt
case b => result.put(b)
} finally readThread.set(null)
}
def cancel(): Unit = waiting.synchronized {
override def read(): Int = {
val result = new LinkedBlockingQueue[Integer]
read(result)
result.poll match {
case null => -1
case i => i.toInt
}
}
def cancel(): Unit = readThread.synchronized {
Option(readThread.getAndSet(null)).foreach(_.interrupt())
waiting.forEach(_.put(-2))
waiting.clear()
readQueue.clear()
}
@ -732,7 +711,11 @@ object Terminal {
}
term.restore()
term.setEchoEnabled(true)
new ConsoleTerminal(term, nonBlockingIn, originalOut)
new ConsoleTerminal(
term,
if (System.console == null) nullWriteableInputStream else nonBlockingIn,
originalOut
)
}
private[sbt] def reset(): Unit = {
@ -780,7 +763,7 @@ object Terminal {
private[sbt] def deprecatedTeminal: jline.Terminal = console.toJLine
private class ConsoleTerminal(
val term: jline.Terminal with jline.Terminal2,
in: InputStream,
in: WriteableInputStream,
out: OutputStream
) extends TerminalImpl(in, out, originalErr, "console0") {
private[util] lazy val system = JLine3.system
@ -837,17 +820,13 @@ object Terminal {
}
}
private[sbt] abstract class TerminalImpl private[sbt] (
val in: InputStream,
val in: WriteableInputStream,
val out: OutputStream,
override val errorStream: OutputStream,
override private[sbt] val name: String
) extends Terminal {
private[this] val rawMode = new AtomicBoolean(false)
private[this] val writeLock = new AnyRef
private[this] val writeableInputStream = in match {
case w: WriteableInputStream => w
case _ => new WriteableInputStream(in, name)
}
def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f
override def getLastLine: Option[String] = progressState.currentLine
override def getLines: Seq[String] = progressState.getLines
@ -886,9 +865,9 @@ object Terminal {
progressState.write(TerminalImpl.this, bytes, ps, hasProgress.get && !rawMode.get)
}
override private[sbt] val printStream: PrintStream = new LinePrintStream(outputStream)
override def inputStream: InputStream = writeableInputStream
override def inputStream: InputStream = in
private[sbt] def write(bytes: Int*): Unit = writeableInputStream.write(bytes: _*)
private[sbt] def write(bytes: Int*): Unit = in.write(bytes: _*)
private[this] val isStopped = new AtomicBoolean(false)
override def getLineHeightAndWidth(line: String): (Int, Int) = getWidth match {
@ -909,9 +888,16 @@ object Terminal {
writeLock.synchronized(f(rawPrintStream))
override def close(): Unit = if (isStopped.compareAndSet(false, true)) {
writeableInputStream.close()
in.close()
}
}
private lazy val nullInputStream: InputStream = () => {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
-1
}
private lazy val nullWriteableInputStream =
new WriteableInputStream(nullInputStream, "null-writeable-input-stream")
private[sbt] val NullTerminal = new Terminal {
override def close(): Unit = {}
override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = false
@ -922,11 +908,7 @@ object Terminal {
override def getNumericCapability(capability: String, jline3: Boolean): Integer = null
override def getStringCapability(capability: String, jline3: Boolean): String = null
override def getWidth: Int = 0
override def inputStream: java.io.InputStream = () => {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
-1
}
override def inputStream: java.io.InputStream = nullInputStream
override def isAnsiSupported: Boolean = false
override def isColorEnabled: Boolean = false
override def isEchoEnabled: Boolean = false

View File

@ -11,7 +11,6 @@ import java.io.File
import java.nio.channels.ClosedChannelException
import java.util.concurrent.atomic.AtomicBoolean
//import jline.console.history.PersistentHistory
import sbt.BasicCommandStrings.{ Cancel, TerminateAction, Shutdown }
import sbt.BasicKeys.{ historyPath, terminalShellPrompt }
import sbt.State
@ -23,55 +22,67 @@ import sbt.internal.util.complete.{ Parser }
import scala.annotation.tailrec
private[sbt] trait UITask extends Runnable with AutoCloseable {
private[sbt] def channel: CommandChannel
private[sbt] def reader: UITask.Reader
private[sbt] val channel: CommandChannel
private[sbt] val reader: UITask.Reader
private[this] final def handleInput(s: Either[String, String]): Boolean = s match {
case Left(m) => channel.onFastTrackTask(m)
case Right(cmd) => channel.onCommand(cmd)
}
private[this] val isStopped = new AtomicBoolean(false)
override def run(): Unit = {
@tailrec def impl(): Unit = {
@tailrec def impl(): Unit = if (!isStopped.get) {
val res = reader.readLine()
if (!handleInput(res) && !isStopped.get) impl()
}
try impl()
catch { case _: InterruptedException | _: ClosedChannelException => isStopped.set(true) }
}
override def close(): Unit = isStopped.set(true)
override def close(): Unit = {
isStopped.set(true)
reader.close()
}
}
private[sbt] object UITask {
trait Reader { def readLine(): Either[String, String] }
trait Reader extends AutoCloseable {
def readLine(): Either[String, String]
override def close(): Unit = {}
}
object Reader {
def terminalReader(parser: Parser[_])(
terminal: Terminal,
state: State
): Reader = { () =>
try {
val clear = terminal.ansi(ClearPromptLine, "")
@tailrec def impl(): Either[String, String] = {
val reader = LineReader.createReader(history(state), parser, terminal, terminal.prompt)
(try reader.readLine(clear + terminal.prompt.mkPrompt())
finally reader.close) match {
case None if terminal == Terminal.console && System.console == null =>
// No stdin is attached to the process so just ignore the result and
// block until the thread is interrupted.
this.synchronized(this.wait())
Right("") // should be unreachable
// JLine returns null on ctrl+d when there is no other input. This interprets
// ctrl+d with no imput as an exit
case None => Left(TerminateAction)
case Some(s: String) =>
s.trim() match {
case "" => impl()
case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd)
case cmd => Right(cmd)
}
): Reader = new Reader {
val closed = new AtomicBoolean(false)
def readLine(): Either[String, String] =
try {
val clear = terminal.ansi(ClearPromptLine, "")
@tailrec def impl(): Either[String, String] = {
val thread = Thread.currentThread
if (thread.isInterrupted || closed.get) throw new InterruptedException
val reader = LineReader.createReader(history(state), parser, terminal)
if (thread.isInterrupted || closed.get) throw new InterruptedException
(try reader.readLine(clear + terminal.prompt.mkPrompt())
finally reader.close) match {
case None if terminal == Terminal.console && System.console == null =>
// No stdin is attached to the process so just ignore the result and
// block until the thread is interrupted.
this.synchronized(this.wait())
Right("") // should be unreachable
// JLine returns null on ctrl+d when there is no other input. This interprets
// ctrl+d with no imput as an exit
case None => Left(TerminateAction)
case Some(s: String) =>
s.trim() match {
case "" => impl()
case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd)
case cmd => Right(cmd)
}
}
}
}
impl()
} catch { case e: InterruptedException => Right("") }
impl()
} catch { case e: InterruptedException => Right("") }
override def close(): Unit = closed.set(true)
}
}
private[this] def history(s: State): Option[File] =
@ -87,7 +98,7 @@ private[sbt] object UITask {
state: State,
override val channel: CommandChannel,
) extends UITask {
override private[sbt] def reader: UITask.Reader = {
override private[sbt] lazy val reader: UITask.Reader = {
UITask.Reader.terminalReader(state.combinedParser)(channel.terminal, state)
}
}

View File

@ -14,7 +14,7 @@ import java.util.concurrent.Executors
import sbt.State
import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util }
import sbt.internal.util.Prompt.{ AskUser, Running }
import sbt.internal.util.Prompt
private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable {
private[this] val uiThread = new AtomicReference[(UITask, Thread)]
@ -31,9 +31,15 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable
uiThread.synchronized {
val task = channel.makeUIThread(state)
def submit(): Thread = {
val thread = new Thread(() => {
task.run()
uiThread.set(null)
def close(): Unit = {
uiThread.get match {
case (_, t) if t == thread => uiThread.set(null)
case _ =>
}
}
lazy val thread = new Thread(() => {
try task.run()
finally close()
}, s"sbt-$name-ui-thread")
thread.setDaemon(true)
thread.start()
@ -60,6 +66,13 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable
case (t, thread) =>
t.close()
Util.ignoreResult(thread.interrupt())
try thread.join(1000)
catch { case _: InterruptedException => }
// This join should always work, but if it doesn't log an error because
// it can cause problems if the thread isn't joined
if (thread.isAlive) System.err.println(s"Unable to join thread $thread")
()
}
}
@ -70,8 +83,9 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable
}
val state = consolePromptEvent.state
terminal.prompt match {
case Running => terminal.setPrompt(AskUser(() => UITask.shellPrompt(terminal, state)))
case _ =>
case Prompt.Running | Prompt.Pending =>
terminal.setPrompt(Prompt.AskUser(() => UITask.shellPrompt(terminal, state)))
case _ =>
}
onProgressEvent(ProgressEvent("Info", Vector(), None, None, None))
reset(state)
@ -80,6 +94,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable
private[sbt] def onConsoleUnpromptEvent(
consoleUnpromptEvent: ConsoleUnpromptEvent
): Unit = {
terminal.setPrompt(Prompt.Pending)
if (consoleUnpromptEvent.lastSource.fold(true)(_.channelName != name)) {
terminal.progressState.reset()
} else stopThread()

View File

@ -1231,7 +1231,7 @@ private[sbt] object ContinuousCommands {
state: State
) extends Thread(s"sbt-${channel.name}-watch-ui-thread")
with UITask {
override private[sbt] def reader: UITask.Reader = () => {
override private[sbt] lazy val reader: UITask.Reader = () => {
def stop = Right(s"${ContinuousCommands.stopWatch} ${channel.name}")
val exitAction: Watch.Action = {
Watch.apply(

View File

@ -110,7 +110,7 @@ final class NetworkChannel(
}
private[sbt] def write(byte: Byte) = inputBuffer.add(byte)
private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal)
private[this] val terminalHolder = new AtomicReference[Terminal](Terminal.NullTerminal)
override private[sbt] def terminal: Terminal = terminalHolder.get
override val userThread: UserThread = new UserThread(this)
@ -152,8 +152,8 @@ final class NetworkChannel(
if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command)
else
new UITask {
override private[sbt] def channel = NetworkChannel.this
override def reader: UITask.Reader = () => {
override private[sbt] val channel = NetworkChannel.this
override private[sbt] lazy val reader: UITask.Reader = () => {
try {
this.synchronized(this.wait)
Left(TerminateAction)
@ -650,6 +650,8 @@ final class NetworkChannel(
}
override def available(): Int = inputBuffer.size
}
private[this] lazy val writeableInputStream: Terminal.WriteableInputStream =
new Terminal.WriteableInputStream(inputStream, name)
import sjsonnew.BasicJsonProtocol._
import scala.collection.JavaConverters._
@ -726,7 +728,8 @@ final class NetworkChannel(
write(java.util.Arrays.copyOfRange(b, off, off + len))
}
}
private class NetworkTerminal extends TerminalImpl(inputStream, outputStream, errorStream, name) {
private class NetworkTerminal
extends TerminalImpl(writeableInputStream, outputStream, errorStream, name) {
private[this] val pending = new AtomicBoolean(false)
private[this] val closed = new AtomicBoolean(false)
private[this] val properties = new AtomicReference[TerminalPropertiesResponse]