From 82ae293e3fcac7f3faa3a611fa38ccf062a216d6 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 7 Jul 2023 10:59:19 +0200 Subject: [PATCH] Add a new CommandProgress API. In addition to ExecuteProgress, this new interface allows builds and plugins to receive events when commands start and finish, including the State before and after each command. It also makes cancellation visible to clients by making the Cancelled type top-level. --- main-command/src/main/scala/sbt/Command.scala | 3 +- main/src/main/scala/sbt/CommandProgress.scala | 27 ++++++ main/src/main/scala/sbt/Defaults.scala | 1 + main/src/main/scala/sbt/EvaluateTask.scala | 7 +- main/src/main/scala/sbt/Keys.scala | 2 + main/src/main/scala/sbt/MainLoop.scala | 89 +++++++++++++------ 6 files changed, 102 insertions(+), 27 deletions(-) create mode 100644 main/src/main/scala/sbt/CommandProgress.scala diff --git a/main-command/src/main/scala/sbt/Command.scala b/main-command/src/main/scala/sbt/Command.scala index f4a1ebc6f..3ef6ae3f2 100644 --- a/main-command/src/main/scala/sbt/Command.scala +++ b/main-command/src/main/scala/sbt/Command.scala @@ -184,12 +184,13 @@ object Command { } ) - def process(command: String, state: State): State = { + def process(command: String, state: State, onParseError: String => Unit = _ => ()): State = { (if (command.contains(";")) parse(command, state.combinedParser) else parse(command, state.nonMultiParser)) match { case Right(s) => s() // apply command. command side effects happen here case Left(errMsg) => state.log error errMsg + onParseError(errMsg) state.fail } } diff --git a/main/src/main/scala/sbt/CommandProgress.scala b/main/src/main/scala/sbt/CommandProgress.scala new file mode 100644 index 000000000..5b353d972 --- /dev/null +++ b/main/src/main/scala/sbt/CommandProgress.scala @@ -0,0 +1,27 @@ +package sbt + +/** + * Tracks command execution progress. In addition to ExecuteProgress, this interface + * adds command start and end events, and gives access to the sbt.State at the beginning + * and end of each command. + */ +trait CommandProgress extends ExecuteProgress[Task] { + + /** + * Called before a command starts processing. The command has not yet been parsed. + * + * @param cmd The command string + * @param state The sbt.State before the command starts executing. + */ + def beforeCommand(cmd: String, state: State): Unit + + /** + * Called after a command finished execution. + * + * @param cmd The command string. + * @param result Left in case of an error. If the command cannot be parsed, it will be + * signalled as a ParseException with a detailed message. If the command + * was cancelled by the user, as sbt.Cancelled. + */ + def afterCommand(cmd: String, result: Either[Throwable, State]): Unit +} diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index e7a601c30..8a4a1abcb 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -341,6 +341,7 @@ object Defaults extends BuildCommon { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector rs map { Keys.TaskProgress(_) } }, + commandProgress := Seq(), // progressState is deprecated SettingKey[Option[ProgressState]]("progressState") := None, Previous.cache := new Previous( diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index 29749cfde..dfeaf9f18 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -289,7 +289,12 @@ object EvaluateTask { extracted, structure ) - val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ + val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ getSetting( + Keys.commandProgress, + Seq(), + extracted, + structure + ) ++ (if (SysProp.taskTimings) new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil else Nil) diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 0fb3c3bdc..e320f7147 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -600,6 +600,7 @@ object Keys { def apply(progress: ExecuteProgress[Task]): TaskProgress = new TaskProgress(progress) } private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress") + private[sbt] val currentCommandProgress = AttributeKey[Seq[CommandProgress]]("current-command-progress") private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress") val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.") val superShellMaxTasks = settingKey[Int]("The max number of tasks to display in the supershell progress report") @@ -613,6 +614,7 @@ object Keys { private[sbt] val postProgressReports = settingKey[Unit]("Internally used to modify logger.").withRank(DTask) @deprecated("No longer used", "1.3.0") private[sbt] val executeProgress = settingKey[State => TaskProgress]("Experimental task execution listener.").withRank(DTask) + val commandProgress = settingKey[Seq[CommandProgress]]("Command progress listeners receive events when commands start and end, in addition to task progress events.") val lintUnused = inputKey[Unit]("Check for keys unused by other settings and tasks.") val lintIncludeFilter = settingKey[String => Boolean]("Filters key names that should be included in the lint check.") val lintExcludeFilter = settingKey[String => Boolean]("Filters key names that should be excluded in the lint check.") diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 3c7a800a6..cf02ae2b4 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -11,17 +11,16 @@ package sbt import java.io.PrintWriter import java.util.concurrent.RejectedExecutionException import java.util.Properties - -import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } +import sbt.BasicCommandStrings.{StashOnFailure, networkExecPrefix} import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey -import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal } -import sbt.internal.{ ShutdownHooks, TaskProgress } -import sbt.io.{ IO, Using } +import sbt.internal.util.{AttributeKey, ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal} +import sbt.internal.{ShutdownHooks, TaskProgress} +import sbt.io.{IO, Using} import sbt.protocol._ -import sbt.util.{ Logger, LoggerContext } +import sbt.util.{Logger, LoggerContext} import scala.annotation.tailrec import scala.concurrent.duration._ @@ -29,6 +28,8 @@ import scala.util.control.NonFatal import sbt.internal.FastTrackCommands import sbt.internal.SysProp +import java.text.ParseException + object MainLoop { /** Entry point to run the remaining commands in State with managed global logging.*/ @@ -212,16 +213,29 @@ object MainLoop { ) try { def process(): State = { - val progressState = state.get(sbt.Keys.currentTaskProgress) match { - case Some(_) => state - case _ => - if (state.get(Keys.stateBuildStructure).isDefined) { - val extracted = Project.extract(state) - val progress = EvaluateTask.executeProgress(extracted, extracted.structure, state) - state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) - } else state + def getOrSet[T](state: State, key: AttributeKey[T], value: Extracted => T): State = { + state.get(key) match { + case Some(_) => state + case _ => + if (state.get(Keys.stateBuildStructure).isDefined) { + val extracted = Project.extract(state) + state.put(key, value(extracted)) + } else state + } } - exchange.setState(progressState) + val progressState = getOrSet( + state, + sbt.Keys.currentTaskProgress, + extracted => + new Keys.TaskProgress( + EvaluateTask.executeProgress(extracted, extracted.structure, state) + ) + ) + + val cmdProgressState = + getOrSet(progressState, sbt.Keys.currentCommandProgress, _.get(Keys.commandProgress)) + + exchange.setState(cmdProgressState) exchange.setExec(Some(exec)) val (restoreTerminal, termState) = channelName.flatMap(exchange.channelForName) match { case Some(c) => @@ -231,9 +245,13 @@ object MainLoop { (() => { ITerminal.set(prevTerminal) c.terminal.flush() - }) -> progressState.put(Keys.terminalKey, Terminal(c.terminal)) - case _ => (() => ()) -> progressState.put(Keys.terminalKey, Terminal(ITerminal.get)) + }) -> cmdProgressState.put(Keys.terminalKey, Terminal(c.terminal)) + case _ => (() => ()) -> cmdProgressState.put(Keys.terminalKey, Terminal(ITerminal.get)) } + + val currentCmdProgress = + cmdProgressState.get(sbt.Keys.currentCommandProgress).getOrElse(Nil) + currentCmdProgress.foreach(_.beforeCommand(exec.commandLine, progressState)) /* * FastTrackCommands.evaluate can be significantly faster than Command.process because * it avoids an expensive parsing step for internal commands that are easy to parse. @@ -241,16 +259,29 @@ object MainLoop { * but slower. */ val newState = try { - FastTrackCommands + var errorMsg: Option[String] = None + val res = FastTrackCommands .evaluate(termState, exec.commandLine) - .getOrElse(Command.process(exec.commandLine, termState)) + .getOrElse(Command.process(exec.commandLine, termState, m => errorMsg = Some(m))) + errorMsg match { + case Some(msg) => + currentCmdProgress.foreach( + _.afterCommand(exec.commandLine, Left(new ParseException(msg, 0))) + ) + case None => currentCmdProgress.foreach(_.afterCommand(exec.commandLine, Right(res))) + } + res } catch { case _: RejectedExecutionException => - // No stack trace since this is just to notify the user which command they cancelled - object Cancelled extends Throwable(exec.commandLine, null, true, false) { - override def toString: String = s"Cancelled: ${exec.commandLine}" - } - throw Cancelled + val cancelled = new Cancelled(exec.commandLine) + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(cancelled))) + throw cancelled + + case e: Throwable => + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(e))) + throw e } finally { // Flush the terminal output after command evaluation to ensure that all output // is displayed in the thin client before we report the command status. Also @@ -270,7 +301,10 @@ object MainLoop { } exchange.setExec(None) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) - newState.remove(sbt.Keys.currentTaskProgress).remove(Keys.terminalKey) + newState + .remove(sbt.Keys.currentTaskProgress) + .remove(Keys.terminalKey) + .remove(Keys.currentCommandProgress) } state.get(CheckBuildSourcesKey) match { case Some(cbs) => @@ -341,3 +375,8 @@ object MainLoop { ExitCode(ErrorCodes.UnknownError) } else ExitCode.Success } + +// No stack trace since this is just to notify the user which command they cancelled +class Cancelled(cmdLine: String) extends Throwable(cmdLine, null, true, false) { + override def toString: String = s"Cancelled: $cmdLine" +}