diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala index eadd09ae6..2ef394a54 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala @@ -10,9 +10,9 @@ package sbt.internal.util import java.io.{ EOFException, InputStream, OutputStream, PrintWriter } import java.nio.charset.Charset import java.util.{ Arrays, EnumSet } -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import org.jline.utils.InfoCmp.Capability -import org.jline.utils.{ NonBlocking, OSUtils } +import org.jline.utils.{ ClosedException, NonBlockingReader, OSUtils } import org.jline.terminal.{ Attributes, Size, Terminal => JTerminal } import org.jline.terminal.Terminal.SignalHandler import org.jline.terminal.impl.AbstractTerminal @@ -20,6 +20,7 @@ import org.jline.terminal.impl.jansi.JansiSupportImpl import org.jline.terminal.impl.jansi.win.JansiWinSysTerminal import scala.collection.JavaConverters._ import scala.util.Try +import java.util.concurrent.LinkedBlockingQueue private[util] object JLine3 { private val capabilityMap = Capability @@ -77,6 +78,8 @@ private[util] object JLine3 { new AbstractTerminal(term.name, "ansi", Charset.forName("UTF-8"), SignalHandler.SIG_DFL) { val closed = new AtomicBoolean(false) setOnClose { () => + doClose() + reader.close() if (closed.compareAndSet(false, true)) { // This is necessary to shutdown the non blocking input reader // so that it doesn't keep blocking @@ -89,8 +92,22 @@ private[util] object JLine3 { parseInfoCmp() override val input: InputStream = new InputStream { override def read: Int = { - val res = try term.inputStream.read - catch { case _: InterruptedException => -2 } + val res = term.inputStream match { + case w: Terminal.WriteableInputStream => + val result = new LinkedBlockingQueue[Integer] + try { + w.read(result) + result.poll match { + case null => throw new ClosedException + case i => i.toInt + } + } catch { + case _: InterruptedException => + w.cancel() + throw new ClosedException + } + case _ => throw new ClosedException + } if (res == 4 && term.prompt.render().endsWith(term.prompt.mkPrompt())) throw new EOFException res @@ -110,8 +127,41 @@ private[util] object JLine3 { override def flush(): Unit = term.withPrintStream(_.flush()) } - override val reader = - NonBlocking.nonBlocking(term.name, input, Charset.defaultCharset()) + override val reader = new NonBlockingReader { + val buffer = new LinkedBlockingQueue[Integer] + val thread = new AtomicReference[Thread] + private def fillBuffer(): Unit = thread.synchronized { + thread.set(Thread.currentThread) + buffer.put( + try input.read() + catch { case _: InterruptedException => -3 } + ) + } + override def close(): Unit = thread.get match { + case null => + case t => t.interrupt() + } + override def read(timeout: Long, peek: Boolean) = { + if (buffer.isEmpty && !peek) fillBuffer() + (if (peek) buffer.peek else buffer.take) match { + case null => -2 + case i => if (i == -3) throw new ClosedException else i + } + } + override def peek(timeout: Long): Int = buffer.peek() match { + case null => -1 + case i => i.toInt + } + override def readBuffered(buf: Array[Char]): Int = { + if (buffer.isEmpty) fillBuffer() + buffer.take match { + case i if i == -1 => -1 + case i => + buf(0) = i.toChar + 1 + } + } + } override val writer: PrintWriter = new PrintWriter(output, true) /* * For now assume that the terminal capabilities for client and server diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala index f6fb154dc..f3a70a8a1 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala @@ -158,7 +158,7 @@ private[sbt] object ProgressState { if (!pe.skipIfActive.getOrElse(false) || (!isRunning && !isBatch)) { terminal.withPrintStream { ps => val commandFromThisTerminal = pe.channelName.fold(true)(_ == terminal.name) - val info = if ((isRunning || isBatch || noPrompt) && commandFromThisTerminal) { + val info = if (commandFromThisTerminal) { pe.items.map { item => val elapsed = item.elapsedMicros / 1000000L s" | => ${item.name} ${elapsed}s" diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala index 90f02f66c..89c1872cb 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala @@ -34,5 +34,6 @@ private[sbt] object Prompt { private[sbt] case object Running extends NoPrompt private[sbt] case object Batch extends NoPrompt private[sbt] case object Watch extends NoPrompt + private[sbt] case object Pending extends NoPrompt private[sbt] case object NoPrompt extends NoPrompt } diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index deb6489f6..f612a7730 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -11,7 +11,7 @@ import java.io.{ InputStream, InterruptedIOException, IOException, OutputStream, import java.nio.channels.ClosedChannelException import java.util.{ Arrays, Locale } import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } -import java.util.concurrent.{ ArrayBlockingQueue, Executors, LinkedBlockingQueue, TimeUnit } +import java.util.concurrent.{ Executors, LinkedBlockingQueue, TimeUnit } import jline.DefaultTerminal2 import jline.console.ConsoleReader @@ -141,7 +141,7 @@ trait Terminal extends AutoCloseable { private[sbt] def withRawOutput[R](f: => R): R private[sbt] def restore(): Unit = {} private[sbt] val progressState = new ProgressState(1) - private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Running) + private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Pending) private[sbt] final def prompt: Prompt = promptHolder.get private[sbt] final def setPrompt(newPrompt: Prompt): Unit = if (prompt != Prompt.NoPrompt) promptHolder.set(newPrompt) @@ -396,50 +396,31 @@ object Terminal { private[sbt] class WriteableInputStream(in: InputStream, name: String) extends InputStream with AutoCloseable { - final def write(bytes: Int*): Unit = waiting.synchronized { - waiting.poll match { - case null => - bytes.foreach(b => buffer.put(b)) - case w => - if (bytes.length > 1) bytes.tail.foreach(b => buffer.put(b)) - bytes.headOption.foreach(b => w.put(b)) - } + final def write(bytes: Int*): Unit = readThread.synchronized { + bytes.foreach(b => buffer.put(b)) } private[this] val executor = Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-input-reader")) private[this] val buffer = new LinkedBlockingQueue[Integer] private[this] val closed = new AtomicBoolean(false) private[this] val readQueue = new LinkedBlockingQueue[Unit] - private[this] val waiting = new ArrayBlockingQueue[LinkedBlockingQueue[Integer]](1) private[this] val readThread = new AtomicReference[Thread] /* - * Starts a loop that waits for consumers of the InputStream to call read. - * When read is called, we enqueue a `LinkedBlockingQueue[Int]` to which - * the runnable can return a byte from stdin. If the read caller is interrupted, - * they remove the result from the waiting set and any byte read will be - * enqueued in the buffer. It is done this way so that we only read from - * System.in when a caller actually asks for bytes. If we constantly poll - * from System.in, then when the user calls reboot from the console, the - * first character they type after reboot is swallowed by the previous - * sbt main program. If the user calls reboot from a remote client, we - * can't avoid losing the first byte inputted in the console. A more - * robust fix would be to override System.in at the launcher level instead - * of at the sbt level. At the moment, the use case of a user calling - * reboot from a network client and the adding input at the server console - * seems pathological enough that it isn't worth putting more effort into - * fixing. - * + * Starts a loop that fills a buffer with bytes from stdin. We only read from + * the underlying stream when the buffer is empty and there is an active reader. + * If the reader detaches without consuming any bytes, we just buffer the + * next byte that we read from the stream. One known issue with this approach + * is that if a remote client triggers a reboot, we cannot necessarily stop this + * loop from consuming the next byte from standard in even if sbt has fully + * rebooted and the byte will never be consumed. We try to fix this in withStreams + * by setting the terminal to raw mode, which the input stream makes it non blocking, + * but this approach only works on posix platforms. */ private[this] val runnable: Runnable = () => { @tailrec def impl(): Unit = { val _ = readQueue.take val b = in.read - // The downstream consumer may have been interrupted. Buffer the result - // when that hapens. - waiting.poll match { - case null => buffer.put(b) - case q => q.put(b) - } + buffer.put(b) if (b != -1 && !Thread.interrupted()) impl() else closed.set(true) } @@ -447,30 +428,28 @@ object Terminal { catch { case _: InterruptedException => closed.set(true) } } executor.submit(runnable) - override def read(): Int = - if (closed.get) -1 - else - synchronized { + def read(result: LinkedBlockingQueue[Integer]): Unit = + if (!closed.get) + readThread.synchronized { readThread.set(Thread.currentThread) try buffer.poll match { case null => - val result = new LinkedBlockingQueue[Integer] - waiting.synchronized(waiting.put(result)) readQueue.put(()) - try result.take.toInt - catch { - case e: InterruptedException => - waiting.remove(result) - -1 - } + result.put(buffer.take) case b if b == -1 => throw new ClosedChannelException - case b => b.toInt + case b => result.put(b) } finally readThread.set(null) } - def cancel(): Unit = waiting.synchronized { + override def read(): Int = { + val result = new LinkedBlockingQueue[Integer] + read(result) + result.poll match { + case null => -1 + case i => i.toInt + } + } + def cancel(): Unit = readThread.synchronized { Option(readThread.getAndSet(null)).foreach(_.interrupt()) - waiting.forEach(_.put(-2)) - waiting.clear() readQueue.clear() } @@ -732,7 +711,11 @@ object Terminal { } term.restore() term.setEchoEnabled(true) - new ConsoleTerminal(term, nonBlockingIn, originalOut) + new ConsoleTerminal( + term, + if (System.console == null) nullWriteableInputStream else nonBlockingIn, + originalOut + ) } private[sbt] def reset(): Unit = { @@ -780,7 +763,7 @@ object Terminal { private[sbt] def deprecatedTeminal: jline.Terminal = console.toJLine private class ConsoleTerminal( val term: jline.Terminal with jline.Terminal2, - in: InputStream, + in: WriteableInputStream, out: OutputStream ) extends TerminalImpl(in, out, originalErr, "console0") { private[util] lazy val system = JLine3.system @@ -837,17 +820,13 @@ object Terminal { } } private[sbt] abstract class TerminalImpl private[sbt] ( - val in: InputStream, + val in: WriteableInputStream, val out: OutputStream, override val errorStream: OutputStream, override private[sbt] val name: String ) extends Terminal { private[this] val rawMode = new AtomicBoolean(false) private[this] val writeLock = new AnyRef - private[this] val writeableInputStream = in match { - case w: WriteableInputStream => w - case _ => new WriteableInputStream(in, name) - } def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f override def getLastLine: Option[String] = progressState.currentLine override def getLines: Seq[String] = progressState.getLines @@ -886,9 +865,9 @@ object Terminal { progressState.write(TerminalImpl.this, bytes, ps, hasProgress.get && !rawMode.get) } override private[sbt] val printStream: PrintStream = new LinePrintStream(outputStream) - override def inputStream: InputStream = writeableInputStream + override def inputStream: InputStream = in - private[sbt] def write(bytes: Int*): Unit = writeableInputStream.write(bytes: _*) + private[sbt] def write(bytes: Int*): Unit = in.write(bytes: _*) private[this] val isStopped = new AtomicBoolean(false) override def getLineHeightAndWidth(line: String): (Int, Int) = getWidth match { @@ -909,9 +888,16 @@ object Terminal { writeLock.synchronized(f(rawPrintStream)) override def close(): Unit = if (isStopped.compareAndSet(false, true)) { - writeableInputStream.close() + in.close() } } + private lazy val nullInputStream: InputStream = () => { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + -1 + } + private lazy val nullWriteableInputStream = + new WriteableInputStream(nullInputStream, "null-writeable-input-stream") private[sbt] val NullTerminal = new Terminal { override def close(): Unit = {} override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = false @@ -922,11 +908,7 @@ object Terminal { override def getNumericCapability(capability: String, jline3: Boolean): Integer = null override def getStringCapability(capability: String, jline3: Boolean): String = null override def getWidth: Int = 0 - override def inputStream: java.io.InputStream = () => { - try this.synchronized(this.wait) - catch { case _: InterruptedException => } - -1 - } + override def inputStream: java.io.InputStream = nullInputStream override def isAnsiSupported: Boolean = false override def isColorEnabled: Boolean = false override def isEchoEnabled: Boolean = false diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala index 0b5618cd2..a4144b3f8 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UITask.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -11,7 +11,6 @@ import java.io.File import java.nio.channels.ClosedChannelException import java.util.concurrent.atomic.AtomicBoolean -//import jline.console.history.PersistentHistory import sbt.BasicCommandStrings.{ Cancel, TerminateAction, Shutdown } import sbt.BasicKeys.{ historyPath, terminalShellPrompt } import sbt.State @@ -23,55 +22,67 @@ import sbt.internal.util.complete.{ Parser } import scala.annotation.tailrec private[sbt] trait UITask extends Runnable with AutoCloseable { - private[sbt] def channel: CommandChannel - private[sbt] def reader: UITask.Reader + private[sbt] val channel: CommandChannel + private[sbt] val reader: UITask.Reader private[this] final def handleInput(s: Either[String, String]): Boolean = s match { case Left(m) => channel.onFastTrackTask(m) case Right(cmd) => channel.onCommand(cmd) } private[this] val isStopped = new AtomicBoolean(false) override def run(): Unit = { - @tailrec def impl(): Unit = { + @tailrec def impl(): Unit = if (!isStopped.get) { val res = reader.readLine() if (!handleInput(res) && !isStopped.get) impl() } try impl() catch { case _: InterruptedException | _: ClosedChannelException => isStopped.set(true) } } - override def close(): Unit = isStopped.set(true) + override def close(): Unit = { + isStopped.set(true) + reader.close() + } } private[sbt] object UITask { - trait Reader { def readLine(): Either[String, String] } + trait Reader extends AutoCloseable { + def readLine(): Either[String, String] + override def close(): Unit = {} + } object Reader { def terminalReader(parser: Parser[_])( terminal: Terminal, state: State - ): Reader = { () => - try { - val clear = terminal.ansi(ClearPromptLine, "") - @tailrec def impl(): Either[String, String] = { - val reader = LineReader.createReader(history(state), parser, terminal, terminal.prompt) - (try reader.readLine(clear + terminal.prompt.mkPrompt()) - finally reader.close) match { - case None if terminal == Terminal.console && System.console == null => - // No stdin is attached to the process so just ignore the result and - // block until the thread is interrupted. - this.synchronized(this.wait()) - Right("") // should be unreachable - // JLine returns null on ctrl+d when there is no other input. This interprets - // ctrl+d with no imput as an exit - case None => Left(TerminateAction) - case Some(s: String) => - s.trim() match { - case "" => impl() - case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd) - case cmd => Right(cmd) - } + ): Reader = new Reader { + val closed = new AtomicBoolean(false) + def readLine(): Either[String, String] = + try { + val clear = terminal.ansi(ClearPromptLine, "") + @tailrec def impl(): Either[String, String] = { + val thread = Thread.currentThread + if (thread.isInterrupted || closed.get) throw new InterruptedException + val reader = LineReader.createReader(history(state), parser, terminal) + if (thread.isInterrupted || closed.get) throw new InterruptedException + (try reader.readLine(clear + terminal.prompt.mkPrompt()) + finally reader.close) match { + case None if terminal == Terminal.console && System.console == null => + // No stdin is attached to the process so just ignore the result and + // block until the thread is interrupted. + this.synchronized(this.wait()) + Right("") // should be unreachable + // JLine returns null on ctrl+d when there is no other input. This interprets + // ctrl+d with no imput as an exit + case None => Left(TerminateAction) + case Some(s: String) => + s.trim() match { + case "" => impl() + case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd) + case cmd => Right(cmd) + } + } } - } - impl() - } catch { case e: InterruptedException => Right("") } + impl() + } catch { case e: InterruptedException => Right("") } + override def close(): Unit = closed.set(true) } } private[this] def history(s: State): Option[File] = @@ -87,7 +98,7 @@ private[sbt] object UITask { state: State, override val channel: CommandChannel, ) extends UITask { - override private[sbt] def reader: UITask.Reader = { + override private[sbt] lazy val reader: UITask.Reader = { UITask.Reader.terminalReader(state.combinedParser)(channel.terminal, state) } } diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala index e868c3682..7f1d007f5 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -14,7 +14,7 @@ import java.util.concurrent.Executors import sbt.State import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } -import sbt.internal.util.Prompt.{ AskUser, Running } +import sbt.internal.util.Prompt private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { private[this] val uiThread = new AtomicReference[(UITask, Thread)] @@ -31,9 +31,15 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable uiThread.synchronized { val task = channel.makeUIThread(state) def submit(): Thread = { - val thread = new Thread(() => { - task.run() - uiThread.set(null) + def close(): Unit = { + uiThread.get match { + case (_, t) if t == thread => uiThread.set(null) + case _ => + } + } + lazy val thread = new Thread(() => { + try task.run() + finally close() }, s"sbt-$name-ui-thread") thread.setDaemon(true) thread.start() @@ -60,6 +66,13 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable case (t, thread) => t.close() Util.ignoreResult(thread.interrupt()) + try thread.join(1000) + catch { case _: InterruptedException => } + + // This join should always work, but if it doesn't log an error because + // it can cause problems if the thread isn't joined + if (thread.isAlive) System.err.println(s"Unable to join thread $thread") + () } } @@ -70,8 +83,9 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable } val state = consolePromptEvent.state terminal.prompt match { - case Running => terminal.setPrompt(AskUser(() => UITask.shellPrompt(terminal, state))) - case _ => + case Prompt.Running | Prompt.Pending => + terminal.setPrompt(Prompt.AskUser(() => UITask.shellPrompt(terminal, state))) + case _ => } onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) reset(state) @@ -80,6 +94,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable private[sbt] def onConsoleUnpromptEvent( consoleUnpromptEvent: ConsoleUnpromptEvent ): Unit = { + terminal.setPrompt(Prompt.Pending) if (consoleUnpromptEvent.lastSource.fold(true)(_.channelName != name)) { terminal.progressState.reset() } else stopThread() diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index 0f13a36b4..351490f80 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -1231,7 +1231,7 @@ private[sbt] object ContinuousCommands { state: State ) extends Thread(s"sbt-${channel.name}-watch-ui-thread") with UITask { - override private[sbt] def reader: UITask.Reader = () => { + override private[sbt] lazy val reader: UITask.Reader = () => { def stop = Right(s"${ContinuousCommands.stopWatch} ${channel.name}") val exitAction: Watch.Action = { Watch.apply( diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index f7583f9e1..fd461c06f 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -110,7 +110,7 @@ final class NetworkChannel( } private[sbt] def write(byte: Byte) = inputBuffer.add(byte) - private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal) + private[this] val terminalHolder = new AtomicReference[Terminal](Terminal.NullTerminal) override private[sbt] def terminal: Terminal = terminalHolder.get override val userThread: UserThread = new UserThread(this) @@ -152,8 +152,8 @@ final class NetworkChannel( if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command) else new UITask { - override private[sbt] def channel = NetworkChannel.this - override def reader: UITask.Reader = () => { + override private[sbt] val channel = NetworkChannel.this + override private[sbt] lazy val reader: UITask.Reader = () => { try { this.synchronized(this.wait) Left(TerminateAction) @@ -650,6 +650,8 @@ final class NetworkChannel( } override def available(): Int = inputBuffer.size } + private[this] lazy val writeableInputStream: Terminal.WriteableInputStream = + new Terminal.WriteableInputStream(inputStream, name) import sjsonnew.BasicJsonProtocol._ import scala.collection.JavaConverters._ @@ -726,7 +728,8 @@ final class NetworkChannel( write(java.util.Arrays.copyOfRange(b, off, off + len)) } } - private class NetworkTerminal extends TerminalImpl(inputStream, outputStream, errorStream, name) { + private class NetworkTerminal + extends TerminalImpl(writeableInputStream, outputStream, errorStream, name) { private[this] val pending = new AtomicBoolean(false) private[this] val closed = new AtomicBoolean(false) private[this] val properties = new AtomicReference[TerminalPropertiesResponse]