Support reboot from remote client

Reboot is a bit tricky for the remote client because the sbt server is
actually shut down during reboot. When sbt shuts down the client, it can
notify the client that the reason is a reboot. The client can then
connect to the recently introduced boot control socket to display the
reboot output and supply input in case the build fails to load. Once the
server has brought back up the server, the client can reconnect. When
the client session is interactive, we're done once we reconnect. When
it's a batch session, the client needs to resend the remaing commands
that have submitted that it hasn't yet run.
This commit is contained in:
Ethan Atkins 2020-06-27 15:54:50 -07:00
parent 332a757682
commit f9d5fbf29b
9 changed files with 175 additions and 50 deletions

View File

@ -137,6 +137,8 @@ $HelpCommand <regular expression>
If a classpath is provided, modules are loaded from a new class loader for this classpath.
"""
private[sbt] def RebootNetwork: String = "sbtRebootNetwork"
private[sbt] def RebootImpl: String = "sbtRebootImpl"
def RebootCommand: String = "reboot"
def RebootDetailed: String =
RebootCommand + """ [dev | full]

View File

@ -53,6 +53,7 @@ object BasicCommands {
stashOnFailure,
popOnFailure,
reboot,
rebootImpl,
call,
early,
exit,
@ -304,6 +305,12 @@ object BasicCommands {
def reboot: Command =
Command(RebootCommand, Help.more(RebootCommand, RebootDetailed))(_ => rebootOptionParser) {
case (s, (full, currentOnly)) =>
val option = if (full) " full" else if (currentOnly) " dev" else ""
RebootNetwork :: s"$RebootImpl$option" :: s
}
def rebootImpl: Command =
Command.arb(_ => (RebootImpl ~> rebootOptionParser).examples()) {
case (s, (full, currentOnly)) =>
s.reboot(full, currentOnly)
}

View File

@ -22,6 +22,7 @@ import BasicCommandStrings.{
PopOnFailure,
ReportResult,
SetTerminal,
StartServer,
StashOnFailure,
networkExecPrefix,
}
@ -339,9 +340,14 @@ object State {
/** Implementation of reboot. */
private[sbt] def reboot(full: Boolean, currentOnly: Boolean): State = {
runExitHooks()
val rs = s.remainingCommands map { case e: Exec => e.commandLine }
if (currentOnly) throw new RebootCurrent(rs)
else throw new xsbti.FullReload(rs.toArray, full)
val remaining: List[String] = s.remainingCommands.map(_.commandLine)
val fullRemaining = s.source match {
case Some(s) if s.channelName.startsWith("network") =>
StartServer :: remaining.dropWhile(!_.startsWith(ReportResult)).tail ::: "shell" :: Nil
case _ => remaining
}
if (currentOnly) throw new RebootCurrent(fullRemaining)
else throw new xsbti.FullReload(fullRemaining.toArray, full)
}
def reload = runExitHooks().setNext(new Return(defaultReload(s)))

View File

@ -85,7 +85,8 @@ class NetworkClient(
private val status = new AtomicReference("Ready")
private val lock: AnyRef = new AnyRef {}
private val running = new AtomicBoolean(true)
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
private val pendingResults =
new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long, String)]
private val pendingCancellations = new ConcurrentHashMap[String, LinkedBlockingQueue[Boolean]]
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
private val attached = new AtomicBoolean(false)
@ -93,6 +94,7 @@ class NetworkClient(
private val connectionHolder = new AtomicReference[ServerConnection]
private val batchMode = new AtomicBoolean(false)
private val interactiveThread = new AtomicReference[Thread](null)
private val rebooting = new AtomicBoolean(false)
private lazy val noTab = arguments.completionArguments.contains("--no-tab")
private lazy val noStdErr = arguments.completionArguments.contains("--no-stderr") &&
System.getenv("SBTC_AUTO_COMPLETE") == null
@ -109,6 +111,7 @@ class NetworkClient(
}
private[this] val stdinBytes = new LinkedBlockingQueue[Int]
private[this] val inLock = new Object
private[this] val inputThread = new AtomicReference(new RawInputThread)
private[this] val exitClean = new AtomicBoolean(true)
private[this] val sbtProcess = new AtomicReference[Process](null)
@ -123,17 +126,17 @@ class NetworkClient(
val msg = if (noTab) "" else "No sbt server is running. Press <tab> to start one..."
errorStream.print(s"\n$msg")
if (noStdErr) System.exit(0)
else if (noTab) forkServer(portfile, log = true)
else if (noTab) waitForServer(portfile, log = true, startServer = true)
else {
stdinBytes.take match {
case 9 =>
errorStream.println("\nStarting server...")
forkServer(portfile, !prompt)
waitForServer(portfile, !prompt, startServer = true)
case _ => System.exit(0)
}
}
} else {
forkServer(portfile, log = true)
waitForServer(portfile, log = true, startServer = true)
}
}
@tailrec def connect(attempt: Int): (Socket, Option[String]) = {
@ -155,32 +158,66 @@ class NetworkClient(
val (sk, tkn) = connect(0)
val conn = new ServerConnection(sk) {
override def onNotification(msg: JsonRpcNotificationMessage): Unit = {
if (msg.toString.contains("shutdown")) System.err.println(msg)
msg.method match {
case `Shutdown` =>
val log = msg.params match {
case Some(jvalue) => Converter.fromJson[Boolean](jvalue).getOrElse(true)
case _ => false
val (log, rebootCommands) = msg.params match {
case Some(jvalue) =>
Converter
.fromJson[(Boolean, Option[(String, String)])](jvalue)
.getOrElse((true, None))
case _ => (false, None)
}
if (running.compareAndSet(true, false) && log) {
if (!arguments.commandArguments.contains(Shutdown)) {
if (Terminal.console.getLastLine.fold(true)(_.nonEmpty)) errorStream.println()
console.appendLog(Level.Error, "sbt server disconnected")
exitClean.set(false)
if (rebootCommands.nonEmpty) {
if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println()
rebooting.set(true)
attached.set(false)
connectionHolder.getAndSet(null) match {
case null =>
case c => c.shutdown()
}
waitForServer(portfile, true, false)
init(prompt = false, retry = false)
attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}"""))
rebooting.set(false)
rebootCommands match {
case Some((execId, cmd)) if execId.nonEmpty =>
if (batchMode.get && !pendingResults.contains(execId) && cmd.isEmpty) {
console.appendLog(
Level.Error,
s"received request to re-run unknown command '$cmd' after reboot"
)
} else if (cmd.nonEmpty) {
if (batchMode.get) sendCommand(ExecCommand(cmd, execId))
else
inLock.synchronized {
val toSend = cmd.getBytes :+ '\r'.toByte
toSend.foreach(b => sendNotification(systemIn, b.toString))
}
} else completeExec(execId, 0)
case _ =>
}
} else {
console.appendLog(Level.Info, "sbt server disconnected")
if (!rebooting.get() && running.compareAndSet(true, false) && log) {
if (!arguments.commandArguments.contains(Shutdown)) {
if (Terminal.console.getLastLine.isDefined)
Terminal.console.printStream.println()
console.appendLog(Level.Error, "sbt server disconnected")
exitClean.set(false)
}
} else {
console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected")
}
stdinBytes.offer(-1)
Option(inputThread.get).foreach(_.close())
Option(interactiveThread.get).foreach(_.interrupt)
}
stdinBytes.offer(-1)
Option(inputThread.get).foreach(_.close())
Option(interactiveThread.get).foreach(_.interrupt)
case "readInput" =>
case _ => self.onNotification(msg)
}
}
override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg)
override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg)
override def onShutdown(): Unit = {
override def onShutdown(): Unit = if (!rebooting.get) {
if (exitClean.get != false) exitClean.set(!running.get)
running.set(false)
Option(interactiveThread.get).foreach(_.interrupt())
@ -202,14 +239,22 @@ class NetworkClient(
* Forks another instance of sbt in the background.
* This instance must be shutdown explicitly via `sbt -client shutdown`
*/
def forkServer(portfile: File, log: Boolean): Unit = {
def waitForServer(portfile: File, log: Boolean, startServer: Boolean): Unit = {
val bootSocketName =
BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath())
var socket: Option[Socket] = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
/*
* For unknown reasons, linux sometimes struggles to connect to the socket in some
* scenarios.
*/
var socket: Option[Socket] =
if (!Properties.isLinux) Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
else None
val process = socket match {
case None =>
case None if startServer =>
val term = Terminal.console
if (log) console.appendLog(Level.Info, "server was not detected. starting an instance")
val props =
Seq(
term.getWidth,
@ -229,12 +274,21 @@ class NetworkClient(
sbtProcess.set(process)
Some(process)
case _ =>
if (log) console.appendLog(Level.Info, "sbt server is booting up")
if (log) {
if (Terminal.console.getLastLine.isDefined) Terminal.console.printStream.println()
console.appendLog(Level.Info, "sbt server is booting up")
}
None
}
if (!startServer) {
val deadline = 5.seconds.fromNow
while (socket.isEmpty && !deadline.isOverdue) {
socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
if (socket.isEmpty) Thread.sleep(20)
}
}
val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly()))
Runtime.getRuntime.addShutdownHook(hook)
val isWin = Properties.isWin
var gotInputBack = false
val readThreadAlive = new AtomicBoolean(true)
/*
@ -248,6 +302,9 @@ class NetworkClient(
override def run(): Unit = {
try {
while (readThreadAlive.get) {
if (socket.isEmpty) {
socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
}
socket.foreach { s =>
try {
s.getInputStream.read match {
@ -272,9 +329,6 @@ class NetworkClient(
}
@tailrec
def blockUntilStart(): Unit = {
if (socket.isEmpty) {
socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
}
val stop = try {
socket match {
case None =>
@ -311,8 +365,9 @@ class NetworkClient(
* will return with exit value 2. In that case, we can treat the process as alive
* even if it is actually dead.
*/
val existsValidProcess = process.fold(socket.isDefined)(p => p.isAlive || p.exitValue == 2)
if (!portfile.exists && !stop && readThreadAlive.get && existsValidProcess) {
val existsValidProcess =
process.fold(readThreadAlive.get)(p => p.isAlive || (Properties.isWin || p.exitValue == 2))
if (!portfile.exists && !stop && existsValidProcess) {
blockUntilStart()
} else {
socket.foreach { s =>
@ -367,19 +422,22 @@ class NetworkClient(
.getOrElse(1)
case _ => 1
}
def onResponse(msg: JsonRpcResponseMessage): Unit = {
pendingResults.remove(msg.id) match {
private def completeExec(execId: String, exitCode: => Int): Unit =
pendingResults.remove(execId) match {
case null =>
case (q, startTime) =>
case (q, startTime, name) =>
val now = System.currentTimeMillis
val message = timing(startTime, now)
val exitCode = getExitCode(msg.result)
val ec = exitCode
if (batchMode.get || !attached.get) {
if (exitCode == 0) console.success(message)
else if (!attached.get) console.appendLog(Level.Error, message)
console.appendLog(Level.Info, s"$name completed")
if (ec == 0) console.success(message)
else console.appendLog(Level.Error, message)
}
q.offer(exitCode)
Util.ignoreResult(q.offer(ec))
}
def onResponse(msg: JsonRpcResponseMessage): Unit = {
completeExec(msg.id, getExitCode(msg.result))
pendingCancellations.remove(msg.id) match {
case null =>
case q => q.offer(msg.toString.contains("Task cancelled"))
@ -681,7 +739,7 @@ class NetworkClient(
val execId = UUID.randomUUID.toString
val queue = new LinkedBlockingQueue[Integer]
sendCommand(ExecCommand(commandLine, execId))
pendingResults.put(execId, (queue, System.currentTimeMillis))
pendingResults.put(execId, (queue, System.currentTimeMillis, commandLine))
queue
}
@ -718,9 +776,12 @@ class NetworkClient(
}
def sendJson(method: String, params: String): String = {
val uuid = UUID.randomUUID.toString
sendJson(method, params, uuid)
uuid
}
def sendJson(method: String, params: String, uuid: String): Unit = {
val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }"""
connection.sendString(msg)
uuid
}
def sendNotification(method: String, params: String): Unit = {
@ -746,13 +807,12 @@ class NetworkClient(
setDaemon(true)
start()
val stopped = new AtomicBoolean(false)
val lock = new Object
override final def run(): Unit = {
@tailrec def read(): Unit = {
inputStream.read match {
case -1 =>
case b =>
lock.synchronized(stdinBytes.offer(b))
inLock.synchronized(stdinBytes.offer(b))
if (attached.get()) drain()
if (!stopped.get()) read()
}
@ -761,7 +821,7 @@ class NetworkClient(
catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) }
}
def drain(): Unit = lock.synchronized {
def drain(): Unit = inLock.synchronized {
while (!stdinBytes.isEmpty) {
val byte = stdinBytes.poll()
sendNotification(systemIn, byte.toString)

View File

@ -106,7 +106,8 @@ private[sbt] object Server {
val socket = serverSocket.accept()
onIncomingSocket(socket, self)
} catch {
case _: SocketTimeoutException => // its ok
case e: IOException if e.getMessage.contains("connect") =>
case _: SocketTimeoutException => // its ok
}
}
serverSocketHolder.get match {

View File

@ -111,7 +111,8 @@ private[sbt] object xMain {
): (Option[BootServerSocket], Option[Exit]) =
try (Some(new BootServerSocket(configuration)) -> None)
catch {
case _: ServerAlreadyBootingException if System.console != null =>
case _: ServerAlreadyBootingException
if System.console != null && !Terminal.startedByRemoteClient =>
println("sbt server is already booting. Create a new server? y/n (default y)")
val exit = Terminal.get.withRawSystemIn(System.in.read) match {
case 110 => Some(Exit(1))
@ -296,6 +297,7 @@ object BuiltinCommands {
skipBanner,
notifyUsersAboutShell,
shell,
rebootNetwork,
startServer,
eval,
last,
@ -1046,6 +1048,11 @@ object BuiltinCommands {
}
}
def rebootNetwork: Command = Command.arb(_ => (RebootNetwork: Parser[String]).examples()) {
(s, _) =>
StandardMain.exchange.reboot(s)
s
}
def startServer: Command =
Command.command(StartServer, Help.more(StartServer, StartServerDetailed)) { s0 =>
val exchange = StandardMain.exchange
@ -1115,10 +1122,12 @@ object BuiltinCommands {
private def intendsToInvokeCompile(state: State) =
state.remainingCommands exists (_.commandLine == Keys.compile.key.label)
private def hasRebooted(state: State) =
state.remainingCommands exists (_.commandLine == StartServer)
private def notifyUsersAboutShell(state: State): Unit = {
val suppress = Project extract state getOpt Keys.suppressSbtShellNotification getOrElse false
if (!suppress && intendsToInvokeCompile(state))
if (!suppress && intendsToInvokeCompile(state) && !hasRebooted(state))
state.log info "Executing in batch mode. For better performance use sbt's shell"
}

View File

@ -13,7 +13,13 @@ import java.net.Socket
import java.util.concurrent.atomic._
import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit }
import sbt.BasicCommandStrings.{ Cancel, Shutdown, TerminateAction, networkExecPrefix }
import sbt.BasicCommandStrings.{
Cancel,
CompleteExec,
Shutdown,
TerminateAction,
networkExecPrefix
}
import sbt.BasicKeys._
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.server._
@ -377,6 +383,33 @@ private[sbt] final class CommandExchange {
channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal))
}
/**
* When a reboot is initiated by a network client, we need to communicate
* to it which
*
* @param state
*/
private[sbt] def reboot(state: State): Unit = state.source match {
case Some(s) if s.channelName.startsWith("network") =>
channels.foreach {
case nc: NetworkChannel if nc.name == s.channelName =>
val remainingCommands =
state.remainingCommands
.takeWhile(!_.commandLine.startsWith(CompleteExec))
.map(_.commandLine)
.filterNot(_.startsWith("sbtReboot"))
.mkString(";")
val execId = state.remainingCommands.collectFirst {
case e if e.commandLine.startsWith(CompleteExec) =>
e.commandLine.split(CompleteExec).last.trim
}
nc.shutdown(true, execId.map(_ -> remainingCommands))
case nc: NetworkChannel => nc.shutdown(true, Some(("", "")))
case _ =>
}
case _ =>
}
private[sbt] def shutdown(name: String): Unit = {
Option(currentExecRef.get).foreach(cancel)
commandQueue.clear()

View File

@ -90,7 +90,8 @@ private[sbt] class CheckBuildSources extends AutoCloseable {
val commands =
allCmds.flatMap(_.split(";").flatMap(_.trim.split(" ").headOption).filterNot(_.isEmpty))
val filter = (c: String) =>
c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown
c == LoadProject || c == RebootCommand || c == TerminateAction || c == Shutdown ||
c.startsWith("sbtReboot")
val res = !commands.exists(filter)
if (!res) {
previousStamps.set(getStamps(force = true))

View File

@ -551,6 +551,9 @@ final class NetworkChannel(
}
import sjsonnew.BasicJsonProtocol.BooleanJsonFormat
override def shutdown(logShutdown: Boolean): Unit =
shutdown(logShutdown, remainingCommands = None)
/**
* Closes down the channel. Before closing the socket, it sends a notification to
* the client to shutdown. If the client initiated the shutdown, we don't want the
@ -559,13 +562,16 @@ final class NetworkChannel(
* easily be done client side because when the client is in interactive session,
* it doesn't know commands it has sent to the server.
*/
override def shutdown(logShutdown: Boolean): Unit = {
private[sbt] def shutdown(
logShutdown: Boolean,
remainingCommands: Option[(String, String)]
): Unit = {
terminal.close()
StandardMain.exchange.removeChannel(this)
super.shutdown(logShutdown)
if (logShutdown) Terminal.consoleLog(s"shutting down client connection $name")
VirtualTerminal.cancelRequests(name)
try jsonRpcNotify(Shutdown, logShutdown)
try jsonRpcNotify(Shutdown, (logShutdown, remainingCommands))
catch { case _: IOException => }
running.set(false)
out.close()