diff --git a/main-command/src/main/scala/sbt/Watched.scala b/main-command/src/main/scala/sbt/Watched.scala index bd31f11e5..b501a1f36 100644 --- a/main-command/src/main/scala/sbt/Watched.scala +++ b/main-command/src/main/scala/sbt/Watched.scala @@ -10,7 +10,12 @@ package sbt import java.io.File import java.nio.file.{ FileSystems, Path } -import sbt.BasicCommandStrings.{ ContinuousExecutePrefix, continuousBriefHelp, continuousDetail } +import sbt.BasicCommandStrings.{ + ContinuousExecutePrefix, + FailureWall, + continuousBriefHelp, + continuousDetail +} import sbt.BasicCommands.otherCommandParser import sbt.internal.LegacyWatched import sbt.internal.io.{ EventMonitor, Source, WatchState } @@ -69,12 +74,24 @@ object Watched { */ case object CancelWatch extends Action + /** + * Action that indicates that an error has occurred. The watch will be terminated when this action + * is produced. + */ + case object HandleError extends Action + /** * Action that indicates that the watch should continue as though nothing happened. This may be * because, for example, no user input was yet available in [[WatchConfig.handleInput]]. */ case object Ignore extends Action + /** + * Action that indicates that the watch should pause while the build is reloaded. This is used to + * automatically reload the project when the build files (e.g. build.sbt) are changed. + */ + case object Reload extends Action + /** * Action that indicates that the watch process should re-run the command. */ @@ -172,6 +189,22 @@ object Watched { Watched.executeContinuously(state, command, setup) } + /** + * Default handler to transform the state when the watch terminates. When the [[Watched.Action]] is + * [[Reload]], the handler will prepend the original command (prefixed by ~) to the + * [[State.remainingCommands]] and then invoke the [[StateOps.reload]] method. When the + * [[Watched.Action]] is [[HandleError]], the handler returns the result of [[StateOps.fail]]. Otherwise + * the original state is returned. + */ + private[sbt] val onTermination: (Action, String, State) => State = (action, command, state) => + action match { + case Reload => + val continuousCommand = Exec(ContinuousExecutePrefix + command, None) + state.copy(remainingCommands = continuousCommand +: state.remainingCommands).reload + case HandleError => state.fail + case _ => state + } + /** * Implements continuous execution. It works by first parsing the command and generating a task to * run with each build. It can run multiple commands that are separated by ";" in the command @@ -190,7 +223,13 @@ object Watched { command: String, setup: WatchSetup, ): State = { - val (s, config, newState) = setup(state, command) + val (s0, config, newState) = setup(state, command) + val failureCommandName = "SbtContinuousWatchOnFail" + val onFail = Command.command(failureCommandName)(identity) + val s = (FailureWall :: s0).copy( + onFailure = Some(Exec(failureCommandName, None)), + definedCommands = s0.definedCommands :+ onFail + ) val commands = command.split(";") match { case Array("", rest @ _*) => rest case Array(cmd) => Seq(cmd) @@ -202,8 +241,7 @@ object Watched { case Right(task) => Right { () => try { - newState(task) - Right(true) + Right(newState(task).remainingCommands.forall(_.commandLine != failureCommandName)) } catch { case e: Exception => Left(e) } } case Left(_) => Left(cmd) @@ -216,8 +254,8 @@ object Watched { case (status, Right(t)) => if (status.getOrElse(true)) t() else status case _ => throw new IllegalStateException("Should be unreachable") } - watch(task, config) - state + val terminationAction = watch(task, config) + config.onWatchTerminated(terminationAction, command, state) } else { config.logger.error( s"Terminating watch due to invalid command(s): ${invalid.mkString("'", "', '", "'")}" @@ -227,19 +265,19 @@ object Watched { } private[sbt] def watch( - task: () => Either[Exception, _], - config: WatchConfig, - ): Unit = { + task: () => Either[Exception, Boolean], + config: WatchConfig + ): Action = { val logger = config.logger def info(msg: String): Unit = if (msg.nonEmpty) logger.info(msg) @tailrec - def impl(count: Int): Unit = { + def impl(count: Int): Action = { @tailrec def nextAction(): Action = { config.handleInput() match { - case CancelWatch => CancelWatch - case Trigger => Trigger + case action @ (CancelWatch | HandleError | Reload) => action + case Trigger => Trigger case _ => val events = config.fileEventMonitor.poll(10.millis) val next = events match { @@ -248,42 +286,53 @@ object Watched { /* * We traverse all of the events and find the one for which we give the highest * weight. - * CancelWatch > Trigger > Ignore + * HandleError > CancelWatch > Reload > Trigger > Ignore */ tail.foldLeft((config.onWatchEvent(head), Some(head))) { - case (r @ (CancelWatch, _), _) => r - // If we've found a trigger, only change the accumulator if we find a CancelWatch. - case ((action, event), e) => - config.onWatchEvent(e) match { - case Trigger if action == Ignore => (Trigger, Some(e)) - case _ => (action, event) + case (current @ (action, _), event) => + config.onWatchEvent(event) match { + case HandleError => (HandleError, Some(event)) + case CancelWatch if action != HandleError => (CancelWatch, Some(event)) + case Reload if action != HandleError && action != CancelWatch => + (Reload, Some(event)) + case Trigger if action == Ignore => (Trigger, Some(event)) + case _ => current } } } + // Note that nextAction should never return Ignore. next match { - case (CancelWatch, Some(event)) => - logger.debug(s"Stopping watch due to event from ${event.entry.typedPath.getPath}.") - CancelWatch + case (action @ (HandleError | CancelWatch), Some(event)) => + val cause = if (action == HandleError) "error" else "cancellation" + logger.debug(s"Stopping watch due to $cause from ${event.entry.typedPath.getPath}") + action case (Trigger, Some(event)) => logger.debug(s"Triggered by ${event.entry.typedPath.getPath}") config.triggeredMessage(event.entry.typedPath, count).foreach(info) Trigger + case (Reload, Some(event)) => + logger.info(s"Reload triggered by ${event.entry.typedPath.getPath}") + Reload case _ => nextAction() } } } task() match { - case Right(status) if !config.shouldTerminate(count) => - config.watchingMessage(count).foreach(info) - nextAction() match { - case CancelWatch => () - case _ => impl(count + 1) + case Right(status) => + config.preWatch(count, status) match { + case Ignore => + config.watchingMessage(count).foreach(info) + nextAction() match { + case action @ (CancelWatch | HandleError | Reload) => action + case _ => impl(count + 1) + } + case Trigger => impl(count + 1) + case action @ (CancelWatch | HandleError | Reload) => action } case Left(e) => logger.error(s"Terminating watch due to Unexpected error: $e") - case _ => - logger.debug("Terminating watch due to WatchConfig.shouldTerminate") + HandleError } } try { @@ -365,10 +414,11 @@ trait WatchConfig { /** * This is run before each watch iteration and if it returns true, the watch is terminated. - * @param count The current number of watch iterstaions. - * @return true if the watch should stop. + * @param count The current number of watch iterations. + * @param lastStatus true if the previous task execution completed successfully + * @return the Action to apply */ - def shouldTerminate(count: Int): Boolean + def preWatch(count: Int, lastStatus: Boolean): Watched.Action /** * Callback that is invoked whenever a file system vent is detected. The next step of the watch @@ -378,6 +428,15 @@ trait WatchConfig { */ def onWatchEvent(event: Event[Path]): Watched.Action + /** + * Transforms the state after the watch terminates. + * @param action the [[Watched.Action Action]] that caused the build to terminate + * @param command the command that the watch was repeating + * @param state the initial state prior to the start of continuous execution + * @return the updated state. + */ + def onWatchTerminated(action: Watched.Action, command: String, state: State): State + /** * The optional message to log when a build is triggered. * @param typedPath the path that triggered the build @@ -400,13 +459,20 @@ trait WatchConfig { object WatchConfig { /** - * Create an instance of [[WatchConfig]]. + * Create an instance of [[WatchConfig]]. * @param logger logger for watch events * @param fileEventMonitor the monitor for file system events. * @param handleInput callback that is periodically invoked to check whether to continue or * terminate the watch based on user input. It is also possible to, for * example time out the watch using this callback. + * @param preWatch callback to invoke before waiting for updates from the sbt.io.FileEventMonitor. + * The input parameters are the current iteration count and whether or not + * the last invocation of the command was successful. Typical uses would be to + * terminate the watch after a fixed number of iterations or to terminate the + * watch if the command was unsuccessful. * @param onWatchEvent callback that is invoked when + * @param onWatchTerminated callback that is invoked to update the state after the watch + * terminates. * @param triggeredMessage optional message that will be logged when a new build is triggered. * The input parameters are the sbt.io.TypedPath that triggered the new * build and the current iteration count. @@ -418,25 +484,29 @@ object WatchConfig { logger: Logger, fileEventMonitor: FileEventMonitor[Path], handleInput: () => Watched.Action, - shouldTerminate: Int => Boolean, + preWatch: (Int, Boolean) => Watched.Action, onWatchEvent: Event[Path] => Watched.Action, + onWatchTerminated: (Watched.Action, String, State) => State, triggeredMessage: (TypedPath, Int) => Option[String], watchingMessage: Int => Option[String] ): WatchConfig = { val l = logger val fem = fileEventMonitor val hi = handleInput - val st = shouldTerminate + val pw = preWatch val owe = onWatchEvent + val owt = onWatchTerminated val tm = triggeredMessage val wm = watchingMessage new WatchConfig { override def logger: Logger = l override def fileEventMonitor: FileEventMonitor[Path] = fem override def handleInput(): Watched.Action = hi() - override def shouldTerminate(count: Int): Boolean = - st(count) + override def preWatch(count: Int, lastResult: Boolean): Watched.Action = + pw(count, lastResult) override def onWatchEvent(event: Event[Path]): Watched.Action = owe(event) + override def onWatchTerminated(action: Watched.Action, command: String, state: State): State = + owt(action, command, state) override def triggeredMessage(typedPath: TypedPath, count: Int): Option[String] = tm(typedPath, count) override def watchingMessage(count: Int): Option[String] = wm(count) diff --git a/main-command/src/test/scala/sbt/WatchedSpec.scala b/main-command/src/test/scala/sbt/WatchedSpec.scala index 1ea7f36f7..612d232b3 100644 --- a/main-command/src/test/scala/sbt/WatchedSpec.scala +++ b/main-command/src/test/scala/sbt/WatchedSpec.scala @@ -29,7 +29,7 @@ class WatchedSpec extends FlatSpec with Matchers { fileEventMonitor: Option[FileEventMonitor[Path]] = None, logger: Logger = NullLogger, handleInput: () => Action = () => Ignore, - shouldTerminate: Int => Boolean = _ => true, + preWatch: (Int, Boolean) => Action = (_, _) => CancelWatch, onWatchEvent: Event[Path] => Action = _ => Ignore, triggeredMessage: (TypedPath, Int) => Option[String] = (_, _) => None, watchingMessage: Int => Option[String] = _ => None @@ -41,8 +41,9 @@ class WatchedSpec extends FlatSpec with Matchers { logger = logger, monitor, handleInput, - shouldTerminate, + preWatch, onWatchEvent, + (_, _, state) => state, triggeredMessage, watchingMessage ) @@ -50,19 +51,19 @@ class WatchedSpec extends FlatSpec with Matchers { } "Watched.watch" should "stop" in IO.withTemporaryDirectory { dir => val config = Defaults.config(sources = Seq(WatchSource(dir.toRealPath))) - Watched.watch(() => Right(true), config) should be(()) + Watched.watch(() => Right(true), config) shouldBe CancelWatch } it should "trigger" in IO.withTemporaryDirectory { dir => val triggered = new AtomicBoolean(false) val config = Defaults.config( sources = Seq(WatchSource(dir.toRealPath)), - shouldTerminate = count => count == 2, + preWatch = (count, _) => if (count == 2) CancelWatch else Ignore, onWatchEvent = _ => { triggered.set(true); Trigger }, watchingMessage = _ => { new File(dir, "file").createNewFile; None } ) - Watched.watch(() => Right(true), config) should be(()) + Watched.watch(() => Right(true), config) shouldBe CancelWatch assert(triggered.get()) } it should "filter events" in IO.withTemporaryDirectory { dir => @@ -72,12 +73,12 @@ class WatchedSpec extends FlatSpec with Matchers { val bar = realDir.toPath.resolve("bar") val config = Defaults.config( sources = Seq(WatchSource(realDir)), - shouldTerminate = count => count == 2, + preWatch = (count, _) => if (count == 2) CancelWatch else Ignore, onWatchEvent = e => if (e.entry.typedPath.getPath == foo) Trigger else Ignore, triggeredMessage = (tp, _) => { queue += tp; None }, watchingMessage = _ => { Files.createFile(bar); Thread.sleep(5); Files.createFile(foo); None } ) - Watched.watch(() => Right(true), config) should be(()) + Watched.watch(() => Right(true), config) shouldBe CancelWatch queue.toIndexedSeq.map(_.getPath) shouldBe Seq(foo) } it should "enforce anti-entropy" in IO.withTemporaryDirectory { dir => @@ -87,21 +88,41 @@ class WatchedSpec extends FlatSpec with Matchers { val bar = realDir.toPath.resolve("bar") val config = Defaults.config( sources = Seq(WatchSource(realDir)), - shouldTerminate = count => count == 3, + preWatch = (count, _) => if (count == 3) CancelWatch else Ignore, onWatchEvent = _ => Trigger, triggeredMessage = (tp, _) => { queue += tp; None }, watchingMessage = count => { - if (count == 1) Files.createFile(bar) - else if (count == 2) { - bar.toFile.setLastModified(5000) - Files.createFile(foo) + count match { + case 1 => Files.createFile(bar) + case 2 => + bar.toFile.setLastModified(5000) + Files.createFile(foo) + case _ => } None } ) - Watched.watch(() => Right(true), config) should be(()) + Watched.watch(() => Right(true), config) shouldBe CancelWatch queue.toIndexedSeq.map(_.getPath) shouldBe Seq(bar, foo) } + it should "halt on error" in IO.withTemporaryDirectory { dir => + val halted = new AtomicBoolean(false) + val config = Defaults.config( + sources = Seq(WatchSource(dir.toRealPath)), + preWatch = (_, lastStatus) => if (lastStatus) Ignore else { halted.set(true); HandleError } + ) + Watched.watch(() => Right(false), config) shouldBe HandleError + assert(halted.get()) + } + it should "reload" in IO.withTemporaryDirectory { dir => + val config = Defaults.config( + sources = Seq(WatchSource(dir.toRealPath)), + preWatch = (_, _) => Ignore, + onWatchEvent = _ => Reload, + watchingMessage = _ => { new File(dir, "file").createNewFile(); None } + ) + Watched.watch(() => Right(true), config) shouldBe Reload + } } object WatchedSpec { diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 8e41ada1d..76fbb59a8 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -250,6 +250,7 @@ object Defaults extends BuildCommon { Nil }, watchSources :== Nil, + watchProjectSources :== Nil, skip :== false, taskTemporaryDirectory := { val dir = IO.createTemporaryDirectory; dir.deleteOnExit(); dir }, onComplete := { @@ -384,6 +385,13 @@ object Defaults extends BuildCommon { else Nil bases.map(b => new Source(b, include, exclude)) ++ baseSources }, + watchProjectSources in ConfigGlobal := (watchProjectSources in ConfigGlobal).value ++ { + val baseDir = baseDirectory.value + Seq( + new Source(baseDir, "*.sbt", HiddenFileFilter, recursive = false), + new Source(baseDir / "project", "*.sbt" || "*.scala", HiddenFileFilter, recursive = true) + ) + }, managedSourceDirectories := Seq(sourceManaged.value), managedSources := generate(sourceGenerators).value, sourceGenerators :== Nil, @@ -606,17 +614,22 @@ object Defaults extends BuildCommon { clean := (Def.task { IO.delete(cleanFiles.value) } tag (Tags.Clean)).value, consoleProject := consoleProjectTask.value, watchTransitiveSources := watchTransitiveSourcesTask.value, + watchProjectTransitiveSources := watchTransitiveSourcesTaskImpl(watchProjectSources).value, watchOnEvent := { val sources = watchTransitiveSources.value + val projectSources = watchProjectTransitiveSources.value e => - if (sources.exists(_.accept(e.entry.typedPath.getPath))) Watched.Trigger else Watched.Ignore + if (sources.exists(_.accept(e.entry.typedPath.getPath))) Watched.Trigger + else if (projectSources.exists(_.accept(e.entry.typedPath.getPath))) Watched.Reload + else Watched.Ignore }, watchHandleInput := Watched.handleInput, - watchShouldTerminate := { _ => - false + watchPreWatch := { (_, _) => + Watched.Ignore }, + watchOnTermination := Watched.onTermination, watchConfig := { - val sources = watchTransitiveSources.value + val sources = watchTransitiveSources.value ++ watchProjectTransitiveSources.value val extracted = Project.extract(state.value) val wm = extracted .getOpt(watchingMessage) @@ -634,8 +647,9 @@ object Defaults extends BuildCommon { logger, viewConfig.newMonitor(viewConfig.newDataView(), sources, logger), watchHandleInput.value, - watchShouldTerminate.value, + watchPreWatch.value, watchOnEvent.value, + watchOnTermination.value, tm, wm ) @@ -648,10 +662,15 @@ object Defaults extends BuildCommon { def generate(generators: SettingKey[Seq[Task[Seq[File]]]]): Initialize[Task[Seq[File]]] = generators { _.join.map(_.flatten) } - def watchTransitiveSourcesTask: Initialize[Task[Seq[Source]]] = { + def watchTransitiveSourcesTask: Initialize[Task[Seq[Source]]] = + watchTransitiveSourcesTaskImpl(watchSources) + + private def watchTransitiveSourcesTaskImpl( + key: TaskKey[Seq[Source]] + ): Initialize[Task[Seq[Source]]] = { import ScopeFilter.Make.{ inDependencies => inDeps, _ } val selectDeps = ScopeFilter(inAggregates(ThisProject) || inDeps(ThisProject)) - val allWatched = (watchSources ?? Nil).all(selectDeps) + val allWatched = (key ?? Nil).all(selectDeps) Def.task { allWatched.value.flatten } } diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 66ffb8808..d2875ca9c 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -152,8 +152,11 @@ object Keys { val watchLogger = taskKey[Logger]("A logger that reports watch events.").withRank(DSetting) val watchHandleInput = settingKey[() => Watched.Action]("Function that is periodically invoked to determine if the continous build should be stopped or if a build should be triggered. It will usually read from stdin to respond to user commands.").withRank(BMinusSetting) val watchOnEvent = taskKey[Event[JPath] => Watched.Action]("Determines how to handle a file event").withRank(BMinusSetting) + val watchOnTermination = taskKey[(Watched.Action, String, State) => State]("Transforms the input state after the continuous build completes.").withRank(BMinusSetting) val watchService = settingKey[() => WatchService]("Service to use to monitor file system changes.").withRank(BMinusSetting) - val watchShouldTerminate = settingKey[Int => Boolean]("Function that may terminate a continuous build based on the number of iterations.").withRank(BMinusSetting) + val watchProjectSources = taskKey[Seq[Watched.WatchSource]]("Defines the sources for the sbt meta project to watch to trigger a reload.").withRank(CSetting) + val watchProjectTransitiveSources = taskKey[Seq[Watched.WatchSource]]("Defines the sources in all projects for the sbt meta project to watch to trigger a reload.").withRank(CSetting) + val watchPreWatch = settingKey[(Int, Boolean) => Watched.Action]("Function that may terminate a continuous build based on the number of iterations and the last result").withRank(BMinusSetting) val watchSources = taskKey[Seq[Watched.WatchSource]]("Defines the sources in this project for continuous execution to watch for changes.").withRank(BMinusSetting) val watchStartMessage = settingKey[Int => Option[String]]("The message to show when triggered execution waits for sources to change. The parameter is the current watch iteration count.").withRank(DSetting) val watchTransitiveSources = taskKey[Seq[Watched.WatchSource]]("Defines the sources in all projects for continuous execution to watch.").withRank(CSetting) diff --git a/sbt/src/sbt-test/watch/on-start-watch/build.sbt b/sbt/src/sbt-test/watch/on-start-watch/build.sbt index d992f3473..1c6dab6c1 100644 --- a/sbt/src/sbt-test/watch/on-start-watch/build.sbt +++ b/sbt/src/sbt-test/watch/on-start-watch/build.sbt @@ -1,10 +1,50 @@ +import scala.util.Try + val checkCount = inputKey[Unit]("check that compile has run a specified number of times") +val checkReloadCount = inputKey[Unit]("check whether the project was reloaded") +val failingTask = taskKey[Unit]("should always fail") +val maybeReload = settingKey[(Int, Boolean) => Watched.Action]("possibly reload") +val resetCount = taskKey[Unit]("reset compile count") +val reloadFile = settingKey[File]("get the current reload file") checkCount := { val expected = Def.spaceDelimited().parsed.head.toInt - assert(Count.get == expected) + if (Count.get != expected) + throw new IllegalStateException(s"Expected ${expected} compilation runs, got ${Count.get}") } +maybeReload := { (_, _) => + if (Count.reloadCount(reloadFile.value) == 0) Watched.Reload else Watched.CancelWatch +} + +reloadFile := baseDirectory.value / "reload-count" + +resetCount := { + Count.reset() +} + +failingTask := { + throw new IllegalStateException("failed") +} + +watchPreWatch := maybeReload.value + +checkReloadCount := { + val expected = Def.spaceDelimited().parsed.head.toInt + assert(Count.reloadCount(reloadFile.value) == expected) +} + +val addReloadShutdownHook = Command.command("addReloadShutdownHook") { state => + state.addExitHook { + val base = Project.extract(state).get(baseDirectory) + val file = base / "reload-count" + val currentCount = Try(Count.reloadCount(file)).getOrElse(0) + IO.write(file, s"${currentCount + 1}".getBytes) + } +} + +commands += addReloadShutdownHook + Compile / compile := { Count.increment() // Trigger a new build by updating the last modified time diff --git a/sbt/src/sbt-test/watch/on-start-watch/project/Count.scala b/sbt/src/sbt-test/watch/on-start-watch/project/Count.scala index 0698b75ff..67d3bf940 100644 --- a/sbt/src/sbt-test/watch/on-start-watch/project/Count.scala +++ b/sbt/src/sbt-test/watch/on-start-watch/project/Count.scala @@ -1,6 +1,10 @@ +import sbt._ +import scala.util.Try + object Count { private var count = 0 def get: Int = count def increment(): Unit = count += 1 def reset(): Unit = count = 0 -} \ No newline at end of file + def reloadCount(file: File): Int = Try(IO.read(file).toInt).getOrElse(0) +} diff --git a/sbt/src/sbt-test/watch/on-start-watch/test b/sbt/src/sbt-test/watch/on-start-watch/test index f5fa900e7..37781fce3 100644 --- a/sbt/src/sbt-test/watch/on-start-watch/test +++ b/sbt/src/sbt-test/watch/on-start-watch/test @@ -1,4 +1,28 @@ +# verify that reloading occurs if watchPreWatch returns Watched.Reload +> addReloadShutdownHook +> checkReloadCount 0 +> ~compile +> checkReloadCount 1 + # verify that the watch terminates when we reach the specified count -> set watchShouldTerminate := { count => count == 2 } +> resetCount +> set watchPreWatch := { (count: Int, _) => if (count == 2) Watched.CancelWatch else Watched.Ignore } > ~compile > checkCount 2 + +# verify that the watch terminates and returns an error when we reach the specified count +> resetCount +> set watchPreWatch := { (count: Int, _) => if (count == 2) Watched.HandleError else Watched.Ignore } +# Returning Watched.HandleError causes the '~' command to fail +-> ~compile +> checkCount 2 + +# verify that a re-build is triggered when we reach the specified count +> resetCount +> set watchPreWatch := { (count: Int, _) => if (count == 2) Watched.Trigger else if (count == 3) Watched.CancelWatch else Watched.Ignore } +> ~compile +> checkCount 3 + +# verify that the watch exits and returns an error if the task fails +> set watchPreWatch := { (_, lastStatus: Boolean) => if (lastStatus) Watched.Ignore else Watched.HandleError } +-> ~failingTask