mirror of https://github.com/sbt/sbt.git
Make watch implementation more sbt idiomatic
The 1.4.0 implementation of watch uses a concurrent hash map to maintain the global watch state which manages the state for an arbitrary number of clients. Using a mutable map is not idiomatic sbt and I found it difficult to reason about when the map was updated. This commit reworks the feature so that the global state is instead stored in an immutable map that is only modified during the internal watch commands, which is easier to reason about.
This commit is contained in:
parent
e76f61bec5
commit
eb48f24f3a
|
|
@ -971,22 +971,23 @@ object BuiltinCommands {
|
|||
}
|
||||
|
||||
private[sbt] def waitCmd: Command =
|
||||
Command.arb(_ => (ContinuousCommands.waitWatch: Parser[String]).examples()) { (s0, _) =>
|
||||
Command.arb(
|
||||
_ => ContinuousCommands.waitWatch.examples() ~> " ".examples() ~> matched(any.*).examples()
|
||||
) { (s0, channel) =>
|
||||
val exchange = StandardMain.exchange
|
||||
if (exchange.channels.exists(ContinuousCommands.isInWatch)) {
|
||||
val s1 = exchange.run(s0)
|
||||
exchange.channels.foreach {
|
||||
case c if ContinuousCommands.isPending(c) =>
|
||||
case c => c.prompt(ConsolePromptEvent(s1))
|
||||
}
|
||||
val exec: Exec = getExec(s1, Duration.Inf)
|
||||
val remaining: List[Exec] =
|
||||
Exec(ContinuousCommands.waitWatch, None) ::
|
||||
Exec(FailureWall, None) :: s1.remainingCommands
|
||||
val newState = s1.copy(remainingCommands = exec +: remaining)
|
||||
if (exec.commandLine.trim.isEmpty) newState
|
||||
else newState.clearGlobalLog
|
||||
} else s0
|
||||
exchange.channelForName(channel) match {
|
||||
case Some(c) if ContinuousCommands.isInWatch(s0, c) =>
|
||||
c.prompt(ConsolePromptEvent(s0))
|
||||
val s1 = exchange.run(s0)
|
||||
val exec: Exec = getExec(s1, Duration.Inf)
|
||||
val remaining: List[Exec] =
|
||||
Exec(s"${ContinuousCommands.waitWatch} $channel", None) ::
|
||||
Exec(FailureWall, None) :: s1.remainingCommands
|
||||
val newState = s1.copy(remainingCommands = exec +: remaining)
|
||||
if (exec.commandLine.trim.isEmpty) newState
|
||||
else newState.clearGlobalLog
|
||||
case _ => s0
|
||||
}
|
||||
}
|
||||
|
||||
private[sbt] def promptChannel = Command.arb(_ => reportParser(PromptChannel)) {
|
||||
|
|
|
|||
|
|
@ -170,17 +170,15 @@ private[sbt] final class CommandExchange {
|
|||
currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e =>
|
||||
Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0")))
|
||||
}
|
||||
if (ContinuousCommands.isInWatch(c)) {
|
||||
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
|
||||
catch { case _: InterruptedException => }
|
||||
}
|
||||
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
|
||||
catch { case _: InterruptedException => }
|
||||
}
|
||||
|
||||
private[this] def mkAskUser(
|
||||
name: String,
|
||||
): (State, CommandChannel) => UITask = { (state, channel) =>
|
||||
ContinuousCommands
|
||||
.watchUITaskFor(channel)
|
||||
.watchUITaskFor(state, channel)
|
||||
.getOrElse(new UITask.AskUserTask(state, channel))
|
||||
}
|
||||
|
||||
|
|
@ -353,8 +351,8 @@ private[sbt] final class CommandExchange {
|
|||
def prompt(event: ConsolePromptEvent): Unit = {
|
||||
currentExecRef.set(null)
|
||||
channels.foreach {
|
||||
case c if ContinuousCommands.isInWatch(c) =>
|
||||
case c => c.prompt(event)
|
||||
case c if ContinuousCommands.isInWatch(lastState.get, c) =>
|
||||
case c => c.prompt(event)
|
||||
}
|
||||
}
|
||||
def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event))
|
||||
|
|
@ -459,10 +457,9 @@ private[sbt] final class CommandExchange {
|
|||
Option(currentExecRef.get).foreach(cancel)
|
||||
mt.channel.prompt(ConsolePromptEvent(lastState.get))
|
||||
case t if t.startsWith(ContinuousCommands.stopWatch) =>
|
||||
ContinuousCommands.stopWatchImpl(mt.channel.name)
|
||||
mt.channel match {
|
||||
case c: NetworkChannel if !c.isInteractive => exit(mt)
|
||||
case _ => mt.channel.prompt(ConsolePromptEvent(lastState.get))
|
||||
case _ =>
|
||||
}
|
||||
commandQueue.add(Exec(t, None, None))
|
||||
case `TerminateAction` => exit(mt)
|
||||
|
|
|
|||
|
|
@ -108,8 +108,8 @@ private[sbt] object Continuous extends DeprecatedContinuous {
|
|||
case Some(c) => s -> c
|
||||
case None => StandardMain.exchange.run(s) -> ConsoleChannel.defaultName
|
||||
}
|
||||
ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
|
||||
s"${ContinuousCommands.runWatch} $channel" :: s1
|
||||
val ws = ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
|
||||
s"${ContinuousCommands.runWatch} $channel" :: ws
|
||||
}
|
||||
|
||||
@deprecated("The input task version of watch is no longer available", "1.4.0")
|
||||
|
|
@ -1056,7 +1056,7 @@ private[sbt] object Continuous extends DeprecatedContinuous {
|
|||
val commands: Seq[String],
|
||||
beforeCommandImpl: (State, mutable.Set[DynamicInput]) => State,
|
||||
val afterCommand: State => State,
|
||||
val afterWatch: () => Unit,
|
||||
val afterWatch: State => State,
|
||||
val callbacks: Callbacks,
|
||||
val dynamicInputs: mutable.Set[DynamicInput],
|
||||
val pending: Boolean,
|
||||
|
|
@ -1102,7 +1102,8 @@ private[sbt] object ContinuousCommands {
|
|||
"",
|
||||
Int.MaxValue
|
||||
)
|
||||
private[this] val watchStates = new ConcurrentHashMap[String, ContinuousState]
|
||||
private[this] val watchStates =
|
||||
AttributeKey[Map[String, ContinuousState]]("sbt-watch-states", Int.MaxValue)
|
||||
private[sbt] val runWatch = networkExecPrefix + "runWatch"
|
||||
private[sbt] val preWatch = networkExecPrefix + "preWatch"
|
||||
private[sbt] val postWatch = networkExecPrefix + "postWatch"
|
||||
|
|
@ -1120,10 +1121,10 @@ private[sbt] object ContinuousCommands {
|
|||
"",
|
||||
Int.MaxValue
|
||||
)
|
||||
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => Unit =
|
||||
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => State =
|
||||
(channelName, count, commands, state) => {
|
||||
watchStates.get(channelName) match {
|
||||
case null =>
|
||||
state.get(watchStates).flatMap(_.get(channelName)) match {
|
||||
case None =>
|
||||
val extracted = Project.extract(state)
|
||||
val repo = state.get(globalFileTreeRepository) match {
|
||||
case Some(r) => localRepo(r)
|
||||
|
|
@ -1161,27 +1162,37 @@ private[sbt] object ContinuousCommands {
|
|||
stateWithCache.put(Continuous.DynamicInputs, dynamicInputs)
|
||||
},
|
||||
afterCommand = state => {
|
||||
watchStates.get(channelName) match {
|
||||
case null =>
|
||||
case ws => watchStates.put(channelName, ws.incremented)
|
||||
val newWatchState = state.get(watchStates) match {
|
||||
case None => state
|
||||
case Some(ws) =>
|
||||
ws.get(channelName) match {
|
||||
case None => state
|
||||
case Some(cs) => state.put(watchStates, ws + (channelName -> cs.incremented))
|
||||
}
|
||||
}
|
||||
val restoredState = state.get(stashedRepo) match {
|
||||
val restoredState = newWatchState.get(stashedRepo) match {
|
||||
case None => throw new IllegalStateException(s"No stashed repository for $state")
|
||||
case Some(r) => state.put(globalFileTreeRepository, r)
|
||||
case Some(r) => newWatchState.put(globalFileTreeRepository, r)
|
||||
}
|
||||
restoredState.remove(persistentFileStampCache).remove(Continuous.DynamicInputs)
|
||||
},
|
||||
afterWatch = () => {
|
||||
watchStates.remove(channelName)
|
||||
afterWatch = state => {
|
||||
LogExchange.unbindLoggerAppenders(channelName + "-watch")
|
||||
repo.close()
|
||||
state.get(watchStates) match {
|
||||
case None => state
|
||||
case Some(ws) => state.put(watchStates, ws - channelName)
|
||||
}
|
||||
},
|
||||
callbacks = cb,
|
||||
dynamicInputs = dynamicInputs,
|
||||
pending = false,
|
||||
)
|
||||
Util.ignoreResult(watchStates.put(channelName, s))
|
||||
case cs =>
|
||||
state.get(watchStates) match {
|
||||
case None => state.put(watchStates, Map(channelName -> s))
|
||||
case Some(ws) => state.put(watchStates, ws + (channelName -> s))
|
||||
}
|
||||
case Some(cs) =>
|
||||
val cmd = cs.commands.mkString("; ")
|
||||
val msg =
|
||||
s"Tried to start new watch while channel, '$channelName', was already watching '$cmd'"
|
||||
|
|
@ -1194,28 +1205,26 @@ private[sbt] object ContinuousCommands {
|
|||
Command.arb { state =>
|
||||
(cmdParser(name) ~> channelParser).map(channel => () => updateState(channel, state))
|
||||
} { case (_, newState) => newState() }
|
||||
private[this] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
|
||||
watchStates.get(channel) match {
|
||||
case null => state
|
||||
case cs =>
|
||||
private[sbt] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
|
||||
state.get(watchStates).flatMap(_.get(channel)) match {
|
||||
case None => state
|
||||
case Some(cs) =>
|
||||
val pre = StashOnFailure :: s"$SetTerminal $channel" :: s"$preWatch $channel" :: Nil
|
||||
val post = FailureWall :: PopOnFailure :: s"$SetTerminal ${ConsoleChannel.defaultName}" ::
|
||||
s"$postWatch $channel" :: waitWatch :: Nil
|
||||
s"$postWatch $channel" :: s"$waitWatch $channel" :: Nil
|
||||
pre ::: cs.commands.toList ::: post ::: state
|
||||
}
|
||||
}
|
||||
private[sbt] def watchUITaskFor(channel: CommandChannel): Option[UITask] =
|
||||
watchStates.get(channel.name) match {
|
||||
case null => None
|
||||
case cs => Some(new WatchUITask(channel, cs))
|
||||
}
|
||||
private[sbt] def isInWatch(channel: CommandChannel): Boolean =
|
||||
watchStates.get(channel.name) != null
|
||||
private[sbt] def isPending(channel: CommandChannel): Boolean =
|
||||
Option(watchStates.get(channel.name)).fold(false)(_.pending)
|
||||
private[sbt] def watchUITaskFor(state: State, channel: CommandChannel): Option[UITask] =
|
||||
state.get(watchStates).flatMap(_.get(channel.name)).map(new WatchUITask(channel, _, state))
|
||||
private[sbt] def isInWatch(state: State, channel: CommandChannel): Boolean =
|
||||
state.get(watchStates).exists(_.contains(channel.name))
|
||||
private[sbt] def isPending(state: State, channel: CommandChannel): Boolean =
|
||||
state.get(watchStates).exists(_.get(channel.name).exists(_.pending))
|
||||
private[this] class WatchUITask(
|
||||
override private[sbt] val channel: CommandChannel,
|
||||
cs: ContinuousState,
|
||||
state: State
|
||||
) extends Thread(s"sbt-${channel.name}-watch-ui-thread")
|
||||
with UITask {
|
||||
override private[sbt] def reader: UITask.Reader = () => {
|
||||
|
|
@ -1229,8 +1238,12 @@ private[sbt] object ContinuousCommands {
|
|||
recursive = false
|
||||
)
|
||||
}
|
||||
val ws = watchState(channel.name)
|
||||
watchStates.put(channel.name, ws.withPending(true))
|
||||
val ws = state.get(watchStates) match {
|
||||
case None => throw new IllegalStateException("no watch states")
|
||||
case Some(ws) =>
|
||||
ws.get(channel.name)
|
||||
.getOrElse(throw new IllegalStateException(s"no watch state for ${channel.name}"))
|
||||
}
|
||||
exitAction match {
|
||||
// Use a Left so that the client can immediately exit watch via <enter>
|
||||
case Watch.CancelWatch => Left(s"$stopWatch ${channel.name}")
|
||||
|
|
@ -1248,30 +1261,40 @@ private[sbt] object ContinuousCommands {
|
|||
}
|
||||
}
|
||||
@inline
|
||||
private[this] def watchState(channel: String): ContinuousState = watchStates.get(channel) match {
|
||||
case null => throw new IllegalStateException(s"No watch state for $channel")
|
||||
case s => s
|
||||
}
|
||||
private[this] def watchState(state: State, channel: String): ContinuousState =
|
||||
state.get(watchStates).flatMap(_.get(channel)) match {
|
||||
case None => throw new IllegalStateException(s"no watch state for $channel")
|
||||
case Some(s) => s
|
||||
}
|
||||
|
||||
private[this] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
|
||||
StandardMain.exchange.channelForName(channel).foreach(_.terminal.setPrompt(Prompt.Watch))
|
||||
watchState(channel).beforeCommand(state)
|
||||
private[sbt] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
|
||||
watchState(state, channel).beforeCommand(state)
|
||||
}
|
||||
private[this] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
|
||||
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
|
||||
val ws = watchState(channel)
|
||||
watchStates.put(channel, ws.withPending(false))
|
||||
ws.afterCommand(state)
|
||||
private[sbt] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
|
||||
val cs = watchState(state, channel)
|
||||
StandardMain.exchange.channelForName(channel).foreach { c =>
|
||||
c.terminal.setPrompt(Prompt.Watch)
|
||||
c.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
|
||||
}
|
||||
val postState = state.get(watchStates) match {
|
||||
case None => state
|
||||
case Some(ws) => state.put(watchStates, ws + (channel -> cs.withPending(false)))
|
||||
}
|
||||
cs.afterCommand(postState)
|
||||
}
|
||||
private[this] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
|
||||
stopWatchImpl(channel)
|
||||
state
|
||||
}
|
||||
private[sbt] def stopWatchImpl(channelName: String): Unit = {
|
||||
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channelName))))
|
||||
Option(watchStates.get(channelName)).foreach { ws =>
|
||||
ws.afterWatch()
|
||||
ws.callbacks.onExit()
|
||||
private[sbt] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
|
||||
state.get(watchStates).flatMap(_.get(channel)) match {
|
||||
case Some(cs) =>
|
||||
val afterWatchState = cs.afterWatch(state)
|
||||
cs.callbacks.onExit()
|
||||
StandardMain.exchange
|
||||
.channelForName(channel)
|
||||
.foreach(_.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel)))))
|
||||
afterWatchState.get(watchStates) match {
|
||||
case None => afterWatchState
|
||||
case Some(w) => afterWatchState.put(watchStates, w - channel)
|
||||
}
|
||||
case _ => state
|
||||
}
|
||||
}
|
||||
private[this] val failWatchCommand = watchCommand(failWatch) { (channel, state) =>
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ final class NetworkChannel(
|
|||
protected def authOptions: Set[ServerAuthentication] = auth
|
||||
|
||||
override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => {
|
||||
if (interactive.get || ContinuousCommands.isInWatch(this)) mkUIThreadImpl(state, command)
|
||||
if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command)
|
||||
else
|
||||
new UITask {
|
||||
override private[sbt] def channel = NetworkChannel.this
|
||||
|
|
@ -789,7 +789,8 @@ final class NetworkChannel(
|
|||
override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false)
|
||||
override def isEchoEnabled: Boolean = waitForPending(_.isEchoEnabled)
|
||||
override def isSuccessEnabled: Boolean =
|
||||
prompt != Prompt.Batch || ContinuousCommands.isInWatch(NetworkChannel.this)
|
||||
prompt != Prompt.Batch ||
|
||||
StandardMain.exchange.withState(ContinuousCommands.isInWatch(_, NetworkChannel.this))
|
||||
override lazy val isColorEnabled: Boolean = waitForPending(_.isColorEnabled)
|
||||
override lazy val isSupershellEnabled: Boolean = waitForPending(_.isSupershellEnabled)
|
||||
getProperties(false)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,17 @@
|
|||
import sbt.legacy.sources.Build._
|
||||
|
||||
Global / watchSources += new sbt.internal.io.Source(baseDirectory.value, "global.txt", NothingFilter, false)
|
||||
|
||||
val setStringValue = inputKey[Unit]("set a global string to a value")
|
||||
val checkStringValue = inputKey[Unit]("check the value of a global")
|
||||
|
||||
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
|
||||
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
|
||||
IO.write(file(stringFile), string)
|
||||
}
|
||||
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
|
||||
val Seq(stringFile, string) = Def.spaceDelimited().parsed
|
||||
assert(IO.read(file(stringFile)) == string)
|
||||
}
|
||||
|
||||
watchSources in setStringValue += new sbt.internal.io.Source(baseDirectory.value, "foo.txt", NothingFilter, false)
|
||||
|
||||
setStringValue := setStringValueImpl.evaluated
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
package sbt.legacy.sources
|
||||
|
||||
import sbt._
|
||||
import Keys._
|
||||
|
||||
object Build {
|
||||
val setStringValue = inputKey[Unit]("set a global string to a value")
|
||||
val checkStringValue = inputKey[Unit]("check the value of a global")
|
||||
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
|
||||
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
|
||||
IO.write(file(stringFile), string)
|
||||
}
|
||||
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
|
||||
val Seq(stringFile, string) = Def.spaceDelimited().parsed
|
||||
assert(IO.read(file(stringFile)) == string)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue