Merge pull request #5723 from eatkins/watch-state-map

Make watch implementation more sbt idiomatic
This commit is contained in:
eugene yokota 2020-08-04 22:17:31 -04:00 committed by GitHub
commit 6c89d7416d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 98 deletions

View File

@ -971,22 +971,23 @@ object BuiltinCommands {
}
private[sbt] def waitCmd: Command =
Command.arb(_ => (ContinuousCommands.waitWatch: Parser[String]).examples()) { (s0, _) =>
Command.arb(
_ => ContinuousCommands.waitWatch.examples() ~> " ".examples() ~> matched(any.*).examples()
) { (s0, channel) =>
val exchange = StandardMain.exchange
if (exchange.channels.exists(ContinuousCommands.isInWatch)) {
val s1 = exchange.run(s0)
exchange.channels.foreach {
case c if ContinuousCommands.isPending(c) =>
case c => c.prompt(ConsolePromptEvent(s1))
}
val exec: Exec = getExec(s1, Duration.Inf)
val remaining: List[Exec] =
Exec(ContinuousCommands.waitWatch, None) ::
Exec(FailureWall, None) :: s1.remainingCommands
val newState = s1.copy(remainingCommands = exec +: remaining)
if (exec.commandLine.trim.isEmpty) newState
else newState.clearGlobalLog
} else s0
exchange.channelForName(channel) match {
case Some(c) if ContinuousCommands.isInWatch(s0, c) =>
c.prompt(ConsolePromptEvent(s0))
val s1 = exchange.run(s0)
val exec: Exec = getExec(s1, Duration.Inf)
val remaining: List[Exec] =
Exec(s"${ContinuousCommands.waitWatch} $channel", None) ::
Exec(FailureWall, None) :: s1.remainingCommands
val newState = s1.copy(remainingCommands = exec +: remaining)
if (exec.commandLine.trim.isEmpty) newState
else newState.clearGlobalLog
case _ => s0
}
}
private[sbt] def promptChannel = Command.arb(_ => reportParser(PromptChannel)) {

View File

@ -170,17 +170,15 @@ private[sbt] final class CommandExchange {
currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e =>
Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0")))
}
if (ContinuousCommands.isInWatch(c)) {
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
catch { case _: InterruptedException => }
}
try commandQueue.put(Exec(s"${ContinuousCommands.stopWatch} ${c.name}", None))
catch { case _: InterruptedException => }
}
private[this] def mkAskUser(
name: String,
): (State, CommandChannel) => UITask = { (state, channel) =>
ContinuousCommands
.watchUITaskFor(channel)
.watchUITaskFor(state, channel)
.getOrElse(new UITask.AskUserTask(state, channel))
}
@ -353,8 +351,8 @@ private[sbt] final class CommandExchange {
def prompt(event: ConsolePromptEvent): Unit = {
currentExecRef.set(null)
channels.foreach {
case c if ContinuousCommands.isInWatch(c) =>
case c => c.prompt(event)
case c if ContinuousCommands.isInWatch(lastState.get, c) =>
case c => c.prompt(event)
}
}
def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event))
@ -459,10 +457,9 @@ private[sbt] final class CommandExchange {
Option(currentExecRef.get).foreach(cancel)
mt.channel.prompt(ConsolePromptEvent(lastState.get))
case t if t.startsWith(ContinuousCommands.stopWatch) =>
ContinuousCommands.stopWatchImpl(mt.channel.name)
mt.channel match {
case c: NetworkChannel if !c.isInteractive => exit(mt)
case _ => mt.channel.prompt(ConsolePromptEvent(lastState.get))
case _ =>
}
commandQueue.add(Exec(t, None, None))
case `TerminateAction` => exit(mt)

View File

@ -108,8 +108,8 @@ private[sbt] object Continuous extends DeprecatedContinuous {
case Some(c) => s -> c
case None => StandardMain.exchange.run(s) -> ConsoleChannel.defaultName
}
ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
s"${ContinuousCommands.runWatch} $channel" :: s1
val ws = ContinuousCommands.setupWatchState(channel, initialCount, commands, s1)
s"${ContinuousCommands.runWatch} $channel" :: ws
}
@deprecated("The input task version of watch is no longer available", "1.4.0")
@ -1056,7 +1056,7 @@ private[sbt] object Continuous extends DeprecatedContinuous {
val commands: Seq[String],
beforeCommandImpl: (State, mutable.Set[DynamicInput]) => State,
val afterCommand: State => State,
val afterWatch: () => Unit,
val afterWatch: State => State,
val callbacks: Callbacks,
val dynamicInputs: mutable.Set[DynamicInput],
val pending: Boolean,
@ -1102,7 +1102,8 @@ private[sbt] object ContinuousCommands {
"",
Int.MaxValue
)
private[this] val watchStates = new ConcurrentHashMap[String, ContinuousState]
private[this] val watchStates =
AttributeKey[Map[String, ContinuousState]]("sbt-watch-states", Int.MaxValue)
private[sbt] val runWatch = networkExecPrefix + "runWatch"
private[sbt] val preWatch = networkExecPrefix + "preWatch"
private[sbt] val postWatch = networkExecPrefix + "postWatch"
@ -1120,10 +1121,10 @@ private[sbt] object ContinuousCommands {
"",
Int.MaxValue
)
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => Unit =
private[sbt] val setupWatchState: (String, Int, Seq[String], State) => State =
(channelName, count, commands, state) => {
watchStates.get(channelName) match {
case null =>
state.get(watchStates).flatMap(_.get(channelName)) match {
case None =>
val extracted = Project.extract(state)
val repo = state.get(globalFileTreeRepository) match {
case Some(r) => localRepo(r)
@ -1161,27 +1162,37 @@ private[sbt] object ContinuousCommands {
stateWithCache.put(Continuous.DynamicInputs, dynamicInputs)
},
afterCommand = state => {
watchStates.get(channelName) match {
case null =>
case ws => watchStates.put(channelName, ws.incremented)
val newWatchState = state.get(watchStates) match {
case None => state
case Some(ws) =>
ws.get(channelName) match {
case None => state
case Some(cs) => state.put(watchStates, ws + (channelName -> cs.incremented))
}
}
val restoredState = state.get(stashedRepo) match {
val restoredState = newWatchState.get(stashedRepo) match {
case None => throw new IllegalStateException(s"No stashed repository for $state")
case Some(r) => state.put(globalFileTreeRepository, r)
case Some(r) => newWatchState.put(globalFileTreeRepository, r)
}
restoredState.remove(persistentFileStampCache).remove(Continuous.DynamicInputs)
},
afterWatch = () => {
watchStates.remove(channelName)
afterWatch = state => {
LogExchange.unbindLoggerAppenders(channelName + "-watch")
repo.close()
state.get(watchStates) match {
case None => state
case Some(ws) => state.put(watchStates, ws - channelName)
}
},
callbacks = cb,
dynamicInputs = dynamicInputs,
pending = false,
)
Util.ignoreResult(watchStates.put(channelName, s))
case cs =>
state.get(watchStates) match {
case None => state.put(watchStates, Map(channelName -> s))
case Some(ws) => state.put(watchStates, ws + (channelName -> s))
}
case Some(cs) =>
val cmd = cs.commands.mkString("; ")
val msg =
s"Tried to start new watch while channel, '$channelName', was already watching '$cmd'"
@ -1194,28 +1205,26 @@ private[sbt] object ContinuousCommands {
Command.arb { state =>
(cmdParser(name) ~> channelParser).map(channel => () => updateState(channel, state))
} { case (_, newState) => newState() }
private[this] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
watchStates.get(channel) match {
case null => state
case cs =>
private[sbt] val runWatchCommand = watchCommand(runWatch) { (channel, state) =>
state.get(watchStates).flatMap(_.get(channel)) match {
case None => state
case Some(cs) =>
val pre = StashOnFailure :: s"$SetTerminal $channel" :: s"$preWatch $channel" :: Nil
val post = FailureWall :: PopOnFailure :: s"$SetTerminal ${ConsoleChannel.defaultName}" ::
s"$postWatch $channel" :: waitWatch :: Nil
s"$postWatch $channel" :: s"$waitWatch $channel" :: Nil
pre ::: cs.commands.toList ::: post ::: state
}
}
private[sbt] def watchUITaskFor(channel: CommandChannel): Option[UITask] =
watchStates.get(channel.name) match {
case null => None
case cs => Some(new WatchUITask(channel, cs))
}
private[sbt] def isInWatch(channel: CommandChannel): Boolean =
watchStates.get(channel.name) != null
private[sbt] def isPending(channel: CommandChannel): Boolean =
Option(watchStates.get(channel.name)).fold(false)(_.pending)
private[sbt] def watchUITaskFor(state: State, channel: CommandChannel): Option[UITask] =
state.get(watchStates).flatMap(_.get(channel.name)).map(new WatchUITask(channel, _, state))
private[sbt] def isInWatch(state: State, channel: CommandChannel): Boolean =
state.get(watchStates).exists(_.contains(channel.name))
private[sbt] def isPending(state: State, channel: CommandChannel): Boolean =
state.get(watchStates).exists(_.get(channel.name).exists(_.pending))
private[this] class WatchUITask(
override private[sbt] val channel: CommandChannel,
cs: ContinuousState,
state: State
) extends Thread(s"sbt-${channel.name}-watch-ui-thread")
with UITask {
override private[sbt] def reader: UITask.Reader = () => {
@ -1229,8 +1238,12 @@ private[sbt] object ContinuousCommands {
recursive = false
)
}
val ws = watchState(channel.name)
watchStates.put(channel.name, ws.withPending(true))
val ws = state.get(watchStates) match {
case None => throw new IllegalStateException("no watch states")
case Some(ws) =>
ws.get(channel.name)
.getOrElse(throw new IllegalStateException(s"no watch state for ${channel.name}"))
}
exitAction match {
// Use a Left so that the client can immediately exit watch via <enter>
case Watch.CancelWatch => Left(s"$stopWatch ${channel.name}")
@ -1248,30 +1261,40 @@ private[sbt] object ContinuousCommands {
}
}
@inline
private[this] def watchState(channel: String): ContinuousState = watchStates.get(channel) match {
case null => throw new IllegalStateException(s"No watch state for $channel")
case s => s
}
private[this] def watchState(state: State, channel: String): ContinuousState =
state.get(watchStates).flatMap(_.get(channel)) match {
case None => throw new IllegalStateException(s"no watch state for $channel")
case Some(s) => s
}
private[this] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
StandardMain.exchange.channelForName(channel).foreach(_.terminal.setPrompt(Prompt.Watch))
watchState(channel).beforeCommand(state)
private[sbt] val preWatchCommand = watchCommand(preWatch) { (channel, state) =>
watchState(state, channel).beforeCommand(state)
}
private[this] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
val ws = watchState(channel)
watchStates.put(channel, ws.withPending(false))
ws.afterCommand(state)
private[sbt] val postWatchCommand = watchCommand(postWatch) { (channel, state) =>
val cs = watchState(state, channel)
StandardMain.exchange.channelForName(channel).foreach { c =>
c.terminal.setPrompt(Prompt.Watch)
c.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel))))
}
val postState = state.get(watchStates) match {
case None => state
case Some(ws) => state.put(watchStates, ws + (channel -> cs.withPending(false)))
}
cs.afterCommand(postState)
}
private[this] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
stopWatchImpl(channel)
state
}
private[sbt] def stopWatchImpl(channelName: String): Unit = {
StandardMain.exchange.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channelName))))
Option(watchStates.get(channelName)).foreach { ws =>
ws.afterWatch()
ws.callbacks.onExit()
private[sbt] val stopWatchCommand = watchCommand(stopWatch) { (channel, state) =>
state.get(watchStates).flatMap(_.get(channel)) match {
case Some(cs) =>
val afterWatchState = cs.afterWatch(state)
cs.callbacks.onExit()
StandardMain.exchange
.channelForName(channel)
.foreach(_.unprompt(ConsoleUnpromptEvent(Some(CommandSource(channel)))))
afterWatchState.get(watchStates) match {
case None => afterWatchState
case Some(w) => afterWatchState.put(watchStates, w - channel)
}
case _ => state
}
}
private[this] val failWatchCommand = watchCommand(failWatch) { (channel, state) =>

View File

@ -149,7 +149,7 @@ final class NetworkChannel(
protected def authOptions: Set[ServerAuthentication] = auth
override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => {
if (interactive.get || ContinuousCommands.isInWatch(this)) mkUIThreadImpl(state, command)
if (interactive.get || ContinuousCommands.isInWatch(state, this)) mkUIThreadImpl(state, command)
else
new UITask {
override private[sbt] def channel = NetworkChannel.this
@ -789,7 +789,8 @@ final class NetworkChannel(
override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false)
override def isEchoEnabled: Boolean = waitForPending(_.isEchoEnabled)
override def isSuccessEnabled: Boolean =
prompt != Prompt.Batch || ContinuousCommands.isInWatch(NetworkChannel.this)
prompt != Prompt.Batch ||
StandardMain.exchange.withState(ContinuousCommands.isInWatch(_, NetworkChannel.this))
override lazy val isColorEnabled: Boolean = waitForPending(_.isColorEnabled)
override lazy val isSupershellEnabled: Boolean = waitForPending(_.isSupershellEnabled)
getProperties(false)

View File

@ -1,7 +1,17 @@
import sbt.legacy.sources.Build._
Global / watchSources += new sbt.internal.io.Source(baseDirectory.value, "global.txt", NothingFilter, false)
val setStringValue = inputKey[Unit]("set a global string to a value")
val checkStringValue = inputKey[Unit]("check the value of a global")
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
IO.write(file(stringFile), string)
}
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed
assert(IO.read(file(stringFile)) == string)
}
watchSources in setStringValue += new sbt.internal.io.Source(baseDirectory.value, "foo.txt", NothingFilter, false)
setStringValue := setStringValueImpl.evaluated

View File

@ -1,17 +0,0 @@
package sbt.legacy.sources
import sbt._
import Keys._
object Build {
val setStringValue = inputKey[Unit]("set a global string to a value")
val checkStringValue = inputKey[Unit]("check the value of a global")
def setStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed.map(_.trim)
IO.write(file(stringFile), string)
}
def checkStringValueImpl: Def.Initialize[InputTask[Unit]] = Def.inputTask {
val Seq(stringFile, string) = Def.spaceDelimited().parsed
assert(IO.read(file(stringFile)) == string)
}
}