Route TaskProgress through CommandExchange

Rather than going through the console appender logging to make
TaskProgress work, we can instead use the CommandExchange. This will be
useful in future commits where there are multiple terminals that all
need to receive progress. By organizing the TaskProgress this way, we
can store a separate progress state for each terminal and update the
progress for all of the active terminals. We also can set the current
running command in command exchange which will be useful in future
commits to show what command is currently running.

This commit also reworks TaskProgress to always kill its thread when
there are no active tasks. It will start a new thread as soon as there
is another active task.
This commit is contained in:
Ethan Atkins 2020-06-24 13:13:02 -07:00
parent 034b9690c1
commit 120e6eb63d
9 changed files with 86 additions and 74 deletions

View File

@ -10,22 +10,24 @@ final class ProgressEvent private (
val items: Vector[sbt.internal.util.ProgressItem],
val lastTaskCount: Option[Int],
channelName: Option[String],
execId: Option[String]) extends sbt.internal.util.AbstractEntry(channelName, execId) with Serializable {
execId: Option[String],
val command: Option[String],
val skipIfActive: Option[Boolean]) extends sbt.internal.util.AbstractEntry(channelName, execId) with Serializable {
private def this(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String]) = this(level, items, lastTaskCount, channelName, execId, None, None)
override def equals(o: Any): Boolean = o match {
case x: ProgressEvent => (this.level == x.level) && (this.items == x.items) && (this.lastTaskCount == x.lastTaskCount) && (this.channelName == x.channelName) && (this.execId == x.execId)
case x: ProgressEvent => (this.level == x.level) && (this.items == x.items) && (this.lastTaskCount == x.lastTaskCount) && (this.channelName == x.channelName) && (this.execId == x.execId) && (this.command == x.command) && (this.skipIfActive == x.skipIfActive)
case _ => false
}
override def hashCode: Int = {
37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.util.ProgressEvent".##) + level.##) + items.##) + lastTaskCount.##) + channelName.##) + execId.##)
37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.util.ProgressEvent".##) + level.##) + items.##) + lastTaskCount.##) + channelName.##) + execId.##) + command.##) + skipIfActive.##)
}
override def toString: String = {
"ProgressEvent(" + level + ", " + items + ", " + lastTaskCount + ", " + channelName + ", " + execId + ")"
"ProgressEvent(" + level + ", " + items + ", " + lastTaskCount + ", " + channelName + ", " + execId + ", " + command + ", " + skipIfActive + ")"
}
private[this] def copy(level: String = level, items: Vector[sbt.internal.util.ProgressItem] = items, lastTaskCount: Option[Int] = lastTaskCount, channelName: Option[String] = channelName, execId: Option[String] = execId): ProgressEvent = {
new ProgressEvent(level, items, lastTaskCount, channelName, execId)
private[this] def copy(level: String = level, items: Vector[sbt.internal.util.ProgressItem] = items, lastTaskCount: Option[Int] = lastTaskCount, channelName: Option[String] = channelName, execId: Option[String] = execId, command: Option[String] = command, skipIfActive: Option[Boolean] = skipIfActive): ProgressEvent = {
new ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive)
}
def withLevel(level: String): ProgressEvent = {
copy(level = level)
@ -51,9 +53,23 @@ final class ProgressEvent private (
def withExecId(execId: String): ProgressEvent = {
copy(execId = Option(execId))
}
def withCommand(command: Option[String]): ProgressEvent = {
copy(command = command)
}
def withCommand(command: String): ProgressEvent = {
copy(command = Option(command))
}
def withSkipIfActive(skipIfActive: Option[Boolean]): ProgressEvent = {
copy(skipIfActive = skipIfActive)
}
def withSkipIfActive(skipIfActive: Boolean): ProgressEvent = {
copy(skipIfActive = Option(skipIfActive))
}
}
object ProgressEvent {
def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String]): ProgressEvent = new ProgressEvent(level, items, lastTaskCount, channelName, execId)
def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Int, channelName: String, execId: String): ProgressEvent = new ProgressEvent(level, items, Option(lastTaskCount), Option(channelName), Option(execId))
def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String], command: Option[String], skipIfActive: Option[Boolean]): ProgressEvent = new ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive)
def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Int, channelName: String, execId: String, command: String, skipIfActive: Boolean): ProgressEvent = new ProgressEvent(level, items, Option(lastTaskCount), Option(channelName), Option(execId), Option(command), Option(skipIfActive))
}

