Make watch implementation more sbt idiomatic

The 1.4.0 implementation of watch uses a concurrent hash map to maintain
the global watch state which manages the state for an arbitrary number
of clients. Using a mutable map is not idiomatic sbt and I found it
difficult to reason about when the map was updated. This commit reworks
the feature so that the global state is instead stored in an immutable
map that is only modified during the internal watch commands, which is
easier to reason about.
This commit is contained in:
Ethan Atkins 2020-08-02 09:35:02 -07:00
parent e76f61bec5
commit eb48f24f3a
6 changed files with 113 additions and 98 deletions

View File

@ -971,22 +971,23 @@ object BuiltinCommands {
}
private[sbt] def waitCmd: Command =
Command.arb(_ => (ContinuousCommands.waitWatch: Parser[String]).examples()) { (s0, _) =>
Command.arb(
_ => ContinuousCommands.waitWatch.examples() ~> " ".examples() ~> matched(any.*).examples()
) { (s0, channel) =>
val exchange = StandardMain.exchange
if (exchange.channels.exists(ContinuousCommands.isInWatch)) {
val s1 = exchange.run(s0)
exchange.channels.foreach {
case c if ContinuousCommands.isPending(c) =>
case c => c.prompt(ConsolePromptEvent(s1))
}
val exec: Exec = getExec(s1, Duration.Inf)
val remaining: List[Exec] =
Exec(ContinuousCommands.waitWatch, None) ::
Exec(FailureWall, None) :: s1.remainingCommands
val newState = s1.copy(remainingCommands = exec +: remaining)
if (exec.commandLine.trim.isEmpty) newState
else newState.clearGlobalLog
} else s0
exchange.channelForName(channel) match {
case Some(c) if ContinuousCommands.isInWatch(s0, c) =>
c.prompt(ConsolePromptEvent(s0))
val s1 = exchange.run(s0)
val exec: Exec = getExec(s1, Duration.Inf)
val remaining: List[Exec] =
Exec(s"${ContinuousCommands.waitWatch} $channel", None) ::
Exec(FailureWall, None) :: s1.remainingCommands
val newState = s1.copy(remainingCommands = exec +: remaining)
if (exec.commandLine.trim.isEmpty) newState
else newState.clearGlobalLog
case _ => s0
}
}
private[sbt] def promptChannel = Command.arb(_ => reportParser(PromptChannel)) {

View File

@ -170,17 +170,15 @@ private[sbt] final class CommandExchange {
currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e =>
Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0")))
}
if (ContinuousCommands.isInWatch(c)) {
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
catch { case _: InterruptedException => }
}
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
catch { case _: InterruptedException => }
}
private[this] def mkAskUser(
name: String,
): (State, CommandChannel) => UITask = { (state, channel) =>
ContinuousCommands
.watchUITaskFor(channel)
.watchUITaskFor(state, channel)
.getOrElse(new UITask.AskUserTask(state, channel))
}
@ -353,8 +351,8 @@ private[sbt] final class CommandExchange {
def prompt(event: ConsolePromptEvent): Unit = {
currentExecRef.set(null)
channels.foreach {
case c if ContinuousCommands.isInWatch(c) =>
case c => c.prompt(event)
case c if ContinuousCommands.isInWatch(lastState.get, c) =>
case c => c.prompt(event)
}
}
def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event))
@ -459,10 +457,9 @@ private[sbt] final class CommandExchange {
Option(currentExecRef.get).foreach(cancel)
mt.channel.prompt(ConsolePromptEvent(lastState.get))
case t if t.startsWith(ContinuousCommands.stopWatch) =>
ContinuousCommands.stopWatchImpl(mt.channel.name)
mt.channel match {
case c: NetworkChannel if !c.isInteractive => exit(mt)
case _ => mt.channel.prompt(ConsolePromptEvent(lastState.get))
case _ =>
}
commandQueue.add(Exec(t, None, None))
case `TerminateAction` => exit(mt)

View File

@ -108,8 +108,8 @@ private[sbt] object Continuous extends DeprecatedContinuous {
case Some(c) => s -> c
case None => StandardMain.exchange.run(s) -> ConsoleChannel.defaultName
}
ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
s"${ContinuousCommands.runWatch} $channel" :: s1
val ws = ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
s"${ContinuousCommands.runWatch} $channel" :: ws
}
@deprecated("The input task version of watch is no longer available", "1.4.0")
@ -1056,7 +1056,7 @@ private[sbt] object Continuous extends DeprecatedContinuous {
val commands: Seq[String],
beforeCommandImpl: (State, mutable.Set[DynamicInput]) => State,
val afterCommand: State => State,
val afterWatch: () => Unit,
val afterWatch: State => State,
val callbacks: Callbacks,
val dynamicInputs: mutable.Set[DynamicInput],
val pending: Boolean,
@ -1102,7 +1102,8 @@ private[sbt] object ContinuousCommands {
"",
Int.MaxValue
)
private[this] val watchStates = new ConcurrentHashMap[String, ContinuousState]
private[this] val watchStates =
AttributeKey[Map[String, ContinuousState]]("sbt-watch-states", Int.MaxValue)
private[sbt] val runWatch = networkExecPrefix + "runWatch"
private[sbt] val preWatch = networkExecPrefix + "preWatch"
private[sbt] val postWatch = networkExecPrefix + "postWatch"
@ -1120,10 +1121,10 @@ private[sbt] object ContinuousCommands {
"",
Int.MaxValue
)
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => Unit =
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => State =
(channelName, count, commands, state) => {
watchStates.get(channelName) match {
case null =>
state.get(watchStates).flatMap(_.get(channelName)) match {
case None =>
val extracted = Project.extract(state)
val repo = state.get(globalFileTreeRepository) match {
case Some(r) => localRepo(r)
@ -1161,27 +1162,37 @@ private[sbt] object ContinuousCommands {
stateWithCache.put(Continuous.DynamicInputs, dynamicInputs)
},
afterCommand = state => {
watchStates.get(channelName) match {
case null =>
case ws => watchStates.put(channelName, ws.incremented)
val newWatchState = state.get(watchStates) match {
case None => state
case Some(ws) =>
ws.get(channelName) match {
case None => state
case Some(cs) => state.put(watchStates, ws + (channelName -> cs.incremented))
}
}
val restoredState = state.get(stashedRepo) match {
val restoredState = newWatchState.get(stashedRepo) match {
case None => throw new IllegalStateException(s"No stashed repository for $state")
case Some(r) => state.put(globalFileTreeRepository, r)
case Some(r) => newWatchState.put(globalFileTreeRepository, r)
}
restoredState.remove(persistentFileStampCache).remove(Continuous.DynamicInputs)
},
afterWatch = () => {
watchStates.remove(channelName)
afterWatch = state => {
LogExchange.unbindLoggerAppenders(channelName + "-watch")
repo.close()
state.get(watchStates) match {
case None => state
case Some(ws) => state.put(watchStates, ws - channelName)
}
},
callbacks = cb,
dynamicInputs = dynamicInputs,
pending = false,
)
Util.ignoreResult(watchStates.put(channelName, s))
case cs =>
state.get(watchStates) match {
case None => state.put(watchStates, Map(channelName -> s))
case Some(ws) => state.put(watchStates, ws + (channelName -> s))
}
case Some(cs) =>
val cmd = cs.commands.mkString("; ")
val msg =
s"Tried to start new watch while channel, '$channelName', was already watching '$cmd'"
@ -1194,28 +1205,26 @@ private[sbt] object ContinuousCommands {
Command.arb { state =>
(cmdParser(name) ~> channelParser).map(channel => () => updateState(channel, state))
} { case (_, newState) => newState() }
private[this] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
watchStates.get(channel) match {
case null => state
case cs =>
private[sbt] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
state.get(watchStates).flatMap(_.get(channel)) match {
case None => state
case Some(cs) =>
val pre = StashOnFailure :: s"$SetTerminal $channel" :: s"$preWatch $channel" :: Nil
val post = FailureWall :: PopOnFailure :: s"$SetTerminal ${ConsoleChannel.defaultName}" ::
s"$postWatch $channel" :: waitWatch :: Nil
s"$postWatch $channel" :: s"$waitWatch $channel" :: Nil
pre ::: cs.commands.toList ::: post ::: state
}
}
private[sbt] def watchUITaskFor(channel: CommandChannel): Option[UITask] =
watchStates.get(channel.name) match {
case null => None
case cs => Some(new WatchUITask(channel, cs))
}
private[sbt] def isInWatch(channel: CommandChannel): Boolean =
watchStates.get(channel.name) != null
private[sbt] def isPending(channel: CommandChannel): Boolean =
Option(watchStates.get(channel.name)).fold(false)(_.pending)
private[sbt] def watchUITaskFor(state: State, channel: CommandChannel): Option[UITask] =
state.get(watchStates).flatMap(_.get(channel.name)).map(new WatchUITask(channel, _, state))
private[sbt] def isInWatch(state: State, channel: CommandChannel): Boolean =
state.get(watchStates).exists(_.contains(channel.name))
private[sbt] def isPending(state: State, channel: CommandChannel): Boolean =
state.get(watchStates).exists(_.get(channel.name).exists(_.pending))
private[this] class WatchUITask(
override private[sbt] val channel: CommandChannel,
cs: ContinuousState,
state: State
) extends Thread(s"sbt-${channel.name}-watch-ui-thread")
with UITask {
override private[sbt] def reader: UITask.Reader = () => {
@ -1229,8 +1238,12 @@ private[sbt] object ContinuousCommands {
recursive = false
)
}
val ws = watchState(channel.name)
watchStates.put(channel.name, ws.withPending(true))
val ws = state.get(watchStates) match {
case None => throw new IllegalStateException("no watch states")
case Some(ws) =>
ws.get(channel.name)
.getOrElse(throw new IllegalStateException(s"no watch state for ${channel.name}"))
}
exitAction match {
// Use a Left so that the client can immediately exit watch via <enter>
case Watch.CancelWatch => Left(s"$stopWatch ${channel.name}")
@ -1248,30 +1261,40 @@ private[sbt] object ContinuousCommands {
}
}
@inline
private[this] def watchState(channel: String): ContinuousState = watchStates.get(channel) match {
case null => throw new IllegalStateException(s"No watch state for $channel")
case s => s
}
private[this] def watchState(state: State, channel: String): ContinuousState =
state.get(watchStates).flatMap(_.get(channel)) match {
case None => throw new IllegalStateException(s"no watch state for $channel")
case Some(s) => s
}
private[this] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
StandardMain.exchange.channelForName(channel).foreach(_.terminal.setPrompt(Prompt.Watch))
watchState(channel).beforeCommand(state)
private[sbt] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
watchState(state, channel).beforeCommand(state)
}
private[this] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
val ws = watchState(channel)
watchStates.put(channel, ws.withPending(false))
ws.afterCommand(state)
private[sbt] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
val cs = watchState(state, channel)
StandardMain.exchange.channelForName(channel).foreach { c =>
c.terminal.setPrompt(Prompt.Watch)
c.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
}
val postState = state.get(watchStates) match {
case None => state
case Some(ws) => state.put(watchStates, ws + (channel -> cs.withPending(false)))
}
cs.afterCommand(postState)
}
private[this] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
stopWatchImpl(channel)
state
}
private[sbt] def stopWatchImpl(channelName: String): Unit = {
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channelName))))
Option(watchStates.get(channelName)).foreach { ws =>
ws.afterWatch()
ws.callbacks.onExit()
private[sbt] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
state.get(watchStates).flatMap(_.get(channel)) match {
case Some(cs) =>
val afterWatchState = cs.afterWatch(state)
cs.callbacks.onExit()
StandardMain.exchange
.channelForName(channel)
.foreach(_.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel)))))
afterWatchState.get(watchStates) match {
case None => afterWatchState
case Some(w) => afterWatchState.put(watchStates, w - channel)
}
case _ => state
}
}
private[this] val failWatchCommand = watchCommand(failWatch) { (channel, state) =>

View File

@ -149,7 +149,7 @@ final class NetworkChannel(
protected def authOptions: Set[ServerAuthentication] = auth
override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => {
if (interactive.get || ContinuousCommands.isInWatch(this)) mkUIThreadImpl(state, command)
if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command)
else
new UITask {
override private[sbt] def channel = NetworkChannel.this
@ -789,7 +789,8 @@ final class NetworkChannel(
override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false)
override def isEchoEnabled: Boolean = waitForPending(_.isEchoEnabled)
override def isSuccessEnabled: Boolean =
prompt != Prompt.Batch || ContinuousCommands.isInWatch(NetworkChannel.this)
prompt != Prompt.Batch ||
StandardMain.exchange.withState(ContinuousCommands.isInWatch(_, NetworkChannel.this))
override lazy val isColorEnabled: Boolean = waitForPending(_.isColorEnabled)
override lazy val isSupershellEnabled: Boolean = waitForPending(_.isSupershellEnabled)
getProperties(false)

View File

@ -1,7 +1,17 @@
import sbt.legacy.sources.Build._
Global / watchSources += new sbt.internal.io.Source(baseDirectory.value, "global.txt", NothingFilter, false)
val setStringValue = inputKey[Unit]("set a global string to a value")
val checkStringValue = inputKey[Unit]("check the value of a global")
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
IO.write(file(stringFile), string)
}
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed
assert(IO.read(file(stringFile)) == string)
}
watchSources in setStringValue += new sbt.internal.io.Source(baseDirectory.value, "foo.txt", NothingFilter, false)
setStringValue := setStringValueImpl.evaluated

View File

@ -1,17 +0,0 @@
package sbt.legacy.sources
import sbt._
import Keys._
object Build {
val setStringValue = inputKey[Unit]("set a global string to a value")
val checkStringValue = inputKey[Unit]("check the value of a global")
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
IO.write(file(stringFile), string)
}
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed
assert(IO.read(file(stringFile)) == string)
}
}