diff --git a/build.sbt b/build.sbt index 446206ea1..a43a65b2e 100644 --- a/build.sbt +++ b/build.sbt @@ -176,6 +176,8 @@ def mimaSettingsSince(versions: Seq[String]): Seq[Def.Setting[_]] = Def settings exclude[DirectMissingMethodProblem]("sbt.PluginData.apply"), exclude[DirectMissingMethodProblem]("sbt.PluginData.copy"), exclude[DirectMissingMethodProblem]("sbt.PluginData.this"), + exclude[IncompatibleResultTypeProblem]("sbt.EvaluateTask.executeProgress"), + exclude[DirectMissingMethodProblem]("sbt.Keys.currentTaskProgress"), exclude[IncompatibleResultTypeProblem]("sbt.PluginData.copy$default$10") ), ) @@ -189,11 +191,11 @@ lazy val sbtRoot: Project = (project in file(".")) minimalSettings, onLoadMessage := { val version = sys.props("java.specification.version") - """ __ __ + """ __ __ | _____/ /_ / /_ | / ___/ __ \/ __/ - | (__ ) /_/ / /_ - | /____/_.___/\__/ + | (__ ) /_/ / /_ + | /____/_.___/\__/ |Welcome to the build for sbt. |""".stripMargin + (if (version != "1.8") diff --git a/main-command/src/main/scala/sbt/Command.scala b/main-command/src/main/scala/sbt/Command.scala index f4a1ebc6f..15e910bb1 100644 --- a/main-command/src/main/scala/sbt/Command.scala +++ b/main-command/src/main/scala/sbt/Command.scala @@ -184,12 +184,17 @@ object Command { } ) - def process(command: String, state: State): State = { + // overload instead of default parameter to keep binary compatibility + @deprecated("Use overload that takes the onParseError callback", since = "1.9.4") + def process(command: String, state: State): State = process(command, state, _ => ()) + + def process(command: String, state: State, onParseError: String => Unit): State = { (if (command.contains(";")) parse(command, state.combinedParser) else parse(command, state.nonMultiParser)) match { case Right(s) => s() // apply command. command side effects happen here case Left(errMsg) => state.log error errMsg + onParseError(errMsg) state.fail } } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index bf30d5f71..78d93f12a 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -341,6 +341,7 @@ object Defaults extends BuildCommon { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector rs map { Keys.TaskProgress(_) } }, + commandProgress := Seq(), // progressState is deprecated SettingKey[Option[ProgressState]]("progressState") := None, Previous.cache := new Previous( diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index eea3b357a..d6eb8a994 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -260,12 +260,15 @@ object EvaluateTask { extracted: Extracted, structure: BuildStructure, state: State - ): ExecuteProgress[Task] = { + ): ExecuteProgress2 = { state - .get(currentTaskProgress) - .map { tp => - new ExecuteProgress[Task] { - val progress = tp.progress + .get(currentCommandProgress) + .map { progress => + new ExecuteProgress2 { + override def beforeCommand(cmd: String, state: State): Unit = + progress.beforeCommand(cmd, state) + override def afterCommand(cmd: String, result: Either[Throwable, State]): Unit = + progress.afterCommand(cmd, result) override def initial(): Unit = progress.initial() override def afterRegistered( task: Task[_], @@ -281,7 +284,9 @@ object EvaluateTask { progress.afterCompleted(task, result) override def afterAllCompleted(results: RMap[Task, Result]): Unit = progress.afterAllCompleted(results) - override def stop(): Unit = {} + override def stop(): Unit = { + // TODO: this is not a typo, but a questionable decision in 6559c3a0 that is probably obsolete + } } } .getOrElse { @@ -295,11 +300,12 @@ object EvaluateTask { (if (SysProp.taskTimings) new TaskTimings(reportOnShutdown = false, state.globalLogging.full) :: Nil else Nil) - reporters match { - case xs if xs.isEmpty => ExecuteProgress.empty[Task] - case xs if xs.size == 1 => xs.head - case xs => ExecuteProgress.aggregate[Task](xs) - } + val cmdProgress = getSetting(Keys.commandProgress, Seq(), extracted, structure) + ExecuteProgress2.aggregate(reporters match { + case xs if xs.isEmpty => cmdProgress + case xs if xs.size == 1 => cmdProgress :+ new ExecuteProgressAdapter(xs.head) + case xs => cmdProgress :+ new ExecuteProgressAdapter(ExecuteProgress.aggregate[Task](xs)) + }) } } // TODO - Should this pull from Global or from the project itself? diff --git a/main/src/main/scala/sbt/ExecuteProgress2.scala b/main/src/main/scala/sbt/ExecuteProgress2.scala new file mode 100644 index 000000000..006b8f829 --- /dev/null +++ b/main/src/main/scala/sbt/ExecuteProgress2.scala @@ -0,0 +1,85 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt +import sbt.internal.util.RMap + +/** + * Tracks command execution progress. In addition to ExecuteProgress, this interface + * adds command start and end events, and gives access to the sbt.State at the beginning + * and end of each command. + * + * Command progress callbacks are wrapping task progress callbacks. That is, the `beforeCommand` + * callback will be called before the `initial` callback from ExecuteProgress, and the + * `afterCommand` callback will be called after the `stop` callback from ExecuteProgress. + */ +trait ExecuteProgress2 extends ExecuteProgress[Task] { + + /** + * Called before a command starts processing. The command has not yet been parsed. + * + * @param cmd The command string + * @param state The sbt.State before the command starts executing. + */ + def beforeCommand(cmd: String, state: State): Unit + + /** + * Called after a command finished execution. + * + * @param cmd The command string. + * @param result Left in case of an error. If the command cannot be parsed, it will be + * signalled as a ParseException with a detailed message. If the command + * was cancelled by the user, as sbt.Cancelled. If the command succeeded, + * Right with the new state after command execution. + * + */ + def afterCommand(cmd: String, result: Either[Throwable, State]): Unit +} + +class ExecuteProgressAdapter(ep: ExecuteProgress[Task]) extends ExecuteProgress2 { + override def beforeCommand(cmd: String, state: State): Unit = {} + override def afterCommand(cmd: String, result: Either[Throwable, State]): Unit = {} + override def initial(): Unit = ep.initial() + override def afterRegistered( + task: Task[_], + allDeps: Iterable[Task[_]], + pendingDeps: Iterable[Task[_]] + ): Unit = ep.afterRegistered(task, allDeps, pendingDeps) + override def afterReady(task: Task[_]): Unit = ep.afterReady(task) + override def beforeWork(task: Task[_]): Unit = ep.beforeWork(task) + override def afterWork[A](task: Task[A], result: Either[Task[A], Result[A]]): Unit = + ep.afterWork(task, result) + override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = + ep.afterCompleted(task, result) + override def afterAllCompleted(results: RMap[Task, Result]): Unit = ep.afterAllCompleted(results) + override def stop(): Unit = ep.stop() +} + +object ExecuteProgress2 { + def aggregate(xs: Seq[ExecuteProgress2]): ExecuteProgress2 = new ExecuteProgress2 { + override def beforeCommand(cmd: String, state: State): Unit = + xs.foreach(_.beforeCommand(cmd, state)) + override def afterCommand(cmd: String, result: Either[Throwable, State]): Unit = + xs.foreach(_.afterCommand(cmd, result)) + override def initial(): Unit = xs.foreach(_.initial()) + override def afterRegistered( + task: Task[_], + allDeps: Iterable[Task[_]], + pendingDeps: Iterable[Task[_]] + ): Unit = xs.foreach(_.afterRegistered(task, allDeps, pendingDeps)) + override def afterReady(task: Task[_]): Unit = xs.foreach(_.afterReady(task)) + override def beforeWork(task: Task[_]): Unit = xs.foreach(_.beforeWork(task)) + override def afterWork[A](task: Task[A], result: Either[Task[A], Result[A]]): Unit = + xs.foreach(_.afterWork(task, result)) + override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = + xs.foreach(_.afterCompleted(task, result)) + override def afterAllCompleted(results: RMap[Task, Result]): Unit = + xs.foreach(_.afterAllCompleted(results)) + override def stop(): Unit = xs.foreach(_.stop()) + } +} diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 3720a3693..4d19d798d 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -601,7 +601,7 @@ object Keys { object TaskProgress { def apply(progress: ExecuteProgress[Task]): TaskProgress = new TaskProgress(progress) } - private[sbt] val currentTaskProgress = AttributeKey[TaskProgress]("current-task-progress") + private[sbt] val currentCommandProgress = AttributeKey[ExecuteProgress2]("current-command-progress") private[sbt] val taskProgress = AttributeKey[sbt.internal.TaskProgress]("active-task-progress") val useSuperShell = settingKey[Boolean]("Enables (true) or disables the super shell.") val superShellMaxTasks = settingKey[Int]("The max number of tasks to display in the supershell progress report") @@ -615,6 +615,7 @@ object Keys { private[sbt] val postProgressReports = settingKey[Unit]("Internally used to modify logger.").withRank(DTask) @deprecated("No longer used", "1.3.0") private[sbt] val executeProgress = settingKey[State => TaskProgress]("Experimental task execution listener.").withRank(DTask) + val commandProgress = settingKey[Seq[ExecuteProgress2]]("Command progress listeners receive events when commands start and end, in addition to task progress events.") val lintUnused = inputKey[Unit]("Check for keys unused by other settings and tasks.") val lintIncludeFilter = settingKey[String => Boolean]("Filters key names that should be included in the lint check.") val lintExcludeFilter = settingKey[String => Boolean]("Filters key names that should be excluded in the lint check.") diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 3c7a800a6..aaa1d98ec 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -11,13 +11,18 @@ package sbt import java.io.PrintWriter import java.util.concurrent.RejectedExecutionException import java.util.Properties - import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey -import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Prompt, Terminal => ITerminal } +import sbt.internal.util.{ + AttributeKey, + ErrorHandling, + GlobalLogBacking, + Prompt, + Terminal => ITerminal +} import sbt.internal.{ ShutdownHooks, TaskProgress } import sbt.io.{ IO, Using } import sbt.protocol._ @@ -29,6 +34,8 @@ import scala.util.control.NonFatal import sbt.internal.FastTrackCommands import sbt.internal.SysProp +import java.text.ParseException + object MainLoop { /** Entry point to run the remaining commands in State with managed global logging.*/ @@ -212,16 +219,25 @@ object MainLoop { ) try { def process(): State = { - val progressState = state.get(sbt.Keys.currentTaskProgress) match { - case Some(_) => state - case _ => - if (state.get(Keys.stateBuildStructure).isDefined) { - val extracted = Project.extract(state) - val progress = EvaluateTask.executeProgress(extracted, extracted.structure, state) - state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) - } else state + def getOrSet[T](state: State, key: AttributeKey[T], value: Extracted => T): State = { + state.get(key) match { + case Some(_) => state + case _ => + if (state.get(Keys.stateBuildStructure).isDefined) { + val extracted = Project.extract(state) + state.put(key, value(extracted)) + } else state + } } - exchange.setState(progressState) + + val cmdProgressState = + getOrSet( + state, + sbt.Keys.currentCommandProgress, + extracted => EvaluateTask.executeProgress(extracted, extracted.structure, state) + ) + + exchange.setState(cmdProgressState) exchange.setExec(Some(exec)) val (restoreTerminal, termState) = channelName.flatMap(exchange.channelForName) match { case Some(c) => @@ -231,9 +247,13 @@ object MainLoop { (() => { ITerminal.set(prevTerminal) c.terminal.flush() - }) -> progressState.put(Keys.terminalKey, Terminal(c.terminal)) - case _ => (() => ()) -> progressState.put(Keys.terminalKey, Terminal(ITerminal.get)) + }) -> cmdProgressState.put(Keys.terminalKey, Terminal(c.terminal)) + case _ => (() => ()) -> cmdProgressState.put(Keys.terminalKey, Terminal(ITerminal.get)) } + + val currentCmdProgress = + cmdProgressState.get(sbt.Keys.currentCommandProgress) + currentCmdProgress.foreach(_.beforeCommand(exec.commandLine, cmdProgressState)) /* * FastTrackCommands.evaluate can be significantly faster than Command.process because * it avoids an expensive parsing step for internal commands that are easy to parse. @@ -241,16 +261,29 @@ object MainLoop { * but slower. */ val newState = try { - FastTrackCommands + var errorMsg: Option[String] = None + val res = FastTrackCommands .evaluate(termState, exec.commandLine) - .getOrElse(Command.process(exec.commandLine, termState)) + .getOrElse(Command.process(exec.commandLine, termState, m => errorMsg = Some(m))) + errorMsg match { + case Some(msg) => + currentCmdProgress.foreach( + _.afterCommand(exec.commandLine, Left(new ParseException(msg, 0))) + ) + case None => currentCmdProgress.foreach(_.afterCommand(exec.commandLine, Right(res))) + } + res } catch { case _: RejectedExecutionException => - // No stack trace since this is just to notify the user which command they cancelled - object Cancelled extends Throwable(exec.commandLine, null, true, false) { - override def toString: String = s"Cancelled: ${exec.commandLine}" - } - throw Cancelled + val cancelled = new Cancelled(exec.commandLine) + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(cancelled))) + throw cancelled + + case e: Throwable => + currentCmdProgress + .foreach(_.afterCommand(exec.commandLine, Left(e))) + throw e } finally { // Flush the terminal output after command evaluation to ensure that all output // is displayed in the thin client before we report the command status. Also @@ -269,8 +302,10 @@ object MainLoop { exchange.respondStatus(doneEvent) } exchange.setExec(None) - newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) - newState.remove(sbt.Keys.currentTaskProgress).remove(Keys.terminalKey) + newState.get(sbt.Keys.currentCommandProgress).foreach(_.stop()) + newState + .remove(Keys.terminalKey) + .remove(Keys.currentCommandProgress) } state.get(CheckBuildSourcesKey) match { case Some(cbs) => @@ -341,3 +376,8 @@ object MainLoop { ExitCode(ErrorCodes.UnknownError) } else ExitCode.Success } + +// No stack trace since this is just to notify the user which command they cancelled +class Cancelled(cmdLine: String) extends Throwable(cmdLine, null, true, false) { + override def toString: String = s"Cancelled: $cmdLine" +} diff --git a/sbt-app/src/sbt-test/reporter/command-progress/build.sbt b/sbt-app/src/sbt-test/reporter/command-progress/build.sbt new file mode 100644 index 000000000..08f7a7492 --- /dev/null +++ b/sbt-app/src/sbt-test/reporter/command-progress/build.sbt @@ -0,0 +1,39 @@ +commandProgress += new ExecuteProgress2 { + override def beforeCommand(cmd: String, state: State): Unit = { + EventLog.logs += (s"BEFORE: $cmd") + // assert that `state` is the current state indeed + assert(state.currentCommand.isDefined) + assert(state.currentCommand.get.commandLine == cmd) + } + override def afterCommand(cmd: String, result: Either[Throwable, State]): Unit = { + EventLog.logs += (s"AFTER: $cmd") + result.left.foreach(EventLog.errors +=) + } + override def initial(): Unit = {} + override def afterRegistered( + task: Task[_], + allDeps: Iterable[Task[_]], + pendingDeps: Iterable[Task[_]] + ): Unit = {} + override def afterReady(task: Task[_]): Unit = {} + override def beforeWork(task: Task[_]): Unit = {} + override def afterWork[A](task: Task[A], result: Either[Task[A], Result[A]]): Unit = {} + override def afterCompleted[A](task: Task[A], result: Result[A]): Unit = {} + override def afterAllCompleted(results: RMap[Task, Result]): Unit = {} + override def stop(): Unit = {} +} + +val check = taskKey[Unit]("Check basic command events") +val checkParseError = taskKey[Unit]("Check that parse error is recorded") + +check := { + def hasEvent(cmd: String): Boolean = + EventLog.logs.contains(s"BEFORE: $cmd") && EventLog.logs.contains(s"AFTER: $cmd") + + assert(hasEvent("compile"), "Missing command event `compile`") + assert(hasEvent("+compile"), "Missing command event `+compile`") +} + +checkParseError := { + assert(EventLog.errors.exists(_.getMessage.toLowerCase.contains("not a valid command"))) +} diff --git a/sbt-app/src/sbt-test/reporter/command-progress/project/EventLog.scala b/sbt-app/src/sbt-test/reporter/command-progress/project/EventLog.scala new file mode 100644 index 000000000..88b00c491 --- /dev/null +++ b/sbt-app/src/sbt-test/reporter/command-progress/project/EventLog.scala @@ -0,0 +1,8 @@ +import scala.collection.mutable.ListBuffer + +object EventLog { + val logs = ListBuffer.empty[String] + val errors = ListBuffer.empty[Throwable] + + def reset(): Unit = logs.clear() +} diff --git a/sbt-app/src/sbt-test/reporter/command-progress/test b/sbt-app/src/sbt-test/reporter/command-progress/test new file mode 100644 index 000000000..9011065e1 --- /dev/null +++ b/sbt-app/src/sbt-test/reporter/command-progress/test @@ -0,0 +1,5 @@ +> compile +> +compile +> check +-> asdf +> checkParseError