From f9d5fbf29b3dc2d090171b3d9ea646076b50334f Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sat, 27 Jun 2020 15:54:50 -0700 Subject: [PATCH] Support reboot from remote client Reboot is a bit tricky for the remote client because the sbt server is actually shut down during reboot. When sbt shuts down the client, it can notify the client that the reason is a reboot. The client can then connect to the recently introduced boot control socket to display the reboot output and supply input in case the build fails to load. Once the server has brought back up the server, the client can reconnect. When the client session is interactive, we're done once we reconnect. When it's a batch session, the client needs to resend the remaing commands that have submitted that it hasn't yet run. --- .../main/scala/sbt/BasicCommandStrings.scala | 2 + .../src/main/scala/sbt/BasicCommands.scala | 7 + main-command/src/main/scala/sbt/State.scala | 12 +- .../sbt/internal/client/NetworkClient.scala | 140 +++++++++++++----- .../scala/sbt/internal/server/Server.scala | 3 +- main/src/main/scala/sbt/Main.scala | 13 +- .../scala/sbt/internal/CommandExchange.scala | 35 ++++- .../sbt/internal/nio/CheckBuildSources.scala | 3 +- .../sbt/internal/server/NetworkChannel.scala | 10 +- 9 files changed, 175 insertions(+), 50 deletions(-) diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 1c5de693e..d00b756be 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -137,6 +137,8 @@ $HelpCommand If a classpath is provided, modules are loaded from a new class loader for this classpath. """ + private[sbt] def RebootNetwork: String = "sbtRebootNetwork" + private[sbt] def RebootImpl: String = "sbtRebootImpl" def RebootCommand: String = "reboot" def RebootDetailed: String = RebootCommand + """ [dev | full] diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index db0812aaf..6ea7a6f9c 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -53,6 +53,7 @@ object BasicCommands { stashOnFailure, popOnFailure, reboot, + rebootImpl, call, early, exit, @@ -304,6 +305,12 @@ object BasicCommands { def reboot: Command = Command(RebootCommand, Help.more(RebootCommand, RebootDetailed))(_ => rebootOptionParser) { + case (s, (full, currentOnly)) => + val option = if (full) " full" else if (currentOnly) " dev" else "" + RebootNetwork :: s"$RebootImpl$option" :: s + } + def rebootImpl: Command = + Command.arb(_ => (RebootImpl ~> rebootOptionParser).examples()) { case (s, (full, currentOnly)) => s.reboot(full, currentOnly) } diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index fa75749a7..a0d21b949 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -22,6 +22,7 @@ import BasicCommandStrings.{ PopOnFailure, ReportResult, SetTerminal, + StartServer, StashOnFailure, networkExecPrefix, } @@ -339,9 +340,14 @@ object State { /** Implementation of reboot. */ private[sbt] def reboot(full: Boolean, currentOnly: Boolean): State = { runExitHooks() - val rs = s.remainingCommands map { case e: Exec => e.commandLine } - if (currentOnly) throw new RebootCurrent(rs) - else throw new xsbti.FullReload(rs.toArray, full) + val remaining: List[String] = s.remainingCommands.map(_.commandLine) + val fullRemaining = s.source match { + case Some(s) if s.channelName.startsWith("network") => + StartServer :: remaining.dropWhile(!_.startsWith(ReportResult)).tail ::: "shell" :: Nil + case _ => remaining + } + if (currentOnly) throw new RebootCurrent(fullRemaining) + else throw new xsbti.FullReload(fullRemaining.toArray, full) } def reload = runExitHooks().setNext(new Return(defaultReload(s))) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 6b9f3d8a5..2acacb441 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -85,7 +85,8 @@ class NetworkClient( private val status = new AtomicReference("Ready") private val lock: AnyRef = new AnyRef {} private val running = new AtomicBoolean(true) - private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)] + private val pendingResults = + new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long, String)] private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]] private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit] private val attached = new AtomicBoolean(false) @@ -93,6 +94,7 @@ class NetworkClient( private val connectionHolder = new AtomicReference[ServerConnection] private val batchMode = new AtomicBoolean(false) private val interactiveThread = new AtomicReference[Thread](null) + private val rebooting = new AtomicBoolean(false) private lazy val noTab = arguments.completionArguments.contains("--no-tab") private lazy val noStdErr = arguments.completionArguments.contains("--no-stderr") && System.getenv("SBTC_AUTO_COMPLETE") == null @@ -109,6 +111,7 @@ class NetworkClient( } private[this] val stdinBytes = new LinkedBlockingQueue[Int] + private[this] val inLock = new Object private[this] val inputThread = new AtomicReference(new RawInputThread) private[this] val exitClean = new AtomicBoolean(true) private[this] val sbtProcess = new AtomicReference[Process](null) @@ -123,17 +126,17 @@ class NetworkClient( val msg = if (noTab) "" else "No sbt server is running. Press to start one..." errorStream.print(s"\n$msg") if (noStdErr) System.exit(0) - else if (noTab) forkServer(portfile, log = true) + else if (noTab) waitForServer(portfile, log = true, startServer = true) else { stdinBytes.take match { case 9 => errorStream.println("\nStarting server...") - forkServer(portfile, !prompt) + waitForServer(portfile, !prompt, startServer = true) case _ => System.exit(0) } } } else { - forkServer(portfile, log = true) + waitForServer(portfile, log = true, startServer = true) } } @tailrec def connect(attempt: Int): (Socket, Option[String]) = { @@ -155,32 +158,66 @@ class NetworkClient( val (sk, tkn) = connect(0) val conn = new ServerConnection(sk) { override def onNotification(msg: JsonRpcNotificationMessage): Unit = { - if (msg.toString.contains("shutdown")) System.err.println(msg) msg.method match { case `Shutdown` => - val log = msg.params match { - case Some(jvalue) => Converter.fromJson[Boolean](jvalue).getOrElse(true) - case _ => false + val (log, rebootCommands) = msg.params match { + case Some(jvalue) => + Converter + .fromJson[(Boolean, Option[(String, String)])](jvalue) + .getOrElse((true, None)) + case _ => (false, None) } - if (running.compareAndSet(true, false) && log) { - if (!arguments.commandArguments.contains(Shutdown)) { - if (Terminal.console.getLastLine.fold(true)(_.nonEmpty)) errorStream.println() - console.appendLog(Level.Error, "sbt server disconnected") - exitClean.set(false) + if (rebootCommands.nonEmpty) { + if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println() + rebooting.set(true) + attached.set(false) + connectionHolder.getAndSet(null) match { + case null => + case c => c.shutdown() + } + waitForServer(portfile, true, false) + init(prompt = false, retry = false) + attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}""")) + rebooting.set(false) + rebootCommands match { + case Some((execId, cmd)) if execId.nonEmpty => + if (batchMode.get && !pendingResults.contains(execId) && cmd.isEmpty) { + console.appendLog( + Level.Error, + s"received request to re-run unknown command '$cmd' after reboot" + ) + } else if (cmd.nonEmpty) { + if (batchMode.get) sendCommand(ExecCommand(cmd, execId)) + else + inLock.synchronized { + val toSend = cmd.getBytes :+ '\r'.toByte + toSend.foreach(b => sendNotification(systemIn, b.toString)) + } + } else completeExec(execId, 0) + case _ => } } else { - console.appendLog(Level.Info, "sbt server disconnected") + if (!rebooting.get() && running.compareAndSet(true, false) && log) { + if (!arguments.commandArguments.contains(Shutdown)) { + if (Terminal.console.getLastLine.isDefined) + Terminal.console.printStream.println() + console.appendLog(Level.Error, "sbt server disconnected") + exitClean.set(false) + } + } else { + console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected") + } + stdinBytes.offer(-1) + Option(inputThread.get).foreach(_.close()) + Option(interactiveThread.get).foreach(_.interrupt) } - stdinBytes.offer(-1) - Option(inputThread.get).foreach(_.close()) - Option(interactiveThread.get).foreach(_.interrupt) case "readInput" => case _ => self.onNotification(msg) } } override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) - override def onShutdown(): Unit = { + override def onShutdown(): Unit = if (!rebooting.get) { if (exitClean.get != false) exitClean.set(!running.get) running.set(false) Option(interactiveThread.get).foreach(_.interrupt()) @@ -202,14 +239,22 @@ class NetworkClient( * Forks another instance of sbt in the background. * This instance must be shutdown explicitly via `sbt -client shutdown` */ - def forkServer(portfile: File, log: Boolean): Unit = { + def waitForServer(portfile: File, log: Boolean, startServer: Boolean): Unit = { val bootSocketName = BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath()) - var socket: Option[Socket] = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + + /* + * For unknown reasons, linux sometimes struggles to connect to the socket in some + * scenarios. + */ + var socket: Option[Socket] = + if (!Properties.isLinux) Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + else None val process = socket match { - case None => + case None if startServer => val term = Terminal.console if (log) console.appendLog(Level.Info, "server was not detected. starting an instance") + val props = Seq( term.getWidth, @@ -229,12 +274,21 @@ class NetworkClient( sbtProcess.set(process) Some(process) case _ => - if (log) console.appendLog(Level.Info, "sbt server is booting up") + if (log) { + if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println() + console.appendLog(Level.Info, "sbt server is booting up") + } None } + if (!startServer) { + val deadline = 5.seconds.fromNow + while (socket.isEmpty && !deadline.isOverdue) { + socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + if (socket.isEmpty) Thread.sleep(20) + } + } val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly())) Runtime.getRuntime.addShutdownHook(hook) - val isWin = Properties.isWin var gotInputBack = false val readThreadAlive = new AtomicBoolean(true) /* @@ -248,6 +302,9 @@ class NetworkClient( override def run(): Unit = { try { while (readThreadAlive.get) { + if (socket.isEmpty) { + socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + } socket.foreach { s => try { s.getInputStream.read match { @@ -272,9 +329,6 @@ class NetworkClient( } @tailrec def blockUntilStart(): Unit = { - if (socket.isEmpty) { - socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption - } val stop = try { socket match { case None => @@ -311,8 +365,9 @@ class NetworkClient( * will return with exit value 2. In that case, we can treat the process as alive * even if it is actually dead. */ - val existsValidProcess = process.fold(socket.isDefined)(p => p.isAlive || p.exitValue == 2) - if (!portfile.exists && !stop && readThreadAlive.get && existsValidProcess) { + val existsValidProcess = + process.fold(readThreadAlive.get)(p => p.isAlive || (Properties.isWin || p.exitValue == 2)) + if (!portfile.exists && !stop && existsValidProcess) { blockUntilStart() } else { socket.foreach { s => @@ -367,19 +422,22 @@ class NetworkClient( .getOrElse(1) case _ => 1 } - def onResponse(msg: JsonRpcResponseMessage): Unit = { - pendingResults.remove(msg.id) match { + private def completeExec(execId: String, exitCode: => Int): Unit = + pendingResults.remove(execId) match { case null => - case (q, startTime) => + case (q, startTime, name) => val now = System.currentTimeMillis val message = timing(startTime, now) - val exitCode = getExitCode(msg.result) + val ec = exitCode if (batchMode.get || !attached.get) { - if (exitCode == 0) console.success(message) - else if (!attached.get) console.appendLog(Level.Error, message) + console.appendLog(Level.Info, s"$name completed") + if (ec == 0) console.success(message) + else console.appendLog(Level.Error, message) } - q.offer(exitCode) + Util.ignoreResult(q.offer(ec)) } + def onResponse(msg: JsonRpcResponseMessage): Unit = { + completeExec(msg.id, getExitCode(msg.result)) pendingCancellations.remove(msg.id) match { case null => case q => q.offer(msg.toString.contains("Task cancelled")) @@ -681,7 +739,7 @@ class NetworkClient( val execId = UUID.randomUUID.toString val queue = new LinkedBlockingQueue[Integer] sendCommand(ExecCommand(commandLine, execId)) - pendingResults.put(execId, (queue, System.currentTimeMillis)) + pendingResults.put(execId, (queue, System.currentTimeMillis, commandLine)) queue } @@ -718,9 +776,12 @@ class NetworkClient( } def sendJson(method: String, params: String): String = { val uuid = UUID.randomUUID.toString + sendJson(method, params, uuid) + uuid + } + def sendJson(method: String, params: String, uuid: String): Unit = { val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }""" connection.sendString(msg) - uuid } def sendNotification(method: String, params: String): Unit = { @@ -746,13 +807,12 @@ class NetworkClient( setDaemon(true) start() val stopped = new AtomicBoolean(false) - val lock = new Object override final def run(): Unit = { @tailrec def read(): Unit = { inputStream.read match { case -1 => case b => - lock.synchronized(stdinBytes.offer(b)) + inLock.synchronized(stdinBytes.offer(b)) if (attached.get()) drain() if (!stopped.get()) read() } @@ -761,7 +821,7 @@ class NetworkClient( catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) } } - def drain(): Unit = lock.synchronized { + def drain(): Unit = inLock.synchronized { while (!stdinBytes.isEmpty) { val byte = stdinBytes.poll() sendNotification(systemIn, byte.toString) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 74376ea49..10d0a519a 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -106,7 +106,8 @@ private[sbt] object Server { val socket = serverSocket.accept() onIncomingSocket(socket, self) } catch { - case _: SocketTimeoutException => // its ok + case e: IOException if e.getMessage.contains("connect") => + case _: SocketTimeoutException => // its ok } } serverSocketHolder.get match { diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 56a582cdc..3d886a3ca 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -111,7 +111,8 @@ private[sbt] object xMain { ): (Option[BootServerSocket], Option[Exit]) = try (Some(new BootServerSocket(configuration)) -> None) catch { - case _: ServerAlreadyBootingException if System.console != null => + case _: ServerAlreadyBootingException + if System.console != null && !Terminal.startedByRemoteClient => println("sbt server is already booting. Create a new server? y/n (default y)") val exit = Terminal.get.withRawSystemIn(System.in.read) match { case 110 => Some(Exit(1)) @@ -296,6 +297,7 @@ object BuiltinCommands { skipBanner, notifyUsersAboutShell, shell, + rebootNetwork, startServer, eval, last, @@ -1046,6 +1048,11 @@ object BuiltinCommands { } } + def rebootNetwork: Command = Command.arb(_ => (RebootNetwork: Parser[String]).examples()) { + (s, _) => + StandardMain.exchange.reboot(s) + s + } def startServer: Command = Command.command(StartServer, Help.more(StartServer, StartServerDetailed)) { s0 => val exchange = StandardMain.exchange @@ -1115,10 +1122,12 @@ object BuiltinCommands { private def intendsToInvokeCompile(state: State) = state.remainingCommands exists (_.commandLine == Keys.compile.key.label) + private def hasRebooted(state: State) = + state.remainingCommands exists (_.commandLine == StartServer) private def notifyUsersAboutShell(state: State): Unit = { val suppress = Project extract state getOpt Keys.suppressSbtShellNotification getOrElse false - if (!suppress && intendsToInvokeCompile(state)) + if (!suppress && intendsToInvokeCompile(state) && !hasRebooted(state)) state.log info "Executing in batch mode. For better performance use sbt's shell" } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 358488605..29091ce96 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -13,7 +13,13 @@ import java.net.Socket import java.util.concurrent.atomic._ import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } -import sbt.BasicCommandStrings.{ Cancel, Shutdown, TerminateAction, networkExecPrefix } +import sbt.BasicCommandStrings.{ + Cancel, + CompleteExec, + Shutdown, + TerminateAction, + networkExecPrefix +} import sbt.BasicKeys._ import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ @@ -377,6 +383,33 @@ private[sbt] final class CommandExchange { channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal)) } + /** + * When a reboot is initiated by a network client, we need to communicate + * to it which + * + * @param state + */ + private[sbt] def reboot(state: State): Unit = state.source match { + case Some(s) if s.channelName.startsWith("network") => + channels.foreach { + case nc: NetworkChannel if nc.name == s.channelName => + val remainingCommands = + state.remainingCommands + .takeWhile(!_.commandLine.startsWith(CompleteExec)) + .map(_.commandLine) + .filterNot(_.startsWith("sbtReboot")) + .mkString(";") + val execId = state.remainingCommands.collectFirst { + case e if e.commandLine.startsWith(CompleteExec) => + e.commandLine.split(CompleteExec).last.trim + } + nc.shutdown(true, execId.map(_ -> remainingCommands)) + case nc: NetworkChannel => nc.shutdown(true, Some(("", ""))) + case _ => + } + case _ => + } + private[sbt] def shutdown(name: String): Unit = { Option(currentExecRef.get).foreach(cancel) commandQueue.clear() diff --git a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala index 363aa7736..928f0ca21 100644 --- a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala +++ b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala @@ -90,7 +90,8 @@ private[sbt] class CheckBuildSources extends AutoCloseable { val commands = allCmds.flatMap(_.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty)) val filter = (c: String) => - c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown + c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown || + c.startsWith("sbtReboot") val res = !commands.exists(filter) if (!res) { previousStamps.set(getStamps(force = true)) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 549ccd022..50b3c7bb6 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -551,6 +551,9 @@ final class NetworkChannel( } import sjsonnew.BasicJsonProtocol.BooleanJsonFormat + override def shutdown(logShutdown: Boolean): Unit = + shutdown(logShutdown, remainingCommands = None) + /** * Closes down the channel. Before closing the socket, it sends a notification to * the client to shutdown. If the client initiated the shutdown, we don't want the @@ -559,13 +562,16 @@ final class NetworkChannel( * easily be done client side because when the client is in interactive session, * it doesn't know commands it has sent to the server. */ - override def shutdown(logShutdown: Boolean): Unit = { + private[sbt] def shutdown( + logShutdown: Boolean, + remainingCommands: Option[(String, String)] + ): Unit = { terminal.close() StandardMain.exchange.removeChannel(this) super.shutdown(logShutdown) if (logShutdown) Terminal.consoleLog(s"shutting down client connection $name") VirtualTerminal.cancelRequests(name) - try jsonRpcNotify(Shutdown, logShutdown) + try jsonRpcNotify(Shutdown, (logShutdown, remainingCommands)) catch { case _: IOException => } running.set(false) out.close()