Only read stdin bytes on demand to fix reboot

When running reboot at the console, the first character that the user
enters after the reboot has completed is lost. This is because it isn't
possible to interrupt System.in and we have a thread that is blocking on
reads to System.in in WriteableInputStream. That thread cannot be
shutdown during normal sbt shutdown while it is reading. When sbt next
starts up (in the same jvm), the previous thread gets the byte but has
nowhere to write it so the byte is lost. This commit fixes that behavior
by ensuring that we only poll from System.in when there is actually a
downstream consumer.

The behavior of reboot is still a little wonky if the user issues a
reboot from a network client and then tries to input commands at the
console. In that case, sbt will have been polling System.in in the ask
user thread prior to the reboot and the ask user thread will be
uninterruptible for the reason described above so the first byte will
again by swallowed by the previous sbt instance. This use case is
sufficiently pathological that it doesn't feel worth the effort to fix.
As annoying as it is, it doesn't break the sbt session. The user will
either submit an invalid command with the missing leading character or
notice the character is missing, possibly think they missed the key,
type backspace a few times and re-type the command.
This commit is contained in:
Ethan Atkins 2020-06-27 13:29:45 -07:00
parent eb66906dad
commit b6b2c3096d
3 changed files with 58 additions and 20 deletions

View File

@ -11,7 +11,7 @@ import java.io.{ InputStream, OutputStream, PrintStream }
import java.nio.channels.ClosedChannelException
import java.util.Locale
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue, TimeUnit }
import java.util.concurrent.{ ConcurrentHashMap, Executors, LinkedBlockingQueue, TimeUnit }
import jline.DefaultTerminal2
import jline.console.ConsoleReader
@ -324,39 +324,69 @@ object Terminal {
private[this] val originalOut = System.out
private[this] val originalIn = System.in
private[this] class WriteableInputStream(in: InputStream, name: String)
private[sbt] class WriteableInputStream(in: InputStream, name: String)
extends InputStream
with AutoCloseable {
final def write(bytes: Int*): Unit = bytes.foreach(buffer.put)
final def write(bytes: Int*): Unit = bytes.foreach(i => buffer.put(i))
private[this] val executor =
Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-input-reader"))
private[this] val buffer = new LinkedBlockingQueue[Int]
private[this] val latch = new CountDownLatch(1)
private[this] val buffer = new LinkedBlockingQueue[Integer]
private[this] val closed = new AtomicBoolean(false)
private[this] def takeOne: Int = if (closed.get) -1 else buffer.take
private[this] val resultQueue = new LinkedBlockingQueue[LinkedBlockingQueue[Int]]
private[this] val waiting = ConcurrentHashMap.newKeySet[LinkedBlockingQueue[Int]]
/*
* Starts a loop that waits for consumers of the InputStream to call read.
* When read is called, we enqueue a `LinkedBlockingQueue[Int]` to which
* the runnable can return a byte from stdin. If the read caller is interrupted,
* they remove the result from the waiting set and any byte read will be
* enqueued in the buffer. It is done this way so that we only read from
* System.in when a caller actually asks for bytes. If we constantly poll
* from System.in, then when the user calls reboot from the console, the
* first character they type after reboot is swallowed by the previous
* sbt main program. If the user calls reboot from a remote client, we
* can't avoid losing the first byte inputted in the console. A more
* robust fix would be to override System.in at the launcher level instead
* of at the sbt level. At the moment, the use case of a user calling
* reboot from a network client and the adding input at the server console
* seems pathological enough that it isn't worth putting more effort into
* fixing.
*
*/
private[this] val runnable: Runnable = () => {
@tailrec def impl(): Unit = {
val result = resultQueue.take
val b = in.read
buffer.put(b)
if (b != -1) impl()
// The downstream consumer may have been interrupted. Buffer the result
// when that hapens.
if (waiting.contains(result)) result.put(b) else buffer.put(b)
if (b != -1 && !Thread.interrupted()) impl()
else closed.set(true)
}
try {
latch.await()
impl()
} catch { case _: InterruptedException => }
try impl()
catch { case _: InterruptedException => closed.set(true) }
}
executor.submit(runnable)
override def read(): Int = {
latch.countDown()
takeOne match {
case -1 => throw new ClosedChannelException
case b => b
}
}
override def read(): Int =
if (closed.get) -1
else
synchronized {
buffer.poll match {
case null =>
val result = new LinkedBlockingQueue[Int]
waiting.add(result)
resultQueue.offer(result)
try result.take
catch {
case e: InterruptedException =>
waiting.remove(result)
throw e
}
case b if b == -1 => throw new ClosedChannelException
case b => b
}
}
override def available(): Int = {
latch.countDown()
buffer.size
}
override def close(): Unit = if (closed.compareAndSet(false, true)) {

View File

@ -155,6 +155,7 @@ class NetworkClient(
val (sk, tkn) = connect(0)
val conn = new ServerConnection(sk) {
override def onNotification(msg: JsonRpcNotificationMessage): Unit = {
if (msg.toString.contains("shutdown")) System.err.println(msg)
msg.method match {
case `Shutdown` =>
val log = msg.params match {

View File

@ -55,6 +55,13 @@ private[sbt] object UITask {
try {
@tailrec def impl(): Either[String, String] = {
lineReader.readLine(clear + terminal.prompt.mkPrompt()) match {
case null if terminal == Terminal.console && System.console == null =>
// No stdin is attached to the process so just ignore the result and
// block until the thread is interrupted.
this.synchronized(this.wait())
Right("") // should be unreachable
// JLine returns null on ctrl+d when there is no other input. This interprets
// ctrl+d with no imput as an exit
case null => Left(TerminateAction)
case s: String =>
lineReader.getHistory match {