From 58822cc3f5402205531bf8d909ed6806c763bacb Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Tue, 17 Dec 2019 12:00:09 -0800 Subject: [PATCH] Add virtual System.out for supershell In order to make supershell work with println, this commit introduces a virtual System.out to sbt. While sbt is running, we override the default java.lang.System.out, java.lang.System.in, scala.Console.out and scala.Console.in (unless the property `sbt.io.virtual` is set to something other than true). When using virtual io, we buffer all of the bytes that are written to System.out and Console.out until flush is called. When flushing the output, we check if there are any progress lines. If so, we interleave them with the new lines to print. The flushing happens on a background thread so it should hopefully not impede task progress. This commit also adds logic for handling progress when the cursor is not all the way to the left. We now track all of the bytes that have been written since the last new line. Supershell will then calculate the cursor position from those bytes* and move the cursor back to the correct position. The motivation for this was to make the run command work with supershell even when multiple main classes were specified. * This might not be completely reliable if the string contains ansi cursor movement characters. --- .../sbt/internal/util/ConsoleAppender.scala | 193 ++++++++++-------- .../scala/sbt/internal/util/Terminal.scala | 121 ++++++++++- main/src/main/scala/sbt/EvaluateTask.scala | 5 +- main/src/main/scala/sbt/Main.scala | 27 ++- .../main/scala/sbt/internal/LogManager.scala | 15 +- 5 files changed, 249 insertions(+), 112 deletions(-) diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala index d7b8fadc1..45b156085 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala @@ -18,6 +18,8 @@ import org.apache.logging.log4j.{ Level => XLevel } import sbt.internal.util.ConsoleAppender._ import sbt.util._ +import scala.collection.mutable.ArrayBuffer + object ConsoleLogger { // These are provided so other modules do not break immediately. @deprecated("Use EscHelpers.ESC instead", "0.13.x") @@ -103,12 +105,15 @@ class ConsoleLogger private[ConsoleLogger] ( } object ConsoleAppender { + private[sbt] def cursorLeft(n: Int): String = s"\u001B[${n}D" + private[sbt] def cursorRight(n: Int): String = s"\u001B[${n}C" private[sbt] def cursorUp(n: Int): String = s"\u001B[${n}A" private[sbt] def cursorDown(n: Int): String = s"\u001B[${n}B" private[sbt] def scrollUp(n: Int): String = s"\u001B[${n}S" private[sbt] def clearScreen(n: Int): String = s"\u001B[${n}J" + private[sbt] def clearLine(n: Int): String = s"\u001B[${n}K" private[sbt] final val DeleteLine = "\u001B[2K" - private[sbt] final val CursorLeft1000 = "\u001B[1000D" + private[sbt] final val CursorLeft1000 = cursorLeft(1000) private[sbt] final val CursorDown1 = cursorDown(1) private[this] val showProgressHolder: AtomicBoolean = new AtomicBoolean(false) def setShowProgress(b: Boolean): Unit = showProgressHolder.set(b) @@ -313,77 +318,6 @@ class ConsoleAppender private[ConsoleAppender] ( ) extends AbstractAppender(name, null, LogExchange.dummyLayout, true, Array.empty) { import scala.Console.{ BLUE, GREEN, RED, YELLOW } - private val progressState: AtomicReference[ProgressState] = new AtomicReference(null) - private[sbt] def setProgressState(state: ProgressState) = progressState.set(state) - - /** - * Splits a log message into individual lines and interlaces each line with - * the task progress report to reduce the appearance of flickering. It is assumed - * that this method is only called while holding the out.lockObject. - */ - private def supershellInterlaceMsg(msg: String): Unit = { - val state = progressState.get - import state._ - val progress = progressLines.get - msg.linesIterator.foreach { l => - out.println(s"$DeleteLine$l") - if (progress.length > 0) { - val pad = if (padding.get > 0) padding.decrementAndGet() else 0 - val width = Terminal.getWidth - val len: Int = progress.foldLeft(progress.length)(_ + terminalLines(width)(_)) - deleteConsoleLines(blankZone + pad) - progress.foreach(printProgressLine) - out.print(cursorUp(blankZone + len + padding.get)) - } - } - out.flush() - } - - private def printProgressLine(line: String): Unit = { - out.print(DeleteLine) - out.println(line) - } - - /** - * Receives a new task report and replaces the old one. In the event that the new - * report has fewer lines than the previous report, padding lines are added on top - * so that the console log lines remain contiguous. When a console line is printed - * at the info or greater level, we can decrement the padding because the console - * line will have filled in the blank line. - */ - private def updateProgressState(pe: ProgressEvent): Unit = { - val state = progressState.get - import state._ - val sorted = pe.items.sortBy(x => x.elapsedMicros) - val info = sorted map { item => - val elapsed = item.elapsedMicros / 1000000L - s" | => ${item.name} ${elapsed}s" - } - - val width = Terminal.getWidth - val currentLength = info.foldLeft(info.length)(_ + terminalLines(width)(_)) - val previousLines = progressLines.getAndSet(info) - val prevLength = previousLines.foldLeft(previousLines.length)(_ + terminalLines(width)(_)) - - val prevPadding = padding.get - val newPadding = math.max(0, prevLength + prevPadding - currentLength) - padding.set(newPadding) - - deleteConsoleLines(newPadding) - deleteConsoleLines(blankZone) - info.foreach(printProgressLine) - - out.print(cursorUp(blankZone + currentLength + newPadding)) - out.flush() - } - private def terminalLines(width: Int): String => Int = - (progressLine: String) => if (width > 0) (progressLine.length - 1) / width else 0 - private def deleteConsoleLines(n: Int): Unit = { - (1 to n) foreach { _ => - out.println(DeleteLine) - } - } - private val reset: String = { if (ansiCodesSupported && useFormat) scala.Console.RESET else "" @@ -514,11 +448,7 @@ class ConsoleAppender private[ConsoleAppender] ( private def write(msg: String): Unit = { val toWrite = if (!useFormat || !ansiCodesSupported) EscHelpers.removeEscapeSequences(msg) else msg - if (progressState.get != null) { - supershellInterlaceMsg(toWrite) - } else { - out.println(toWrite) - } + out.println(toWrite) } private def appendMessage(level: Level.Value, msg: Message): Unit = @@ -548,18 +478,16 @@ class ConsoleAppender private[ConsoleAppender] ( } } - private def appendProgressEvent(pe: ProgressEvent): Unit = - if (progressState.get != null) { - out.lockObject.synchronized(updateProgressState(pe)) - } - private def appendMessageContent(level: Level.Value, o: AnyRef): Unit = { def appendEvent(oe: ObjectEvent[_]): Unit = { val contentType = oe.contentType contentType match { case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent]) case "sbt.internal.util.ProgressEvent" => - appendProgressEvent(oe.message.asInstanceOf[ProgressEvent]) + oe.message match { + case pe: ProgressEvent => ProgressState.updateProgressState(pe) + case _ => + } case _ => LogExchange.stringCodec[AnyRef](contentType) match { case Some(codec) if contentType == "sbt.internal.util.SuccessEvent" => @@ -586,11 +514,106 @@ final class SuppressedTraceContext(val traceLevel: Int, val useFormat: Boolean) private[sbt] final class ProgressState( val progressLines: AtomicReference[Seq[String]], val padding: AtomicInteger, - val blankZone: Int + val blankZone: Int, + val currentLineBytes: AtomicReference[ArrayBuffer[Byte]], ) { - def this(blankZone: Int) = this(new AtomicReference(Nil), new AtomicInteger(0), blankZone) + def this(blankZone: Int) = + this( + new AtomicReference(Nil), + new AtomicInteger(0), + blankZone, + new AtomicReference(new ArrayBuffer[Byte]) + ) def reset(): Unit = { progressLines.set(Nil) padding.set(0) } } +private[sbt] object ProgressState { + private val progressState: AtomicReference[ProgressState] = new AtomicReference(null) + private[util] def clearBytes(): Unit = progressState.get match { + case null => + case state => + val pad = state.padding.get + if (state.currentLineBytes.get.isEmpty && pad > 0) state.padding.decrementAndGet() + state.currentLineBytes.set(new ArrayBuffer[Byte]) + } + + private[util] def addBytes(bytes: ArrayBuffer[Byte]): Unit = progressState.get match { + case null => + case state => + val previous = state.currentLineBytes.get + val padding = state.padding.get + val prevLineCount = if (padding > 0) Terminal.lineCount(new String(previous.toArray)) else 0 + previous ++= bytes + if (padding > 0) { + val newLineCount = Terminal.lineCount(new String(previous.toArray)) + val diff = newLineCount - prevLineCount + state.padding.set(math.max(padding - diff, 0)) + } + } + + private[util] def reprint(printStream: PrintStream): Unit = progressState.get match { + case null => printStream.write('\n') + case state => + if (state.progressLines.get.nonEmpty) { + val lines = printProgress(0, 0) + printStream.print(ConsoleAppender.clearScreen(0) + "\n" + lines) + } else printStream.write('\n') + } + + /** + * Receives a new task report and replaces the old one. In the event that the new + * report has fewer lines than the previous report, padding lines are added on top + * so that the console log lines remain contiguous. When a console line is printed + * at the info or greater level, we can decrement the padding because the console + * line will have filled in the blank line. + */ + private[util] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps => + progressState.get match { + case null => + case state => + val info = pe.items.map { item => + val elapsed = item.elapsedMicros / 1000000L + s" | => ${item.name} ${elapsed}s" + } + + val currentLength = info.foldLeft(0)(_ + Terminal.lineCount(_)) + val previousLines = state.progressLines.getAndSet(info) + val prevLength = previousLines.foldLeft(0)(_ + Terminal.lineCount(_)) + + val (height, width) = Terminal.getLineHeightAndWidth + val prevSize = prevLength + state.padding.get + + val newPadding = math.max(0, prevSize - currentLength) + state.padding.set(newPadding) + ps.print(printProgress(height, width)) + ps.flush() + } + } + + private[sbt] def set(state: ProgressState): Unit = progressState.set(state) + + private[util] def printProgress(height: Int, width: Int): String = progressState.get match { + case null => "" + case state => + val previousLines = state.progressLines.get + if (previousLines.nonEmpty) { + val currentLength = previousLines.foldLeft(0)(_ + Terminal.lineCount(_)) + val left = cursorLeft(1000) // resets the position to the left + val offset = width > 0 + val pad = math.max(state.padding.get - height, 0) + val start = clearScreen(0) + (if (offset) "\n" else "") + val totalSize = currentLength + state.blankZone + pad + val blank = left + s"\n$DeleteLine" * (totalSize - currentLength) + val lines = previousLines.mkString(DeleteLine, s"\n$DeleteLine", s"\n$DeleteLine") + val resetCursorUp = cursorUp(totalSize + (if (offset) 1 else 0)) + val resetCursorRight = left + (if (offset) cursorRight(width) else "") + val resetCursor = resetCursorUp + resetCursorRight + start + blank + lines + resetCursor + } else { + clearScreen(0) + } + } + +} diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index 44a15caf4..680ad1956 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -7,14 +7,17 @@ package sbt.internal.util -import java.io.{ InputStream, OutputStream } +import java.io.{ InputStream, OutputStream, PrintStream } import java.nio.channels.ClosedChannelException import java.util.Locale +import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import java.util.concurrent.locks.ReentrantLock import jline.console.ConsoleReader +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal object Terminal { @@ -37,6 +40,38 @@ object Terminal { */ def getHeight: Int = terminal.getHeight + /** + * Returns the height and width of the current line that is displayed on the terminal. If the + * most recently flushed byte is a newline, this will be `(0, 0)`. + * + * @return the (height, width) pair + */ + def getLineHeightAndWidth: (Int, Int) = currentLine.get.toArray match { + case bytes if bytes.isEmpty => (0, 0) + case bytes => + val width = getWidth + val line = EscHelpers.removeEscapeSequences(new String(bytes)) + val count = lineCount(line) + (count, line.length - ((count - 1) * width)) + } + + /** + * Returns the number of lines that the input string will cover given the current width of the + * terminal. + * + * @param line the input line + * @return the number of lines that the line will cover on the terminal + */ + def lineCount(line: String): Int = { + val width = getWidth + val lines = EscHelpers.removeEscapeSequences(line).split('\n') + def count(l: String): Int = { + val len = l.length + if (width > 0 && len > 0) (len - 1 + width) / width else 0 + } + lines.tail.foldLeft(lines.headOption.fold(0)(count))(_ + count(_)) + } + /** * Returns true if the current terminal supports ansi characters. * @@ -101,6 +136,17 @@ object Terminal { } } + /** + * + * @param f the thunk to run + * @tparam T the result type of the thunk + * @return the result of the thunk + */ + private[sbt] def withStreams[T](f: => T): T = + if (System.getProperty("sbt.io.virtual", "true") == "true") { + withOut(withIn(f)) + } else f + /** * Runs a thunk ensuring that the terminal is in canonical mode: * [[https://www.gnu.org/software/libc/manual/html_node/Canonical-or-Not.html Canonical or Not]]. @@ -139,6 +185,79 @@ object Terminal { } } + private[this] val originalOut = System.out + private[this] val originalIn = System.in + private[this] val currentLine = new AtomicReference(new ArrayBuffer[Byte]) + private[this] val lineBuffer = new LinkedBlockingQueue[Byte] + private[this] val flushQueue = new LinkedBlockingQueue[Unit] + private[this] val writeLock = new AnyRef + private[this] final class WriteThread extends Thread("sbt-stdout-write-thread") { + setDaemon(true) + start() + private[this] val isStopped = new AtomicBoolean(false) + def close(): Unit = { + isStopped.set(true) + flushQueue.put(()) + () + } + @tailrec override def run(): Unit = { + try { + flushQueue.take() + val bytes = new java.util.ArrayList[Byte] + writeLock.synchronized { + lineBuffer.drainTo(bytes) + import scala.collection.JavaConverters._ + val remaining = bytes.asScala.foldLeft(new ArrayBuffer[Byte]) { (buf, i) => + if (i == 10) { + ProgressState.addBytes(buf) + ProgressState.clearBytes() + buf.foreach(b => originalOut.write(b & 0xFF)) + ProgressState.reprint(originalOut) + currentLine.set(new ArrayBuffer[Byte]) + new ArrayBuffer[Byte] + } else buf += i + } + if (remaining.nonEmpty) { + currentLine.get ++= remaining + originalOut.write(remaining.toArray) + } + originalOut.flush() + } + } catch { case _: InterruptedException => isStopped.set(true) } + if (!isStopped.get) run() + } + } + private[this] def withOut[T](f: => T): T = { + val thread = new WriteThread + try { + System.setOut(SystemPrintStream) + scala.Console.withOut(SystemPrintStream)(f) + } finally { + thread.close() + System.setOut(originalOut) + } + } + private[this] def withIn[T](f: => T): T = + try { + System.setIn(Terminal.wrappedSystemIn) + scala.Console.withIn(Terminal.wrappedSystemIn)(f) + } finally System.setIn(originalIn) + + private[sbt] def withPrintStream[T](f: PrintStream => T): T = writeLock.synchronized { + f(originalOut) + } + private object SystemOutputStream extends OutputStream { + override def write(b: Int): Unit = writeLock.synchronized(lineBuffer.put(b.toByte)) + override def write(b: Array[Byte]): Unit = writeLock.synchronized(b.foreach(lineBuffer.put)) + override def write(b: Array[Byte], off: Int, len: Int): Unit = writeLock.synchronized { + val lo = math.max(0, off) + val hi = math.min(math.max(off + len, 0), b.length) + (lo until hi).foreach(i => lineBuffer.put(b(i))) + } + def write(s: String): Unit = s.getBytes.foreach(lineBuffer.put) + override def flush(): Unit = writeLock.synchronized(flushQueue.put(())) + } + private object SystemPrintStream extends PrintStream(SystemOutputStream, true) private[this] object WrappedSystemIn extends InputStream { private[this] val in = terminal.wrapInIfNeeded(System.in) override def available(): Int = if (attached.get) in.available else 0 diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index 005b2776a..2816acfe9 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -260,10 +260,7 @@ object EvaluateTask { ps.reset() ConsoleAppender.setShowProgress(true) val appender = MainAppender.defaultScreen(StandardMain.console) - appender match { - case c: ConsoleAppender => c.setProgressState(ps) - case _ => - } + ProgressState.set(ps) val log = LogManager.progressLogger(appender) Some(new TaskProgress(log)) case _ => None diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 463da05e7..9f76527d7 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -50,21 +50,20 @@ private[sbt] object xMain { // if we detect -Dsbt.client=true or -client, run thin client. val clientModByEnv = SysProp.client val userCommands = configuration.arguments.map(_.trim) - if (clientModByEnv || (userCommands.exists { cmd => - (cmd == DashClient) || (cmd == DashDashClient) - })) { - val args = userCommands.toList filterNot { cmd => - (cmd == DashClient) || (cmd == DashDashClient) + val isClient: String => Boolean = cmd => (cmd == DashClient) || (cmd == DashDashClient) + Terminal.withStreams { + if (clientModByEnv || userCommands.exists(isClient)) { + val args = userCommands.toList.filterNot(isClient) + NetworkClient.run(configuration, args) + Exit(0) + } else { + val state = StandardMain.initialState( + configuration, + Seq(defaults, early), + runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil + ) + StandardMain.runManaged(state) } - NetworkClient.run(configuration, args) - Exit(0) - } else { - val state = StandardMain.initialState( - configuration, - Seq(defaults, early), - runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil - ) - StandardMain.runManaged(state) } } finally { ShutdownHooks.close() diff --git a/main/src/main/scala/sbt/internal/LogManager.scala b/main/src/main/scala/sbt/internal/LogManager.scala index 2563e6c1e..7781e5b0f 100644 --- a/main/src/main/scala/sbt/internal/LogManager.scala +++ b/main/src/main/scala/sbt/internal/LogManager.scala @@ -9,6 +9,7 @@ package sbt package internal import java.io.PrintWriter + import Def.ScopedKey import Scope.GlobalScope import Keys.{ logLevel, logManager, persistLogLevel, persistTraceLevel, sLog, traceLevel } @@ -16,13 +17,14 @@ import sbt.internal.util.{ AttributeKey, ConsoleAppender, ConsoleOut, + MainAppender, + ManagedLogger, + ProgressState, Settings, - SuppressedTraceContext, - MainAppender + SuppressedTraceContext } import MainAppender._ -import sbt.util.{ Level, Logger, LogExchange } -import sbt.internal.util.ManagedLogger +import sbt.util.{ Level, LogExchange, Logger } import org.apache.logging.log4j.core.Appender sealed abstract class LogManager { @@ -142,10 +144,7 @@ object LogManager { val extraBacked = state.globalLogging.backed :: relay :: Nil val ps = Project.extract(state).get(sbt.Keys.progressState in ThisBuild) val consoleOpt = consoleLocally(state, console) - consoleOpt foreach { - case a: ConsoleAppender => ps.foreach(a.setProgressState) - case _ => - } + ps.foreach(ProgressState.set) val config = MainAppender.MainAppenderConfig( consoleOpt, backed,