diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 1c5de693e..d00b756be 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -137,6 +137,8 @@ $HelpCommand If a classpath is provided, modules are loaded from a new class loader for this classpath. """ + private[sbt] def RebootNetwork: String = "sbtRebootNetwork" + private[sbt] def RebootImpl: String = "sbtRebootImpl" def RebootCommand: String = "reboot" def RebootDetailed: String = RebootCommand + """ [dev | full] diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index db0812aaf..6ea7a6f9c 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -53,6 +53,7 @@ object BasicCommands { stashOnFailure, popOnFailure, reboot, + rebootImpl, call, early, exit, @@ -304,6 +305,12 @@ object BasicCommands { def reboot: Command = Command(RebootCommand, Help.more(RebootCommand, RebootDetailed))(_ => rebootOptionParser) { + case (s, (full, currentOnly)) => + val option = if (full) " full" else if (currentOnly) " dev" else "" + RebootNetwork :: s"$RebootImpl$option" :: s + } + def rebootImpl: Command = + Command.arb(_ => (RebootImpl ~> rebootOptionParser).examples()) { case (s, (full, currentOnly)) => s.reboot(full, currentOnly) } diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index fa75749a7..a0d21b949 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -22,6 +22,7 @@ import BasicCommandStrings.{ PopOnFailure, ReportResult, SetTerminal, + StartServer, StashOnFailure, networkExecPrefix, } @@ -339,9 +340,14 @@ object State { /** Implementation of reboot. */ private[sbt] def reboot(full: Boolean, currentOnly: Boolean): State = { runExitHooks() - val rs = s.remainingCommands map { case e: Exec => e.commandLine } - if (currentOnly) throw new RebootCurrent(rs) - else throw new xsbti.FullReload(rs.toArray, full) + val remaining: List[String] = s.remainingCommands.map(_.commandLine) + val fullRemaining = s.source match { + case Some(s) if s.channelName.startsWith("network") => + StartServer :: remaining.dropWhile(!_.startsWith(ReportResult)).tail ::: "shell" :: Nil + case _ => remaining + } + if (currentOnly) throw new RebootCurrent(fullRemaining) + else throw new xsbti.FullReload(fullRemaining.toArray, full) } def reload = runExitHooks().setNext(new Return(defaultReload(s))) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 6b9f3d8a5..2acacb441 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -85,7 +85,8 @@ class NetworkClient( private val status = new AtomicReference("Ready") private val lock: AnyRef = new AnyRef {} private val running = new AtomicBoolean(true) - private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)] + private val pendingResults = + new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long, String)] private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]] private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit] private val attached = new AtomicBoolean(false) @@ -93,6 +94,7 @@ class NetworkClient( private val connectionHolder = new AtomicReference[ServerConnection] private val batchMode = new AtomicBoolean(false) private val interactiveThread = new AtomicReference[Thread](null) + private val rebooting = new AtomicBoolean(false) private lazy val noTab = arguments.completionArguments.contains("--no-tab") private lazy val noStdErr = arguments.completionArguments.contains("--no-stderr") && System.getenv("SBTC_AUTO_COMPLETE") == null @@ -109,6 +111,7 @@ class NetworkClient( } private[this] val stdinBytes = new LinkedBlockingQueue[Int] + private[this] val inLock = new Object private[this] val inputThread = new AtomicReference(new RawInputThread) private[this] val exitClean = new AtomicBoolean(true) private[this] val sbtProcess = new AtomicReference[Process](null) @@ -123,17 +126,17 @@ class NetworkClient( val msg = if (noTab) "" else "No sbt server is running. Press to start one..." errorStream.print(s"\n$msg") if (noStdErr) System.exit(0) - else if (noTab) forkServer(portfile, log = true) + else if (noTab) waitForServer(portfile, log = true, startServer = true) else { stdinBytes.take match { case 9 => errorStream.println("\nStarting server...") - forkServer(portfile, !prompt) + waitForServer(portfile, !prompt, startServer = true) case _ => System.exit(0) } } } else { - forkServer(portfile, log = true) + waitForServer(portfile, log = true, startServer = true) } } @tailrec def connect(attempt: Int): (Socket, Option[String]) = { @@ -155,32 +158,66 @@ class NetworkClient( val (sk, tkn) = connect(0) val conn = new ServerConnection(sk) { override def onNotification(msg: JsonRpcNotificationMessage): Unit = { - if (msg.toString.contains("shutdown")) System.err.println(msg) msg.method match { case `Shutdown` => - val log = msg.params match { - case Some(jvalue) => Converter.fromJson[Boolean](jvalue).getOrElse(true) - case _ => false + val (log, rebootCommands) = msg.params match { + case Some(jvalue) => + Converter + .fromJson[(Boolean, Option[(String, String)])](jvalue) + .getOrElse((true, None)) + case _ => (false, None) } - if (running.compareAndSet(true, false) && log) { - if (!arguments.commandArguments.contains(Shutdown)) { - if (Terminal.console.getLastLine.fold(true)(_.nonEmpty)) errorStream.println() - console.appendLog(Level.Error, "sbt server disconnected") - exitClean.set(false) + if (rebootCommands.nonEmpty) { + if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println() + rebooting.set(true) + attached.set(false) + connectionHolder.getAndSet(null) match { + case null => + case c => c.shutdown() + } + waitForServer(portfile, true, false) + init(prompt = false, retry = false) + attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}""")) + rebooting.set(false) + rebootCommands match { + case Some((execId, cmd)) if execId.nonEmpty => + if (batchMode.get && !pendingResults.contains(execId) && cmd.isEmpty) { + console.appendLog( + Level.Error, + s"received request to re-run unknown command '$cmd' after reboot" + ) + } else if (cmd.nonEmpty) { + if (batchMode.get) sendCommand(ExecCommand(cmd, execId)) + else + inLock.synchronized { + val toSend = cmd.getBytes :+ '\r'.toByte + toSend.foreach(b => sendNotification(systemIn, b.toString)) + } + } else completeExec(execId, 0) + case _ => } } else { - console.appendLog(Level.Info, "sbt server disconnected") + if (!rebooting.get() && running.compareAndSet(true, false) && log) { + if (!arguments.commandArguments.contains(Shutdown)) { + if (Terminal.console.getLastLine.isDefined) + Terminal.console.printStream.println() + console.appendLog(Level.Error, "sbt server disconnected") + exitClean.set(false) + } + } else { + console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected") + } + stdinBytes.offer(-1) + Option(inputThread.get).foreach(_.close()) + Option(interactiveThread.get).foreach(_.interrupt) } - stdinBytes.offer(-1) - Option(inputThread.get).foreach(_.close()) - Option(interactiveThread.get).foreach(_.interrupt) case "readInput" => case _ => self.onNotification(msg) } } override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) - override def onShutdown(): Unit = { + override def onShutdown(): Unit = if (!rebooting.get) { if (exitClean.get != false) exitClean.set(!running.get) running.set(false) Option(interactiveThread.get).foreach(_.interrupt()) @@ -202,14 +239,22 @@ class NetworkClient( * Forks another instance of sbt in the background. * This instance must be shutdown explicitly via `sbt -client shutdown` */ - def forkServer(portfile: File, log: Boolean): Unit = { + def waitForServer(portfile: File, log: Boolean, startServer: Boolean): Unit = { val bootSocketName = BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath()) - var socket: Option[Socket] = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + + /* + * For unknown reasons, linux sometimes struggles to connect to the socket in some + * scenarios. + */ + var socket: Option[Socket] = + if (!Properties.isLinux) Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + else None val process = socket match { - case None => + case None if startServer => val term = Terminal.console if (log) console.appendLog(Level.Info, "server was not detected. starting an instance") + val props = Seq( term.getWidth, @@ -229,12 +274,21 @@ class NetworkClient( sbtProcess.set(process) Some(process) case _ => - if (log) console.appendLog(Level.Info, "sbt server is booting up") + if (log) { + if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println() + console.appendLog(Level.Info, "sbt server is booting up") + } None } + if (!startServer) { + val deadline = 5.seconds.fromNow + while (socket.isEmpty && !deadline.isOverdue) { + socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + if (socket.isEmpty) Thread.sleep(20) + } + } val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly())) Runtime.getRuntime.addShutdownHook(hook) - val isWin = Properties.isWin var gotInputBack = false val readThreadAlive = new AtomicBoolean(true) /* @@ -248,6 +302,9 @@ class NetworkClient( override def run(): Unit = { try { while (readThreadAlive.get) { + if (socket.isEmpty) { + socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + } socket.foreach { s => try { s.getInputStream.read match { @@ -272,9 +329,6 @@ class NetworkClient( } @tailrec def blockUntilStart(): Unit = { - if (socket.isEmpty) { - socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption - } val stop = try { socket match { case None => @@ -311,8 +365,9 @@ class NetworkClient( * will return with exit value 2. In that case, we can treat the process as alive * even if it is actually dead. */ - val existsValidProcess = process.fold(socket.isDefined)(p => p.isAlive || p.exitValue == 2) - if (!portfile.exists && !stop && readThreadAlive.get && existsValidProcess) { + val existsValidProcess = + process.fold(readThreadAlive.get)(p => p.isAlive || (Properties.isWin || p.exitValue == 2)) + if (!portfile.exists && !stop && existsValidProcess) { blockUntilStart() } else { socket.foreach { s => @@ -367,19 +422,22 @@ class NetworkClient( .getOrElse(1) case _ => 1 } - def onResponse(msg: JsonRpcResponseMessage): Unit = { - pendingResults.remove(msg.id) match { + private def completeExec(execId: String, exitCode: => Int): Unit = + pendingResults.remove(execId) match { case null => - case (q, startTime) => + case (q, startTime, name) => val now = System.currentTimeMillis val message = timing(startTime, now) - val exitCode = getExitCode(msg.result) + val ec = exitCode if (batchMode.get || !attached.get) { - if (exitCode == 0) console.success(message) - else if (!attached.get) console.appendLog(Level.Error, message) + console.appendLog(Level.Info, s"$name completed") + if (ec == 0) console.success(message) + else console.appendLog(Level.Error, message) } - q.offer(exitCode) + Util.ignoreResult(q.offer(ec)) } + def onResponse(msg: JsonRpcResponseMessage): Unit = { + completeExec(msg.id, getExitCode(msg.result)) pendingCancellations.remove(msg.id) match { case null => case q => q.offer(msg.toString.contains("Task cancelled")) @@ -681,7 +739,7 @@ class NetworkClient( val execId = UUID.randomUUID.toString val queue = new LinkedBlockingQueue[Integer] sendCommand(ExecCommand(commandLine, execId)) - pendingResults.put(execId, (queue, System.currentTimeMillis)) + pendingResults.put(execId, (queue, System.currentTimeMillis, commandLine)) queue } @@ -718,9 +776,12 @@ class NetworkClient( } def sendJson(method: String, params: String): String = { val uuid = UUID.randomUUID.toString + sendJson(method, params, uuid) + uuid + } + def sendJson(method: String, params: String, uuid: String): Unit = { val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }""" connection.sendString(msg) - uuid } def sendNotification(method: String, params: String): Unit = { @@ -746,13 +807,12 @@ class NetworkClient( setDaemon(true) start() val stopped = new AtomicBoolean(false) - val lock = new Object override final def run(): Unit = { @tailrec def read(): Unit = { inputStream.read match { case -1 => case b => - lock.synchronized(stdinBytes.offer(b)) + inLock.synchronized(stdinBytes.offer(b)) if (attached.get()) drain() if (!stopped.get()) read() } @@ -761,7 +821,7 @@ class NetworkClient( catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) } } - def drain(): Unit = lock.synchronized { + def drain(): Unit = inLock.synchronized { while (!stdinBytes.isEmpty) { val byte = stdinBytes.poll() sendNotification(systemIn, byte.toString) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 74376ea49..10d0a519a 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -106,7 +106,8 @@ private[sbt] object Server { val socket = serverSocket.accept() onIncomingSocket(socket, self) } catch { - case _: SocketTimeoutException => // its ok + case e: IOException if e.getMessage.contains("connect") => + case _: SocketTimeoutException => // its ok } } serverSocketHolder.get match { diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 56a582cdc..3d886a3ca 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -111,7 +111,8 @@ private[sbt] object xMain { ): (Option[BootServerSocket], Option[Exit]) = try (Some(new BootServerSocket(configuration)) -> None) catch { - case _: ServerAlreadyBootingException if System.console != null => + case _: ServerAlreadyBootingException + if System.console != null && !Terminal.startedByRemoteClient => println("sbt server is already booting. Create a new server? y/n (default y)") val exit = Terminal.get.withRawSystemIn(System.in.read) match { case 110 => Some(Exit(1)) @@ -296,6 +297,7 @@ object BuiltinCommands { skipBanner, notifyUsersAboutShell, shell, + rebootNetwork, startServer, eval, last, @@ -1046,6 +1048,11 @@ object BuiltinCommands { } } + def rebootNetwork: Command = Command.arb(_ => (RebootNetwork: Parser[String]).examples()) { + (s, _) => + StandardMain.exchange.reboot(s) + s + } def startServer: Command = Command.command(StartServer, Help.more(StartServer, StartServerDetailed)) { s0 => val exchange = StandardMain.exchange @@ -1115,10 +1122,12 @@ object BuiltinCommands { private def intendsToInvokeCompile(state: State) = state.remainingCommands exists (_.commandLine == Keys.compile.key.label) + private def hasRebooted(state: State) = + state.remainingCommands exists (_.commandLine == StartServer) private def notifyUsersAboutShell(state: State): Unit = { val suppress = Project extract state getOpt Keys.suppressSbtShellNotification getOrElse false - if (!suppress && intendsToInvokeCompile(state)) + if (!suppress && intendsToInvokeCompile(state) && !hasRebooted(state)) state.log info "Executing in batch mode. For better performance use sbt's shell" } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 358488605..29091ce96 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -13,7 +13,13 @@ import java.net.Socket import java.util.concurrent.atomic._ import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } -import sbt.BasicCommandStrings.{ Cancel, Shutdown, TerminateAction, networkExecPrefix } +import sbt.BasicCommandStrings.{ + Cancel, + CompleteExec, + Shutdown, + TerminateAction, + networkExecPrefix +} import sbt.BasicKeys._ import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ @@ -377,6 +383,33 @@ private[sbt] final class CommandExchange { channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal)) } + /** + * When a reboot is initiated by a network client, we need to communicate + * to it which + * + * @param state + */ + private[sbt] def reboot(state: State): Unit = state.source match { + case Some(s) if s.channelName.startsWith("network") => + channels.foreach { + case nc: NetworkChannel if nc.name == s.channelName => + val remainingCommands = + state.remainingCommands + .takeWhile(!_.commandLine.startsWith(CompleteExec)) + .map(_.commandLine) + .filterNot(_.startsWith("sbtReboot")) + .mkString(";") + val execId = state.remainingCommands.collectFirst { + case e if e.commandLine.startsWith(CompleteExec) => + e.commandLine.split(CompleteExec).last.trim + } + nc.shutdown(true, execId.map(_ -> remainingCommands)) + case nc: NetworkChannel => nc.shutdown(true, Some(("", ""))) + case _ => + } + case _ => + } + private[sbt] def shutdown(name: String): Unit = { Option(currentExecRef.get).foreach(cancel) commandQueue.clear() diff --git a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala index 363aa7736..928f0ca21 100644 --- a/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala +++ b/main/src/main/scala/sbt/internal/nio/CheckBuildSources.scala @@ -90,7 +90,8 @@ private[sbt] class CheckBuildSources extends AutoCloseable { val commands = allCmds.flatMap(_.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty)) val filter = (c: String) => - c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown + c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown || + c.startsWith("sbtReboot") val res = !commands.exists(filter) if (!res) { previousStamps.set(getStamps(force = true)) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 549ccd022..50b3c7bb6 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -551,6 +551,9 @@ final class NetworkChannel( } import sjsonnew.BasicJsonProtocol.BooleanJsonFormat + override def shutdown(logShutdown: Boolean): Unit = + shutdown(logShutdown, remainingCommands = None) + /** * Closes down the channel. Before closing the socket, it sends a notification to * the client to shutdown. If the client initiated the shutdown, we don't want the @@ -559,13 +562,16 @@ final class NetworkChannel( * easily be done client side because when the client is in interactive session, * it doesn't know commands it has sent to the server. */ - override def shutdown(logShutdown: Boolean): Unit = { + private[sbt] def shutdown( + logShutdown: Boolean, + remainingCommands: Option[(String, String)] + ): Unit = { terminal.close() StandardMain.exchange.removeChannel(this) super.shutdown(logShutdown) if (logShutdown) Terminal.consoleLog(s"shutting down client connection $name") VirtualTerminal.cancelRequests(name) - try jsonRpcNotify(Shutdown, logShutdown) + try jsonRpcNotify(Shutdown, (logShutdown, remainingCommands)) catch { case _: IOException => } running.set(false) out.close()