diff --git a/main-command/src/main/scala/sbt/Command.scala b/main-command/src/main/scala/sbt/Command.scala index f4a1ebc6f..3ef6ae3f2 100644 --- a/main-command/src/main/scala/sbt/Command.scala +++ b/main-command/src/main/scala/sbt/Command.scala @@ -184,12 +184,13 @@ object Command { } ) - def process(command: String, state: State): State = { + def process(command: String, state: State, onParseError: String => Unit = _ => ()): State = { (if (command.contains(";")) parse(command, state.combinedParser) else parse(command, state.nonMultiParser)) match { case Right(s) => s() // apply command. command side effects happen here case Left(errMsg) => state.log error errMsg + onParseError(errMsg) state.fail } } diff --git a/main/src/main/scala/sbt/CommandProgress.scala b/main/src/main/scala/sbt/CommandProgress.scala new file mode 100644 index 000000000..5b353d972 --- /dev/null +++ b/main/src/main/scala/sbt/CommandProgress.scala @@ -0,0 +1,27 @@ +package sbt + +/** + * Tracks command execution progress. In addition to ExecuteProgress, this interface + * adds command start and end events, and gives access to the sbt.State at the beginning + * and end of each command. + */ +trait CommandProgress extends ExecuteProgress[Task] { + + /** + * Called before a command starts processing. The command has not yet been parsed. + * + * @param cmd The command string + * @param state The sbt.State before the command starts executing. + */ + def beforeCommand(cmd: String, state: State): Unit + + /** + * Called after a command finished execution. + * + * @param cmd The command string. + * @param result Left in case of an error. If the command cannot be parsed, it will be + * signalled as a ParseException with a detailed message. If the command + * was cancelled by the user, as sbt.Cancelled. + */ + def afterCommand(cmd: String, result: Either[Throwable, State]): Unit +} diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index e7a601c30..8a4a1abcb 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -341,6 +341,7 @@ object Defaults extends BuildCommon { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector rs map { Keys.TaskProgress(_) } }, + commandProgress := Seq(), // progressState is deprecated SettingKey[Option[ProgressState]]("progressState") := None, Previous.cache := new Previous( diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index 29749cfde..dfeaf9f18 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -289,7 +289,12 @@ object EvaluateTask { extracted, structure ) - val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ + val reporters = maker.map(_.progress) ++ state.get(Keys.taskProgress) ++ getSetting( + Keys.commandProgress, + Seq(), + extracted, + structure + ) ++ (if (SysProp.taskTimings) new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil else Nil) diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 0fb3c3bdc..e320f7147 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -600,6 +600,7 @@ object Keys { def apply(progress: ExecuteProgress[Task]): TaskProgress = new TaskProgress(progress) } private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress") + private[sbt] val currentCommandProgress = AttributeKey[Seq[CommandProgress]]("current-command-progress") private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress") val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.") val superShellMaxTasks = settingKey[Int]("The max number of tasks to display in the supershell progress report") @@ -613,6 +614,7 @@ object Keys { private[sbt] val postProgressReports = settingKey[Unit]("Internally used to modify logger.").withRank(DTask) @deprecated("No longer used", "1.3.0") private[sbt] val executeProgress = settingKey[State => TaskProgress]("Experimental task execution listener.").withRank(DTask) + val commandProgress = settingKey[Seq[CommandProgress]]("Command progress listeners receive events when commands start and end, in addition to task progress events.") val lintUnused = inputKey[Unit]("Check for keys unused by other settings and tasks.") val lintIncludeFilter = settingKey[String => Boolean]("Filters key names that should be included in the lint check.") val lintExcludeFilter = settingKey[String => Boolean]("Filters key names that should be excluded in the lint check.") diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 3c7a800a6..cf02ae2b4 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -11,17 +11,16 @@ package sbt import java.io.PrintWriter import java.util.concurrent.RejectedExecutionException import java.util.Properties - -import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } +import sbt.BasicCommandStrings.{StashOnFailure, networkExecPrefix} import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey -import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal } -import sbt.internal.{ ShutdownHooks, TaskProgress } -import sbt.io.{ IO, Using } +import sbt.internal.util.{AttributeKey, ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal} +import sbt.internal.{ShutdownHooks, TaskProgress} +import sbt.io.{IO, Using} import sbt.protocol._ -import sbt.util.{ Logger, LoggerContext } +import sbt.util.{Logger, LoggerContext} import scala.annotation.tailrec import scala.concurrent.duration._ @@ -29,6 +28,8 @@ import scala.util.control.NonFatal import sbt.internal.FastTrackCommands import sbt.internal.SysProp +import java.text.ParseException + object MainLoop { /** Entry point to run the remaining commands in State with managed global logging.*/ @@ -212,16 +213,29 @@ object MainLoop { ) try { def process(): State = { - val progressState = state.get(sbt.Keys.currentTaskProgress) match { - case Some(_) => state - case _ => - if (state.get(Keys.stateBuildStructure).isDefined) { - val extracted = Project.extract(state) - val progress = EvaluateTask.executeProgress(extracted, extracted.structure, state) - state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) - } else state + def getOrSet[T](state: State, key: AttributeKey[T], value: Extracted => T): State = { + state.get(key) match { + case Some(_) => state + case _ => + if (state.get(Keys.stateBuildStructure).isDefined) { + val extracted = Project.extract(state) + state.put(key, value(extracted)) + } else state + } } - exchange.setState(progressState) + val progressState = getOrSet( + state, + sbt.Keys.currentTaskProgress, + extracted => + new Keys.TaskProgress( + EvaluateTask.executeProgress(extracted, extracted.structure, state) + ) + ) + + val cmdProgressState = + getOrSet(progressState, sbt.Keys.currentCommandProgress, _.get(Keys.commandProgress)) + + exchange.setState(cmdProgressState) exchange.setExec(Some(exec)) val (restoreTerminal, termState) = channelName.flatMap(exchange.channelForName) match { case Some(c) => @@ -231,9 +245,13 @@ object MainLoop { (() => { ITerminal.set(prevTerminal) c.terminal.flush() - }) -> progressState.put(Keys.terminalKey, Terminal(c.terminal)) - case _ => (() => ()) -> progressState.put(Keys.terminalKey, Terminal(ITerminal.get)) + }) -> cmdProgressState.put(Keys.terminalKey, Terminal(c.terminal)) + case _ => (() => ()) -> cmdProgressState.put(Keys.terminalKey, Terminal(ITerminal.get)) } + + val currentCmdProgress = + cmdProgressState.get(sbt.Keys.currentCommandProgress).getOrElse(Nil) + currentCmdProgress.foreach(_.beforeCommand(exec.commandLine, progressState)) /* * FastTrackCommands.evaluate can be significantly faster than Command.process because * it avoids an expensive parsing step for internal commands that are easy to parse. @@ -241,16 +259,29 @@ object MainLoop { * but slower. */ val newState = try { - FastTrackCommands + var errorMsg: Option[String] = None + val res = FastTrackCommands .evaluate(termState, exec.commandLine) - .getOrElse(Command.process(exec.commandLine, termState)) + .getOrElse(Command.process(exec.commandLine, termState, m => errorMsg = Some(m))) + errorMsg match { + case Some(msg) => + currentCmdProgress.foreach( + _.afterCommand(exec.commandLine, Left(new ParseException(msg, 0))) + ) + case None => currentCmdProgress.foreach(_.afterCommand(exec.commandLine, Right(res))) + } + res } catch { case _: RejectedExecutionException => - // No stack trace since this is just to notify the user which command they cancelled - object Cancelled extends Throwable(exec.commandLine, null, true, false) { - override def toString: String = s"Cancelled: ${exec.commandLine}" - } - throw Cancelled + val cancelled = new Cancelled(exec.commandLine) + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(cancelled))) + throw cancelled + + case e: Throwable => + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(e))) + throw e } finally { // Flush the terminal output after command evaluation to ensure that all output // is displayed in the thin client before we report the command status. Also @@ -270,7 +301,10 @@ object MainLoop { } exchange.setExec(None) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) - newState.remove(sbt.Keys.currentTaskProgress).remove(Keys.terminalKey) + newState + .remove(sbt.Keys.currentTaskProgress) + .remove(Keys.terminalKey) + .remove(Keys.currentCommandProgress) } state.get(CheckBuildSourcesKey) match { case Some(cbs) => @@ -341,3 +375,8 @@ object MainLoop { ExitCode(ErrorCodes.UnknownError) } else ExitCode.Success } + +// No stack trace since this is just to notify the user which command they cancelled +class Cancelled(cmdLine: String) extends Throwable(cmdLine, null, true, false) { + override def toString: String = s"Cancelled: $cmdLine" +}