Add virtual System.out for supershell

In order to make supershell work with println, this commit introduces a
virtual System.out to sbt. While sbt is running, we override the default
java.lang.System.out, java.lang.System.in, scala.Console.out and
scala.Console.in (unless the property `sbt.io.virtual` is set to
something other than true). When using virtual io, we buffer all of the
bytes that are written to System.out and Console.out until flush is
called. When flushing the output, we check if there are any progress
lines. If so, we interleave them with the new lines to print.

The flushing happens on a background thread so it should hopefully not
impede task progress.

This commit also adds logic for handling progress when the cursor is not
all the way to the left. We now track all of the bytes that have been
written since the last new line. Supershell will then calculate the
cursor position from those bytes* and move the cursor back to the
correct position. The motivation for this was to make the run command
work with supershell even when multiple main classes were specified.

* This might not be completely reliable if the string contains ansi
cursor movement characters.
This commit is contained in:
Ethan Atkins 2019-12-17 12:00:09 -08:00
parent a0012bab75
commit 58822cc3f5
5 changed files with 249 additions and 112 deletions

View File

@ -18,6 +18,8 @@ import org.apache.logging.log4j.{ Level => XLevel }
import sbt.internal.util.ConsoleAppender._
import sbt.util._
import scala.collection.mutable.ArrayBuffer
object ConsoleLogger {
// These are provided so other modules do not break immediately.
@deprecated("Use EscHelpers.ESC instead", "0.13.x")
@ -103,12 +105,15 @@ class ConsoleLogger private[ConsoleLogger] (
}
object ConsoleAppender {
private[sbt] def cursorLeft(n: Int): String = s"\u001B[${n}D"
private[sbt] def cursorRight(n: Int): String = s"\u001B[${n}C"
private[sbt] def cursorUp(n: Int): String = s"\u001B[${n}A"
private[sbt] def cursorDown(n: Int): String = s"\u001B[${n}B"
private[sbt] def scrollUp(n: Int): String = s"\u001B[${n}S"
private[sbt] def clearScreen(n: Int): String = s"\u001B[${n}J"
private[sbt] def clearLine(n: Int): String = s"\u001B[${n}K"
private[sbt] final val DeleteLine = "\u001B[2K"
private[sbt] final val CursorLeft1000 = "\u001B[1000D"
private[sbt] final val CursorLeft1000 = cursorLeft(1000)
private[sbt] final val CursorDown1 = cursorDown(1)
private[this] val showProgressHolder: AtomicBoolean = new AtomicBoolean(false)
def setShowProgress(b: Boolean): Unit = showProgressHolder.set(b)
@ -313,77 +318,6 @@ class ConsoleAppender private[ConsoleAppender] (
) extends AbstractAppender(name, null, LogExchange.dummyLayout, true, Array.empty) {
import scala.Console.{ BLUE, GREEN, RED, YELLOW }
private val progressState: AtomicReference[ProgressState] = new AtomicReference(null)
private[sbt] def setProgressState(state: ProgressState) = progressState.set(state)
/**
* Splits a log message into individual lines and interlaces each line with
* the task progress report to reduce the appearance of flickering. It is assumed
* that this method is only called while holding the out.lockObject.
*/
private def supershellInterlaceMsg(msg: String): Unit = {
val state = progressState.get
import state._
val progress = progressLines.get
msg.linesIterator.foreach { l =>
out.println(s"$DeleteLine$l")
if (progress.length > 0) {
val pad = if (padding.get > 0) padding.decrementAndGet() else 0
val width = Terminal.getWidth
val len: Int = progress.foldLeft(progress.length)(_ + terminalLines(width)(_))
deleteConsoleLines(blankZone + pad)
progress.foreach(printProgressLine)
out.print(cursorUp(blankZone + len + padding.get))
}
}
out.flush()
}
private def printProgressLine(line: String): Unit = {
out.print(DeleteLine)
out.println(line)
}
/**
* Receives a new task report and replaces the old one. In the event that the new
* report has fewer lines than the previous report, padding lines are added on top
* so that the console log lines remain contiguous. When a console line is printed
* at the info or greater level, we can decrement the padding because the console
* line will have filled in the blank line.
*/
private def updateProgressState(pe: ProgressEvent): Unit = {
val state = progressState.get
import state._
val sorted = pe.items.sortBy(x => x.elapsedMicros)
val info = sorted map { item =>
val elapsed = item.elapsedMicros / 1000000L
s" | => ${item.name} ${elapsed}s"
}
val width = Terminal.getWidth
val currentLength = info.foldLeft(info.length)(_ + terminalLines(width)(_))
val previousLines = progressLines.getAndSet(info)
val prevLength = previousLines.foldLeft(previousLines.length)(_ + terminalLines(width)(_))
val prevPadding = padding.get
val newPadding = math.max(0, prevLength + prevPadding - currentLength)
padding.set(newPadding)
deleteConsoleLines(newPadding)
deleteConsoleLines(blankZone)
info.foreach(printProgressLine)
out.print(cursorUp(blankZone + currentLength + newPadding))
out.flush()
}
private def terminalLines(width: Int): String => Int =
(progressLine: String) => if (width > 0) (progressLine.length - 1) / width else 0
private def deleteConsoleLines(n: Int): Unit = {
(1 to n) foreach { _ =>
out.println(DeleteLine)
}
}
private val reset: String = {
if (ansiCodesSupported && useFormat) scala.Console.RESET
else ""
@ -514,11 +448,7 @@ class ConsoleAppender private[ConsoleAppender] (
private def write(msg: String): Unit = {
val toWrite =
if (!useFormat || !ansiCodesSupported) EscHelpers.removeEscapeSequences(msg) else msg
if (progressState.get != null) {
supershellInterlaceMsg(toWrite)
} else {
out.println(toWrite)
}
out.println(toWrite)
}
private def appendMessage(level: Level.Value, msg: Message): Unit =
@ -548,18 +478,16 @@ class ConsoleAppender private[ConsoleAppender] (
}
}
private def appendProgressEvent(pe: ProgressEvent): Unit =
if (progressState.get != null) {
out.lockObject.synchronized(updateProgressState(pe))
}
private def appendMessageContent(level: Level.Value, o: AnyRef): Unit = {
def appendEvent(oe: ObjectEvent[_]): Unit = {
val contentType = oe.contentType
contentType match {
case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent])
case "sbt.internal.util.ProgressEvent" =>
appendProgressEvent(oe.message.asInstanceOf[ProgressEvent])
oe.message match {
case pe: ProgressEvent => ProgressState.updateProgressState(pe)
case _ =>
}
case _ =>
LogExchange.stringCodec[AnyRef](contentType) match {
case Some(codec) if contentType == "sbt.internal.util.SuccessEvent" =>
@ -586,11 +514,106 @@ final class SuppressedTraceContext(val traceLevel: Int, val useFormat: Boolean)
private[sbt] final class ProgressState(
val progressLines: AtomicReference[Seq[String]],
val padding: AtomicInteger,
val blankZone: Int
val blankZone: Int,
val currentLineBytes: AtomicReference[ArrayBuffer[Byte]],
) {
def this(blankZone: Int) = this(new AtomicReference(Nil), new AtomicInteger(0), blankZone)
def this(blankZone: Int) =
this(
new AtomicReference(Nil),
new AtomicInteger(0),
blankZone,
new AtomicReference(new ArrayBuffer[Byte])
)
def reset(): Unit = {
progressLines.set(Nil)
padding.set(0)
}
}
private[sbt] object ProgressState {
private val progressState: AtomicReference[ProgressState] = new AtomicReference(null)
private[util] def clearBytes(): Unit = progressState.get match {
case null =>
case state =>
val pad = state.padding.get
if (state.currentLineBytes.get.isEmpty && pad > 0) state.padding.decrementAndGet()
state.currentLineBytes.set(new ArrayBuffer[Byte])
}
private[util] def addBytes(bytes: ArrayBuffer[Byte]): Unit = progressState.get match {
case null =>
case state =>
val previous = state.currentLineBytes.get
val padding = state.padding.get
val prevLineCount = if (padding > 0) Terminal.lineCount(new String(previous.toArray)) else 0
previous ++= bytes
if (padding > 0) {
val newLineCount = Terminal.lineCount(new String(previous.toArray))
val diff = newLineCount - prevLineCount
state.padding.set(math.max(padding - diff, 0))
}
}
private[util] def reprint(printStream: PrintStream): Unit = progressState.get match {
case null => printStream.write('\n')
case state =>
if (state.progressLines.get.nonEmpty) {
val lines = printProgress(0, 0)
printStream.print(ConsoleAppender.clearScreen(0) + "\n" + lines)
} else printStream.write('\n')
}
/**
* Receives a new task report and replaces the old one. In the event that the new
* report has fewer lines than the previous report, padding lines are added on top
* so that the console log lines remain contiguous. When a console line is printed
* at the info or greater level, we can decrement the padding because the console
* line will have filled in the blank line.
*/
private[util] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps =>
progressState.get match {
case null =>
case state =>
val info = pe.items.map { item =>
val elapsed = item.elapsedMicros / 1000000L
s" | => ${item.name} ${elapsed}s"
}
val currentLength = info.foldLeft(0)(_ + Terminal.lineCount(_))
val previousLines = state.progressLines.getAndSet(info)
val prevLength = previousLines.foldLeft(0)(_ + Terminal.lineCount(_))
val (height, width) = Terminal.getLineHeightAndWidth
val prevSize = prevLength + state.padding.get
val newPadding = math.max(0, prevSize - currentLength)
state.padding.set(newPadding)
ps.print(printProgress(height, width))
ps.flush()
}
}
private[sbt] def set(state: ProgressState): Unit = progressState.set(state)
private[util] def printProgress(height: Int, width: Int): String = progressState.get match {
case null => ""
case state =>
val previousLines = state.progressLines.get
if (previousLines.nonEmpty) {
val currentLength = previousLines.foldLeft(0)(_ + Terminal.lineCount(_))
val left = cursorLeft(1000) // resets the position to the left
val offset = width > 0
val pad = math.max(state.padding.get - height, 0)
val start = clearScreen(0) + (if (offset) "\n" else "")
val totalSize = currentLength + state.blankZone + pad
val blank = left + s"\n$DeleteLine" * (totalSize - currentLength)
val lines = previousLines.mkString(DeleteLine, s"\n$DeleteLine", s"\n$DeleteLine")
val resetCursorUp = cursorUp(totalSize + (if (offset) 1 else 0))
val resetCursorRight = left + (if (offset) cursorRight(width) else "")
val resetCursor = resetCursorUp + resetCursorRight
start + blank + lines + resetCursor
} else {
clearScreen(0)
}
}
}

View File

@ -7,14 +7,17 @@
package sbt.internal.util
import java.io.{ InputStream, OutputStream }
import java.io.{ InputStream, OutputStream, PrintStream }
import java.nio.channels.ClosedChannelException
import java.util.Locale
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.locks.ReentrantLock
import jline.console.ConsoleReader
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
object Terminal {
@ -37,6 +40,38 @@ object Terminal {
*/
def getHeight: Int = terminal.getHeight
/**
* Returns the height and width of the current line that is displayed on the terminal. If the
* most recently flushed byte is a newline, this will be `(0, 0)`.
*
* @return the (height, width) pair
*/
def getLineHeightAndWidth: (Int, Int) = currentLine.get.toArray match {
case bytes if bytes.isEmpty => (0, 0)
case bytes =>
val width = getWidth
val line = EscHelpers.removeEscapeSequences(new String(bytes))
val count = lineCount(line)
(count, line.length - ((count - 1) * width))
}
/**
* Returns the number of lines that the input string will cover given the current width of the
* terminal.
*
* @param line the input line
* @return the number of lines that the line will cover on the terminal
*/
def lineCount(line: String): Int = {
val width = getWidth
val lines = EscHelpers.removeEscapeSequences(line).split('\n')
def count(l: String): Int = {
val len = l.length
if (width > 0 && len > 0) (len - 1 + width) / width else 0
}
lines.tail.foldLeft(lines.headOption.fold(0)(count))(_ + count(_))
}
/**
* Returns true if the current terminal supports ansi characters.
*
@ -101,6 +136,17 @@ object Terminal {
}
}
/**
*
* @param f the thunk to run
* @tparam T the result type of the thunk
* @return the result of the thunk
*/
private[sbt] def withStreams[T](f: => T): T =
if (System.getProperty("sbt.io.virtual", "true") == "true") {
withOut(withIn(f))
} else f
/**
* Runs a thunk ensuring that the terminal is in canonical mode:
* [[https://www.gnu.org/software/libc/manual/html_node/Canonical-or-Not.html Canonical or Not]].
@ -139,6 +185,79 @@ object Terminal {
}
}
private[this] val originalOut = System.out
private[this] val originalIn = System.in
private[this] val currentLine = new AtomicReference(new ArrayBuffer[Byte])
private[this] val lineBuffer = new LinkedBlockingQueue[Byte]
private[this] val flushQueue = new LinkedBlockingQueue[Unit]
private[this] val writeLock = new AnyRef
private[this] final class WriteThread extends Thread("sbt-stdout-write-thread") {
setDaemon(true)
start()
private[this] val isStopped = new AtomicBoolean(false)
def close(): Unit = {
isStopped.set(true)
flushQueue.put(())
()
}
@tailrec override def run(): Unit = {
try {
flushQueue.take()
val bytes = new java.util.ArrayList[Byte]
writeLock.synchronized {
lineBuffer.drainTo(bytes)
import scala.collection.JavaConverters._
val remaining = bytes.asScala.foldLeft(new ArrayBuffer[Byte]) { (buf, i) =>
if (i == 10) {
ProgressState.addBytes(buf)
ProgressState.clearBytes()
buf.foreach(b => originalOut.write(b & 0xFF))
ProgressState.reprint(originalOut)
currentLine.set(new ArrayBuffer[Byte])
new ArrayBuffer[Byte]
} else buf += i
}
if (remaining.nonEmpty) {
currentLine.get ++= remaining
originalOut.write(remaining.toArray)
}
originalOut.flush()
}
} catch { case _: InterruptedException => isStopped.set(true) }
if (!isStopped.get) run()
}
}
private[this] def withOut[T](f: => T): T = {
val thread = new WriteThread
try {
System.setOut(SystemPrintStream)
scala.Console.withOut(SystemPrintStream)(f)
} finally {
thread.close()
System.setOut(originalOut)
}
}
private[this] def withIn[T](f: => T): T =
try {
System.setIn(Terminal.wrappedSystemIn)
scala.Console.withIn(Terminal.wrappedSystemIn)(f)
} finally System.setIn(originalIn)
private[sbt] def withPrintStream[T](f: PrintStream => T): T = writeLock.synchronized {
f(originalOut)
}
private object SystemOutputStream extends OutputStream {
override def write(b: Int): Unit = writeLock.synchronized(lineBuffer.put(b.toByte))
override def write(b: Array[Byte]): Unit = writeLock.synchronized(b.foreach(lineBuffer.put))
override def write(b: Array[Byte], off: Int, len: Int): Unit = writeLock.synchronized {
val lo = math.max(0, off)
val hi = math.min(math.max(off + len, 0), b.length)
(lo until hi).foreach(i => lineBuffer.put(b(i)))
}
def write(s: String): Unit = s.getBytes.foreach(lineBuffer.put)
override def flush(): Unit = writeLock.synchronized(flushQueue.put(()))
}
private object SystemPrintStream extends PrintStream(SystemOutputStream, true)
private[this] object WrappedSystemIn extends InputStream {
private[this] val in = terminal.wrapInIfNeeded(System.in)
override def available(): Int = if (attached.get) in.available else 0

View File

@ -260,10 +260,7 @@ object EvaluateTask {
ps.reset()
ConsoleAppender.setShowProgress(true)
val appender = MainAppender.defaultScreen(StandardMain.console)
appender match {
case c: ConsoleAppender => c.setProgressState(ps)
case _ =>
}
ProgressState.set(ps)
val log = LogManager.progressLogger(appender)
Some(new TaskProgress(log))
case _ => None

View File

@ -50,21 +50,20 @@ private[sbt] object xMain {
// if we detect -Dsbt.client=true or -client, run thin client.
val clientModByEnv = SysProp.client
val userCommands = configuration.arguments.map(_.trim)
if (clientModByEnv || (userCommands.exists { cmd =>
(cmd == DashClient) || (cmd == DashDashClient)
})) {
val args = userCommands.toList filterNot { cmd =>
(cmd == DashClient) || (cmd == DashDashClient)
val isClient: String => Boolean = cmd => (cmd == DashClient) || (cmd == DashDashClient)
Terminal.withStreams {
if (clientModByEnv || userCommands.exists(isClient)) {
val args = userCommands.toList.filterNot(isClient)
NetworkClient.run(configuration, args)
Exit(0)
} else {
val state = StandardMain.initialState(
configuration,
Seq(defaults, early),
runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil
)
StandardMain.runManaged(state)
}
NetworkClient.run(configuration, args)
Exit(0)
} else {
val state = StandardMain.initialState(
configuration,
Seq(defaults, early),
runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil
)
StandardMain.runManaged(state)
}
} finally {
ShutdownHooks.close()

View File

@ -9,6 +9,7 @@ package sbt
package internal
import java.io.PrintWriter
import Def.ScopedKey
import Scope.GlobalScope
import Keys.{ logLevel, logManager, persistLogLevel, persistTraceLevel, sLog, traceLevel }
@ -16,13 +17,14 @@ import sbt.internal.util.{
AttributeKey,
ConsoleAppender,
ConsoleOut,
MainAppender,
ManagedLogger,
ProgressState,
Settings,
SuppressedTraceContext,
MainAppender
SuppressedTraceContext
}
import MainAppender._
import sbt.util.{ Level, Logger, LogExchange }
import sbt.internal.util.ManagedLogger
import sbt.util.{ Level, LogExchange, Logger }
import org.apache.logging.log4j.core.Appender
sealed abstract class LogManager {
@ -142,10 +144,7 @@ object LogManager {
val extraBacked = state.globalLogging.backed :: relay :: Nil
val ps = Project.extract(state).get(sbt.Keys.progressState in ThisBuild)
val consoleOpt = consoleLocally(state, console)
consoleOpt foreach {
case a: ConsoleAppender => ps.foreach(a.setProgressState)
case _ =>
}
ps.foreach(ProgressState.set)
val config = MainAppender.MainAppenderConfig(
consoleOpt,
backed,