Add a new CommandProgress API.

In addition to ExecuteProgress, this new interface allows builds and plugins to receive events when commands start and finish, including the State before and after each command. It also makes cancellation visible to clients by making the Cancelled type top-level.
This commit is contained in:
Iulian Dragos 2023-07-07 10:59:19 +02:00
parent e3b7870b2d
commit 82ae293e3f
No known key found for this signature in database
GPG Key ID: A38C8E571FA4621E
6 changed files with 102 additions and 27 deletions

View File

@ -184,12 +184,13 @@ object Command {
}
)
def process(command: String, state: State): State = {
def process(command: String, state: State, onParseError: String => Unit = _ => ()): State = {
(if (command.contains(";")) parse(command, state.combinedParser)
else parse(command, state.nonMultiParser)) match {
case Right(s) => s() // apply command. command side effects happen here
case Left(errMsg) =>
state.log error errMsg
onParseError(errMsg)
state.fail
}
}

View File

@ -0,0 +1,27 @@
package sbt
/**
* Tracks command execution progress. In addition to ExecuteProgress, this interface
* adds command start and end events, and gives access to the sbt.State at the beginning
* and end of each command.
*/
trait CommandProgress extends ExecuteProgress[Task] {
/**
* Called before a command starts processing. The command has not yet been parsed.
*
* @param cmd The command string
* @param state The sbt.State before the command starts executing.
*/
def beforeCommand(cmd: String, state: State): Unit
/**
* Called after a command finished execution.
*
* @param cmd The command string.
* @param result Left in case of an error. If the command cannot be parsed, it will be
* signalled as a ParseException with a detailed message. If the command
* was cancelled by the user, as sbt.Cancelled.
*/
def afterCommand(cmd: String, result: Either[Throwable, State]): Unit
}

View File

@ -341,6 +341,7 @@ object Defaults extends BuildCommon {
val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector
rs map { Keys.TaskProgress(_) }
},
commandProgress := Seq(),
// progressState is deprecated
SettingKey[Option[ProgressState]]("progressState") := None,
Previous.cache := new Previous(

View File

@ -289,7 +289,12 @@ object EvaluateTask {
extracted,
structure
)
val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++
val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ getSetting(
Keys.commandProgress,
Seq(),
extracted,
structure
) ++
(if (SysProp.taskTimings)
new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil
else Nil)

View File

@ -600,6 +600,7 @@ object Keys {
def apply(progress: ExecuteProgress[Task]): TaskProgress = new TaskProgress(progress)
}
private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress")
private[sbt] val currentCommandProgress = AttributeKey[Seq[CommandProgress]]("current-command-progress")
private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress")
val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.")
val superShellMaxTasks = settingKey[Int]("The max number of tasks to display in the supershell progress report")
@ -613,6 +614,7 @@ object Keys {
private[sbt] val postProgressReports = settingKey[Unit]("Internally used to modify logger.").withRank(DTask)
@deprecated("No longer used", "1.3.0")
private[sbt] val executeProgress = settingKey[State => TaskProgress]("Experimental task execution listener.").withRank(DTask)
val commandProgress = settingKey[Seq[CommandProgress]]("Command progress listeners receive events when commands start and end, in addition to task progress events.")
val lintUnused = inputKey[Unit]("Check for keys unused by other settings and tasks.")
val lintIncludeFilter = settingKey[String => Boolean]("Filters key names that should be included in the lint check.")
val lintExcludeFilter = settingKey[String => Boolean]("Filters key names that should be excluded in the lint check.")

View File

@ -11,17 +11,16 @@ package sbt
import java.io.PrintWriter
import java.util.concurrent.RejectedExecutionException
import java.util.Properties
import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix }
import sbt.BasicCommandStrings.{StashOnFailure, networkExecPrefix}
import sbt.internal.ShutdownHooks
import sbt.internal.langserver.ErrorCodes
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey
import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal }
import sbt.internal.{ ShutdownHooks, TaskProgress }
import sbt.io.{ IO, Using }
import sbt.internal.util.{AttributeKey, ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal}
import sbt.internal.{ShutdownHooks, TaskProgress}
import sbt.io.{IO, Using}
import sbt.protocol._
import sbt.util.{ Logger, LoggerContext }
import sbt.util.{Logger, LoggerContext}
import scala.annotation.tailrec
import scala.concurrent.duration._
@ -29,6 +28,8 @@ import scala.util.control.NonFatal
import sbt.internal.FastTrackCommands
import sbt.internal.SysProp
import java.text.ParseException
object MainLoop {
/** Entry point to run the remaining commands in State with managed global logging.*/
@ -212,16 +213,29 @@ object MainLoop {
)
try {
def process(): State = {
val progressState = state.get(sbt.Keys.currentTaskProgress) match {
case Some(_) => state
case _ =>
if (state.get(Keys.stateBuildStructure).isDefined) {
val extracted = Project.extract(state)
val progress = EvaluateTask.executeProgress(extracted, extracted.structure, state)
state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress))
} else state
def getOrSet[T](state: State, key: AttributeKey[T], value: Extracted => T): State = {
state.get(key) match {
case Some(_) => state
case _ =>
if (state.get(Keys.stateBuildStructure).isDefined) {
val extracted = Project.extract(state)
state.put(key, value(extracted))
} else state
}
}
exchange.setState(progressState)
val progressState = getOrSet(
state,
sbt.Keys.currentTaskProgress,
extracted =>
new Keys.TaskProgress(
EvaluateTask.executeProgress(extracted, extracted.structure, state)
)
)
val cmdProgressState =
getOrSet(progressState, sbt.Keys.currentCommandProgress, _.get(Keys.commandProgress))
exchange.setState(cmdProgressState)
exchange.setExec(Some(exec))
val (restoreTerminal, termState) = channelName.flatMap(exchange.channelForName) match {
case Some(c) =>
@ -231,9 +245,13 @@ object MainLoop {
(() => {
ITerminal.set(prevTerminal)
c.terminal.flush()
}) -> progressState.put(Keys.terminalKey, Terminal(c.terminal))
case _ => (() => ()) -> progressState.put(Keys.terminalKey, Terminal(ITerminal.get))
}) -> cmdProgressState.put(Keys.terminalKey, Terminal(c.terminal))
case _ => (() => ()) -> cmdProgressState.put(Keys.terminalKey, Terminal(ITerminal.get))
}
val currentCmdProgress =
cmdProgressState.get(sbt.Keys.currentCommandProgress).getOrElse(Nil)
currentCmdProgress.foreach(_.beforeCommand(exec.commandLine, progressState))
/*
* FastTrackCommands.evaluate can be significantly faster than Command.process because
* it avoids an expensive parsing step for internal commands that are easy to parse.
@ -241,16 +259,29 @@ object MainLoop {
* but slower.
*/
val newState = try {
FastTrackCommands
var errorMsg: Option[String] = None
val res = FastTrackCommands
.evaluate(termState, exec.commandLine)
.getOrElse(Command.process(exec.commandLine, termState))
.getOrElse(Command.process(exec.commandLine, termState, m => errorMsg = Some(m)))
errorMsg match {
case Some(msg) =>
currentCmdProgress.foreach(
_.afterCommand(exec.commandLine, Left(new ParseException(msg, 0)))
)
case None => currentCmdProgress.foreach(_.afterCommand(exec.commandLine, Right(res)))
}
res
} catch {
case _: RejectedExecutionException =>
// No stack trace since this is just to notify the user which command they cancelled
object Cancelled extends Throwable(exec.commandLine, null, true, false) {
override def toString: String = s"Cancelled: ${exec.commandLine}"
}
throw Cancelled
val cancelled = new Cancelled(exec.commandLine)
currentCmdProgress
.foreach(_.afterCommand(exec.commandLine, Left(cancelled)))
throw cancelled
case e: Throwable =>
currentCmdProgress
.foreach(_.afterCommand(exec.commandLine, Left(e)))
throw e
} finally {
// Flush the terminal output after command evaluation to ensure that all output
// is displayed in the thin client before we report the command status. Also
@ -270,7 +301,10 @@ object MainLoop {
}
exchange.setExec(None)
newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop())
newState.remove(sbt.Keys.currentTaskProgress).remove(Keys.terminalKey)
newState
.remove(sbt.Keys.currentTaskProgress)
.remove(Keys.terminalKey)
.remove(Keys.currentCommandProgress)
}
state.get(CheckBuildSourcesKey) match {
case Some(cbs) =>
@ -341,3 +375,8 @@ object MainLoop {
ExitCode(ErrorCodes.UnknownError)
} else ExitCode.Success
}
// No stack trace since this is just to notify the user which command they cancelled
class Cancelled(cmdLine: String) extends Throwable(cmdLine, null, true, false) {
override def toString: String = s"Cancelled: $cmdLine"
}