View File

@ -16,8 +16,10 @@ implicit lazy val ProgressEventFormat: JsonFormat[sbt.internal.util.ProgressEven
val lastTaskCount = unbuilder.readField[Option[Int]]("lastTaskCount")
val channelName = unbuilder.readField[Option[String]]("channelName")
val execId = unbuilder.readField[Option[String]]("execId")
val command = unbuilder.readField[Option[String]]("command")
val skipIfActive = unbuilder.readField[Option[Boolean]]("skipIfActive")
unbuilder.endObject()
sbt.internal.util.ProgressEvent(level, items, lastTaskCount, channelName, execId)
sbt.internal.util.ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive)
case None =>
deserializationError("Expected JsObject but found None")
}
@ -29,6 +31,8 @@ implicit lazy val ProgressEventFormat: JsonFormat[sbt.internal.util.ProgressEven
builder.addField("lastTaskCount", obj.lastTaskCount)
builder.addField("channelName", obj.channelName)
builder.addField("execId", obj.execId)
builder.addField("command", obj.command)
builder.addField("skipIfActive", obj.skipIfActive)
builder.endObject()
}
}

View File

@ -29,6 +29,8 @@ type ProgressEvent implements sbt.internal.util.AbstractEntry {
lastTaskCount: Int
channelName: String
execId: String
command: String @since("1.4.0")
skipIfActive: Boolean @since("1.4.0")
}
## used by super shell

View File

@ -483,12 +483,8 @@ class ConsoleAppender private[ConsoleAppender] (
def appendEvent(oe: ObjectEvent[_]): Unit = {
val contentType = oe.contentType
contentType match {
case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent])
case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent])
case "sbt.internal.util.ProgressEvent" =>
oe.message match {
case pe: ProgressEvent => ProgressState.updateProgressState(pe)
case _ =>
}
case _ =>
LogExchange.stringCodec[AnyRef](contentType) match {
case Some(codec) if contentType == "sbt.internal.util.SuccessEvent" =>
@ -570,7 +566,7 @@ private[sbt] object ProgressState {
* at the info or greater level, we can decrement the padding because the console
* line will have filled in the blank line.
*/
private[util] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps =>
private[sbt] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps =>
progressState.get match {
case null =>
case state =>

View File

@ -255,17 +255,7 @@ object EvaluateTask {
extracted,
structure
)
val progressReporter = extracted.getOpt(progressState in ThisBuild).flatMap {
case Some(ps) =>
ps.reset()
ConsoleAppender.setShowProgress(true)
val appender = MainAppender.defaultScreen(StandardMain.console)
ProgressState.set(ps)
val log = LogManager.progressLogger(appender)
Some(new TaskProgress(log))
case _ => None
}
val reporters = maker.map(_.progress) ++ progressReporter ++
val reporters = maker.map(_.progress) ++ Some(new TaskProgress) ++
(if (SysProp.taskTimings)
new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil
else Nil)

View File

@ -195,6 +195,7 @@ object MainLoop {
state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress))
} else state
}
StandardMain.exchange.setExec(Some(exec))
val newState = Command.process(exec.commandLine, progressState)
val doneEvent = ExecStatusEvent(
"Done",
@ -204,6 +205,7 @@ object MainLoop {
exitCode(newState, state),
)
StandardMain.exchange.respondStatus(doneEvent)
StandardMain.exchange.setExec(None)
newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop())
newState.remove(sbt.Keys.currentTaskProgress)
}

View File

