From 120e6eb63dfad904b16cdeb4abf240732b23c0c1 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 24 Jun 2020 13:13:02 -0700 Subject: [PATCH] Route TaskProgress through CommandExchange Rather than going through the console appender logging to make TaskProgress work, we can instead use the CommandExchange. This will be useful in future commits where there are multiple terminals that all need to receive progress. By organizing the TaskProgress this way, we can store a separate progress state for each terminal and update the progress for all of the active terminals. We also can set the current running command in command exchange which will be useful in future commits to show what command is currently running. This commit also reworks TaskProgress to always kill its thread when there are no active tasks. It will start a new thread as soon as there is another active task. --- .../sbt/internal/util/ProgressEvent.scala | 30 ++++++++--- .../util/codec/ProgressEventFormats.scala | 6 ++- .../src/main/contraband/logging.contra | 2 + .../sbt/internal/util/ConsoleAppender.scala | 8 +-- main/src/main/scala/sbt/EvaluateTask.scala | 12 +---- main/src/main/scala/sbt/MainLoop.scala | 2 + .../scala/sbt/internal/CommandExchange.scala | 19 ++++++- .../main/scala/sbt/internal/LogManager.scala | 27 +++------- .../scala/sbt/internal/TaskProgress.scala | 54 +++++++++---------- 9 files changed, 86 insertions(+), 74 deletions(-) diff --git a/internal/util-logging/src/main/contraband-scala/sbt/internal/util/ProgressEvent.scala b/internal/util-logging/src/main/contraband-scala/sbt/internal/util/ProgressEvent.scala index b2b2ffc30..0886bbb48 100644 --- a/internal/util-logging/src/main/contraband-scala/sbt/internal/util/ProgressEvent.scala +++ b/internal/util-logging/src/main/contraband-scala/sbt/internal/util/ProgressEvent.scala @@ -10,22 +10,24 @@ final class ProgressEvent private ( val items: Vector[sbt.internal.util.ProgressItem], val lastTaskCount: Option[Int], channelName: Option[String], - execId: Option[String]) extends sbt.internal.util.AbstractEntry(channelName, execId) with Serializable { - + execId: Option[String], + val command: Option[String], + val skipIfActive: Option[Boolean]) extends sbt.internal.util.AbstractEntry(channelName, execId) with Serializable { + private def this(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String]) = this(level, items, lastTaskCount, channelName, execId, None, None) override def equals(o: Any): Boolean = o match { - case x: ProgressEvent => (this.level == x.level) && (this.items == x.items) && (this.lastTaskCount == x.lastTaskCount) && (this.channelName == x.channelName) && (this.execId == x.execId) + case x: ProgressEvent => (this.level == x.level) && (this.items == x.items) && (this.lastTaskCount == x.lastTaskCount) && (this.channelName == x.channelName) && (this.execId == x.execId) && (this.command == x.command) && (this.skipIfActive == x.skipIfActive) case _ => false } override def hashCode: Int = { - 37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.util.ProgressEvent".##) + level.##) + items.##) + lastTaskCount.##) + channelName.##) + execId.##) + 37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.util.ProgressEvent".##) + level.##) + items.##) + lastTaskCount.##) + channelName.##) + execId.##) + command.##) + skipIfActive.##) } override def toString: String = { - "ProgressEvent(" + level + ", " + items + ", " + lastTaskCount + ", " + channelName + ", " + execId + ")" + "ProgressEvent(" + level + ", " + items + ", " + lastTaskCount + ", " + channelName + ", " + execId + ", " + command + ", " + skipIfActive + ")" } - private[this] def copy(level: String = level, items: Vector[sbt.internal.util.ProgressItem] = items, lastTaskCount: Option[Int] = lastTaskCount, channelName: Option[String] = channelName, execId: Option[String] = execId): ProgressEvent = { - new ProgressEvent(level, items, lastTaskCount, channelName, execId) + private[this] def copy(level: String = level, items: Vector[sbt.internal.util.ProgressItem] = items, lastTaskCount: Option[Int] = lastTaskCount, channelName: Option[String] = channelName, execId: Option[String] = execId, command: Option[String] = command, skipIfActive: Option[Boolean] = skipIfActive): ProgressEvent = { + new ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive) } def withLevel(level: String): ProgressEvent = { copy(level = level) @@ -51,9 +53,23 @@ final class ProgressEvent private ( def withExecId(execId: String): ProgressEvent = { copy(execId = Option(execId)) } + def withCommand(command: Option[String]): ProgressEvent = { + copy(command = command) + } + def withCommand(command: String): ProgressEvent = { + copy(command = Option(command)) + } + def withSkipIfActive(skipIfActive: Option[Boolean]): ProgressEvent = { + copy(skipIfActive = skipIfActive) + } + def withSkipIfActive(skipIfActive: Boolean): ProgressEvent = { + copy(skipIfActive = Option(skipIfActive)) + } } object ProgressEvent { def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String]): ProgressEvent = new ProgressEvent(level, items, lastTaskCount, channelName, execId) def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Int, channelName: String, execId: String): ProgressEvent = new ProgressEvent(level, items, Option(lastTaskCount), Option(channelName), Option(execId)) + def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Option[Int], channelName: Option[String], execId: Option[String], command: Option[String], skipIfActive: Option[Boolean]): ProgressEvent = new ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive) + def apply(level: String, items: Vector[sbt.internal.util.ProgressItem], lastTaskCount: Int, channelName: String, execId: String, command: String, skipIfActive: Boolean): ProgressEvent = new ProgressEvent(level, items, Option(lastTaskCount), Option(channelName), Option(execId), Option(command), Option(skipIfActive)) } diff --git a/internal/util-logging/src/main/contraband-scala/sbt/internal/util/codec/ProgressEventFormats.scala b/internal/util-logging/src/main/contraband-scala/sbt/internal/util/codec/ProgressEventFormats.scala index a6679d93a..ea9c7b825 100644 --- a/internal/util-logging/src/main/contraband-scala/sbt/internal/util/codec/ProgressEventFormats.scala +++ b/internal/util-logging/src/main/contraband-scala/sbt/internal/util/codec/ProgressEventFormats.scala @@ -16,8 +16,10 @@ implicit lazy val ProgressEventFormat: JsonFormat[sbt.internal.util.ProgressEven val lastTaskCount = unbuilder.readField[Option[Int]]("lastTaskCount") val channelName = unbuilder.readField[Option[String]]("channelName") val execId = unbuilder.readField[Option[String]]("execId") + val command = unbuilder.readField[Option[String]]("command") + val skipIfActive = unbuilder.readField[Option[Boolean]]("skipIfActive") unbuilder.endObject() - sbt.internal.util.ProgressEvent(level, items, lastTaskCount, channelName, execId) + sbt.internal.util.ProgressEvent(level, items, lastTaskCount, channelName, execId, command, skipIfActive) case None => deserializationError("Expected JsObject but found None") } @@ -29,6 +31,8 @@ implicit lazy val ProgressEventFormat: JsonFormat[sbt.internal.util.ProgressEven builder.addField("lastTaskCount", obj.lastTaskCount) builder.addField("channelName", obj.channelName) builder.addField("execId", obj.execId) + builder.addField("command", obj.command) + builder.addField("skipIfActive", obj.skipIfActive) builder.endObject() } } diff --git a/internal/util-logging/src/main/contraband/logging.contra b/internal/util-logging/src/main/contraband/logging.contra index 34fa75c24..447daa3d0 100644 --- a/internal/util-logging/src/main/contraband/logging.contra +++ b/internal/util-logging/src/main/contraband/logging.contra @@ -29,6 +29,8 @@ type ProgressEvent implements sbt.internal.util.AbstractEntry { lastTaskCount: Int channelName: String execId: String + command: String @since("1.4.0") + skipIfActive: Boolean @since("1.4.0") } ## used by super shell diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala index 807292a93..e403c9213 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ConsoleAppender.scala @@ -483,12 +483,8 @@ class ConsoleAppender private[ConsoleAppender] ( def appendEvent(oe: ObjectEvent[_]): Unit = { val contentType = oe.contentType contentType match { - case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent]) + case "sbt.internal.util.TraceEvent" => appendTraceEvent(oe.message.asInstanceOf[TraceEvent]) case "sbt.internal.util.ProgressEvent" => - oe.message match { - case pe: ProgressEvent => ProgressState.updateProgressState(pe) - case _ => - } case _ => LogExchange.stringCodec[AnyRef](contentType) match { case Some(codec) if contentType == "sbt.internal.util.SuccessEvent" => @@ -570,7 +566,7 @@ private[sbt] object ProgressState { * at the info or greater level, we can decrement the padding because the console * line will have filled in the blank line. */ - private[util] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps => + private[sbt] def updateProgressState(pe: ProgressEvent): Unit = Terminal.withPrintStream { ps => progressState.get match { case null => case state => diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index c16ef6b1c..ca78e4240 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -255,17 +255,7 @@ object EvaluateTask { extracted, structure ) - val progressReporter = extracted.getOpt(progressState in ThisBuild).flatMap { - case Some(ps) => - ps.reset() - ConsoleAppender.setShowProgress(true) - val appender = MainAppender.defaultScreen(StandardMain.console) - ProgressState.set(ps) - val log = LogManager.progressLogger(appender) - Some(new TaskProgress(log)) - case _ => None - } - val reporters = maker.map(_.progress) ++ progressReporter ++ + val reporters = maker.map(_.progress) ++ Some(new TaskProgress) ++ (if (SysProp.taskTimings) new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil else Nil) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index b9d9fa43c..41f312f0c 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -195,6 +195,7 @@ object MainLoop { state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) } else state } + StandardMain.exchange.setExec(Some(exec)) val newState = Command.process(exec.commandLine, progressState) val doneEvent = ExecStatusEvent( "Done", @@ -204,6 +205,7 @@ object MainLoop { exitCode(newState, state), ) StandardMain.exchange.respondStatus(doneEvent) + StandardMain.exchange.setExec(None) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.remove(sbt.Keys.currentTaskProgress) } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 8b7da864a..96cb3844e 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -6,8 +6,8 @@ */ package sbt -package internal +package internal import java.io.IOException import java.net.Socket import java.util.concurrent.{ ConcurrentLinkedQueue, LinkedBlockingQueue, TimeUnit } @@ -17,7 +17,7 @@ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, Terminal } +import sbt.internal.util.{ ConsoleOut, MainAppender, ProgressEvent, ProgressState, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } import sbt.protocol.{ ExecStatusEvent, LogEvent } @@ -48,6 +48,7 @@ private[sbt] final class CommandExchange { private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val nextChannelId: AtomicInteger = new AtomicInteger(0) private[this] val activePrompt = new AtomicBoolean(false) + private[this] val currentExecRef = new AtomicReference[Exec] def channels: List[CommandChannel] = channelBuffer.toList private[this] def removeChannel(channel: CommandChannel): Unit = { @@ -112,6 +113,8 @@ private[sbt] final class CommandExchange { private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" + private[sbt] def currentExec = Option(currentExecRef.get) + /** * Check if a server instance is running already, and start one if it isn't. */ @@ -271,6 +274,8 @@ private[sbt] final class CommandExchange { } } + private[sbt] def setExec(exec: Option[Exec]): Unit = currentExecRef.set(exec.orNull) + def prompt(event: ConsolePromptEvent): Unit = { activePrompt.set(Terminal.systemInIsAttached) channels @@ -296,4 +301,14 @@ private[sbt] final class CommandExchange { } private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) + private[sbt] def updateProgress(pe: ProgressEvent): Unit = { + val newPE = currentExec match { + case Some(e) => + pe.withCommand(currentExec.map(_.commandLine)) + .withExecId(currentExec.flatMap(_.execId)) + .withChannelName(currentExec.flatMap(_.source.map(_.channelName))) + case _ => pe + } + ProgressState.updateProgressState(newPE) + } } diff --git a/main/src/main/scala/sbt/internal/LogManager.scala b/main/src/main/scala/sbt/internal/LogManager.scala index 7781e5b0f..f44bd0371 100644 --- a/main/src/main/scala/sbt/internal/LogManager.scala +++ b/main/src/main/scala/sbt/internal/LogManager.scala @@ -10,22 +10,13 @@ package internal import java.io.PrintWriter -import Def.ScopedKey -import Scope.GlobalScope -import Keys.{ logLevel, logManager, persistLogLevel, persistTraceLevel, sLog, traceLevel } -import sbt.internal.util.{ - AttributeKey, - ConsoleAppender, - ConsoleOut, - MainAppender, - ManagedLogger, - ProgressState, - Settings, - SuppressedTraceContext -} -import MainAppender._ -import sbt.util.{ Level, LogExchange, Logger } import org.apache.logging.log4j.core.Appender +import sbt.Def.ScopedKey +import sbt.Keys._ +import sbt.Scope.GlobalScope +import sbt.internal.util.MainAppender._ +import sbt.internal.util._ +import sbt.util.{ Level, LogExchange, Logger } sealed abstract class LogManager { def apply( @@ -142,9 +133,7 @@ object LogManager { val screenTrace = getOr(traceLevel.key, data, scope, state, defaultTraceLevel(state)) val backingTrace = getOr(persistTraceLevel.key, data, scope, state, Int.MaxValue) val extraBacked = state.globalLogging.backed :: relay :: Nil - val ps = Project.extract(state).get(sbt.Keys.progressState in ThisBuild) val consoleOpt = consoleLocally(state, console) - ps.foreach(ProgressState.set) val config = MainAppender.MainAppenderConfig( consoleOpt, backed, @@ -164,7 +153,6 @@ object LogManager { x.source match { // TODO: Fix this stringliness case Some(x: CommandSource) if x.channelName == "console0" => Option(console) - case Some(_: CommandSource) => None case _ => Option(console) } case _ => Option(console) @@ -254,7 +242,8 @@ object LogManager { s1 } - def progressLogger(appender: Appender): ManagedLogger = { + @deprecated("No longer used.", "1.4.0") + private[sbt] def progressLogger(appender: Appender): ManagedLogger = { val log = LogExchange.logger("progress", None, None) LogExchange.unbindLoggerAppenders("progress") LogExchange.bindLoggerAppenders( diff --git a/main/src/main/scala/sbt/internal/TaskProgress.scala b/main/src/main/scala/sbt/internal/TaskProgress.scala index 1f89eaf80..e208f4a03 100644 --- a/main/src/main/scala/sbt/internal/TaskProgress.scala +++ b/main/src/main/scala/sbt/internal/TaskProgress.scala @@ -12,7 +12,6 @@ import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger, AtomicReferen import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } import sbt.internal.util._ -import sbt.util.Level import scala.annotation.tailrec import scala.concurrent.duration._ @@ -20,33 +19,37 @@ import scala.concurrent.duration._ /** * implements task progress display on the shell. */ -private[sbt] final class TaskProgress(log: ManagedLogger) +private[sbt] final class TaskProgress extends AbstractTaskExecuteProgress with ExecuteProgress[Task] { + @deprecated("Use the constructor taking an ExecID.", "1.4.0") + def this(log: ManagedLogger) = this() private[this] val lastTaskCount = new AtomicInteger(0) private[this] val currentProgressThread = new AtomicReference[Option[ProgressThread]](None) private[this] val sleepDuration = SysProp.supershellSleep.millis private[this] val threshold = 10.millis + private[this] val tasks = new LinkedBlockingQueue[Task[_]] private[this] final class ProgressThread extends Thread("task-progress-report-thread") with AutoCloseable { private[this] val isClosed = new AtomicBoolean(false) private[this] val firstTime = new AtomicBoolean(true) - private[this] val tasks = new LinkedBlockingQueue[Task[_]] + private[this] val hasReported = new AtomicBoolean(false) + private[this] def doReport(): Unit = { hasReported.set(true); report() } setDaemon(true) start() @tailrec override def run(): Unit = { - if (!isClosed.get()) { + if (!isClosed.get() && (!hasReported.get || active.nonEmpty)) { try { - report() + if (activeExceedingThreshold.nonEmpty) doReport() val duration = - if (firstTime.compareAndSet(true, activeExceedingThreshold.nonEmpty)) threshold + if (firstTime.compareAndSet(true, activeExceedingThreshold.isEmpty)) threshold else sleepDuration val limit = duration.fromNow while (Deadline.now < limit) { var task = tasks.poll((limit - Deadline.now).toMillis, TimeUnit.MILLISECONDS) while (task != null) { - if (containsSkipTasks(Vector(task)) || lastTaskCount.get == 0) report() + if (containsSkipTasks(Vector(task)) || lastTaskCount.get == 0) doReport() task = tasks.poll } } @@ -54,7 +57,8 @@ private[sbt] final class TaskProgress(log: ManagedLogger) case _: InterruptedException => isClosed.set(true) // One last report after close in case the last one hadn't gone through yet. - report() + doReport() + } run() } @@ -65,21 +69,22 @@ private[sbt] final class TaskProgress(log: ManagedLogger) override def close(): Unit = { isClosed.set(true) interrupt() + report() + appendProgress(ProgressEvent("Info", Vector(), None, None, None)) + () } } override def initial(): Unit = () override def beforeWork(task: Task[_]): Unit = { + maybeStartThread() super.beforeWork(task) - currentProgressThread.get match { - case Some(t) => t.addTask(task) - case _ => maybeStartThread() - } + tasks.put(task) } - override def afterReady(task: Task[_]): Unit = () + override def afterReady(task: Task[_]): Unit = maybeStartThread() - override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = () + override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = maybeStartThread() override def stop(): Unit = currentProgressThread.synchronized { currentProgressThread.getAndSet(None).foreach(_.close()) @@ -113,10 +118,8 @@ private[sbt] final class TaskProgress(log: ManagedLogger) case _ => } } - private[this] def appendProgress(event: ProgressEvent): Unit = { - import sbt.internal.util.codec.JsonProtocol._ - log.logEvent(Level.Info, event) - } + private[this] def appendProgress(event: ProgressEvent): Unit = + StandardMain.exchange.updateProgress(event) private[this] def active: Vector[Task[_]] = activeTasks.toVector.filterNot(Def.isDummy) private[this] def activeExceedingThreshold: Vector[(Task[_], Long)] = active.flatMap { task => val elapsed = timings.get(task).currentElapsedMicros @@ -133,18 +136,13 @@ private[sbt] final class TaskProgress(log: ManagedLogger) .sortBy(_.elapsedMicros), Some(ltc), None, - None + None, + None, + Some(containsSkipTasks(active)) ) if (active.nonEmpty) maybeStartThread() - if (containsSkipTasks(active)) { - if (ltc > 0) { - lastTaskCount.set(0) - appendProgress(event(Vector.empty)) - } - } else { - lastTaskCount.set(currentTasksCount) - appendProgress(event(currentTasks)) - } + lastTaskCount.set(currentTasksCount) + appendProgress(event(currentTasks)) } private[this] def containsSkipTasks(tasks: Vector[Task[_]]): Boolean = {