diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index b1ded7e0e..76bb7ec39 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -11,7 +11,7 @@ import java.io.{ InputStream, OutputStream, PrintStream } import java.nio.channels.ClosedChannelException import java.util.Locale import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } -import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue, TimeUnit } +import java.util.concurrent.{ ConcurrentHashMap, Executors, LinkedBlockingQueue, TimeUnit } import jline.DefaultTerminal2 import jline.console.ConsoleReader @@ -324,39 +324,69 @@ object Terminal { private[this] val originalOut = System.out private[this] val originalIn = System.in - private[this] class WriteableInputStream(in: InputStream, name: String) + private[sbt] class WriteableInputStream(in: InputStream, name: String) extends InputStream with AutoCloseable { - final def write(bytes: Int*): Unit = bytes.foreach(buffer.put) + final def write(bytes: Int*): Unit = bytes.foreach(i => buffer.put(i)) private[this] val executor = Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-input-reader")) - private[this] val buffer = new LinkedBlockingQueue[Int] - private[this] val latch = new CountDownLatch(1) + private[this] val buffer = new LinkedBlockingQueue[Integer] private[this] val closed = new AtomicBoolean(false) - private[this] def takeOne: Int = if (closed.get) -1 else buffer.take + private[this] val resultQueue = new LinkedBlockingQueue[LinkedBlockingQueue[Int]] + private[this] val waiting = ConcurrentHashMap.newKeySet[LinkedBlockingQueue[Int]] + /* + * Starts a loop that waits for consumers of the InputStream to call read. + * When read is called, we enqueue a `LinkedBlockingQueue[Int]` to which + * the runnable can return a byte from stdin. If the read caller is interrupted, + * they remove the result from the waiting set and any byte read will be + * enqueued in the buffer. It is done this way so that we only read from + * System.in when a caller actually asks for bytes. If we constantly poll + * from System.in, then when the user calls reboot from the console, the + * first character they type after reboot is swallowed by the previous + * sbt main program. If the user calls reboot from a remote client, we + * can't avoid losing the first byte inputted in the console. A more + * robust fix would be to override System.in at the launcher level instead + * of at the sbt level. At the moment, the use case of a user calling + * reboot from a network client and the adding input at the server console + * seems pathological enough that it isn't worth putting more effort into + * fixing. + * + */ private[this] val runnable: Runnable = () => { @tailrec def impl(): Unit = { + val result = resultQueue.take val b = in.read - buffer.put(b) - if (b != -1) impl() + // The downstream consumer may have been interrupted. Buffer the result + // when that hapens. + if (waiting.contains(result)) result.put(b) else buffer.put(b) + if (b != -1 && !Thread.interrupted()) impl() else closed.set(true) } - try { - latch.await() - impl() - } catch { case _: InterruptedException => } + try impl() + catch { case _: InterruptedException => closed.set(true) } } executor.submit(runnable) - override def read(): Int = { - latch.countDown() - takeOne match { - case -1 => throw new ClosedChannelException - case b => b - } - } + override def read(): Int = + if (closed.get) -1 + else + synchronized { + buffer.poll match { + case null => + val result = new LinkedBlockingQueue[Int] + waiting.add(result) + resultQueue.offer(result) + try result.take + catch { + case e: InterruptedException => + waiting.remove(result) + throw e + } + case b if b == -1 => throw new ClosedChannelException + case b => b + } + } override def available(): Int = { - latch.countDown() buffer.size } override def close(): Unit = if (closed.compareAndSet(false, true)) { diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 40a7d5569..6b9f3d8a5 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -155,6 +155,7 @@ class NetworkClient( val (sk, tkn) = connect(0) val conn = new ServerConnection(sk) { override def onNotification(msg: JsonRpcNotificationMessage): Unit = { + if (msg.toString.contains("shutdown")) System.err.println(msg) msg.method match { case `Shutdown` => val log = msg.params match { diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala index 9c5f40c55..64e4a8439 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UITask.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -55,6 +55,13 @@ private[sbt] object UITask { try { @tailrec def impl(): Either[String, String] = { lineReader.readLine(clear + terminal.prompt.mkPrompt()) match { + case null if terminal == Terminal.console && System.console == null => + // No stdin is attached to the process so just ignore the result and + // block until the thread is interrupted. + this.synchronized(this.wait()) + Right("") // should be unreachable + // JLine returns null on ctrl+d when there is no other input. This interprets + // ctrl+d with no imput as an exit case null => Left(TerminateAction) case s: String => lineReader.getHistory match {