From 18cb839c476b81d21e0ec9255031473848cab67d Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 24 Jun 2020 15:43:31 -0700 Subject: [PATCH] Wrap network commands for reporting Running multi commands (input commands delimited by semi-colons) did not work with the thin client. The commands would actually run on the server, but the thin client would exit immediately without displaying the output. The reason was that MainLoop would report the exec complete when all it had done was split the original command into its constituent parts and prepended them to the state command list. To work around this, when we detect a network source command, we can remap its exec id to a different id and only report the original exec id after the commands complete. We also have to keep track of whether or not the command succeeded or failed so that the reporting command reports the correct result. The way its implemented is with the the following steps: 1. set the terminal to the network terminal 2. stash the current onFailure so that we can properly report failures 3. add the new exec id to a map of the original exec id to the generated id 4. actually run the command 5. if the command succeeds, add the original exec id to a result map 6. pop the onFailure 7. restore the terminal to console 8. report the result -- if the original exec id is in the result map we report success. Otherwise we report failure. There is also logic in NetworkChannel for finding the original exec id if reporting one of the artificially generated exec ids because the client will not be aware of that id. --- .../main/scala/sbt/BasicCommandStrings.scala | 5 ++ .../src/main/scala/sbt/BasicCommands.scala | 41 ++++++++++++++ main-command/src/main/scala/sbt/State.scala | 48 +++++++++++++++- .../sbt/internal/client/NetworkClient.scala | 52 ++++++++++++++++++ main/src/main/scala/sbt/Main.scala | 9 ++- main/src/main/scala/sbt/MainLoop.scala | 9 ++- .../sbt/internal/nio/CheckBuildSources.scala | 14 +++-- .../sbt/internal/server/NetworkChannel.scala | 55 ++++++++++++++----- .../src/test/scala/testpkg/ClientTest.scala | 46 ++++++++++++++++ 9 files changed, 256 insertions(+), 23 deletions(-) create mode 100644 server-test/src/test/scala/testpkg/ClientTest.scala diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 939b95e9e..77a3b6fc3 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -209,6 +209,11 @@ $AliasCommand name= def FailureWall: String = "resumeFromFailure" + def SetTerminal = "sbtSetTerminal" + def ReportResult = "sbtReportResult" + def CompleteExec = "sbtCompleteExec" + def MapExec = "sbtMapExec" + def ClearOnFailure: String = "sbtClearOnFailure" def OnFailure: String = "onFailure" def OnFailureDetailed: String = diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index cba9c73c1..594ebcca1 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -61,6 +61,9 @@ object BasicCommands { client, read, alias, + reportResultsCommand, + mapExecCommand, + completeExecCommand, ) def nop: Command = Command.custom(s => success(() => s)) @@ -544,4 +547,42 @@ object BasicCommands { "is-command-alias", "Internal: marker for Commands created as aliases for another command." ) + + private[sbt] def reportParser(key: String) = + (key: Parser[String]).examples() ~> " ".examples() ~> matched(any.*).examples() + def reportResultsCommand = + Command.arb(_ => reportParser(ReportResult)) { (state, id) => + val newState = state.get(execMap) match { + case Some(m) => state.put(execMap, m - id) + case _ => state + } + newState.get(execResults) match { + case Some(m) if m.contains(id) => state.put(execResults, m - id) + case _ => state.fail + } + } + def mapExecCommand = + Command.arb(_ => reportParser(MapExec)) { (state, mapping) => + mapping.split(" ") match { + case Array(key, value) => + state.get(execMap) match { + case Some(m) => state.put(execMap, m + (key -> value)) + case None => state.put(execMap, Map(key -> value)) + } + case _ => state + } + } + def completeExecCommand = + Command.arb(_ => reportParser(CompleteExec)) { (state, id) => + val newState = state.get(execResults) match { + case Some(m) => state.put(execResults, m + (id -> true)) + case _ => state.put(execResults, Map(id -> true)) + } + newState.get(execMap) match { + case Some(m) => newState.put(execMap, m - id) + case _ => newState + } + } + private[sbt] val execResults = AttributeKey[Map[String, Boolean]]("execResults", Int.MaxValue) + private[sbt] val execMap = AttributeKey[Map[String, String]]("execMap", Int.MaxValue) } diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index 784c0f162..fa75749a7 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -16,6 +16,15 @@ import sbt.internal.inc.classpath.{ ClassLoaderCache => IncClassLoaderCache } import sbt.internal.util.complete.{ HistoryCommands, Parser } import sbt.internal.util._ import sbt.util.Logger +import BasicCommandStrings.{ + CompleteExec, + MapExec, + PopOnFailure, + ReportResult, + SetTerminal, + StashOnFailure, + networkExecPrefix, +} /** * Data structure representing all command execution information. @@ -273,8 +282,43 @@ object State { f(cmd, s1) } s.remainingCommands match { - case Nil => exit(true) - case x :: xs => runCmd(x, xs) + case Nil => exit(true) + case x :: xs => + (x.execId, x.source) match { + /* + * If the command is coming from a network source, it might be a multi-command. To handle + * that, we need to give the command a new exec id and wrap some commands around the + * actual command that are used to report it. To make this work, we add a map of exec + * results as well as a mapping of exec ids to the exec id that spawned the exec. + * We add a command that fills the result map for the original exec. If the command fails, + * that map filling command (called sbtCompleteExec) is skipped so the map is never filled + * for the original event. The report command (called sbtReportResult) checks the result + * map and, if it finds an entry, it succeeds and removes the entry. Otherwise it fails. + * The exec for the report command is given the original exec id so the result reported + * to the client will be the result of the report command (which should correspond to + * the result of the underlying multi-command, which succeeds only if all of the commands + * succeed) + * + */ + case (Some(id), Some(s)) + if s.channelName.startsWith("network") && + !x.commandLine.startsWith(ReportResult) && + !x.commandLine.startsWith(networkExecPrefix) && + !id.startsWith(networkExecPrefix) => + val newID = networkExecPrefix + Exec.newExecId + val cmd = x.withExecId(newID) + val map = Exec(s"$MapExec $id $newID", None) + val complete = Exec(s"$CompleteExec $id", None) + val report = Exec(s"$ReportResult $id", Some(id), x.source) + val stash = Exec(StashOnFailure, None) + val failureWall = Exec(FailureWall, None) + val pop = Exec(PopOnFailure, None) + val setTerm = Exec(s"$SetTerminal ${s.channelName}", None) + val setConsole = Exec(s"$SetTerminal console0", None) + val remaining = stash :: map :: cmd :: complete :: failureWall :: pop :: setConsole :: report :: xs + runCmd(setTerm, remaining) + case _ => runCmd(x, xs) + } } } def :::(newCommands: List[String]): State = ++:(newCommands map { Exec(_, s.source) }) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 520a4e593..10634bb83 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -641,6 +641,58 @@ object NetworkClient { new Arguments(new File("").getCanonicalFile, sbtArguments, commandArgs, sbtScript) } + def client( + baseDirectory: File, + args: Array[String], + inputStream: InputStream, + errorStream: PrintStream, + printStream: PrintStream, + useJNI: Boolean + ): Int = { + val client = + simpleClient( + NetworkClient.parseArgs(args).withBaseDirectory(baseDirectory), + inputStream, + errorStream, + printStream, + useJNI, + ) + try { + client.connect(log = true) + client.run() + } catch { case _: Exception => 1 } finally client.close() + } + private def simpleClient( + arguments: Arguments, + inputStream: InputStream, + errorStream: PrintStream, + printStream: PrintStream, + useJNI: Boolean, + ): NetworkClient = + new NetworkClient( + NetworkClient.simpleConsoleInterface(printStream), + arguments, + inputStream, + errorStream, + printStream, + useJNI + ) + def main(useJNI: Boolean, args: Array[String]): Unit = { + val hook = new Thread(() => { + System.out.print(ConsoleAppender.ClearScreenAfterCursor) + System.out.flush() + }) + Runtime.getRuntime.addShutdownHook(hook) + System.exit(Terminal.withStreams { + val base = new File("").getCanonicalFile() + try client(base, args, System.in, System.err, System.out, useJNI) + finally { + Runtime.getRuntime.removeShutdownHook(hook) + hook.run() + } + }) + } + def run(configuration: xsbti.AppConfiguration, arguments: List[String]): Int = try { val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index fe155b856..cd513f456 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -15,7 +15,7 @@ import java.util.Properties import java.util.concurrent.ForkJoinPool import java.util.concurrent.atomic.AtomicBoolean -import sbt.BasicCommandStrings.{ Shell, TemplateCommand, networkExecPrefix } +import sbt.BasicCommandStrings.{ SetTerminal, Shell, TemplateCommand, networkExecPrefix } import sbt.Project.LoadAction import sbt.compiler.EvalImports import sbt.internal.Aggregation.AnyKeys @@ -265,6 +265,7 @@ object BuiltinCommands { continuous, clearCaches, NetworkChannel.disconnect, + setTerminalCommand, ) ++ allBasicCommands def DefaultBootCommands: Seq[String] = @@ -915,6 +916,12 @@ object BuiltinCommands { Command.command(ClearCaches, help)(f) } + def setTerminalCommand = Command.arb(_ => BasicCommands.reportParser(SetTerminal)) { + (s, channel) => + StandardMain.exchange.channelForName(channel).foreach(c => Terminal.set(c.terminal)) + s + } + private def getExec(state: State, interval: Duration): Exec = { val exec: Exec = StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index fd3005138..8b144f4d5 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -10,7 +10,7 @@ package sbt import java.io.PrintWriter import java.util.Properties -import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } +import sbt.BasicCommandStrings.{ SetTerminal, StashOnFailure, networkExecPrefix } import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError @@ -219,7 +219,12 @@ object MainLoop { state.get(CheckBuildSourcesKey) match { case Some(cbs) => if (!cbs.needsReload(state, exec.commandLine)) process() - else Exec("reload", None, None) +: exec +: state.remove(CheckBuildSourcesKey) + else { + if (exec.commandLine.startsWith(SetTerminal)) + exec +: Exec("reload", None, None) +: state.remove(CheckBuildSourcesKey) + else + Exec("reload", None, None) +: exec +: state.remove(CheckBuildSourcesKey) + } case _ => process() } } catch { diff --git a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala index 7a4f7b25a..6cc9bcc38 100644 --- a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala +++ b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala @@ -83,11 +83,15 @@ private[sbt] class CheckBuildSources extends AutoCloseable { previousStamps.set(getStamps(force = true)) } } - private def needCheck(cmd: String): Boolean = { - val commands = cmd.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty) - val res = !commands.exists { c => + private def needCheck(state: State, cmd: String): Boolean = { + val allCmds = state.remainingCommands + .map(_.commandLine) + .dropWhile(!_.startsWith(BasicCommandStrings.MapExec)) :+ cmd + val commands = + allCmds.flatMap(_.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty)) + val filter = (c: String) => c == LoadProject || c == RebootCommand || c == TerminateAction || c == "shutdown" - } + val res = !commands.exists(filter) if (!res) { previousStamps.set(getStamps(force = true)) needUpdate.set(false) @@ -96,7 +100,7 @@ private[sbt] class CheckBuildSources extends AutoCloseable { } @inline private def forceCheck = fileTreeRepository.isEmpty private[sbt] def needsReload(state: State, cmd: String) = { - (needCheck(cmd) && (forceCheck || needUpdate.compareAndSet(true, false))) && { + (needCheck(state, cmd) && (forceCheck || needUpdate.compareAndSet(true, false))) && { val extracted = Project.extract(state) val onChanges = extracted.get(Global / onChangedBuildSource) val logger = state.globalLogging.full diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 9aff64443..d3700a666 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -246,12 +246,23 @@ final class NetworkChannel( err: JsonRpcResponseError, execId: Option[String] ): Unit = this.synchronized { + def respond(id: String) = { + pendingRequests -= id + jsonRpcRespondError(id, err) + } + def error(): Unit = logMessage("error", s"Error ${err.code}: ${err.message}") execId match { - case Some(id) if pendingRequests.contains(id) => - pendingRequests -= id - jsonRpcRespondError(id, err) - case _ => - logMessage("error", s"Error ${err.code}: ${err.message}") + case Some(id) if pendingRequests.contains(id) => respond(id) + // This handles multi commands from the network that were remapped to a different + // exec id for reporting purposes. + case Some(id) if id.startsWith(BasicCommandStrings.networkExecPrefix) => + StandardMain.exchange.withState { s => + s.get(BasicCommands.execMap).flatMap(_.collectFirst { case (k, `id`) => k }) match { + case Some(id) if pendingRequests.contains(id) => respond(id) + case _ => error() + } + } + case _ => error() } } @@ -267,14 +278,27 @@ final class NetworkChannel( event: A, execId: Option[String] ): Unit = this.synchronized { + def error(): Unit = { + val msg = + s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" + log.debug(msg) + } + def respond(id: String): Unit = { + pendingRequests -= id + jsonRpcRespond(event, id) + } execId match { - case Some(id) if pendingRequests.contains(id) => - pendingRequests -= id - jsonRpcRespond(event, id) - case _ => - log.debug( - s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" - ) + case Some(id) if pendingRequests.contains(id) => respond(id) + // This handles multi commands from the network that were remapped to a different + // exec id for reporting purposes. + case Some(id) if id.startsWith(BasicCommandStrings.networkExecPrefix) => + StandardMain.exchange.withState { s => + s.get(BasicCommands.execMap).flatMap(_.collectFirst { case (k, `id`) => k }) match { + case Some(id) if pendingRequests.contains(id) => respond(id) + case _ => error() + } + } + case _ => error() } } @@ -436,6 +460,11 @@ final class NetworkChannel( Option(EvaluateTask.currentlyRunningEngine.get) match { case Some((state, runningEngine)) => val runningExecId = state.currentExecId.getOrElse("") + val expected = StandardMain.exchange.withState( + _.get(BasicCommands.execMap) + .flatMap(s => s.get(crp.id) orElse s.get("\u2668" + crp.id)) + .getOrElse(crp.id) + ) def checkId(): Boolean = { if (runningExecId.startsWith("\u2668")) { @@ -446,7 +475,7 @@ final class NetworkChannel( case (Some(id), Some(eid)) => id == eid case _ => false } - } else runningExecId == crp.id + } else runningExecId == expected } // direct comparison on strings and diff --git a/server-test/src/test/scala/testpkg/ClientTest.scala b/server-test/src/test/scala/testpkg/ClientTest.scala new file mode 100644 index 000000000..a48f623c5 --- /dev/null +++ b/server-test/src/test/scala/testpkg/ClientTest.scala @@ -0,0 +1,46 @@ +package testpkg + +import java.io.{ InputStream, PrintStream } +import sbt.internal.client.NetworkClient + +object ClientTest extends AbstractServerTest { + override val testDirectory: String = "client" + object NullInputStream extends InputStream { + override def read(): Int = { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + -1 + } + } + val NullPrintStream = new PrintStream(_ => {}, false) + private def client(args: String*) = + NetworkClient.client( + testPath.toFile, + args.toArray, + NullInputStream, + NullPrintStream, + NullPrintStream, + false + ) + test("exit success") { c => + assert(client("willSucceed") == 0) + } + test("exit failure") { _ => + assert(client("willFail") == 1) + } + test("two commands") { _ => + assert(client("compile;willSucceed") == 0) + } + test("two commands with failing second") { _ => + assert(client("compile;willFail") == 1) + } + test("two commands with leading failure") { _ => + assert(client("willFail;willSucceed") == 1) + } + test("three commands") { _ => + assert(client("compile;clean;willSucceed") == 0) + } + test("three commands with middle failure") { _ => + assert(client("compile;willFail;willSucceed") == 1) + } +}