@ -6,8 +6,8 @@
*/
package sbt
package internal
package internal
import java.io.IOException
import java.net.Socket
import java.util.concurrent.{ ConcurrentLinkedQueue, LinkedBlockingQueue, TimeUnit }
@ -17,7 +17,7 @@ import sbt.BasicKeys._
import sbt.nio.Watch.NullLogger
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.server._
import sbt.internal.util.{ ConsoleOut, MainAppender, Terminal }
import sbt.internal.util.{ ConsoleOut, MainAppender, ProgressEvent, ProgressState, Terminal }
import sbt.io.syntax._
import sbt.io.{ Hash, IO }
import sbt.protocol.{ ExecStatusEvent, LogEvent }
@ -48,6 +48,7 @@ private[sbt] final class CommandExchange {
private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel]
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
private[this] val activePrompt = new AtomicBoolean(false)
private[this] val currentExecRef = new AtomicReference[Exec]
def channels: List[CommandChannel] = channelBuffer.toList
private[this] def removeChannel(channel: CommandChannel): Unit = {
@ -112,6 +113,8 @@ private[sbt] final class CommandExchange {
private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}"
private[sbt] def currentExec = Option(currentExecRef.get)
/**
* Check if a server instance is running already, and start one if it isn't.
*/
@ -271,6 +274,8 @@ private[sbt] final class CommandExchange {
}
}
private[sbt] def setExec(exec: Option[Exec]): Unit = currentExecRef.set(exec.orNull)
def prompt(event: ConsolePromptEvent): Unit = {
activePrompt.set(Terminal.systemInIsAttached)
channels
@ -296,4 +301,14 @@ private[sbt] final class CommandExchange {
}
private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false)
private[sbt] def updateProgress(pe: ProgressEvent): Unit = {
val newPE = currentExec match {
case Some(e) =>
pe.withCommand(currentExec.map(_.commandLine))
.withExecId(currentExec.flatMap(_.execId))
.withChannelName(currentExec.flatMap(_.source.map(_.channelName)))
case _ => pe
}
ProgressState.updateProgressState(newPE)
}
}

View File

