From 44a605198bc3bbb2e2afea0767f308d4911b0e05 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Mon, 3 Aug 2020 14:42:40 -0700 Subject: [PATCH 1/7] Add toString implementation to ConsoleOut instances This makes it easier to debug which ConsoleOut is printing output. --- .../src/main/scala/sbt/internal/util/ConsoleOut.scala | 7 +++++++ .../src/main/scala/sbt/internal/util/Terminal.scala | 1 + 2 files changed, 8 insertions(+) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleOut.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleOut.scala index 2ea84f182..17f74c53c 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleOut.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleOut.scala @@ -33,6 +33,7 @@ object ConsoleOut { override def println(s: String): Unit = get.println(s) override def println(): Unit = get.println() override def flush(): Unit = get.flush() + override def toString: String = s"ProxyConsoleOut" } def overwriteContaining(s: String): (String, String) => Boolean = @@ -70,6 +71,7 @@ object ConsoleOut { last = Some(s) current.setLength(0) } + override def toString: String = s"SystemOutOverwrite@${System.identityHashCode(this)}" } def terminalOut: ConsoleOut = new ConsoleOut { @@ -78,6 +80,7 @@ object ConsoleOut { override def println(s: String): Unit = Terminal.get.printStream.println(s) override def println(): Unit = Terminal.get.printStream.println() override def flush(): Unit = Terminal.get.printStream.flush() + override def toString: String = s"TerminalOut" } private[this] val consoleOutPerTerminal = new ConcurrentHashMap[Terminal, ConsoleOut] @@ -89,6 +92,7 @@ object ConsoleOut { override def println(s: String): Unit = terminal.printStream.println(s) override def println(): Unit = terminal.printStream.println() override def flush(): Unit = terminal.printStream.flush() + override def toString: String = s"TerminalOut($terminal)" } consoleOutPerTerminal.put(terminal, res) res @@ -100,6 +104,7 @@ object ConsoleOut { def println(s: String) = out.println(s) def println() = out.println() def flush() = out.flush() + override def toString: String = s"PrintStreamConsoleOut($out)" } def printWriterOut(out: PrintWriter): ConsoleOut = new ConsoleOut { val lockObject = out @@ -107,6 +112,7 @@ object ConsoleOut { def println(s: String) = { out.println(s); flush() } def println() = { out.println(); flush() } def flush() = { out.flush() } + override def toString: String = s"PrintWriterConsoleOut($out)" } def bufferedWriterOut(out: BufferedWriter): ConsoleOut = new ConsoleOut { val lockObject = out @@ -114,5 +120,6 @@ object ConsoleOut { def println(s: String) = { out.write(s); println() } def println() = { out.newLine(); flush() } def flush() = { out.flush() } + override def toString: String = s"BufferedWriterConsoleOut($out)" } } diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index 8b7a99df6..9c832ec5e 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -350,6 +350,7 @@ object Terminal { override def getLastLine: Option[String] = t.getLastLine override def getLines: Seq[String] = t.getLines override private[sbt] def name: String = t.name + override def toString: String = s"ProxyTerminal(current = $t)" } private[sbt] def get: Terminal = ProxyTerminal From e4cd6a38fcaf2af2e0a0879d5c2136a05c3eb055 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sat, 25 Jul 2020 14:32:38 -0700 Subject: [PATCH 2/7] Hold lock while writing bytes to stdout We should always hold the print stream lock when calling progressState.write because otherwise the task progress thread could concurrently write to stdout. --- .../src/main/scala/sbt/internal/util/Terminal.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index 9c832ec5e..deb6489f6 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -882,8 +882,9 @@ object Terminal { } override def flush(): Unit = combinedOutputStream.flush() } - private def doWrite(bytes: Array[Byte]): Unit = - progressState.write(TerminalImpl.this, bytes, rawPrintStream, hasProgress.get && !rawMode.get) + private def doWrite(bytes: Array[Byte]): Unit = withPrintStream { ps => + progressState.write(TerminalImpl.this, bytes, ps, hasProgress.get && !rawMode.get) + } override private[sbt] val printStream: PrintStream = new LinePrintStream(outputStream) override def inputStream: InputStream = writeableInputStream From 6dd69a54ae81ac6bd01b21f23d33366cea708016 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sat, 25 Jul 2020 12:20:45 -0700 Subject: [PATCH 3/7] Close line reader when interrupted There are cases where if the ui state is changing rapidly, that an AskUserThread can be created and cancelled in a short time windows. This could cause problems if the AskUserThread is interrupted during `LineReader.createReader` which I think can shell out to run some commands so it is relatively slow. If the thread was interrupted during the call to `LineReader.createReader` and the interruption was not handled, then the thread would go into `LineReader.readLine`, which wouldn't exit until the user pressed enter. This ultimately caused the ui to break until enter because this zombie line reader would be holding the lock on the terminal input stream. --- .../main/scala/sbt/internal/util/JLine3.scala | 62 +++++++++- .../sbt/internal/util/ProgressState.scala | 2 +- .../main/scala/sbt/internal/util/Prompt.scala | 1 + .../scala/sbt/internal/util/Terminal.scala | 110 ++++++++---------- .../main/scala/sbt/internal/ui/UITask.scala | 73 +++++++----- .../scala/sbt/internal/ui/UserThread.scala | 27 ++++- .../main/scala/sbt/internal/Continuous.scala | 2 +- .../sbt/internal/server/NetworkChannel.scala | 11 +- 8 files changed, 175 insertions(+), 113 deletions(-) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala index eadd09ae6..2ef394a54 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala @@ -10,9 +10,9 @@ package sbt.internal.util import java.io.{ EOFException, InputStream, OutputStream, PrintWriter } import java.nio.charset.Charset import java.util.{ Arrays, EnumSet } -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import org.jline.utils.InfoCmp.Capability -import org.jline.utils.{ NonBlocking, OSUtils } +import org.jline.utils.{ ClosedException, NonBlockingReader, OSUtils } import org.jline.terminal.{ Attributes, Size, Terminal => JTerminal } import org.jline.terminal.Terminal.SignalHandler import org.jline.terminal.impl.AbstractTerminal @@ -20,6 +20,7 @@ import org.jline.terminal.impl.jansi.JansiSupportImpl import org.jline.terminal.impl.jansi.win.JansiWinSysTerminal import scala.collection.JavaConverters._ import scala.util.Try +import java.util.concurrent.LinkedBlockingQueue private[util] object JLine3 { private val capabilityMap = Capability @@ -77,6 +78,8 @@ private[util] object JLine3 { new AbstractTerminal(term.name, "ansi", Charset.forName("UTF-8"), SignalHandler.SIG_DFL) { val closed = new AtomicBoolean(false) setOnClose { () => + doClose() + reader.close() if (closed.compareAndSet(false, true)) { // This is necessary to shutdown the non blocking input reader // so that it doesn't keep blocking @@ -89,8 +92,22 @@ private[util] object JLine3 { parseInfoCmp() override val input: InputStream = new InputStream { override def read: Int = { - val res = try term.inputStream.read - catch { case _: InterruptedException => -2 } + val res = term.inputStream match { + case w: Terminal.WriteableInputStream => + val result = new LinkedBlockingQueue[Integer] + try { + w.read(result) + result.poll match { + case null => throw new ClosedException + case i => i.toInt + } + } catch { + case _: InterruptedException => + w.cancel() + throw new ClosedException + } + case _ => throw new ClosedException + } if (res == 4 && term.prompt.render().endsWith(term.prompt.mkPrompt())) throw new EOFException res @@ -110,8 +127,41 @@ private[util] object JLine3 { override def flush(): Unit = term.withPrintStream(_.flush()) } - override val reader = - NonBlocking.nonBlocking(term.name, input, Charset.defaultCharset()) + override val reader = new NonBlockingReader { + val buffer = new LinkedBlockingQueue[Integer] + val thread = new AtomicReference[Thread] + private def fillBuffer(): Unit = thread.synchronized { + thread.set(Thread.currentThread) + buffer.put( + try input.read() + catch { case _: InterruptedException => -3 } + ) + } + override def close(): Unit = thread.get match { + case null => + case t => t.interrupt() + } + override def read(timeout: Long, peek: Boolean) = { + if (buffer.isEmpty && !peek) fillBuffer() + (if (peek) buffer.peek else buffer.take) match { + case null => -2 + case i => if (i == -3) throw new ClosedException else i + } + } + override def peek(timeout: Long): Int = buffer.peek() match { + case null => -1 + case i => i.toInt + } + override def readBuffered(buf: Array[Char]): Int = { + if (buffer.isEmpty) fillBuffer() + buffer.take match { + case i if i == -1 => -1 + case i => + buf(0) = i.toChar + 1 + } + } + } override val writer: PrintWriter = new PrintWriter(output, true) /* * For now assume that the terminal capabilities for client and server diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala index f6fb154dc..f3a70a8a1 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala @@ -158,7 +158,7 @@ private[sbt] object ProgressState { if (!pe.skipIfActive.getOrElse(false) || (!isRunning && !isBatch)) { terminal.withPrintStream { ps => val commandFromThisTerminal = pe.channelName.fold(true)(_ == terminal.name) - val info = if ((isRunning || isBatch || noPrompt) && commandFromThisTerminal) { + val info = if (commandFromThisTerminal) { pe.items.map { item => val elapsed = item.elapsedMicros / 1000000L s" | => ${item.name} ${elapsed}s" diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala index 90f02f66c..89c1872cb 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Prompt.scala @@ -34,5 +34,6 @@ private[sbt] object Prompt { private[sbt] case object Running extends NoPrompt private[sbt] case object Batch extends NoPrompt private[sbt] case object Watch extends NoPrompt + private[sbt] case object Pending extends NoPrompt private[sbt] case object NoPrompt extends NoPrompt } diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index deb6489f6..f612a7730 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -11,7 +11,7 @@ import java.io.{ InputStream, InterruptedIOException, IOException, OutputStream, import java.nio.channels.ClosedChannelException import java.util.{ Arrays, Locale } import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } -import java.util.concurrent.{ ArrayBlockingQueue, Executors, LinkedBlockingQueue, TimeUnit } +import java.util.concurrent.{ Executors, LinkedBlockingQueue, TimeUnit } import jline.DefaultTerminal2 import jline.console.ConsoleReader @@ -141,7 +141,7 @@ trait Terminal extends AutoCloseable { private[sbt] def withRawOutput[R](f: => R): R private[sbt] def restore(): Unit = {} private[sbt] val progressState = new ProgressState(1) - private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Running) + private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Pending) private[sbt] final def prompt: Prompt = promptHolder.get private[sbt] final def setPrompt(newPrompt: Prompt): Unit = if (prompt != Prompt.NoPrompt) promptHolder.set(newPrompt) @@ -396,50 +396,31 @@ object Terminal { private[sbt] class WriteableInputStream(in: InputStream, name: String) extends InputStream with AutoCloseable { - final def write(bytes: Int*): Unit = waiting.synchronized { - waiting.poll match { - case null => - bytes.foreach(b => buffer.put(b)) - case w => - if (bytes.length > 1) bytes.tail.foreach(b => buffer.put(b)) - bytes.headOption.foreach(b => w.put(b)) - } + final def write(bytes: Int*): Unit = readThread.synchronized { + bytes.foreach(b => buffer.put(b)) } private[this] val executor = Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-input-reader")) private[this] val buffer = new LinkedBlockingQueue[Integer] private[this] val closed = new AtomicBoolean(false) private[this] val readQueue = new LinkedBlockingQueue[Unit] - private[this] val waiting = new ArrayBlockingQueue[LinkedBlockingQueue[Integer]](1) private[this] val readThread = new AtomicReference[Thread] /* - * Starts a loop that waits for consumers of the InputStream to call read. - * When read is called, we enqueue a `LinkedBlockingQueue[Int]` to which - * the runnable can return a byte from stdin. If the read caller is interrupted, - * they remove the result from the waiting set and any byte read will be - * enqueued in the buffer. It is done this way so that we only read from - * System.in when a caller actually asks for bytes. If we constantly poll - * from System.in, then when the user calls reboot from the console, the - * first character they type after reboot is swallowed by the previous - * sbt main program. If the user calls reboot from a remote client, we - * can't avoid losing the first byte inputted in the console. A more - * robust fix would be to override System.in at the launcher level instead - * of at the sbt level. At the moment, the use case of a user calling - * reboot from a network client and the adding input at the server console - * seems pathological enough that it isn't worth putting more effort into - * fixing. - * + * Starts a loop that fills a buffer with bytes from stdin. We only read from + * the underlying stream when the buffer is empty and there is an active reader. + * If the reader detaches without consuming any bytes, we just buffer the + * next byte that we read from the stream. One known issue with this approach + * is that if a remote client triggers a reboot, we cannot necessarily stop this + * loop from consuming the next byte from standard in even if sbt has fully + * rebooted and the byte will never be consumed. We try to fix this in withStreams + * by setting the terminal to raw mode, which the input stream makes it non blocking, + * but this approach only works on posix platforms. */ private[this] val runnable: Runnable = () => { @tailrec def impl(): Unit = { val _ = readQueue.take val b = in.read - // The downstream consumer may have been interrupted. Buffer the result - // when that hapens. - waiting.poll match { - case null => buffer.put(b) - case q => q.put(b) - } + buffer.put(b) if (b != -1 && !Thread.interrupted()) impl() else closed.set(true) } @@ -447,30 +428,28 @@ object Terminal { catch { case _: InterruptedException => closed.set(true) } } executor.submit(runnable) - override def read(): Int = - if (closed.get) -1 - else - synchronized { + def read(result: LinkedBlockingQueue[Integer]): Unit = + if (!closed.get) + readThread.synchronized { readThread.set(Thread.currentThread) try buffer.poll match { case null => - val result = new LinkedBlockingQueue[Integer] - waiting.synchronized(waiting.put(result)) readQueue.put(()) - try result.take.toInt - catch { - case e: InterruptedException => - waiting.remove(result) - -1 - } + result.put(buffer.take) case b if b == -1 => throw new ClosedChannelException - case b => b.toInt + case b => result.put(b) } finally readThread.set(null) } - def cancel(): Unit = waiting.synchronized { + override def read(): Int = { + val result = new LinkedBlockingQueue[Integer] + read(result) + result.poll match { + case null => -1 + case i => i.toInt + } + } + def cancel(): Unit = readThread.synchronized { Option(readThread.getAndSet(null)).foreach(_.interrupt()) - waiting.forEach(_.put(-2)) - waiting.clear() readQueue.clear() } @@ -732,7 +711,11 @@ object Terminal { } term.restore() term.setEchoEnabled(true) - new ConsoleTerminal(term, nonBlockingIn, originalOut) + new ConsoleTerminal( + term, + if (System.console == null) nullWriteableInputStream else nonBlockingIn, + originalOut + ) } private[sbt] def reset(): Unit = { @@ -780,7 +763,7 @@ object Terminal { private[sbt] def deprecatedTeminal: jline.Terminal = console.toJLine private class ConsoleTerminal( val term: jline.Terminal with jline.Terminal2, - in: InputStream, + in: WriteableInputStream, out: OutputStream ) extends TerminalImpl(in, out, originalErr, "console0") { private[util] lazy val system = JLine3.system @@ -837,17 +820,13 @@ object Terminal { } } private[sbt] abstract class TerminalImpl private[sbt] ( - val in: InputStream, + val in: WriteableInputStream, val out: OutputStream, override val errorStream: OutputStream, override private[sbt] val name: String ) extends Terminal { private[this] val rawMode = new AtomicBoolean(false) private[this] val writeLock = new AnyRef - private[this] val writeableInputStream = in match { - case w: WriteableInputStream => w - case _ => new WriteableInputStream(in, name) - } def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f override def getLastLine: Option[String] = progressState.currentLine override def getLines: Seq[String] = progressState.getLines @@ -886,9 +865,9 @@ object Terminal { progressState.write(TerminalImpl.this, bytes, ps, hasProgress.get && !rawMode.get) } override private[sbt] val printStream: PrintStream = new LinePrintStream(outputStream) - override def inputStream: InputStream = writeableInputStream + override def inputStream: InputStream = in - private[sbt] def write(bytes: Int*): Unit = writeableInputStream.write(bytes: _*) + private[sbt] def write(bytes: Int*): Unit = in.write(bytes: _*) private[this] val isStopped = new AtomicBoolean(false) override def getLineHeightAndWidth(line: String): (Int, Int) = getWidth match { @@ -909,9 +888,16 @@ object Terminal { writeLock.synchronized(f(rawPrintStream)) override def close(): Unit = if (isStopped.compareAndSet(false, true)) { - writeableInputStream.close() + in.close() } } + private lazy val nullInputStream: InputStream = () => { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + -1 + } + private lazy val nullWriteableInputStream = + new WriteableInputStream(nullInputStream, "null-writeable-input-stream") private[sbt] val NullTerminal = new Terminal { override def close(): Unit = {} override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = false @@ -922,11 +908,7 @@ object Terminal { override def getNumericCapability(capability: String, jline3: Boolean): Integer = null override def getStringCapability(capability: String, jline3: Boolean): String = null override def getWidth: Int = 0 - override def inputStream: java.io.InputStream = () => { - try this.synchronized(this.wait) - catch { case _: InterruptedException => } - -1 - } + override def inputStream: java.io.InputStream = nullInputStream override def isAnsiSupported: Boolean = false override def isColorEnabled: Boolean = false override def isEchoEnabled: Boolean = false diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala index 0b5618cd2..a4144b3f8 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UITask.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -11,7 +11,6 @@ import java.io.File import java.nio.channels.ClosedChannelException import java.util.concurrent.atomic.AtomicBoolean -//import jline.console.history.PersistentHistory import sbt.BasicCommandStrings.{ Cancel, TerminateAction, Shutdown } import sbt.BasicKeys.{ historyPath, terminalShellPrompt } import sbt.State @@ -23,55 +22,67 @@ import sbt.internal.util.complete.{ Parser } import scala.annotation.tailrec private[sbt] trait UITask extends Runnable with AutoCloseable { - private[sbt] def channel: CommandChannel - private[sbt] def reader: UITask.Reader + private[sbt] val channel: CommandChannel + private[sbt] val reader: UITask.Reader private[this] final def handleInput(s: Either[String, String]): Boolean = s match { case Left(m) => channel.onFastTrackTask(m) case Right(cmd) => channel.onCommand(cmd) } private[this] val isStopped = new AtomicBoolean(false) override def run(): Unit = { - @tailrec def impl(): Unit = { + @tailrec def impl(): Unit = if (!isStopped.get) { val res = reader.readLine() if (!handleInput(res) && !isStopped.get) impl() } try impl() catch { case _: InterruptedException | _: ClosedChannelException => isStopped.set(true) } } - override def close(): Unit = isStopped.set(true) + override def close(): Unit = { + isStopped.set(true) + reader.close() + } } private[sbt] object UITask { - trait Reader { def readLine(): Either[String, String] } + trait Reader extends AutoCloseable { + def readLine(): Either[String, String] + override def close(): Unit = {} + } object Reader { def terminalReader(parser: Parser[_])( terminal: Terminal, state: State - ): Reader = { () => - try { - val clear = terminal.ansi(ClearPromptLine, "") - @tailrec def impl(): Either[String, String] = { - val reader = LineReader.createReader(history(state), parser, terminal, terminal.prompt) - (try reader.readLine(clear + terminal.prompt.mkPrompt()) - finally reader.close) match { - case None if terminal == Terminal.console && System.console == null => - // No stdin is attached to the process so just ignore the result and - // block until the thread is interrupted. - this.synchronized(this.wait()) - Right("") // should be unreachable - // JLine returns null on ctrl+d when there is no other input. This interprets - // ctrl+d with no imput as an exit - case None => Left(TerminateAction) - case Some(s: String) => - s.trim() match { - case "" => impl() - case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd) - case cmd => Right(cmd) - } + ): Reader = new Reader { + val closed = new AtomicBoolean(false) + def readLine(): Either[String, String] = + try { + val clear = terminal.ansi(ClearPromptLine, "") + @tailrec def impl(): Either[String, String] = { + val thread = Thread.currentThread + if (thread.isInterrupted || closed.get) throw new InterruptedException + val reader = LineReader.createReader(history(state), parser, terminal) + if (thread.isInterrupted || closed.get) throw new InterruptedException + (try reader.readLine(clear + terminal.prompt.mkPrompt()) + finally reader.close) match { + case None if terminal == Terminal.console && System.console == null => + // No stdin is attached to the process so just ignore the result and + // block until the thread is interrupted. + this.synchronized(this.wait()) + Right("") // should be unreachable + // JLine returns null on ctrl+d when there is no other input. This interprets + // ctrl+d with no imput as an exit + case None => Left(TerminateAction) + case Some(s: String) => + s.trim() match { + case "" => impl() + case cmd @ (`Shutdown` | `TerminateAction` | `Cancel`) => Left(cmd) + case cmd => Right(cmd) + } + } } - } - impl() - } catch { case e: InterruptedException => Right("") } + impl() + } catch { case e: InterruptedException => Right("") } + override def close(): Unit = closed.set(true) } } private[this] def history(s: State): Option[File] = @@ -87,7 +98,7 @@ private[sbt] object UITask { state: State, override val channel: CommandChannel, ) extends UITask { - override private[sbt] def reader: UITask.Reader = { + override private[sbt] lazy val reader: UITask.Reader = { UITask.Reader.terminalReader(state.combinedParser)(channel.terminal, state) } } diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala index e868c3682..7f1d007f5 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -14,7 +14,7 @@ import java.util.concurrent.Executors import sbt.State import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } -import sbt.internal.util.Prompt.{ AskUser, Running } +import sbt.internal.util.Prompt private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { private[this] val uiThread = new AtomicReference[(UITask, Thread)] @@ -31,9 +31,15 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable uiThread.synchronized { val task = channel.makeUIThread(state) def submit(): Thread = { - val thread = new Thread(() => { - task.run() - uiThread.set(null) + def close(): Unit = { + uiThread.get match { + case (_, t) if t == thread => uiThread.set(null) + case _ => + } + } + lazy val thread = new Thread(() => { + try task.run() + finally close() }, s"sbt-$name-ui-thread") thread.setDaemon(true) thread.start() @@ -60,6 +66,13 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable case (t, thread) => t.close() Util.ignoreResult(thread.interrupt()) + try thread.join(1000) + catch { case _: InterruptedException => } + + // This join should always work, but if it doesn't log an error because + // it can cause problems if the thread isn't joined + if (thread.isAlive) System.err.println(s"Unable to join thread $thread") + () } } @@ -70,8 +83,9 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable } val state = consolePromptEvent.state terminal.prompt match { - case Running => terminal.setPrompt(AskUser(() => UITask.shellPrompt(terminal, state))) - case _ => + case Prompt.Running | Prompt.Pending => + terminal.setPrompt(Prompt.AskUser(() => UITask.shellPrompt(terminal, state))) + case _ => } onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) reset(state) @@ -80,6 +94,7 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable private[sbt] def onConsoleUnpromptEvent( consoleUnpromptEvent: ConsoleUnpromptEvent ): Unit = { + terminal.setPrompt(Prompt.Pending) if (consoleUnpromptEvent.lastSource.fold(true)(_.channelName != name)) { terminal.progressState.reset() } else stopThread() diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index 0f13a36b4..351490f80 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -1231,7 +1231,7 @@ private[sbt] object ContinuousCommands { state: State ) extends Thread(s"sbt-${channel.name}-watch-ui-thread") with UITask { - override private[sbt] def reader: UITask.Reader = () => { + override private[sbt] lazy val reader: UITask.Reader = () => { def stop = Right(s"${ContinuousCommands.stopWatch} ${channel.name}") val exitAction: Watch.Action = { Watch.apply( diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index f7583f9e1..fd461c06f 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -110,7 +110,7 @@ final class NetworkChannel( } private[sbt] def write(byte: Byte) = inputBuffer.add(byte) - private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal) + private[this] val terminalHolder = new AtomicReference[Terminal](Terminal.NullTerminal) override private[sbt] def terminal: Terminal = terminalHolder.get override val userThread: UserThread = new UserThread(this) @@ -152,8 +152,8 @@ final class NetworkChannel( if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command) else new UITask { - override private[sbt] def channel = NetworkChannel.this - override def reader: UITask.Reader = () => { + override private[sbt] val channel = NetworkChannel.this + override private[sbt] lazy val reader: UITask.Reader = () => { try { this.synchronized(this.wait) Left(TerminateAction) @@ -650,6 +650,8 @@ final class NetworkChannel( } override def available(): Int = inputBuffer.size } + private[this] lazy val writeableInputStream: Terminal.WriteableInputStream = + new Terminal.WriteableInputStream(inputStream, name) import sjsonnew.BasicJsonProtocol._ import scala.collection.JavaConverters._ @@ -726,7 +728,8 @@ final class NetworkChannel( write(java.util.Arrays.copyOfRange(b, off, off + len)) } } - private class NetworkTerminal extends TerminalImpl(inputStream, outputStream, errorStream, name) { + private class NetworkTerminal + extends TerminalImpl(writeableInputStream, outputStream, errorStream, name) { private[this] val pending = new AtomicBoolean(false) private[this] val closed = new AtomicBoolean(false) private[this] val properties = new AtomicReference[TerminalPropertiesResponse] From 90dacc339c0bc5afdfa80b98775f235cff56101f Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Mon, 20 Jul 2020 10:12:04 -0700 Subject: [PATCH 4/7] Support scala 2.13 console in thin client In order to make the console task work with scala 2.13 and the thin client, we need to provide a way for the scala repl to use an sbt provided jline3 terminal instead of the default terminal typically built by the repl. We also need to put jline 3 higher up in the classloading hierarchy to ensure that two versions of jline 3 are not loaded (which makes it impossible to share the sbt terminal with the scala terminal). One impact of this change is the decoupling of the version of jline-terminal used by the in process scala console and the version of jline-terminal specified by the scala version itself. It is possible to override this by setting the `useScalaReplJLine` flag to true. When that is set, the scala REPL will run in a fully isolated classloader. That will ensure that the versions are consistent. It will, however, for sure break the thin client and may interfere with the embedded shell ui. As part of this work, I also discovered that jline 3 Terminal.getSize is very slow. In jline 2, the terminal attributes were automatically cached with a timeout of, I think, 1 second so it wasn't a big deal to call Terminal.getAttributes. The getSize method in jline 3 is not cached and it shells out to run a tty command. This caused a significant performance regression in sbt because when progress is enabled, we call Terminal.getSize whenever we log any messages. I added caching of getSize at the TerminalImpl level to address this. The timeout is 1 second, which seems responsive enough for most use cases. We could also move the calculation onto a background thread and have it periodically updated, but that seems like overkill. --- build.sbt | 7 ++-- .../sbt/internal/util/DeprecatedJLine.java | 21 ++++++++++ .../main/scala/sbt/internal/util/JLine3.scala | 8 ++-- .../scala/sbt/internal/util/Terminal.scala | 36 +++++++++++------ main-actions/src/main/scala/sbt/Console.scala | 4 +- main-command/src/main/scala/sbt/State.scala | 4 ++ .../internal/classpath/ClassLoaderCache.scala | 25 +++++++++--- .../sbt/internal/client/NetworkClient.scala | 11 +++++- .../main/scala/sbt/internal/ui/UITask.scala | 6 ++- .../main/java/sbt/internal/JLineLoader.java | 34 ++++++++++++++++ .../java/sbt/internal/MetaBuildLoader.java | 23 ++++++++--- main/src/main/scala/sbt/Defaults.scala | 39 ++++++++++++------- main/src/main/scala/sbt/Keys.scala | 3 ++ main/src/main/scala/sbt/Main.scala | 3 ++ .../sbt/internal/XMainConfiguration.scala | 7 ++-- .../sbt/internal/server/NetworkChannel.scala | 8 ++++ .../sbt/internal/server/VirtualTerminal.scala | 29 +++++++++++++- project/Dependencies.scala | 6 ++- .../sbt/protocol/TerminalGetSizeQuery.scala | 29 ++++++++++++++ .../protocol/TerminalGetSizeResponse.scala | 36 +++++++++++++++++ .../codec/CommandMessageFormats.scala | 4 +- .../protocol/codec/EventMessageFormats.scala | 4 +- .../sbt/protocol/codec/JsonProtocol.scala | 2 + .../codec/TerminalGetSizeQueryFormats.scala | 27 +++++++++++++ .../TerminalGetSizeResponseFormats.scala | 29 ++++++++++++++ protocol/src/main/contraband/server.contra | 6 +++ .../scala/sbt/protocol/Serialization.scala | 3 +- 27 files changed, 352 insertions(+), 62 deletions(-) create mode 100644 internal/util-logging/src/main/java/sbt/internal/util/DeprecatedJLine.java create mode 100644 main/src/main/java/sbt/internal/JLineLoader.java create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeQuery.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeResponse.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeQueryFormats.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeResponseFormats.scala diff --git a/build.sbt b/build.sbt index 582a7ef93..5b0fcd298 100644 --- a/build.sbt +++ b/build.sbt @@ -304,7 +304,7 @@ val completeProj = (project in file("internal") / "util-complete") testedBaseSettings, name := "Completion", libraryDependencies += jline, - libraryDependencies += jline3, + libraryDependencies += jline3Reader, mimaSettings, // Parser is used publicly, so we can't break bincompat. mimaBinaryIssueFilters := Seq( @@ -366,7 +366,8 @@ lazy val utilLogging = (project in file("internal") / "util-logging") libraryDependencies ++= Seq( jline, - jline3, + jline3Terminal, + jline3Jansi, log4jApi, log4jCore, disruptor, @@ -661,6 +662,7 @@ lazy val actionsProj = (project in file("main-actions")) testedBaseSettings, name := "Actions", libraryDependencies += sjsonNewScalaJson.value, + libraryDependencies += jline3Terminal, mimaSettings, mimaBinaryIssueFilters ++= Seq( // Removed unused private[sbt] nested class @@ -1103,7 +1105,6 @@ lazy val sbtClientProj = (project in file("client")) crossPaths := false, exportJars := true, libraryDependencies += jansi, - libraryDependencies += jline3Jansi, libraryDependencies += scalatest % "test", /* * On windows, the raw classpath is too large to be a command argument to an diff --git a/internal/util-logging/src/main/java/sbt/internal/util/DeprecatedJLine.java b/internal/util-logging/src/main/java/sbt/internal/util/DeprecatedJLine.java new file mode 100644 index 000000000..4e091a2a5 --- /dev/null +++ b/internal/util-logging/src/main/java/sbt/internal/util/DeprecatedJLine.java @@ -0,0 +1,21 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.util; + +import org.jline.terminal.TerminalBuilder; + +/** + * This exists to a provide a wrapper to TerminalBuilder.setTerminalOverride that will not emit a + * deprecation warning when called from scala. + */ +public class DeprecatedJLine { + @SuppressWarnings("deprecation") + public static void setTerminalOverride(final org.jline.terminal.Terminal terminal) { + TerminalBuilder.setTerminalOverride(terminal); + } +} diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala index 2ef394a54..953042b72 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/JLine3.scala @@ -7,7 +7,7 @@ package sbt.internal.util -import java.io.{ EOFException, InputStream, OutputStream, PrintWriter } +import java.io.{ InputStream, OutputStream, PrintWriter } import java.nio.charset.Charset import java.util.{ Arrays, EnumSet } import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.util.Try import java.util.concurrent.LinkedBlockingQueue -private[util] object JLine3 { +private[sbt] object JLine3 { private val capabilityMap = Capability .values() .map { c => @@ -109,18 +109,18 @@ private[util] object JLine3 { case _ => throw new ClosedException } if (res == 4 && term.prompt.render().endsWith(term.prompt.mkPrompt())) - throw new EOFException + throw new ClosedException res } } override val output: OutputStream = new OutputStream { override def write(b: Int): Unit = write(Array[Byte](b.toByte)) override def write(b: Array[Byte]): Unit = if (!closed.get) term.withPrintStream { ps => + ps.write(b) term.prompt match { case a: Prompt.AskUser => a.write(b) case _ => } - ps.write(b) } override def write(b: Array[Byte], offset: Int, len: Int) = write(Arrays.copyOfRange(b, offset, offset + len)) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index f612a7730..a440fc77b 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -16,6 +16,7 @@ import java.util.concurrent.{ Executors, LinkedBlockingQueue, TimeUnit } import jline.DefaultTerminal2 import jline.console.ConsoleReader import scala.annotation.tailrec +import scala.concurrent.duration._ import scala.util.Try import scala.util.control.NonFatal @@ -174,10 +175,7 @@ object Terminal { try Terminal.console.printStream.println(s"[info] $string") catch { case _: IOException => } } - private[sbt] def set(terminal: Terminal): Terminal = { - jline.TerminalFactory.set(terminal.toJLine) - activeTerminal.getAndSet(terminal) - } + private[sbt] def set(terminal: Terminal): Terminal = activeTerminal.getAndSet(terminal) implicit class TerminalOps(private val term: Terminal) extends AnyVal { def ansi(richString: => String, string: => String): String = if (term.isAnsiSupported) richString else string @@ -500,7 +498,6 @@ object Terminal { * System.out through the terminal's input and output streams. */ private[this] val activeTerminal = new AtomicReference[Terminal](consoleTerminalHolder.get) - jline.TerminalFactory.set(consoleTerminalHolder.get.toJLine) /** * The boot input stream allows a remote client to forward input to the sbt process while @@ -674,13 +671,13 @@ object Terminal { if (alive) try terminal.init() catch { - case _: InterruptedException => + case _: InterruptedException | _: java.io.IOError => } override def restore(): Unit = if (alive) try terminal.restore() catch { - case _: InterruptedException => + case _: InterruptedException | _: java.io.IOError => } override def reset(): Unit = try terminal.reset() @@ -767,10 +764,12 @@ object Terminal { out: OutputStream ) extends TerminalImpl(in, out, originalErr, "console0") { private[util] lazy val system = JLine3.system - private[this] def isCI = sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI") - override def getWidth: Int = system.getSize.getColumns - override def getHeight: Int = system.getSize.getRows - override def isAnsiSupported: Boolean = term.isAnsiSupported && !isCI + override private[sbt] def getSizeImpl: (Int, Int) = { + val size = system.getSize + (size.getColumns, size.getRows) + } + private[this] val isCI = sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI") + override lazy val isAnsiSupported: Boolean = term.isAnsiSupported && !isCI override def isEchoEnabled: Boolean = system.echo() override def isSuccessEnabled: Boolean = true override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = @@ -785,7 +784,7 @@ object Terminal { override private[sbt] def restore(): Unit = term.restore() override private[sbt] def getAttributes: Map[String, String] = - JLine3.toMap(system.getAttributes) + Try(JLine3.toMap(system.getAttributes)).getOrElse(Map.empty) override private[sbt] def setAttributes(attributes: Map[String, String]): Unit = system.setAttributes(JLine3.attributesFromMap(attributes)) override private[sbt] def setSize(width: Int, height: Int): Unit = @@ -825,6 +824,19 @@ object Terminal { override val errorStream: OutputStream, override private[sbt] val name: String ) extends Terminal { + private[sbt] def getSizeImpl: (Int, Int) + private[this] val sizeRefreshPeriod = 1.second + private[this] val size = + new AtomicReference[((Int, Int), Deadline)](((1, 1), Deadline.now - 1.day)) + private[this] def setSize() = size.set((Try(getSizeImpl).getOrElse((1, 1)), Deadline.now)) + private[this] def getSize = size.get match { + case (s, d) if (d + sizeRefreshPeriod).isOverdue => + setSize() + size.get._1 + case (s, _) => s + } + override def getWidth: Int = getSize._1 + override def getHeight: Int = getSize._2 private[this] val rawMode = new AtomicBoolean(false) private[this] val writeLock = new AnyRef def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f diff --git a/main-actions/src/main/scala/sbt/Console.scala b/main-actions/src/main/scala/sbt/Console.scala index 8e9d40984..43ca79b2a 100644 --- a/main-actions/src/main/scala/sbt/Console.scala +++ b/main-actions/src/main/scala/sbt/Console.scala @@ -10,7 +10,7 @@ package sbt import java.io.File import java.nio.channels.ClosedChannelException import sbt.internal.inc.{ AnalyzingCompiler, PlainVirtualFile } -import sbt.internal.util.Terminal +import sbt.internal.util.{ DeprecatedJLine, Terminal } import sbt.util.Logger import xsbti.compile.{ Compilers, Inputs } @@ -67,6 +67,8 @@ final class Console(compiler: AnalyzingCompiler) { try { sys.props("scala.color") = if (terminal.isColorEnabled) "true" else "false" terminal.withRawOutput { + jline.TerminalFactory.set(terminal.toJLine) + DeprecatedJLine.setTerminalOverride(sbt.internal.util.JLine3(terminal)) terminal.withRawInput(Run.executeTrapExit(console0, log)) } } finally { diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index 409bac887..32fd3c7a8 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -389,6 +389,10 @@ object State { s get BasicKeys.classLoaderCache getOrElse (throw new IllegalStateException( "Tried to get classloader cache for uninitialized state." )) + private[sbt] def extendedClassLoaderCache: ClassLoaderCache = + s get BasicKeys.extendedClassLoaderCache getOrElse (throw new IllegalStateException( + "Tried to get extended classloader cache for uninitialized state." + )) def initializeClassLoaderCache: State = { s.get(BasicKeys.extendedClassLoaderCache).foreach(_.close()) val cache = newClassLoaderCache diff --git a/main-command/src/main/scala/sbt/internal/classpath/ClassLoaderCache.scala b/main-command/src/main/scala/sbt/internal/classpath/ClassLoaderCache.scala index dd730db87..0ef7d3911 100644 --- a/main-command/src/main/scala/sbt/internal/classpath/ClassLoaderCache.scala +++ b/main-command/src/main/scala/sbt/internal/classpath/ClassLoaderCache.scala @@ -11,7 +11,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.lang.ref.{ Reference, ReferenceQueue, SoftReference } import java.net.URLClassLoader -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{ AtomicInteger, AtomicReference } import sbt.internal.inc.classpath.{ AbstractClassLoaderCache, @@ -30,9 +30,12 @@ private object ClassLoaderCache { private def threadID = new AtomicInteger(0) } private[sbt] class ClassLoaderCache( - override val commonParent: ClassLoader, + val parent: ClassLoader, private val miniProvider: Option[(File, ClassLoader)] ) extends AbstractClassLoaderCache { + private[this] val parentHolder = new AtomicReference(parent) + def commonParent = parentHolder.get() + def setParent(parent: ClassLoader): Unit = parentHolder.set(parent) def this(commonParent: ClassLoader) = this(commonParent, None) def this(scalaProvider: ScalaProvider) = this(scalaProvider.launcher.topLoader, { @@ -51,8 +54,9 @@ private[sbt] class ClassLoaderCache( } } private class Key(val fileStamps: Seq[(File, Long)], val parent: ClassLoader) { - def this(files: List[File]) = - this(files.map(f => f -> IO.getModifiedTimeOrZero(f)), commonParent) + def this(files: List[File], parent: ClassLoader) = + this(files.map(f => f -> IO.getModifiedTimeOrZero(f)), parent) + def this(files: List[File]) = this(files, commonParent) lazy val files: Seq[File] = fileStamps.map(_._1) lazy val maxStamp: Long = fileStamps.maxBy(_._2)._2 class CachedClassLoader @@ -169,10 +173,19 @@ private[sbt] class ClassLoaderCache( val key = new Key(files, parent) get(key, mkLoader) } - override def apply(files: List[File]): ClassLoader = { - val key = new Key(files) + def apply(files: List[File], parent: ClassLoader): ClassLoader = { + val key = new Key(files, parent) get(key, () => key.toClassLoader) } + override def apply(files: List[File]): ClassLoader = { + files match { + case d :: s :: Nil if d.getName.startsWith("dotty-library") => + apply(files, classOf[org.jline.terminal.Terminal].getClassLoader) + case _ => + val key = new Key(files) + get(key, () => key.toClassLoader) + } + } override def cachedCustomClassloader( files: List[File], mkLoader: () => ClassLoader diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index b95b6f55c..46fc1fa31 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -47,11 +47,12 @@ import Serialization.{ systemErrFlush, terminalCapabilities, terminalCapabilitiesResponse, + terminalGetSize, terminalPropertiesQuery, terminalPropertiesResponse, + terminalSetSize, getTerminalAttributes, setTerminalAttributes, - setTerminalSize, } import NetworkClient.Arguments @@ -657,7 +658,13 @@ class NetworkClient( cchars = attrs.getOrElse("cchars", ""), ) sendCommandResponse("", response, msg.id) - case (`setTerminalSize`, Some(json)) => + case (`terminalGetSize`, _) => + val response = TerminalGetSizeResponse( + Terminal.console.getWidth, + Terminal.console.getHeight, + ) + sendCommandResponse("", response, msg.id) + case (`terminalSetSize`, Some(json)) => Converter.fromJson[TerminalSetSizeCommand](json) match { case Success(size) => Terminal.console.setSize(size.width, size.height) diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala index a4144b3f8..dce4b1012 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UITask.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -49,6 +49,8 @@ private[sbt] object UITask { override def close(): Unit = {} } object Reader { + // Avoid filling the stack trace since it isn't helpful here + object interrupted extends InterruptedException def terminalReader(parser: Parser[_])( terminal: Terminal, state: State @@ -59,9 +61,9 @@ private[sbt] object UITask { val clear = terminal.ansi(ClearPromptLine, "") @tailrec def impl(): Either[String, String] = { val thread = Thread.currentThread - if (thread.isInterrupted || closed.get) throw new InterruptedException + if (thread.isInterrupted || closed.get) throw interrupted val reader = LineReader.createReader(history(state), parser, terminal) - if (thread.isInterrupted || closed.get) throw new InterruptedException + if (thread.isInterrupted || closed.get) throw interrupted (try reader.readLine(clear + terminal.prompt.mkPrompt()) finally reader.close) match { case None if terminal == Terminal.console && System.console == null => diff --git a/main/src/main/java/sbt/internal/JLineLoader.java b/main/src/main/java/sbt/internal/JLineLoader.java new file mode 100644 index 000000000..098b21bc9 --- /dev/null +++ b/main/src/main/java/sbt/internal/JLineLoader.java @@ -0,0 +1,34 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +import java.net.URL; +import java.net.URLClassLoader; + +class JLineLoader extends URLClassLoader { + JLineLoader(final URL[] urls, final ClassLoader parent) { + super(urls, parent); + } + + @Override + public String toString() { + final StringBuilder result = new StringBuilder(); + result.append("JLineLoader("); + final URL[] urls = getURLs(); + for (int i = 0; i < urls.length; ++i) { + result.append(urls[i].toString()); + if (i < urls.length - 1) result.append(", "); + } + result.append(")"); + return result.toString(); + } + + static { + registerAsParallelCapable(); + } +} diff --git a/main/src/main/java/sbt/internal/MetaBuildLoader.java b/main/src/main/java/sbt/internal/MetaBuildLoader.java index a027314b1..2ca2a5088 100644 --- a/main/src/main/java/sbt/internal/MetaBuildLoader.java +++ b/main/src/main/java/sbt/internal/MetaBuildLoader.java @@ -22,16 +22,19 @@ public final class MetaBuildLoader extends URLClassLoader { private final URLClassLoader fullScalaLoader; private final URLClassLoader libraryLoader; private final URLClassLoader interfaceLoader; + private final URLClassLoader jlineLoader; MetaBuildLoader( final URL[] urls, final URLClassLoader fullScalaLoader, final URLClassLoader libraryLoader, - final URLClassLoader interfaceLoader) { + final URLClassLoader interfaceLoader, + final URLClassLoader jlineLoader) { super(urls, fullScalaLoader); this.fullScalaLoader = fullScalaLoader; this.libraryLoader = libraryLoader; this.interfaceLoader = interfaceLoader; + this.jlineLoader = jlineLoader; } @Override @@ -45,6 +48,7 @@ public final class MetaBuildLoader extends URLClassLoader { fullScalaLoader.close(); libraryLoader.close(); interfaceLoader.close(); + jlineLoader.close(); } static { @@ -61,20 +65,26 @@ public final class MetaBuildLoader extends URLClassLoader { */ public static MetaBuildLoader makeLoader(final AppProvider appProvider) throws IOException { final Pattern pattern = - Pattern.compile("^(test-interface-[0-9.]+|jline-[0-9.]+-sbt-.*|jansi-[0-9.]+)\\.jar"); + Pattern.compile( + "^(test-interface-[0-9.]+|jline-(terminal-)?[0-9.]+-sbt-.*|jansi-[0-9.]+)\\.jar"); final File[] cp = appProvider.mainClasspath(); - final URL[] interfaceURLs = new URL[3]; + final URL[] interfaceURLs = new URL[1]; + final URL[] jlineURLs = new URL[3]; final File[] extra = appProvider.id().classpathExtra() == null ? new File[0] : appProvider.id().classpathExtra(); final Set bottomClasspath = new LinkedHashSet<>(); { int interfaceIndex = 0; + int jlineIndex = 0; for (final File file : cp) { final String name = file.getName(); - if (pattern.matcher(name).find()) { + if (name.contains("test-interface") && pattern.matcher(name).find()) { interfaceURLs[interfaceIndex] = file.toURI().toURL(); interfaceIndex += 1; + } else if (pattern.matcher(name).find()) { + jlineURLs[jlineIndex] = file.toURI().toURL(); + jlineIndex += 1; } else { bottomClasspath.add(file); } @@ -108,6 +118,7 @@ public final class MetaBuildLoader extends URLClassLoader { if (topLoader == null) topLoader = scalaProvider.launcher().topLoader(); final TestInterfaceLoader interfaceLoader = new TestInterfaceLoader(interfaceURLs, topLoader); + final JLineLoader jlineLoader = new JLineLoader(jlineURLs, interfaceLoader); final File[] siJars = scalaProvider.jars(); final URL[] lib = new URL[1]; int scalaRestCount = siJars.length - 1; @@ -131,8 +142,8 @@ public final class MetaBuildLoader extends URLClassLoader { } } assert lib[0] != null : "no scala-library.jar"; - final ScalaLibraryClassLoader libraryLoader = new ScalaLibraryClassLoader(lib, interfaceLoader); + final ScalaLibraryClassLoader libraryLoader = new ScalaLibraryClassLoader(lib, jlineLoader); final FullScalaLoader fullScalaLoader = new FullScalaLoader(scalaRest, libraryLoader); - return new MetaBuildLoader(rest, fullScalaLoader, libraryLoader, interfaceLoader); + return new MetaBuildLoader(rest, fullScalaLoader, libraryLoader, interfaceLoader, jlineLoader); } } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index cb3c28f05..7fbf1aec9 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -8,7 +8,7 @@ package sbt import java.io.{ File, PrintWriter } -import java.net.{ URI, URL, URLClassLoader } +import java.net.{ URI, URL } import java.nio.file.{ Paths, Path => NioPath } import java.util.Optional import java.util.concurrent.TimeUnit @@ -34,9 +34,8 @@ import sbt.Scope.{ GlobalScope, ThisScope, fillTaskAxis } import sbt.coursierint._ import sbt.internal.CommandStrings.ExportStream import sbt.internal._ -import sbt.internal.classpath.AlternativeZincUtil +import sbt.internal.classpath.{ AlternativeZincUtil, ClassLoaderCache } import sbt.internal.inc.JavaInterfaceUtil._ -import sbt.internal.inc.classpath.{ ClassLoaderCache, ClasspathFilter, ClasspathUtil } import sbt.internal.inc.{ CompileOutput, MappedFileConverter, @@ -45,6 +44,8 @@ import sbt.internal.inc.{ ZincLmUtil, ZincUtil } +import sbt.internal.inc.classpath.{ ClasspathFilter, ClasspathUtil } +import sbt.internal.inc.{ MappedFileConverter, PlainVirtualFile, Stamps, ZincLmUtil, ZincUtil } import sbt.internal.io.{ Source, WatchState } import sbt.internal.librarymanagement.mavenint.{ PomExtraDependencyAttributes, @@ -386,6 +387,11 @@ object Defaults extends BuildCommon { }, turbo :== SysProp.turbo, usePipelining :== SysProp.pipelining, + useScalaReplJLine :== false, + scalaInstanceTopLoader := { + if (!useScalaReplJLine.value) classOf[org.jline.terminal.Terminal].getClassLoader + else appConfiguration.value.provider.scalaProvider.launcher.topLoader.getParent + }, useSuperShell := { if (insideCI.value) false else Terminal.console.isSupershellEnabled }, progressReports := { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector @@ -888,8 +894,15 @@ object Defaults extends BuildCommon { val libraryJars = allJars.filter(_.getName == "scala-library.jar") allJars.filter(_.getName == "scala-compiler.jar") match { case Array(compilerJar) if libraryJars.nonEmpty => - val cache = state.value.classLoaderCache - mkScalaInstance(version, allJars, libraryJars, compilerJar, cache) + val cache = state.value.extendedClassLoaderCache + mkScalaInstance( + version, + allJars, + libraryJars, + compilerJar, + cache, + scalaInstanceTopLoader.value + ) case _ => ScalaInstance(version, scalaProvider) } } else @@ -931,7 +944,8 @@ object Defaults extends BuildCommon { allJars, Array(libraryJar), compilerJar, - state.value.classLoaderCache + state.value.extendedClassLoaderCache, + scalaInstanceTopLoader.value, ) } private[this] def mkScalaInstance( @@ -940,15 +954,11 @@ object Defaults extends BuildCommon { libraryJars: Array[File], compilerJar: File, classLoaderCache: ClassLoaderCache, + topLoader: ClassLoader, ): ScalaInstance = { val allJarsDistinct = allJars.distinct - val libraryLoader = classLoaderCache(libraryJars.toList) - class ScalaLoader - extends URLClassLoader(allJarsDistinct.map(_.toURI.toURL).toArray, libraryLoader) - val fullLoader = classLoaderCache.cachedCustomClassloader( - allJarsDistinct.toList, - () => new ScalaLoader - ) + val libraryLoader = classLoaderCache(libraryJars.toList, topLoader) + val fullLoader = classLoaderCache(allJarsDistinct.toList, libraryLoader) new ScalaInstance( version, fullLoader, @@ -970,7 +980,8 @@ object Defaults extends BuildCommon { dummy.allJars, dummy.libraryJars, dummy.compilerJar, - state.value.classLoaderCache + state.value.extendedClassLoaderCache, + scalaInstanceTopLoader.value, ) } diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 3609cf245..5da958c30 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -570,6 +570,9 @@ object Keys { val includeLintKeys = settingKey[Set[Def.KeyedInitialize[_]]]("Task keys that are included into lintUnused task") val lintUnusedKeysOnLoad = settingKey[Boolean]("Toggles whether or not to check for unused keys during startup") + val useScalaReplJLine = settingKey[Boolean]("Toggles whether or not to use sbt's forked jline in the scala repl. Enabling this flag may break the thin client in the scala console.").withRank(KeyRanks.Invisible) + val scalaInstanceTopLoader = settingKey[ClassLoader]("The top classloader for the scala instance").withRank(KeyRanks.Invisible) + val stateStreams = AttributeKey[Streams]("stateStreams", "Streams manager, which provides streams for different contexts. Setting this on State will override the default Streams implementation.") val resolvedScoped = Def.resolvedScoped val pluginData = taskKey[PluginData]("Information from the plugin build needed in the main build definition.").withRank(DTask) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 70e291940..c746856fe 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -932,6 +932,9 @@ object BuiltinCommands { val s3 = addCacheStoreFactoryFactory(Project.setProject(session, structure, s2)) val s4 = s3.put(Keys.useLog4J.key, Project.extract(s3).get(Keys.useLog4J)) val s5 = setupGlobalFileTreeRepository(s4) + // This is a workaround for the console task in dotty which uses the classloader cache. + // We need to override the top loader in that case so that it gets the forked jline. + s5.extendedClassLoaderCache.setParent(Project.extract(s5).get(Keys.scalaInstanceTopLoader)) CheckBuildSources.init(LintUnused.lintUnusedFunc(s5)) } diff --git a/main/src/main/scala/sbt/internal/XMainConfiguration.scala b/main/src/main/scala/sbt/internal/XMainConfiguration.scala index e4ebdf6bc..93dcb4e3b 100644 --- a/main/src/main/scala/sbt/internal/XMainConfiguration.scala +++ b/main/src/main/scala/sbt/internal/XMainConfiguration.scala @@ -59,17 +59,16 @@ private[sbt] class XMainConfiguration { val topLoader = configuration.provider.scalaProvider.launcher.topLoader val updatedConfiguration = try { - val method = topLoader.getClass.getMethod("getEarlyJars") + val method = topLoader.getClass.getMethod("getJLineJars") val jars = method.invoke(topLoader).asInstanceOf[Array[URL]] var canReuseConfiguration = jars.length == 3 var j = 0 while (j < jars.length && canReuseConfiguration) { val s = jars(j).toString - canReuseConfiguration = - s.contains("jline") || s.contains("test-interface") || s.contains("jansi") + canReuseConfiguration = s.contains("jline") || s.contains("jansi") j += 1 } - if (canReuseConfiguration) configuration else makeConfiguration(configuration) + if (canReuseConfiguration && j == 3) configuration else makeConfiguration(configuration) } catch { case _: NoSuchMethodException => makeConfiguration(configuration) } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index fd461c06f..ca75b369a 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -875,6 +875,14 @@ final class NetworkChannel( try queue.take catch { case _: InterruptedException => } } + override private[sbt] def getSizeImpl: (Int, Int) = + if (!closed.get) { + import sbt.protocol.codec.JsonProtocol._ + val queue = VirtualTerminal.getTerminalSize(name, jsonRpcRequest) + val res = try queue.take + catch { case _: InterruptedException => TerminalGetSizeResponse(1, 1) } + (res.width, res.height) + } else (1, 1) override def setSize(width: Int, height: Int): Unit = if (!closed.get) { import sbt.protocol.codec.JsonProtocol._ diff --git a/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala b/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala index c299c0f4c..497b78c0c 100644 --- a/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala +++ b/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala @@ -20,7 +20,9 @@ import sbt.protocol.Serialization.{ attach, systemIn, terminalCapabilities, + terminalGetSize, terminalPropertiesQuery, + terminalSetSize, } import sjsonnew.support.scalajson.unsafe.Converter import sbt.protocol.{ @@ -30,10 +32,13 @@ import sbt.protocol.{ TerminalCapabilitiesQuery, TerminalCapabilitiesResponse, TerminalPropertiesResponse, + TerminalGetSizeQuery, + TerminalGetSizeResponse, TerminalSetAttributesCommand, TerminalSetSizeCommand, } import sbt.protocol.codec.JsonProtocol._ +import sbt.protocol.TerminalGetSizeResponse object VirtualTerminal { private[this] val pendingTerminalProperties = @@ -46,6 +51,8 @@ object VirtualTerminal { new ConcurrentHashMap[(String, String), ArrayBlockingQueue[Unit]] private[this] val pendingTerminalSetSize = new ConcurrentHashMap[(String, String), ArrayBlockingQueue[Unit]] + private[this] val pendingTerminalGetSize = + new ConcurrentHashMap[(String, String), ArrayBlockingQueue[TerminalGetSizeResponse]] private[sbt] def sendTerminalPropertiesQuery( channelName: String, jsonRpcRequest: (String, String, String) => Unit @@ -111,9 +118,22 @@ object VirtualTerminal { val id = UUID.randomUUID.toString val queue = new ArrayBlockingQueue[Unit](1) pendingTerminalSetSize.put((channelName, id), queue) - jsonRpcRequest(id, terminalCapabilities, query) + jsonRpcRequest(id, terminalSetSize, query) queue } + + private[sbt] def getTerminalSize( + channelName: String, + jsonRpcRequest: (String, String, TerminalGetSizeQuery) => Unit, + ): ArrayBlockingQueue[TerminalGetSizeResponse] = { + val id = UUID.randomUUID.toString + val query = TerminalGetSizeQuery() + val queue = new ArrayBlockingQueue[TerminalGetSizeResponse](1) + pendingTerminalGetSize.put((channelName, id), queue) + jsonRpcRequest(id, terminalGetSize, query) + queue + } + val handler = ServerHandler { cb => ServerIntent(requestHandler(cb), responseHandler(cb), notificationHandler(cb)) } @@ -166,6 +186,13 @@ object VirtualTerminal { case null => case buffer => buffer.put(()) } + case r if pendingTerminalGetSize.get((callback.name, r.id)) != null => + val response = + r.result.flatMap(Converter.fromJson[TerminalGetSizeResponse](_).toOption) + pendingTerminalGetSize.remove((callback.name, r.id)) match { + case null => + case buffer => buffer.put(response.getOrElse(TerminalGetSizeResponse(1, 1))) + } } private val notificationHandler: Handler[JsonRpcNotificationMessage] = callback => { diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 63cd768e4..7f1f1c68c 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -84,8 +84,10 @@ object Dependencies { val sjsonNewMurmurhash = sjsonNew("sjson-new-murmurhash") val jline = "org.scala-sbt.jline" % "jline" % "2.14.7-sbt-5e51b9d4f9631ebfa29753ce4accc57808e7fd6b" - val jline3 = "org.jline" % "jline" % "3.15.0" - val jline3Jansi = "org.jline" % "jline-terminal-jansi" % "3.15.0" + val jline3Version = "3.16.0" // Once the base jline version is upgraded, we can use the official jline-terminal + val jline3Terminal = "org.scala-sbt.jline3" % "jline-terminal" % s"$jline3Version-sbt-211a082ed6326908dc84ca017ce4430728f18a8a" + val jline3Jansi = "org.jline" % "jline-terminal-jansi" % jline3Version + val jline3Reader = "org.jline" % "jline-reader" % jline3Version val jansi = "org.fusesource.jansi" % "jansi" % "1.18" val scalatest = "org.scalatest" %% "scalatest" % "3.0.8" val scalacheck = "org.scalacheck" %% "scalacheck" % "1.14.0" diff --git a/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeQuery.scala b/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeQuery.scala new file mode 100644 index 000000000..70fa74148 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeQuery.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class TerminalGetSizeQuery private () extends sbt.protocol.CommandMessage() with Serializable { + + + +override def equals(o: Any): Boolean = o match { + case _: TerminalGetSizeQuery => true + case _ => false +} +override def hashCode: Int = { + 37 * (17 + "sbt.protocol.TerminalGetSizeQuery".##) +} +override def toString: String = { + "TerminalGetSizeQuery()" +} +private[this] def copy(): TerminalGetSizeQuery = { + new TerminalGetSizeQuery() +} + +} +object TerminalGetSizeQuery { + + def apply(): TerminalGetSizeQuery = new TerminalGetSizeQuery() +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeResponse.scala b/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeResponse.scala new file mode 100644 index 000000000..d3ebcffe9 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/TerminalGetSizeResponse.scala @@ -0,0 +1,36 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class TerminalGetSizeResponse private ( + val width: Int, + val height: Int) extends sbt.protocol.EventMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TerminalGetSizeResponse => (this.width == x.width) && (this.height == x.height) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.protocol.TerminalGetSizeResponse".##) + width.##) + height.##) + } + override def toString: String = { + "TerminalGetSizeResponse(" + width + ", " + height + ")" + } + private[this] def copy(width: Int = width, height: Int = height): TerminalGetSizeResponse = { + new TerminalGetSizeResponse(width, height) + } + def withWidth(width: Int): TerminalGetSizeResponse = { + copy(width = width) + } + def withHeight(height: Int): TerminalGetSizeResponse = { + copy(height = height) + } +} +object TerminalGetSizeResponse { + + def apply(width: Int, height: Int): TerminalGetSizeResponse = new TerminalGetSizeResponse(width, height) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala index 1ecd02122..56d3b1757 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.AttachFormats with sbt.protocol.codec.TerminalCapabilitiesQueryFormats with sbt.protocol.codec.TerminalSetAttributesCommandFormats with sbt.protocol.codec.TerminalAttributesQueryFormats with sbt.protocol.codec.TerminalSetSizeCommandFormats => -implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat8[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery, sbt.protocol.Attach, sbt.protocol.TerminalCapabilitiesQuery, sbt.protocol.TerminalSetAttributesCommand, sbt.protocol.TerminalAttributesQuery, sbt.protocol.TerminalSetSizeCommand]("type") +trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.AttachFormats with sbt.protocol.codec.TerminalCapabilitiesQueryFormats with sbt.protocol.codec.TerminalSetAttributesCommandFormats with sbt.protocol.codec.TerminalAttributesQueryFormats with sbt.protocol.codec.TerminalGetSizeQueryFormats with sbt.protocol.codec.TerminalSetSizeCommandFormats => +implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat9[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery, sbt.protocol.Attach, sbt.protocol.TerminalCapabilitiesQuery, sbt.protocol.TerminalSetAttributesCommand, sbt.protocol.TerminalAttributesQuery, sbt.protocol.TerminalGetSizeQuery, sbt.protocol.TerminalSetSizeCommand]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala index 5475a901b..9aff6f3cf 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait EventMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ChannelAcceptedEventFormats with sbt.protocol.codec.LogEventFormats with sbt.protocol.codec.ExecStatusEventFormats with sbt.internal.util.codec.JValueFormats with sbt.protocol.codec.SettingQuerySuccessFormats with sbt.protocol.codec.SettingQueryFailureFormats with sbt.protocol.codec.TerminalPropertiesResponseFormats with sbt.protocol.codec.TerminalCapabilitiesResponseFormats with sbt.protocol.codec.TerminalSetAttributesResponseFormats with sbt.protocol.codec.TerminalAttributesResponseFormats with sbt.protocol.codec.TerminalSetSizeResponseFormats => -implicit lazy val EventMessageFormat: JsonFormat[sbt.protocol.EventMessage] = flatUnionFormat10[sbt.protocol.EventMessage, sbt.protocol.ChannelAcceptedEvent, sbt.protocol.LogEvent, sbt.protocol.ExecStatusEvent, sbt.protocol.SettingQuerySuccess, sbt.protocol.SettingQueryFailure, sbt.protocol.TerminalPropertiesResponse, sbt.protocol.TerminalCapabilitiesResponse, sbt.protocol.TerminalSetAttributesResponse, sbt.protocol.TerminalAttributesResponse, sbt.protocol.TerminalSetSizeResponse]("type") +trait EventMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ChannelAcceptedEventFormats with sbt.protocol.codec.LogEventFormats with sbt.protocol.codec.ExecStatusEventFormats with sbt.internal.util.codec.JValueFormats with sbt.protocol.codec.SettingQuerySuccessFormats with sbt.protocol.codec.SettingQueryFailureFormats with sbt.protocol.codec.TerminalPropertiesResponseFormats with sbt.protocol.codec.TerminalCapabilitiesResponseFormats with sbt.protocol.codec.TerminalSetAttributesResponseFormats with sbt.protocol.codec.TerminalAttributesResponseFormats with sbt.protocol.codec.TerminalGetSizeResponseFormats with sbt.protocol.codec.TerminalSetSizeResponseFormats => +implicit lazy val EventMessageFormat: JsonFormat[sbt.protocol.EventMessage] = flatUnionFormat11[sbt.protocol.EventMessage, sbt.protocol.ChannelAcceptedEvent, sbt.protocol.LogEvent, sbt.protocol.ExecStatusEvent, sbt.protocol.SettingQuerySuccess, sbt.protocol.SettingQueryFailure, sbt.protocol.TerminalPropertiesResponse, sbt.protocol.TerminalCapabilitiesResponse, sbt.protocol.TerminalSetAttributesResponse, sbt.protocol.TerminalAttributesResponse, sbt.protocol.TerminalGetSizeResponse, sbt.protocol.TerminalSetSizeResponse]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala index e3a6e2b99..7c5046c6d 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala @@ -12,6 +12,7 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.protocol.codec.TerminalCapabilitiesQueryFormats with sbt.protocol.codec.TerminalSetAttributesCommandFormats with sbt.protocol.codec.TerminalAttributesQueryFormats + with sbt.protocol.codec.TerminalGetSizeQueryFormats with sbt.protocol.codec.TerminalSetSizeCommandFormats with sbt.protocol.codec.CommandMessageFormats with sbt.protocol.codec.CompletionParamsFormats @@ -25,6 +26,7 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.protocol.codec.TerminalCapabilitiesResponseFormats with sbt.protocol.codec.TerminalSetAttributesResponseFormats with sbt.protocol.codec.TerminalAttributesResponseFormats + with sbt.protocol.codec.TerminalGetSizeResponseFormats with sbt.protocol.codec.TerminalSetSizeResponseFormats with sbt.protocol.codec.EventMessageFormats with sbt.protocol.codec.SettingQueryResponseFormats diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeQueryFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeQueryFormats.scala new file mode 100644 index 000000000..9989ad96c --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeQueryFormats.scala @@ -0,0 +1,27 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TerminalGetSizeQueryFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TerminalGetSizeQueryFormat: JsonFormat[sbt.protocol.TerminalGetSizeQuery] = new JsonFormat[sbt.protocol.TerminalGetSizeQuery] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.TerminalGetSizeQuery = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + + unbuilder.endObject() + sbt.protocol.TerminalGetSizeQuery() + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.TerminalGetSizeQuery, builder: Builder[J]): Unit = { + builder.beginObject() + + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeResponseFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeResponseFormats.scala new file mode 100644 index 000000000..c5c489a8d --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalGetSizeResponseFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TerminalGetSizeResponseFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TerminalGetSizeResponseFormat: JsonFormat[sbt.protocol.TerminalGetSizeResponse] = new JsonFormat[sbt.protocol.TerminalGetSizeResponse] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.TerminalGetSizeResponse = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val width = unbuilder.readField[Int]("width") + val height = unbuilder.readField[Int]("height") + unbuilder.endObject() + sbt.protocol.TerminalGetSizeResponse(width, height) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.TerminalGetSizeResponse, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("width", obj.width) + builder.addField("height", obj.height) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index 2cde1ec82..37c2b7e08 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -126,6 +126,12 @@ type TerminalAttributesResponse implements EventMessage { cchars: String!, } +type TerminalGetSizeQuery implements CommandMessage {} +type TerminalGetSizeResponse implements EventMessage { + width: Int! + height: Int! +} + type TerminalSetSizeCommand implements CommandMessage { width: Int! height: Int! diff --git a/protocol/src/main/scala/sbt/protocol/Serialization.scala b/protocol/src/main/scala/sbt/protocol/Serialization.scala index a8f2e54b6..35dd19149 100644 --- a/protocol/src/main/scala/sbt/protocol/Serialization.scala +++ b/protocol/src/main/scala/sbt/protocol/Serialization.scala @@ -39,7 +39,8 @@ object Serialization { val promptChannel = "sbt/promptChannel" val setTerminalAttributes = "sbt/setTerminalAttributes" val getTerminalAttributes = "sbt/getTerminalAttributes" - val setTerminalSize = "sbt/setTerminalSize" + val terminalGetSize = "sbt/terminalGetSize" + val terminalSetSize = "sbt/terminalSetSize" val CancelAll = "__CancelAll" @deprecated("unused", since = "1.4.0") From d569abe70ae9b5e95fb0d699fb01ffd492d9345e Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Mon, 3 Aug 2020 10:37:32 -0700 Subject: [PATCH 5/7] Consolidate terminal prompt management It was a bit tricky to reason about the state of the prompt for a terminal. To help make things more clear, I reworked things so that the LineReader always sets the prompt to Pending after it reads a command. In MainLoop, we cache the prompt value and temporarily set it to Running while the command is running, which is really how it should have always been. --- .../scala/sbt/internal/util/LineReader.scala | 1 - .../main/scala/sbt/internal/ui/UITask.scala | 4 +- .../scala/sbt/internal/ui/UserThread.scala | 60 ++++++++++--------- main/src/main/scala/sbt/MainLoop.scala | 13 ++-- .../scala/sbt/internal/CommandExchange.scala | 9 +-- 5 files changed, 44 insertions(+), 43 deletions(-) diff --git a/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala b/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala index 52a36ca8d..4b0235572 100644 --- a/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala +++ b/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala @@ -73,7 +73,6 @@ object LineReader { historyPath: Option[File], parser: Parser[_], terminal: Terminal, - prompt: Prompt = Prompt.Running, ): LineReader = { val term = JLine3(terminal) // We may want to consider insourcing LineReader.java from jline. We don't otherwise diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala index dce4b1012..8fd4d4177 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UITask.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -82,7 +82,9 @@ private[sbt] object UITask { } } } - impl() + val res = impl() + terminal.setPrompt(Prompt.Pending) + res } catch { case e: InterruptedException => Right("") } override def close(): Unit = closed.set(true) } diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala index 7f1d007f5..fd6d45778 100644 --- a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -31,21 +31,21 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable uiThread.synchronized { val task = channel.makeUIThread(state) def submit(): Thread = { - def close(): Unit = { - uiThread.get match { - case (_, t) if t == thread => uiThread.set(null) - case _ => - } + val thread: Thread = new Thread(s"sbt-$name-ui-thread") { + setDaemon(true) + override def run(): Unit = + try task.run() + finally uiThread.get match { + case (_, t) if t == this => uiThread.set(null) + case _ => + } } - lazy val thread = new Thread(() => { - try task.run() - finally close() - }, s"sbt-$name-ui-thread") - thread.setDaemon(true) - thread.start() uiThread.getAndSet((task, thread)) match { - case null => - case (_, t) => t.interrupt() + case null => thread.start() + case (task, t) if t.getClass != task.getClass => + stopThreadImpl() + thread.start() + case t => uiThread.set(t) } thread } @@ -53,14 +53,14 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable case null => uiThread.set((task, submit())) case (t, _) if t.getClass == task.getClass => case (t, thread) => - thread.interrupt() + stopThreadImpl() uiThread.set((task, submit())) } } Option(lastProgressEvent.get).foreach(onProgressEvent) } - private[sbt] def stopThread(): Unit = uiThread.synchronized { + private[sbt] def stopThreadImpl(): Unit = uiThread.synchronized { uiThread.getAndSet(null) match { case null => case (t, thread) => @@ -75,21 +75,25 @@ private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable () } } + private[sbt] def stopThread(): Unit = uiThread.synchronized(stopThreadImpl()) - private[sbt] def onConsolePromptEvent(consolePromptEvent: ConsolePromptEvent): Unit = { - channel.terminal.withPrintStream { ps => - ps.print(ConsoleAppender.ClearScreenAfterCursor) - ps.flush() + private[sbt] def onConsolePromptEvent(consolePromptEvent: ConsolePromptEvent): Unit = + // synchronize to ensure that the state isn't modified during the call to reset + // at the bottom + synchronized { + channel.terminal.withPrintStream { ps => + ps.print(ConsoleAppender.ClearScreenAfterCursor) + ps.flush() + } + val state = consolePromptEvent.state + terminal.prompt match { + case Prompt.Running | Prompt.Pending => + terminal.setPrompt(Prompt.AskUser(() => UITask.shellPrompt(terminal, state))) + case _ => + } + onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) + reset(state) } - val state = consolePromptEvent.state - terminal.prompt match { - case Prompt.Running | Prompt.Pending => - terminal.setPrompt(Prompt.AskUser(() => UITask.shellPrompt(terminal, state))) - case _ => - } - onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) - reset(state) - } private[sbt] def onConsoleUnpromptEvent( consoleUnpromptEvent: ConsoleUnpromptEvent diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 0d2f57c2e..de8b5dfb7 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -15,8 +15,8 @@ import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey -import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Terminal } -import sbt.internal.{ ConsoleUnpromptEvent, ShutdownHooks } +import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal } +import sbt.internal.ShutdownHooks import sbt.io.{ IO, Using } import sbt.protocol._ import sbt.util.{ Logger, LoggerContext } @@ -206,13 +206,16 @@ object MainLoop { state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) } else state } - StandardMain.exchange.setState(progressState) - StandardMain.exchange.setExec(Some(exec)) - StandardMain.exchange.unprompt(ConsoleUnpromptEvent(exec.source)) + exchange.setState(progressState) + exchange.setExec(Some(exec)) val restoreTerminal = channelName.flatMap(exchange.channelForName) match { case Some(c) => val prevTerminal = Terminal.set(c.terminal) + val prevPrompt = c.terminal.prompt + // temporarily set the prompt to running during task evaluation + c.terminal.setPrompt(Prompt.Running) () => { + c.terminal.setPrompt(prevPrompt) Terminal.set(prevTerminal) c.terminal.flush() } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 3a5eaea79..6cb4f4ece 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -133,17 +133,10 @@ private[sbt] final class CommandExchange { } } // Do not manually run GC until the user has been idling for at least the min gc interval. - val exec = impl(interval match { + impl(interval match { case d: FiniteDuration => Some(d.fromNow) case _ => None }, idleDeadline) - exec.source.foreach { s => - channelForName(s.channelName).foreach { - case c if c.terminal.prompt != Prompt.Batch => c.terminal.setPrompt(Prompt.Running) - case _ => - } - } - exec } private def addConsoleChannel(): Unit = From 102e3d1969a521c774b0dfec63efeac7e8318be9 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sat, 25 Jul 2020 11:53:27 -0700 Subject: [PATCH 6/7] Improve supershell performance It turns out that task progress actually introduces a fair bit of overhead. The biggest issue is that the task progress callbacks block the Execute main thread. This means that time in those callbacks delays task evaluation, slowing down sbt. This was not negligible, I was seeing a lot of the total time of a no-op compile in https://github.com/jtjeferreira/sbt-multi-module-sample was spent in TaskProgress callbacks. Prior to these changes, I ran 30 no-op compiles in that project and the average time was about 570ms. This number got worse and worse because there were memory leaks in the TaskProgress object. After these changes, it dropped to 250ms and after jit-ing, it would drop to about 200ms. I also successfully ran 5000 consecutive no-op compiles without leaking any memory. A lot of the overhead of task progress was in adding tasks to the timings map in AbstractTaskProgress. Tasks were never removed and ConcurrentHashMap insertion time is proportional to the size of the map (not sure if it's linear, quadratic or other) which was why sbt actually got slower and slower the longer it ran. Much of the time was spent adding tasks to the progress timings. To fix this, I did something similar to what I did to manage logger state in https://github.com/jtjeferreira/sbt-multi-module-sample. In MainLoop, we create a new TaskProgress instance before command evaluation and clean it up after. Earlier I made TaskProgress an object to try to ensure there was only one progress thread at a time, and that introduced the memory leak. In addition to removing the leak, I was able to improve performance by removing tasks from the timings map when they completed. Unlike TaskTimings and TaskTraceEvent, we don't care about tasks that have completed for TaskProgress so it is safe to remove them. In addition to the memory leaks, I also reworked how the background threads work. Instead of having one thread that sleeps and prints progress reports, we now use two single threaded executors. One is a scheduled executor that is used to schedule progress reports and the other is the actual thread on which the report is generated. When progress starts, we schedule a recurring report that is generated every sleep interval until task evaluation completes. Whenever we add a new task, if we have haven't previously generated a progress report, we schedule a report in threshold milliseconds. If the task completes before the threshold period has elapsed, we just cancel the schedule report. By doing things this way, we reduce the total number of reports that are generated. Because reports need to effectively lock System.out, the less we generate them, the better. I also modified the internal data structures of AbstractTaskProgress so that there is a single task map of timings instead of one map for timings and one for active tasks. --- build.sbt | 1 + .../sbt/internal/util/ProgressState.scala | 14 +- .../scala/sbt/internal/util/Terminal.scala | 2 +- main/src/main/scala/sbt/EvaluateTask.scala | 2 +- main/src/main/scala/sbt/Keys.scala | 1 + main/src/main/scala/sbt/Main.scala | 10 +- main/src/main/scala/sbt/MainLoop.scala | 14 +- .../sbt/internal/AbstractTaskProgress.scala | 81 +++++-- .../src/main/scala/sbt/internal/SysProp.scala | 1 + .../scala/sbt/internal/TaskProgress.scala | 205 +++++++++--------- .../main/scala/sbt/internal/TaskTimings.scala | 8 +- .../scala/sbt/internal/TaskTraceEvent.scala | 9 +- 12 files changed, 197 insertions(+), 151 deletions(-) diff --git a/build.sbt b/build.sbt index 5b0fcd298..af9868dc6 100644 --- a/build.sbt +++ b/build.sbt @@ -1019,6 +1019,7 @@ lazy val mainProj = (project in file("main")) // internal logging apis, exclude[IncompatibleSignatureProblem]("sbt.internal.LogManager*"), exclude[MissingTypesProblem]("sbt.internal.RelayAppender"), + exclude[MissingClassProblem]("sbt.internal.TaskProgress$ProgressThread") ) ) .configure( diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala index f3a70a8a1..3b3b4d125 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala @@ -78,7 +78,7 @@ private[sbt] final class ProgressState( } private[util] def getPrompt(terminal: Terminal): Array[Byte] = { - if (terminal.prompt != Prompt.Running && terminal.prompt != Prompt.Batch) { + if (terminal.prompt.isInstanceOf[Prompt.AskUser]) { val prefix = if (terminal.isAnsiSupported) s"$DeleteLine$CursorLeft1000" else "" prefix.getBytes ++ terminal.prompt.render().getBytes("UTF-8") } else Array.empty @@ -108,8 +108,8 @@ private[sbt] final class ProgressState( val lines = printProgress(terminal, lastLine) toWrite ++= (ClearScreenAfterCursor + lines).getBytes("UTF-8") } + toWrite ++= getPrompt(terminal) } - toWrite ++= getPrompt(terminal) printStream.write(toWrite.toArray) printStream.flush() } else printStream.write(bytes) @@ -136,6 +136,9 @@ private[sbt] final class ProgressState( } private[sbt] object ProgressState { + private val MIN_COMMAND_WIDTH = 10 + private val SERVER_IS_RUNNING = "sbt server is running " + private val SERVER_IS_RUNNING_LENGTH = SERVER_IS_RUNNING.length + 2 /** * Receives a new task report and replaces the old one. In the event that the new @@ -165,8 +168,13 @@ private[sbt] object ProgressState { } } else { pe.command.toSeq.flatMap { cmd => + val width = terminal.getWidth + val sanitized = if ((cmd.length + SERVER_IS_RUNNING_LENGTH) < width) { + if (SERVER_IS_RUNNING_LENGTH + cmd.length < width) cmd + else cmd.take(MIN_COMMAND_WIDTH) + "..." + } else cmd val tail = if (isWatch) Nil else "enter 'cancel' to stop evaluation" :: Nil - s"sbt server is running '$cmd'" :: tail + s"$SERVER_IS_RUNNING '$sanitized'" :: tail } } diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index a440fc77b..1137829d3 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -467,7 +467,7 @@ object Terminal { try { System.setOut(proxyPrintStream) System.setErr(proxyErrorStream) - scala.Console.withErr(proxyErrorStream)(scala.Console.withOut(proxyOutputStream)(f)) + scala.Console.withErr(proxyErrorStream)(scala.Console.withOut(proxyPrintStream)(f)) } finally { System.setOut(originalOut) System.setErr(originalErr) diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index e4d94e431..7292ff694 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -255,7 +255,7 @@ object EvaluateTask { extracted, structure ) - val reporters = maker.map(_.progress) ++ Some(TaskProgress) ++ + val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ (if (SysProp.taskTimings) new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil else Nil) diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 5da958c30..bdc25000c 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -555,6 +555,7 @@ object Keys { def apply(progress: ExecuteProgress[Task]): TaskProgress = new TaskProgress(progress) } private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress") + private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress") val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.") val turbo = settingKey[Boolean]("Enables (true) or disables optional performance features.") // This key can be used to add custom ExecuteProgress instances diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index c746856fe..8361d1492 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -15,7 +15,7 @@ import java.util.Properties import java.util.concurrent.ForkJoinPool import java.util.concurrent.atomic.AtomicBoolean -import sbt.BasicCommandStrings.{ Shell, Shutdown, TemplateCommand, networkExecPrefix } +import sbt.BasicCommandStrings.{ Shell, Shutdown, TemplateCommand } import sbt.Project.LoadAction import sbt.compiler.EvalImports import sbt.internal.Aggregation.AnyKeys @@ -999,13 +999,7 @@ object BuiltinCommands { } private def getExec(state: State, interval: Duration): Exec = { - val exec: Exec = - StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) - if (exec.source.fold(true)(_.channelName != ConsoleChannel.defaultName) && - !exec.commandLine.startsWith(networkExecPrefix)) { - Terminal.consoleLog(s"received remote command: ${exec.commandLine}") - } - exec + StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) } def shell: Command = Command.command(Shell, Help.more(Shell, ShellDetailed)) { s0 => diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index de8b5dfb7..801465b8d 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -16,7 +16,7 @@ import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal } -import sbt.internal.ShutdownHooks +import sbt.internal.{ ShutdownHooks, TaskProgress } import sbt.io.{ IO, Using } import sbt.protocol._ import sbt.util.{ Logger, LoggerContext } @@ -150,9 +150,13 @@ object MainLoop { def next(state: State): State = { val context = LoggerContext(useLog4J = state.get(Keys.useLog4J.key).getOrElse(false)) + val taskProgress = new TaskProgress try { ErrorHandling.wideConvert { - state.put(Keys.loggerContext, context).process(processCommand) + state + .put(Keys.loggerContext, context) + .put(Keys.taskProgress, taskProgress) + .process(processCommand) } match { case Right(s) => s.remove(Keys.loggerContext) case Left(t: xsbti.FullReload) => throw t @@ -186,7 +190,10 @@ object MainLoop { state.log.error(msg) state.log.error("\n") state.handleError(oom) - } finally context.close() + } finally { + context.close() + taskProgress.close() + } } /** This is the main function State transfer function of the sbt command processing. */ @@ -217,6 +224,7 @@ object MainLoop { () => { c.terminal.setPrompt(prevPrompt) Terminal.set(prevTerminal) + c.terminal.setPrompt(prevPrompt) c.terminal.flush() } case _ => () => () diff --git a/main/src/main/scala/sbt/internal/AbstractTaskProgress.scala b/main/src/main/scala/sbt/internal/AbstractTaskProgress.scala index 61e93050b..dfac9b2e5 100644 --- a/main/src/main/scala/sbt/internal/AbstractTaskProgress.scala +++ b/main/src/main/scala/sbt/internal/AbstractTaskProgress.scala @@ -9,8 +9,11 @@ package sbt package internal import java.util.concurrent.ConcurrentHashMap -import scala.collection.concurrent.TrieMap +import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.immutable.VectorBuilder +import scala.concurrent.duration._ private[sbt] abstract class AbstractTaskExecuteProgress extends ExecuteProgress[Task] { import AbstractTaskExecuteProgress.Timer @@ -18,10 +21,51 @@ private[sbt] abstract class AbstractTaskExecuteProgress extends ExecuteProgress[ private[this] val showScopedKey = Def.showShortKey(None) private[this] val anonOwners = new ConcurrentHashMap[Task[_], Task[_]] private[this] val calledBy = new ConcurrentHashMap[Task[_], Task[_]] - private[this] val activeTasksMap = new ConcurrentHashMap[Task[_], Unit] - protected val timings = new ConcurrentHashMap[Task[_], Timer] + private[this] val timings = new ConcurrentHashMap[Task[_], Timer] + private[sbt] def timingsByName: mutable.Map[String, AtomicLong] = { + val result = new ConcurrentHashMap[String, AtomicLong] + timings.forEach { (task, timing) => + val duration = timing.durationNanos + result.putIfAbsent(taskName(task), new AtomicLong(duration)) match { + case null => + case t => t.getAndAdd(duration); () + } + } + result.asScala + } + private[sbt] def anyTimings = !timings.isEmpty + def currentTimings: Iterator[(Task[_], Timer)] = timings.asScala.iterator - def activeTasks: Set[Task[_]] = activeTasksMap.keySet.asScala.toSet + private[internal] def exceededThreshold(task: Task[_], threshold: FiniteDuration): Boolean = + timings.get(task) match { + case null => false + case t => t.durationMicros > threshold.toMicros + } + private[internal] def timings( + tasks: java.util.Set[Task[_]], + thresholdMicros: Long + ): Vector[(Task[_], Long)] = { + val result = new VectorBuilder[(Task[_], Long)] + val now = System.nanoTime + tasks.forEach { t => + timings.get(t) match { + case null => + case timing => + if (timing.isActive) { + val elapsed = (now - timing.startNanos) / 1000 + if (elapsed > thresholdMicros) result += t -> elapsed + } + } + } + result.result() + } + def activeTasks(now: Long) = { + val result = new VectorBuilder[(Task[_], FiniteDuration)] + timings.forEach { (task, timing) => + if (timing.isActive) result += task -> (now - timing.startNanos).nanos + } + result.result + } override def afterRegistered( task: Task[_], @@ -38,15 +82,17 @@ private[sbt] abstract class AbstractTaskExecuteProgress extends ExecuteProgress[ override def beforeWork(task: Task[_]): Unit = { timings.put(task, new Timer) - activeTasksMap.put(task, ()) + () } + protected def clearTimings: Boolean = false override def afterWork[A](task: Task[A], result: Either[Task[A], Result[A]]): Unit = { - timings.get(task) match { - case null => - case t => t.stop() - } - activeTasksMap.remove(task) + if (clearTimings) timings.remove(task) + else + timings.get(task) match { + case null => + case t => t.stop() + } // we need this to infer anonymous task names result.left.foreach { t => @@ -54,14 +100,14 @@ private[sbt] abstract class AbstractTaskExecuteProgress extends ExecuteProgress[ } } - protected def reset(): Unit = { - activeTasksMap.clear() - timings.clear() + private[this] val taskNameCache = new ConcurrentHashMap[Task[_], String] + protected def taskName(t: Task[_]): String = taskNameCache.get(t) match { + case null => + val name = taskName0(t) + taskNameCache.putIfAbsent(t, name) + name + case name => name } - - private[this] val taskNameCache = TrieMap.empty[Task[_], String] - protected def taskName(t: Task[_]): String = - taskNameCache.getOrElseUpdate(t, taskName0(t)) private[this] def taskName0(t: Task[_]): String = { def definedName(node: Task[_]): Option[String] = node.info.name orElse TaskName.transformNode(node).map(showScopedKey.show) @@ -80,6 +126,7 @@ object AbstractTaskExecuteProgress { def stop(): Unit = { endNanos = System.nanoTime() } + def isActive = endNanos == 0L def durationNanos: Long = endNanos - startNanos def startMicros: Long = (startNanos.toDouble / 1000).toLong def durationMicros: Long = (durationNanos.toDouble / 1000).toLong diff --git a/main/src/main/scala/sbt/internal/SysProp.scala b/main/src/main/scala/sbt/internal/SysProp.scala index c02038e87..737485965 100644 --- a/main/src/main/scala/sbt/internal/SysProp.scala +++ b/main/src/main/scala/sbt/internal/SysProp.scala @@ -11,6 +11,7 @@ package internal import java.util.Locale import scala.util.control.NonFatal +import scala.concurrent.duration._ import sbt.internal.util.ConsoleAppender import sbt.internal.util.complete.SizeParser diff --git a/main/src/main/scala/sbt/internal/TaskProgress.scala b/main/src/main/scala/sbt/internal/TaskProgress.scala index 494520fcc..f292db125 100644 --- a/main/src/main/scala/sbt/internal/TaskProgress.scala +++ b/main/src/main/scala/sbt/internal/TaskProgress.scala @@ -9,99 +9,102 @@ package sbt package internal import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger, AtomicReference } -import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } +import java.util.concurrent.TimeUnit import sbt.internal.util._ -import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.concurrent.duration._ - -object TaskProgress extends TaskProgress +import java.util.concurrent.{ ConcurrentHashMap, Executors, TimeoutException } /** * implements task progress display on the shell. */ -private[sbt] class TaskProgress private () +private[sbt] class TaskProgress extends AbstractTaskExecuteProgress - with ExecuteProgress[Task] { + with ExecuteProgress[Task] + with AutoCloseable { private[this] val lastTaskCount = new AtomicInteger(0) - private[this] val currentProgressThread = new AtomicReference[Option[ProgressThread]](None) private[this] val sleepDuration = SysProp.supershellSleep.millis private[this] val threshold = 10.millis - private[this] val tasks = new LinkedBlockingQueue[Task[_]] - private[this] final class ProgressThread - extends Thread("task-progress-report-thread") - with AutoCloseable { - private[this] val isClosed = new AtomicBoolean(false) - private[this] val firstTime = new AtomicBoolean(true) - private[this] val hasReported = new AtomicBoolean(false) - private[this] def doReport(): Unit = { hasReported.set(true); report() } - setDaemon(true) - start() - private def resetThread(): Unit = - currentProgressThread.synchronized { - currentProgressThread.getAndSet(None) match { - case Some(t) if t != this => currentProgressThread.set(Some(t)) - case _ => - } - } - @tailrec override def run(): Unit = { - if (!isClosed.get() && (!hasReported.get || active.nonEmpty)) { - try { - if (activeExceedingThreshold.nonEmpty) doReport() - val duration = - if (firstTime.compareAndSet(true, activeExceedingThreshold.isEmpty)) threshold - else sleepDuration - val limit = duration.fromNow - while (Deadline.now < limit && !isClosed.get && active.nonEmpty) { - var task = tasks.poll((limit - Deadline.now).toMillis, TimeUnit.MILLISECONDS) - while (task != null) { - if (containsSkipTasks(Vector(task)) || lastTaskCount.get == 0) doReport() - task = tasks.poll - tasks.clear() - } - } - } catch { - case _: InterruptedException => - isClosed.set(true) - // One last report after close in case the last one hadn't gone through yet. - doReport() - - } - run() - } else { - resetThread() + private[this] val reportLoop = new AtomicReference[AutoCloseable] + private[this] val active = new ConcurrentHashMap[Task[_], AutoCloseable] + private[this] val nextReport = new AtomicReference(Deadline.now) + private[this] val scheduler = + Executors.newSingleThreadScheduledExecutor(r => new Thread(r, "sbt-progress-report-scheduler")) + private[this] val pending = new java.util.Vector[java.util.concurrent.Future[_]] + private def schedule[R](duration: FiniteDuration, recurring: Boolean)(f: => R): AutoCloseable = { + val cancelled = new AtomicBoolean(false) + val runnable: Runnable = () => { + if (!cancelled.get) { + try Util.ignoreResult(f) + catch { case _: InterruptedException => } } } - - def addTask(task: Task[_]): Unit = tasks.put(task) - - override def close(): Unit = { - isClosed.set(true) - interrupt() - report() - appendProgress(ProgressEvent("Info", Vector(), None, None, None)) - resetThread() + val delay = duration.toMillis + val future = + if (recurring) scheduler.schedule(runnable, delay, TimeUnit.MILLISECONDS) + else scheduler.scheduleAtFixedRate(runnable, delay, delay, TimeUnit.MILLISECONDS) + pending.add(future) + () => Util.ignoreResult(future.cancel(true)) + } + private[this] val executor = + Executors.newSingleThreadExecutor(r => new Thread(r, "sbt-task-progress-report-thread")) + override def close(): Unit = { + Option(reportLoop.get).foreach(_.close()) + pending.forEach(f => Util.ignoreResult(f.cancel(true))) + pending.clear() + scheduler.shutdownNow() + executor.shutdownNow() + if (!executor.awaitTermination(1, TimeUnit.SECONDS) || + !scheduler.awaitTermination(1, TimeUnit.SECONDS)) { + throw new TimeoutException } } + override protected def clearTimings: Boolean = true override def initial(): Unit = () + private[this] def doReport(): Unit = { + val runnable: Runnable = () => { + if (nextReport.get.isOverdue) { + report() + } + } + Util.ignoreResult(pending.add(executor.submit(runnable))) + } override def beforeWork(task: Task[_]): Unit = { - maybeStartThread() super.beforeWork(task) - tasks.put(task) + reportLoop.get match { + case null => + val loop = schedule(sleepDuration, recurring = true)(doReport()) + reportLoop.getAndSet(loop) match { + case null => + case l => + reportLoop.set(l) + loop.close() + } + case s => + } } - override def afterReady(task: Task[_]): Unit = maybeStartThread() - override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = maybeStartThread() + override def afterReady(task: Task[_]): Unit = + Util.ignoreResult(active.put(task, schedule(threshold, recurring = false)(doReport()))) + override def stop(): Unit = {} - override def stop(): Unit = currentProgressThread.synchronized { - currentProgressThread.getAndSet(None).foreach(_.close()) - } + override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = + active.remove(task) match { + case null => + case a => + a.close() + if (exceededThreshold(task, threshold)) report() + } override def afterAllCompleted(results: RMap[Task, Result]): Unit = { - reset() + reportLoop.getAndSet(null) match { + case null => + case l => l.close() + } // send an empty progress report to clear out the previous report appendProgress(ProgressEvent("Info", Vector(), Some(lastTaskCount.get), None, None)) } @@ -117,51 +120,39 @@ private[sbt] class TaskProgress private () "consoleQuick", "state" ) - private[this] def maybeStartThread(): Unit = { - currentProgressThread.get() match { - case None => - currentProgressThread.synchronized { - currentProgressThread.get() match { - case None => currentProgressThread.set(Some(new ProgressThread)) - case _ => - } - } - case _ => - } - } private[this] def appendProgress(event: ProgressEvent): Unit = StandardMain.exchange.updateProgress(event) - private[this] def active: Vector[Task[_]] = activeTasks.toVector.filterNot(Def.isDummy) - private[this] def activeExceedingThreshold: Vector[(Task[_], Long)] = active.flatMap { task => - timings.get(task) match { - case null => None - case t => - val elapsed = t.currentElapsedMicros - if (elapsed.micros > threshold) Some[(Task[_], Long)](task -> elapsed) else None + private[this] def report(): Unit = { + val currentTasks = timings(active.keySet, threshold.toMicros) + val ltc = lastTaskCount.get + if (currentTasks.nonEmpty || ltc != 0) { + val currentTasksCount = currentTasks.size + def event(tasks: Vector[(Task[_], Long)]): ProgressEvent = { + if (tasks.nonEmpty) nextReport.set(Deadline.now + sleepDuration) + val toWrite = tasks.sortBy(_._2) + val distinct = new java.util.LinkedHashMap[String, ProgressItem] + toWrite.foreach { + case (task, elapsed) => + val name = taskName(task) + distinct.put(name, ProgressItem(name, elapsed)) + } + ProgressEvent( + "Info", + distinct.values.asScala.toVector, + Some(ltc), + None, + None, + None, + Some(containsSkipTasks(active.keySet)) + ) + } + lastTaskCount.set(currentTasksCount) + appendProgress(event(currentTasks)) } } - private[this] def report(): Unit = { - val currentTasks = activeExceedingThreshold - val ltc = lastTaskCount.get - val currentTasksCount = currentTasks.size - def event(tasks: Vector[(Task[_], Long)]): ProgressEvent = ProgressEvent( - "Info", - tasks - .map { case (task, elapsed) => ProgressItem(taskName(task), elapsed) } - .sortBy(_.elapsedMicros), - Some(ltc), - None, - None, - None, - Some(containsSkipTasks(active)) - ) - if (active.nonEmpty) maybeStartThread() - lastTaskCount.set(currentTasksCount) - appendProgress(event(currentTasks)) - } - private[this] def containsSkipTasks(tasks: Vector[Task[_]]): Boolean = { - tasks.map(taskName).exists { n => + private[this] def containsSkipTasks(tasks: java.util.Set[Task[_]]): Boolean = { + tasks.iterator.asScala.map(taskName).exists { n => val shortName = n.lastIndexOf('/') match { case -1 => n case i => diff --git a/main/src/main/scala/sbt/internal/TaskTimings.scala b/main/src/main/scala/sbt/internal/TaskTimings.scala index 1723dba83..6f421e829 100644 --- a/main/src/main/scala/sbt/internal/TaskTimings.scala +++ b/main/src/main/scala/sbt/internal/TaskTimings.scala @@ -31,7 +31,6 @@ private[sbt] final class TaskTimings(reportOnShutdown: Boolean, logger: Logger) override def log(level: Level.Value, message: => String): Unit = ConsoleOut.systemOut.println(message) }) - import AbstractTaskExecuteProgress.Timer private[this] var start = 0L private[this] val threshold = SysProp.taskTimingsThreshold private[this] val omitPaths = SysProp.taskTimingsOmitPaths @@ -61,15 +60,12 @@ private[sbt] final class TaskTimings(reportOnShutdown: Boolean, logger: Logger) private[this] def report() = { val total = divide(System.nanoTime - start) logger.info(s"Total time: $total $unit") - import collection.JavaConverters._ - def sumTimes(in: Seq[(Task[_], Timer)]) = in.map(_._2.durationNanos).sum - val timingsByName = timings.asScala.toSeq.groupBy { case (t, _) => taskName(t) } mapValues (sumTimes) val times = timingsByName.toSeq - .sortBy(_._2) + .sortBy(_._2.get) .reverse .map { case (name, time) => - (if (omitPaths) reFilePath.replaceFirstIn(name, "") else name, divide(time)) + (if (omitPaths) reFilePath.replaceFirstIn(name, "") else name, divide(time.get)) } .filter { _._2 > threshold } if (times.size > 0) { diff --git a/main/src/main/scala/sbt/internal/TaskTraceEvent.scala b/main/src/main/scala/sbt/internal/TaskTraceEvent.scala index b623c70cc..7034c0af6 100644 --- a/main/src/main/scala/sbt/internal/TaskTraceEvent.scala +++ b/main/src/main/scala/sbt/internal/TaskTraceEvent.scala @@ -13,7 +13,6 @@ import java.nio.file.Files import sbt.internal.util.{ RMap, ConsoleOut } import sbt.io.IO import sbt.io.syntax._ -import scala.collection.JavaConverters._ import sjsonnew.shaded.scalajson.ast.unsafe.JString import sjsonnew.support.scalajson.unsafe.CompactPrinter @@ -39,7 +38,7 @@ private[sbt] final class TaskTraceEvent ShutdownHooks.add(() => report()) private[this] def report() = { - if (timings.asScala.nonEmpty) { + if (anyTimings) { writeTraceEvent() } } @@ -63,10 +62,10 @@ private[sbt] final class TaskTraceEvent CompactPrinter.print(new JString(name), sb) s"""{"name": ${sb.toString}, "cat": "$cat", "ph": "X", "ts": ${(t.startMicros)}, "dur": ${(t.durationMicros)}, "pid": 0, "tid": ${t.threadId}}""" } - val entryIterator = timings.entrySet().iterator() + val entryIterator = currentTimings while (entryIterator.hasNext) { - val entry = entryIterator.next() - trace.append(durationEvent(taskName(entry.getKey), "task", entry.getValue)) + val (key, value) = entryIterator.next() + trace.append(durationEvent(taskName(key), "task", value)) if (entryIterator.hasNext) trace.append(",") } trace.append("]}") From d58aab5d8471c4ed75002fbab79a46d1fce075e0 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sun, 9 Aug 2020 17:39:15 -0700 Subject: [PATCH 7/7] Add super shell options This commit adds a few options to supershell: 1. Max items -- sets the max number of tasks to display in the progress reports. It is pretty hard to read more than a few items in the progress reports so I set the default limit to 8 and made that configurable via the superShellMaxTasks parameter. If there are more than the limit, there is an additional line telling how many additional tasks are running 2. sleep -- sets how long to sleep between reports. The default is 500ms to ensure that it updates at least once per second but the previous value of 100ms is more frequent than necessary 3. threshold -- sets the minimum duration a task has to run before being printed in the progress reports. The default threshold is increased from 10ms to 100ms. This introduces a delay of threshold milliseconds before any progress lines appear and also means that if no tasks ever exceed the threshold, then no progress is ever displayed. --- .../sbt/internal/util/ProgressState.scala | 24 ++++++++++++------- .../scala/sbt/internal/util/Terminal.scala | 8 ++++++- main/src/main/scala/sbt/Defaults.scala | 8 +++++-- main/src/main/scala/sbt/Keys.scala | 4 ++++ main/src/main/scala/sbt/Main.scala | 14 ++++++++++- main/src/main/scala/sbt/MainLoop.scala | 8 ++++++- .../scala/sbt/internal/CommandExchange.scala | 3 ++- .../src/main/scala/sbt/internal/SysProp.scala | 4 +++- .../scala/sbt/internal/TaskProgress.scala | 4 +--- .../sbt/internal/server/NetworkChannel.scala | 11 +++++++-- 10 files changed, 67 insertions(+), 21 deletions(-) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala index 3b3b4d125..d6b594578 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ProgressState.scala @@ -26,14 +26,16 @@ private[sbt] final class ProgressState( val padding: AtomicInteger, val blankZone: Int, val currentLineBytes: AtomicReference[ArrayBuffer[Byte]], + val maxItems: Int, ) { - def this(blankZone: Int) = - this( - new AtomicReference(Nil), - new AtomicInteger(0), - blankZone, - new AtomicReference(new ArrayBuffer[Byte]), - ) + def this(blankZone: Int, maxItems: Int) = this( + new AtomicReference(Nil), + new AtomicInteger(0), + blankZone, + new AtomicReference(new ArrayBuffer[Byte]), + maxItems, + ) + def this(blankZone: Int) = this(blankZone, 8) def currentLine: Option[String] = new String(currentLineBytes.get.toArray, "UTF-8").linesIterator.toSeq.lastOption .map(EscHelpers.stripColorsAndMoves) @@ -162,14 +164,18 @@ private[sbt] object ProgressState { terminal.withPrintStream { ps => val commandFromThisTerminal = pe.channelName.fold(true)(_ == terminal.name) val info = if (commandFromThisTerminal) { - pe.items.map { item => + val base = pe.items.map { item => val elapsed = item.elapsedMicros / 1000000L s" | => ${item.name} ${elapsed}s" } + val limit = state.maxItems + if (base.size > limit) + s" | ... (${base.size - limit} other tasks)" +: base.takeRight(limit) + else base } else { pe.command.toSeq.flatMap { cmd => val width = terminal.getWidth - val sanitized = if ((cmd.length + SERVER_IS_RUNNING_LENGTH) < width) { + val sanitized = if ((cmd.length + SERVER_IS_RUNNING_LENGTH) > width) { if (SERVER_IS_RUNNING_LENGTH + cmd.length < width) cmd else cmd.take(MIN_COMMAND_WIDTH) + "..." } else cmd diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index 1137829d3..3c7d7a008 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -141,7 +141,7 @@ trait Terminal extends AutoCloseable { private[sbt] def withPrintStream[T](f: PrintStream => T): T private[sbt] def withRawOutput[R](f: => R): R private[sbt] def restore(): Unit = {} - private[sbt] val progressState = new ProgressState(1) + private[sbt] def progressState: ProgressState private[this] val promptHolder: AtomicReference[Prompt] = new AtomicReference(Prompt.Pending) private[sbt] final def prompt: Prompt = promptHolder.get private[sbt] final def setPrompt(newPrompt: Prompt): Unit = @@ -315,6 +315,7 @@ object Terminal { private[this] object ProxyTerminal extends Terminal { private def t: Terminal = activeTerminal.get + override private[sbt] def progressState: ProgressState = t.progressState override def getWidth: Int = t.getWidth override def getHeight: Int = t.getHeight override def getLineHeightAndWidth(line: String): (Int, Int) = t.getLineHeightAndWidth(line) @@ -755,6 +756,9 @@ object Terminal { private val capabilityMap = org.jline.utils.InfoCmp.Capability.values().map(c => c.toString -> c).toMap + private val consoleProgressState = new AtomicReference[ProgressState](new ProgressState(1)) + private[sbt] def setConsoleProgressState(progressState: ProgressState): Unit = + consoleProgressState.set(progressState) @deprecated("For compatibility only", "1.4.0") private[sbt] def deprecatedTeminal: jline.Terminal = console.toJLine @@ -770,6 +774,7 @@ object Terminal { } private[this] val isCI = sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI") override lazy val isAnsiSupported: Boolean = term.isAnsiSupported && !isCI + override private[sbt] def progressState: ProgressState = consoleProgressState.get override def isEchoEnabled: Boolean = system.echo() override def isSuccessEnabled: Boolean = true override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = @@ -912,6 +917,7 @@ object Terminal { new WriteableInputStream(nullInputStream, "null-writeable-input-stream") private[sbt] val NullTerminal = new Terminal { override def close(): Unit = {} + override private[sbt] def progressState: ProgressState = new ProgressState(1) override def getBooleanCapability(capability: String, jline3: Boolean): Boolean = false override def getHeight: Int = 0 override def getLastLine: Option[String] = None diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 7fbf1aec9..2f9a4498b 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -97,7 +97,7 @@ import sjsonnew._ import sjsonnew.support.scalajson.unsafe.Converter import scala.collection.immutable.ListMap -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration._ import scala.util.control.NonFatal import scala.xml.NodeSeq @@ -393,11 +393,15 @@ object Defaults extends BuildCommon { else appConfiguration.value.provider.scalaProvider.launcher.topLoader.getParent }, useSuperShell := { if (insideCI.value) false else Terminal.console.isSupershellEnabled }, + superShellThreshold :== SysProp.supershellThreshold, + superShellMaxTasks :== SysProp.supershellMaxTasks, + superShellSleep :== SysProp.supershellSleep.millis, progressReports := { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector rs map { Keys.TaskProgress(_) } }, - progressState := Some(new ProgressState(SysProp.supershellBlankZone)), + // progressState is deprecated + SettingKey[Option[ProgressState]]("progressState") := None, Previous.cache := new Previous( Def.streamsManagerKey.value, Previous.references.value.getReferences diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index bdc25000c..30cfe2425 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -557,9 +557,13 @@ object Keys { private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress") private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress") val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.") + val superShellMaxTasks = settingKey[Int]("The max number of tasks to display in the supershell progress report") + val superShellSleep = settingKey[FiniteDuration]("The minimum duration to sleep between progress reports") + val superShellThreshold = settingKey[FiniteDuration]("The minimum amount of time a task must be running to appear in the supershell progress report") val turbo = settingKey[Boolean]("Enables (true) or disables optional performance features.") // This key can be used to add custom ExecuteProgress instances val progressReports = settingKey[Seq[TaskProgress]]("A function that returns a list of progress reporters.").withRank(DTask) + @deprecated("unused", "1.4.0") private[sbt] val progressState = settingKey[Option[ProgressState]]("The optional progress state if supershell is enabled.").withRank(Invisible) private[sbt] val postProgressReports = settingKey[Unit]("Internally used to modify logger.").withRank(DTask) @deprecated("No longer used", "1.3.0") diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 8361d1492..34e6333c6 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -935,13 +935,25 @@ object BuiltinCommands { // This is a workaround for the console task in dotty which uses the classloader cache. // We need to override the top loader in that case so that it gets the forked jline. s5.extendedClassLoaderCache.setParent(Project.extract(s5).get(Keys.scalaInstanceTopLoader)) - CheckBuildSources.init(LintUnused.lintUnusedFunc(s5)) + addSuperShellParams(CheckBuildSources.init(LintUnused.lintUnusedFunc(s5))) } private val setupGlobalFileTreeRepository: State => State = { state => state.get(sbt.nio.Keys.globalFileTreeRepository).foreach(_.close()) state.put(sbt.nio.Keys.globalFileTreeRepository, FileTreeRepository.default) } + private val addSuperShellParams: State => State = (s: State) => { + val extracted = Project.extract(s) + import scala.concurrent.duration._ + val sleep = extracted.getOpt(Keys.superShellSleep).getOrElse(SysProp.supershellSleep.millis) + val threshold = + extracted.getOpt(Keys.superShellThreshold).getOrElse(SysProp.supershellThreshold) + val maxItems = extracted.getOpt(Keys.superShellMaxTasks).getOrElse(SysProp.supershellMaxTasks) + Terminal.setConsoleProgressState(new ProgressState(1, maxItems)) + s.put(Keys.superShellSleep.key, sleep) + .put(Keys.superShellThreshold.key, threshold) + .put(Keys.superShellMaxTasks.key, maxItems) + } private val addCacheStoreFactoryFactory: State => State = (s: State) => { val size = Project .extract(s) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 801465b8d..7bcea26ce 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -22,8 +22,10 @@ import sbt.protocol._ import sbt.util.{ Logger, LoggerContext } import scala.annotation.tailrec +import scala.concurrent.duration._ import scala.util.control.NonFatal import sbt.internal.FastTrackCommands +import sbt.internal.SysProp object MainLoop { @@ -150,7 +152,11 @@ object MainLoop { def next(state: State): State = { val context = LoggerContext(useLog4J = state.get(Keys.useLog4J.key).getOrElse(false)) - val taskProgress = new TaskProgress + val superShellSleep = + state.get(Keys.superShellSleep.key).getOrElse(SysProp.supershellSleep.millis) + val superShellThreshold = + state.get(Keys.superShellThreshold.key).getOrElse(SysProp.supershellThreshold) + val taskProgress = new TaskProgress(superShellSleep, superShellThreshold) try { ErrorHandling.wideConvert { state diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 6cb4f4ece..505c913c2 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -201,7 +201,8 @@ private[sbt] final class CommandExchange { instance, handlers, s.log, - mkAskUser(name) + mkAskUser(name), + Option(lastState.get), ) subscribe(channel) } diff --git a/main/src/main/scala/sbt/internal/SysProp.scala b/main/src/main/scala/sbt/internal/SysProp.scala index 737485965..2f4e95953 100644 --- a/main/src/main/scala/sbt/internal/SysProp.scala +++ b/main/src/main/scala/sbt/internal/SysProp.scala @@ -104,7 +104,9 @@ object SysProp { def dumbTerm: Boolean = sys.env.get("TERM").contains("dumb") def supershell: Boolean = booleanOpt("sbt.supershell").getOrElse(!dumbTerm && color) - def supershellSleep: Long = long("sbt.supershell.sleep", 100L) + def supershellMaxTasks: Int = int("sbt.supershell.maxitems", 8) + def supershellSleep: Long = long("sbt.supershell.sleep", 500.millis.toMillis) + def supershellThreshold: FiniteDuration = long("sbt.supershell.threshold", 100L).millis def supershellBlankZone: Int = int("sbt.supershell.blankzone", 1) def defaultUseCoursier: Boolean = { diff --git a/main/src/main/scala/sbt/internal/TaskProgress.scala b/main/src/main/scala/sbt/internal/TaskProgress.scala index f292db125..0c3cfbc72 100644 --- a/main/src/main/scala/sbt/internal/TaskProgress.scala +++ b/main/src/main/scala/sbt/internal/TaskProgress.scala @@ -20,13 +20,11 @@ import java.util.concurrent.{ ConcurrentHashMap, Executors, TimeoutException } /** * implements task progress display on the shell. */ -private[sbt] class TaskProgress +private[sbt] class TaskProgress(sleepDuration: FiniteDuration, threshold: FiniteDuration) extends AbstractTaskExecuteProgress with ExecuteProgress[Task] with AutoCloseable { private[this] val lastTaskCount = new AtomicInteger(0) - private[this] val sleepDuration = SysProp.supershellSleep.millis - private[this] val threshold = 10.millis private[this] val reportLoop = new AtomicReference[AutoCloseable] private[this] val active = new ConcurrentHashMap[Task[_], AutoCloseable] private[this] val nextReport = new AtomicReference(Deadline.now) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index ca75b369a..1c7d93484 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -49,6 +49,7 @@ import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } import BasicJsonProtocol._ import Serialization.{ attach, promptChannel } +import sbt.internal.util.ProgressState final class NetworkChannel( val name: String, @@ -58,7 +59,8 @@ final class NetworkChannel( instance: ServerInstance, handlers: Seq[ServerHandler], val log: Logger, - mkUIThreadImpl: (State, CommandChannel) => UITask + mkUIThreadImpl: (State, CommandChannel) => UITask, + state: Option[State], ) extends CommandChannel { self => def this( name: String, @@ -77,7 +79,8 @@ final class NetworkChannel( instance, handlers, log, - new UITask.AskUserTask(_, _) + new UITask.AskUserTask(_, _), + None ) private val running = new AtomicBoolean(true) @@ -787,6 +790,10 @@ final class NetworkChannel( ) } private[this] val blockedThreads = ConcurrentHashMap.newKeySet[Thread] + override private[sbt] val progressState: ProgressState = new ProgressState( + 1, + state.flatMap(_.get(Keys.superShellMaxTasks.key)).getOrElse(SysProp.supershellMaxTasks) + ) override def getWidth: Int = getProperty(_.width, 0).getOrElse(0) override def getHeight: Int = getProperty(_.height, 0).getOrElse(0) override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false)