diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 939b95e9e..77a3b6fc3 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -209,6 +209,11 @@ $AliasCommand name= def FailureWall: String = "resumeFromFailure" + def SetTerminal = "sbtSetTerminal" + def ReportResult = "sbtReportResult" + def CompleteExec = "sbtCompleteExec" + def MapExec = "sbtMapExec" + def ClearOnFailure: String = "sbtClearOnFailure" def OnFailure: String = "onFailure" def OnFailureDetailed: String = diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index cba9c73c1..594ebcca1 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -61,6 +61,9 @@ object BasicCommands { client, read, alias, + reportResultsCommand, + mapExecCommand, + completeExecCommand, ) def nop: Command = Command.custom(s => success(() => s)) @@ -544,4 +547,42 @@ object BasicCommands { "is-command-alias", "Internal: marker for Commands created as aliases for another command." ) + + private[sbt] def reportParser(key: String) = + (key: Parser[String]).examples() ~> " ".examples() ~> matched(any.*).examples() + def reportResultsCommand = + Command.arb(_ => reportParser(ReportResult)) { (state, id) => + val newState = state.get(execMap) match { + case Some(m) => state.put(execMap, m - id) + case _ => state + } + newState.get(execResults) match { + case Some(m) if m.contains(id) => state.put(execResults, m - id) + case _ => state.fail + } + } + def mapExecCommand = + Command.arb(_ => reportParser(MapExec)) { (state, mapping) => + mapping.split(" ") match { + case Array(key, value) => + state.get(execMap) match { + case Some(m) => state.put(execMap, m + (key -> value)) + case None => state.put(execMap, Map(key -> value)) + } + case _ => state + } + } + def completeExecCommand = + Command.arb(_ => reportParser(CompleteExec)) { (state, id) => + val newState = state.get(execResults) match { + case Some(m) => state.put(execResults, m + (id -> true)) + case _ => state.put(execResults, Map(id -> true)) + } + newState.get(execMap) match { + case Some(m) => newState.put(execMap, m - id) + case _ => newState + } + } + private[sbt] val execResults = AttributeKey[Map[String, Boolean]]("execResults", Int.MaxValue) + private[sbt] val execMap = AttributeKey[Map[String, String]]("execMap", Int.MaxValue) } diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index 784c0f162..fa75749a7 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -16,6 +16,15 @@ import sbt.internal.inc.classpath.{ ClassLoaderCache => IncClassLoaderCache } import sbt.internal.util.complete.{ HistoryCommands, Parser } import sbt.internal.util._ import sbt.util.Logger +import BasicCommandStrings.{ + CompleteExec, + MapExec, + PopOnFailure, + ReportResult, + SetTerminal, + StashOnFailure, + networkExecPrefix, +} /** * Data structure representing all command execution information. @@ -273,8 +282,43 @@ object State { f(cmd, s1) } s.remainingCommands match { - case Nil => exit(true) - case x :: xs => runCmd(x, xs) + case Nil => exit(true) + case x :: xs => + (x.execId, x.source) match { + /* + * If the command is coming from a network source, it might be a multi-command. To handle + * that, we need to give the command a new exec id and wrap some commands around the + * actual command that are used to report it. To make this work, we add a map of exec + * results as well as a mapping of exec ids to the exec id that spawned the exec. + * We add a command that fills the result map for the original exec. If the command fails, + * that map filling command (called sbtCompleteExec) is skipped so the map is never filled + * for the original event. The report command (called sbtReportResult) checks the result + * map and, if it finds an entry, it succeeds and removes the entry. Otherwise it fails. + * The exec for the report command is given the original exec id so the result reported + * to the client will be the result of the report command (which should correspond to + * the result of the underlying multi-command, which succeeds only if all of the commands + * succeed) + * + */ + case (Some(id), Some(s)) + if s.channelName.startsWith("network") && + !x.commandLine.startsWith(ReportResult) && + !x.commandLine.startsWith(networkExecPrefix) && + !id.startsWith(networkExecPrefix) => + val newID = networkExecPrefix + Exec.newExecId + val cmd = x.withExecId(newID) + val map = Exec(s"$MapExec $id $newID", None) + val complete = Exec(s"$CompleteExec $id", None) + val report = Exec(s"$ReportResult $id", Some(id), x.source) + val stash = Exec(StashOnFailure, None) + val failureWall = Exec(FailureWall, None) + val pop = Exec(PopOnFailure, None) + val setTerm = Exec(s"$SetTerminal ${s.channelName}", None) + val setConsole = Exec(s"$SetTerminal console0", None) + val remaining = stash :: map :: cmd :: complete :: failureWall :: pop :: setConsole :: report :: xs + runCmd(setTerm, remaining) + case _ => runCmd(x, xs) + } } } def :::(newCommands: List[String]): State = ++:(newCommands map { Exec(_, s.source) }) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 520a4e593..10634bb83 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -641,6 +641,58 @@ object NetworkClient { new Arguments(new File("").getCanonicalFile, sbtArguments, commandArgs, sbtScript) } + def client( + baseDirectory: File, + args: Array[String], + inputStream: InputStream, + errorStream: PrintStream, + printStream: PrintStream, + useJNI: Boolean + ): Int = { + val client = + simpleClient( + NetworkClient.parseArgs(args).withBaseDirectory(baseDirectory), + inputStream, + errorStream, + printStream, + useJNI, + ) + try { + client.connect(log = true) + client.run() + } catch { case _: Exception => 1 } finally client.close() + } + private def simpleClient( + arguments: Arguments, + inputStream: InputStream, + errorStream: PrintStream, + printStream: PrintStream, + useJNI: Boolean, + ): NetworkClient = + new NetworkClient( + NetworkClient.simpleConsoleInterface(printStream), + arguments, + inputStream, + errorStream, + printStream, + useJNI + ) + def main(useJNI: Boolean, args: Array[String]): Unit = { + val hook = new Thread(() => { + System.out.print(ConsoleAppender.ClearScreenAfterCursor) + System.out.flush() + }) + Runtime.getRuntime.addShutdownHook(hook) + System.exit(Terminal.withStreams { + val base = new File("").getCanonicalFile() + try client(base, args, System.in, System.err, System.out, useJNI) + finally { + Runtime.getRuntime.removeShutdownHook(hook) + hook.run() + } + }) + } + def run(configuration: xsbti.AppConfiguration, arguments: List[String]): Int = try { val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index fe155b856..cd513f456 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -15,7 +15,7 @@ import java.util.Properties import java.util.concurrent.ForkJoinPool import java.util.concurrent.atomic.AtomicBoolean -import sbt.BasicCommandStrings.{ Shell, TemplateCommand, networkExecPrefix } +import sbt.BasicCommandStrings.{ SetTerminal, Shell, TemplateCommand, networkExecPrefix } import sbt.Project.LoadAction import sbt.compiler.EvalImports import sbt.internal.Aggregation.AnyKeys @@ -265,6 +265,7 @@ object BuiltinCommands { continuous, clearCaches, NetworkChannel.disconnect, + setTerminalCommand, ) ++ allBasicCommands def DefaultBootCommands: Seq[String] = @@ -915,6 +916,12 @@ object BuiltinCommands { Command.command(ClearCaches, help)(f) } + def setTerminalCommand = Command.arb(_ => BasicCommands.reportParser(SetTerminal)) { + (s, channel) => + StandardMain.exchange.channelForName(channel).foreach(c => Terminal.set(c.terminal)) + s + } + private def getExec(state: State, interval: Duration): Exec = { val exec: Exec = StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index fd3005138..8b144f4d5 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -10,7 +10,7 @@ package sbt import java.io.PrintWriter import java.util.Properties -import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } +import sbt.BasicCommandStrings.{ SetTerminal, StashOnFailure, networkExecPrefix } import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError @@ -219,7 +219,12 @@ object MainLoop { state.get(CheckBuildSourcesKey) match { case Some(cbs) => if (!cbs.needsReload(state, exec.commandLine)) process() - else Exec("reload", None, None) +: exec +: state.remove(CheckBuildSourcesKey) + else { + if (exec.commandLine.startsWith(SetTerminal)) + exec +: Exec("reload", None, None) +: state.remove(CheckBuildSourcesKey) + else + Exec("reload", None, None) +: exec +: state.remove(CheckBuildSourcesKey) + } case _ => process() } } catch { diff --git a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala index 7a4f7b25a..6cc9bcc38 100644 --- a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala +++ b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala @@ -83,11 +83,15 @@ private[sbt] class CheckBuildSources extends AutoCloseable { previousStamps.set(getStamps(force = true)) } } - private def needCheck(cmd: String): Boolean = { - val commands = cmd.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty) - val res = !commands.exists { c => + private def needCheck(state: State, cmd: String): Boolean = { + val allCmds = state.remainingCommands + .map(_.commandLine) + .dropWhile(!_.startsWith(BasicCommandStrings.MapExec)) :+ cmd + val commands = + allCmds.flatMap(_.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty)) + val filter = (c: String) => c == LoadProject || c == RebootCommand || c == TerminateAction || c == "shutdown" - } + val res = !commands.exists(filter) if (!res) { previousStamps.set(getStamps(force = true)) needUpdate.set(false) @@ -96,7 +100,7 @@ private[sbt] class CheckBuildSources extends AutoCloseable { } @inline private def forceCheck = fileTreeRepository.isEmpty private[sbt] def needsReload(state: State, cmd: String) = { - (needCheck(cmd) && (forceCheck || needUpdate.compareAndSet(true, false))) && { + (needCheck(state, cmd) && (forceCheck || needUpdate.compareAndSet(true, false))) && { val extracted = Project.extract(state) val onChanges = extracted.get(Global / onChangedBuildSource) val logger = state.globalLogging.full diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 9aff64443..d3700a666 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -246,12 +246,23 @@ final class NetworkChannel( err: JsonRpcResponseError, execId: Option[String] ): Unit = this.synchronized { + def respond(id: String) = { + pendingRequests -= id + jsonRpcRespondError(id, err) + } + def error(): Unit = logMessage("error", s"Error ${err.code}: ${err.message}") execId match { - case Some(id) if pendingRequests.contains(id) => - pendingRequests -= id - jsonRpcRespondError(id, err) - case _ => - logMessage("error", s"Error ${err.code}: ${err.message}") + case Some(id) if pendingRequests.contains(id) => respond(id) + // This handles multi commands from the network that were remapped to a different + // exec id for reporting purposes. + case Some(id) if id.startsWith(BasicCommandStrings.networkExecPrefix) => + StandardMain.exchange.withState { s => + s.get(BasicCommands.execMap).flatMap(_.collectFirst { case (k, `id`) => k }) match { + case Some(id) if pendingRequests.contains(id) => respond(id) + case _ => error() + } + } + case _ => error() } } @@ -267,14 +278,27 @@ final class NetworkChannel( event: A, execId: Option[String] ): Unit = this.synchronized { + def error(): Unit = { + val msg = + s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" + log.debug(msg) + } + def respond(id: String): Unit = { + pendingRequests -= id + jsonRpcRespond(event, id) + } execId match { - case Some(id) if pendingRequests.contains(id) => - pendingRequests -= id - jsonRpcRespond(event, id) - case _ => - log.debug( - s"unmatched json response for requestId $execId: ${CompactPrinter(Converter.toJsonUnsafe(event))}" - ) + case Some(id) if pendingRequests.contains(id) => respond(id) + // This handles multi commands from the network that were remapped to a different + // exec id for reporting purposes. + case Some(id) if id.startsWith(BasicCommandStrings.networkExecPrefix) => + StandardMain.exchange.withState { s => + s.get(BasicCommands.execMap).flatMap(_.collectFirst { case (k, `id`) => k }) match { + case Some(id) if pendingRequests.contains(id) => respond(id) + case _ => error() + } + } + case _ => error() } } @@ -436,6 +460,11 @@ final class NetworkChannel( Option(EvaluateTask.currentlyRunningEngine.get) match { case Some((state, runningEngine)) => val runningExecId = state.currentExecId.getOrElse("") + val expected = StandardMain.exchange.withState( + _.get(BasicCommands.execMap) + .flatMap(s => s.get(crp.id) orElse s.get("\u2668" + crp.id)) + .getOrElse(crp.id) + ) def checkId(): Boolean = { if (runningExecId.startsWith("\u2668")) { @@ -446,7 +475,7 @@ final class NetworkChannel( case (Some(id), Some(eid)) => id == eid case _ => false } - } else runningExecId == crp.id + } else runningExecId == expected } // direct comparison on strings and diff --git a/server-test/src/test/scala/testpkg/ClientTest.scala b/server-test/src/test/scala/testpkg/ClientTest.scala new file mode 100644 index 000000000..a48f623c5 --- /dev/null +++ b/server-test/src/test/scala/testpkg/ClientTest.scala @@ -0,0 +1,46 @@ +package testpkg + +import java.io.{ InputStream, PrintStream } +import sbt.internal.client.NetworkClient + +object ClientTest extends AbstractServerTest { + override val testDirectory: String = "client" + object NullInputStream extends InputStream { + override def read(): Int = { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + -1 + } + } + val NullPrintStream = new PrintStream(_ => {}, false) + private def client(args: String*) = + NetworkClient.client( + testPath.toFile, + args.toArray, + NullInputStream, + NullPrintStream, + NullPrintStream, + false + ) + test("exit success") { c => + assert(client("willSucceed") == 0) + } + test("exit failure") { _ => + assert(client("willFail") == 1) + } + test("two commands") { _ => + assert(client("compile;willSucceed") == 0) + } + test("two commands with failing second") { _ => + assert(client("compile;willFail") == 1) + } + test("two commands with leading failure") { _ => + assert(client("willFail;willSucceed") == 1) + } + test("three commands") { _ => + assert(client("compile;clean;willSucceed") == 0) + } + test("three commands with middle failure") { _ => + assert(client("compile;willFail;willSucceed") == 1) + } +}