diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 763d6cef6..9184670bf 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -971,22 +971,23 @@ object BuiltinCommands { } private[sbt] def waitCmd: Command = - Command.arb(_ => (ContinuousCommands.waitWatch: Parser[String]).examples()) { (s0, _) => + Command.arb( + _ => ContinuousCommands.waitWatch.examples() ~> " ".examples() ~> matched(any.*).examples() + ) { (s0, channel) => val exchange = StandardMain.exchange - if (exchange.channels.exists(ContinuousCommands.isInWatch)) { - val s1 = exchange.run(s0) - exchange.channels.foreach { - case c if ContinuousCommands.isPending(c) => - case c => c.prompt(ConsolePromptEvent(s1)) - } - val exec: Exec = getExec(s1, Duration.Inf) - val remaining: List[Exec] = - Exec(ContinuousCommands.waitWatch, None) :: - Exec(FailureWall, None) :: s1.remainingCommands - val newState = s1.copy(remainingCommands = exec +: remaining) - if (exec.commandLine.trim.isEmpty) newState - else newState.clearGlobalLog - } else s0 + exchange.channelForName(channel) match { + case Some(c) if ContinuousCommands.isInWatch(s0, c) => + c.prompt(ConsolePromptEvent(s0)) + val s1 = exchange.run(s0) + val exec: Exec = getExec(s1, Duration.Inf) + val remaining: List[Exec] = + Exec(s"${ContinuousCommands.waitWatch} $channel", None) :: + Exec(FailureWall, None) :: s1.remainingCommands + val newState = s1.copy(remainingCommands = exec +: remaining) + if (exec.commandLine.trim.isEmpty) newState + else newState.clearGlobalLog + case _ => s0 + } } private[sbt] def promptChannel = Command.arb(_ => reportParser(PromptChannel)) { diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index d4c6bbf27..3a5eaea79 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -170,17 +170,15 @@ private[sbt] final class CommandExchange { currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e => Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0"))) } - if (ContinuousCommands.isInWatch(c)) { - try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None)) - catch { case _: InterruptedException => } - } + try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None)) + catch { case _: InterruptedException => } } private[this] def mkAskUser( name: String, ): (State, CommandChannel) => UITask = { (state, channel) => ContinuousCommands - .watchUITaskFor(channel) + .watchUITaskFor(state, channel) .getOrElse(new UITask.AskUserTask(state, channel)) } @@ -353,8 +351,8 @@ private[sbt] final class CommandExchange { def prompt(event: ConsolePromptEvent): Unit = { currentExecRef.set(null) channels.foreach { - case c if ContinuousCommands.isInWatch(c) => - case c => c.prompt(event) + case c if ContinuousCommands.isInWatch(lastState.get, c) => + case c => c.prompt(event) } } def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event)) @@ -459,10 +457,9 @@ private[sbt] final class CommandExchange { Option(currentExecRef.get).foreach(cancel) mt.channel.prompt(ConsolePromptEvent(lastState.get)) case t if t.startsWith(ContinuousCommands.stopWatch) => - ContinuousCommands.stopWatchImpl(mt.channel.name) mt.channel match { case c: NetworkChannel if !c.isInteractive => exit(mt) - case _ => mt.channel.prompt(ConsolePromptEvent(lastState.get)) + case _ => } commandQueue.add(Exec(t, None, None)) case `TerminateAction` => exit(mt) diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index 86618ceb2..557fa3300 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -108,8 +108,8 @@ private[sbt] object Continuous extends DeprecatedContinuous { case Some(c) => s -> c case None => StandardMain.exchange.run(s) -> ConsoleChannel.defaultName } - ContinuousCommands.setupWatchState(channel, initialCount, commands, s1) - s"${ContinuousCommands.runWatch} $channel" :: s1 + val ws = ContinuousCommands.setupWatchState(channel, initialCount, commands, s1) + s"${ContinuousCommands.runWatch} $channel" :: ws } @deprecated("The input task version of watch is no longer available", "1.4.0") @@ -1056,7 +1056,7 @@ private[sbt] object Continuous extends DeprecatedContinuous { val commands: Seq[String], beforeCommandImpl: (State, mutable.Set[DynamicInput]) => State, val afterCommand: State => State, - val afterWatch: () => Unit, + val afterWatch: State => State, val callbacks: Callbacks, val dynamicInputs: mutable.Set[DynamicInput], val pending: Boolean, @@ -1102,7 +1102,8 @@ private[sbt] object ContinuousCommands { "", Int.MaxValue ) - private[this] val watchStates = new ConcurrentHashMap[String, ContinuousState] + private[this] val watchStates = + AttributeKey[Map[String, ContinuousState]]("sbt-watch-states", Int.MaxValue) private[sbt] val runWatch = networkExecPrefix + "runWatch" private[sbt] val preWatch = networkExecPrefix + "preWatch" private[sbt] val postWatch = networkExecPrefix + "postWatch" @@ -1120,10 +1121,10 @@ private[sbt] object ContinuousCommands { "", Int.MaxValue ) - private[sbt] val setupWatchState: (String, Int, Seq[String], State) => Unit = + private[sbt] val setupWatchState: (String, Int, Seq[String], State) => State = (channelName, count, commands, state) => { - watchStates.get(channelName) match { - case null => + state.get(watchStates).flatMap(_.get(channelName)) match { + case None => val extracted = Project.extract(state) val repo = state.get(globalFileTreeRepository) match { case Some(r) => localRepo(r) @@ -1161,27 +1162,37 @@ private[sbt] object ContinuousCommands { stateWithCache.put(Continuous.DynamicInputs, dynamicInputs) }, afterCommand = state => { - watchStates.get(channelName) match { - case null => - case ws => watchStates.put(channelName, ws.incremented) + val newWatchState = state.get(watchStates) match { + case None => state + case Some(ws) => + ws.get(channelName) match { + case None => state + case Some(cs) => state.put(watchStates, ws + (channelName -> cs.incremented)) + } } - val restoredState = state.get(stashedRepo) match { + val restoredState = newWatchState.get(stashedRepo) match { case None => throw new IllegalStateException(s"No stashed repository for $state") - case Some(r) => state.put(globalFileTreeRepository, r) + case Some(r) => newWatchState.put(globalFileTreeRepository, r) } restoredState.remove(persistentFileStampCache).remove(Continuous.DynamicInputs) }, - afterWatch = () => { - watchStates.remove(channelName) + afterWatch = state => { LogExchange.unbindLoggerAppenders(channelName + "-watch") repo.close() + state.get(watchStates) match { + case None => state + case Some(ws) => state.put(watchStates, ws - channelName) + } }, callbacks = cb, dynamicInputs = dynamicInputs, pending = false, ) - Util.ignoreResult(watchStates.put(channelName, s)) - case cs => + state.get(watchStates) match { + case None => state.put(watchStates, Map(channelName -> s)) + case Some(ws) => state.put(watchStates, ws + (channelName -> s)) + } + case Some(cs) => val cmd = cs.commands.mkString("; ") val msg = s"Tried to start new watch while channel, '$channelName', was already watching '$cmd'" @@ -1194,28 +1205,26 @@ private[sbt] object ContinuousCommands { Command.arb { state => (cmdParser(name) ~> channelParser).map(channel => () => updateState(channel, state)) } { case (_, newState) => newState() } - private[this] val runWatchCommand = watchCommand(runWatch) { (channel, state) => - watchStates.get(channel) match { - case null => state - case cs => + private[sbt] val runWatchCommand = watchCommand(runWatch) { (channel, state) => + state.get(watchStates).flatMap(_.get(channel)) match { + case None => state + case Some(cs) => val pre = StashOnFailure :: s"$SetTerminal $channel" :: s"$preWatch $channel" :: Nil val post = FailureWall :: PopOnFailure :: s"$SetTerminal ${ConsoleChannel.defaultName}" :: - s"$postWatch $channel" :: waitWatch :: Nil + s"$postWatch $channel" :: s"$waitWatch $channel" :: Nil pre ::: cs.commands.toList ::: post ::: state } } - private[sbt] def watchUITaskFor(channel: CommandChannel): Option[UITask] = - watchStates.get(channel.name) match { - case null => None - case cs => Some(new WatchUITask(channel, cs)) - } - private[sbt] def isInWatch(channel: CommandChannel): Boolean = - watchStates.get(channel.name) != null - private[sbt] def isPending(channel: CommandChannel): Boolean = - Option(watchStates.get(channel.name)).fold(false)(_.pending) + private[sbt] def watchUITaskFor(state: State, channel: CommandChannel): Option[UITask] = + state.get(watchStates).flatMap(_.get(channel.name)).map(new WatchUITask(channel, _, state)) + private[sbt] def isInWatch(state: State, channel: CommandChannel): Boolean = + state.get(watchStates).exists(_.contains(channel.name)) + private[sbt] def isPending(state: State, channel: CommandChannel): Boolean = + state.get(watchStates).exists(_.get(channel.name).exists(_.pending)) private[this] class WatchUITask( override private[sbt] val channel: CommandChannel, cs: ContinuousState, + state: State ) extends Thread(s"sbt-${channel.name}-watch-ui-thread") with UITask { override private[sbt] def reader: UITask.Reader = () => { @@ -1229,8 +1238,12 @@ private[sbt] object ContinuousCommands { recursive = false ) } - val ws = watchState(channel.name) - watchStates.put(channel.name, ws.withPending(true)) + val ws = state.get(watchStates) match { + case None => throw new IllegalStateException("no watch states") + case Some(ws) => + ws.get(channel.name) + .getOrElse(throw new IllegalStateException(s"no watch state for ${channel.name}")) + } exitAction match { // Use a Left so that the client can immediately exit watch via case Watch.CancelWatch => Left(s"$stopWatch ${channel.name}") @@ -1248,30 +1261,40 @@ private[sbt] object ContinuousCommands { } } @inline - private[this] def watchState(channel: String): ContinuousState = watchStates.get(channel) match { - case null => throw new IllegalStateException(s"No watch state for $channel") - case s => s - } + private[this] def watchState(state: State, channel: String): ContinuousState = + state.get(watchStates).flatMap(_.get(channel)) match { + case None => throw new IllegalStateException(s"no watch state for $channel") + case Some(s) => s + } - private[this] val preWatchCommand = watchCommand(preWatch) { (channel, state) => - StandardMain.exchange.channelForName(channel).foreach(_.terminal.setPrompt(Prompt.Watch)) - watchState(channel).beforeCommand(state) + private[sbt] val preWatchCommand = watchCommand(preWatch) { (channel, state) => + watchState(state, channel).beforeCommand(state) } - private[this] val postWatchCommand = watchCommand(postWatch) { (channel, state) => - StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel)))) - val ws = watchState(channel) - watchStates.put(channel, ws.withPending(false)) - ws.afterCommand(state) + private[sbt] val postWatchCommand = watchCommand(postWatch) { (channel, state) => + val cs = watchState(state, channel) + StandardMain.exchange.channelForName(channel).foreach { c => + c.terminal.setPrompt(Prompt.Watch) + c.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel)))) + } + val postState = state.get(watchStates) match { + case None => state + case Some(ws) => state.put(watchStates, ws + (channel -> cs.withPending(false))) + } + cs.afterCommand(postState) } - private[this] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) => - stopWatchImpl(channel) - state - } - private[sbt] def stopWatchImpl(channelName: String): Unit = { - StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channelName)))) - Option(watchStates.get(channelName)).foreach { ws => - ws.afterWatch() - ws.callbacks.onExit() + private[sbt] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) => + state.get(watchStates).flatMap(_.get(channel)) match { + case Some(cs) => + val afterWatchState = cs.afterWatch(state) + cs.callbacks.onExit() + StandardMain.exchange + .channelForName(channel) + .foreach(_.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))) + afterWatchState.get(watchStates) match { + case None => afterWatchState + case Some(w) => afterWatchState.put(watchStates, w - channel) + } + case _ => state } } private[this] val failWatchCommand = watchCommand(failWatch) { (channel, state) => diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index ddb42df1c..f7583f9e1 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -149,7 +149,7 @@ final class NetworkChannel( protected def authOptions: Set[ServerAuthentication] = auth override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => { - if (interactive.get || ContinuousCommands.isInWatch(this)) mkUIThreadImpl(state, command) + if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command) else new UITask { override private[sbt] def channel = NetworkChannel.this @@ -789,7 +789,8 @@ final class NetworkChannel( override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false) override def isEchoEnabled: Boolean = waitForPending(_.isEchoEnabled) override def isSuccessEnabled: Boolean = - prompt != Prompt.Batch || ContinuousCommands.isInWatch(NetworkChannel.this) + prompt != Prompt.Batch || + StandardMain.exchange.withState(ContinuousCommands.isInWatch(_, NetworkChannel.this)) override lazy val isColorEnabled: Boolean = waitForPending(_.isColorEnabled) override lazy val isSupershellEnabled: Boolean = waitForPending(_.isSupershellEnabled) getProperties(false) diff --git a/sbt/src/sbt-test/watch/legacy-sources/build.sbt b/sbt/src/sbt-test/watch/legacy-sources/build.sbt index 3ee3097d9..9c20c4496 100644 --- a/sbt/src/sbt-test/watch/legacy-sources/build.sbt +++ b/sbt/src/sbt-test/watch/legacy-sources/build.sbt @@ -1,7 +1,17 @@ -import sbt.legacy.sources.Build._ - Global / watchSources += new sbt.internal.io.Source(baseDirectory.value, "global.txt", NothingFilter, false) +val setStringValue = inputKey[Unit]("set a global string to a value") +val checkStringValue = inputKey[Unit]("check the value of a global") + +def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask { + val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim) + IO.write(file(stringFile), string) +} +def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask { + val Seq(stringFile, string) = Def.spaceDelimited().parsed + assert(IO.read(file(stringFile)) == string) +} + watchSources in setStringValue += new sbt.internal.io.Source(baseDirectory.value, "foo.txt", NothingFilter, false) setStringValue := setStringValueImpl.evaluated diff --git a/sbt/src/sbt-test/watch/legacy-sources/project/Build.scala b/sbt/src/sbt-test/watch/legacy-sources/project/Build.scala deleted file mode 100644 index 17643092a..000000000 --- a/sbt/src/sbt-test/watch/legacy-sources/project/Build.scala +++ /dev/null @@ -1,17 +0,0 @@ -package sbt.legacy.sources - -import sbt._ -import Keys._ - -object Build { - val setStringValue = inputKey[Unit]("set a global string to a value") - val checkStringValue = inputKey[Unit]("check the value of a global") - def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask { - val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim) - IO.write(file(stringFile), string) - } - def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask { - val Seq(stringFile, string) = Def.spaceDelimited().parsed - assert(IO.read(file(stringFile)) == string) - } -} \ No newline at end of file