@ -10,22 +10,13 @@ package internal
import java.io.PrintWriter
import Def.ScopedKey
import Scope.GlobalScope
import Keys.{ logLevel, logManager, persistLogLevel, persistTraceLevel, sLog, traceLevel }
import sbt.internal.util.{
AttributeKey,
ConsoleAppender,
ConsoleOut,
MainAppender,
ManagedLogger,
ProgressState,
Settings,
SuppressedTraceContext
}
import MainAppender._
import sbt.util.{ Level, LogExchange, Logger }
import org.apache.logging.log4j.core.Appender
import sbt.Def.ScopedKey
import sbt.Keys._
import sbt.Scope.GlobalScope
import sbt.internal.util.MainAppender._
import sbt.internal.util._
import sbt.util.{ Level, LogExchange, Logger }
sealed abstract class LogManager {
def apply(
@ -142,9 +133,7 @@ object LogManager {
val screenTrace = getOr(traceLevel.key, data, scope, state, defaultTraceLevel(state))
val backingTrace = getOr(persistTraceLevel.key, data, scope, state, Int.MaxValue)
val extraBacked = state.globalLogging.backed :: relay :: Nil
val ps = Project.extract(state).get(sbt.Keys.progressState in ThisBuild)
val consoleOpt = consoleLocally(state, console)
ps.foreach(ProgressState.set)
val config = MainAppender.MainAppenderConfig(
consoleOpt,
backed,
@ -164,7 +153,6 @@ object LogManager {
x.source match {
// TODO: Fix this stringliness
case Some(x: CommandSource) if x.channelName == "console0" => Option(console)
case Some(_: CommandSource) => None
case _ => Option(console)
}
case _ => Option(console)
@ -254,7 +242,8 @@ object LogManager {
s1
}
def progressLogger(appender: Appender): ManagedLogger = {
@deprecated("No longer used.", "1.4.0")
private[sbt] def progressLogger(appender: Appender): ManagedLogger = {
val log = LogExchange.logger("progress", None, None)
LogExchange.unbindLoggerAppenders("progress")
LogExchange.bindLoggerAppenders(

View File

@ -12,7 +12,6 @@ import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger, AtomicReferen
import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit }
import sbt.internal.util._
import sbt.util.Level
import scala.annotation.tailrec
import scala.concurrent.duration._
@ -20,33 +19,37 @@ import scala.concurrent.duration._
/**
* implements task progress display on the shell.
*/
private[sbt] final class TaskProgress(log: ManagedLogger)
private[sbt] final class TaskProgress
extends AbstractTaskExecuteProgress
with ExecuteProgress[Task] {
@deprecated("Use the constructor taking an ExecID.", "1.4.0")
def this(log: ManagedLogger) = this()
private[this] val lastTaskCount = new AtomicInteger(0)
private[this] val currentProgressThread = new AtomicReference[Option[ProgressThread]](None)
private[this] val sleepDuration = SysProp.supershellSleep.millis
private[this] val threshold = 10.millis
private[this] val tasks = new LinkedBlockingQueue[Task[_]]
private[this] final class ProgressThread
extends Thread("task-progress-report-thread")
with AutoCloseable {
private[this] val isClosed = new AtomicBoolean(false)
private[this] val firstTime = new AtomicBoolean(true)
private[this] val tasks = new LinkedBlockingQueue[Task[_]]
private[this] val hasReported = new AtomicBoolean(false)
private[this] def doReport(): Unit = { hasReported.set(true); report() }
setDaemon(true)
start()
@tailrec override def run(): Unit = {
if (!isClosed.get()) {
if (!isClosed.get() && (!hasReported.get || active.nonEmpty)) {
try {
report()
if (activeExceedingThreshold.nonEmpty) doReport()
val duration =
if (firstTime.compareAndSet(true, activeExceedingThreshold.nonEmpty)) threshold
if (firstTime.compareAndSet(true, activeExceedingThreshold.isEmpty)) threshold
else sleepDuration
val limit = duration.fromNow
while (Deadline.now < limit) {
var task = tasks.poll((limit - Deadline.now).toMillis, TimeUnit.MILLISECONDS)
while (task != null) {
if (containsSkipTasks(Vector(task)) || lastTaskCount.get == 0) report()
if (containsSkipTasks(Vector(task)) || lastTaskCount.get == 0) doReport()
task = tasks.poll
}
}
@ -54,7 +57,8 @@ private[sbt] final class TaskProgress(log: ManagedLogger)
case _: InterruptedException =>
isClosed.set(true)
// One last report after close in case the last one hadn't gone through yet.
report()
doReport()
}
run()
}
@ -65,21 +69,22 @@ private[sbt] final class TaskProgress(log: ManagedLogger)
override def close(): Unit = {
isClosed.set(true)
interrupt()
report()
appendProgress(ProgressEvent("Info", Vector(), None, None, None))
()
}
}
override def initial(): Unit = ()
override def beforeWork(task: Task[_]): Unit = {
maybeStartThread()
super.beforeWork(task)
currentProgressThread.get match {
case Some(t) => t.addTask(task)
case _ => maybeStartThread()
}
tasks.put(task)
}
override def afterReady(task: Task[_]): Unit = ()
override def afterReady(task: Task[_]): Unit = maybeStartThread()
override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = ()
override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = maybeStartThread()
override def stop(): Unit = currentProgressThread.synchronized {
currentProgressThread.getAndSet(None).foreach(_.close())
@ -113,10 +118,8 @@ private[sbt] final class TaskProgress(log: ManagedLogger)
case _ =>
}
}
private[this] def appendProgress(event: ProgressEvent): Unit = {
import sbt.internal.util.codec.JsonProtocol._
log.logEvent(Level.Info, event)
}
private[this] def appendProgress(event: ProgressEvent): Unit =
StandardMain.exchange.updateProgress(event)
private[this] def active: Vector[Task[_]] = activeTasks.toVector.filterNot(Def.isDummy)
private[this] def activeExceedingThreshold: Vector[(Task[_], Long)] = active.flatMap { task =>
val elapsed = timings.get(task).currentElapsedMicros
@ -133,18 +136,13 @@ private[sbt] final class TaskProgress(log: ManagedLogger)
.sortBy(_.elapsedMicros),
Some(ltc),
None,
None
None,
None,
Some(containsSkipTasks(active))
)
if (active.nonEmpty) maybeStartThread()
if (containsSkipTasks(active)) {
if (ltc > 0) {
lastTaskCount.set(0)
appendProgress(event(Vector.empty))
}
} else {
lastTaskCount.set(currentTasksCount)
appendProgress(event(currentTasks))
}
lastTaskCount.set(currentTasksCount)
appendProgress(event(currentTasks))
}
private[this] def containsSkipTasks(tasks: Vector[Task[_]]): Boolean = {