fix: reconnect thin client on reboot

This commit is contained in:
jonathanchang31 2026-04-16 15:14:12 +02:00
parent 7218b2a1ac
commit 9ef618266c
2 changed files with 71 additions and 32 deletions

View File

@ -156,13 +156,42 @@ class NetworkClient(
private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI)
@tailrec
private def connect(file: File, attempt: Int = 0): (Socket, Option[String]) = {
val res =
try Some(mkSocket(file))
catch {
case e: IOException
if Option(e.getMessage).exists(_.contains("Couldn't open")) && attempt < 10 =>
if (
Option(e.getMessage).exists(m => m.contains("Access is denied") || m.contains("(5)"))
) {
errorStream.println(s"Access denied for portfile $file")
throw new NetworkClient.AccessDeniedException
}
None
case e: IOException => throw new ConnectionRefusedException(e)
}
res match {
case Some(r) => r
case None =>
Thread.sleep(new java.util.Random().nextInt(20).toLong)
connect(file, attempt + 1)
}
}
private def portfile = arguments.baseDirectory / "project" / "target" / "active.json"
def connection: ServerSession = connectionHolder.synchronized {
connectionHolder.get match {
case null => init(promptCompleteUsers = false, retry = true)
case c => c
}
@tailrec def getConnection(): ServerSession =
connectionHolder.get match {
case null if rebooting.get && running.get =>
connectionHolder.wait(20)
getConnection()
case null => init(promptCompleteUsers = false, retry = true)
case c => c
}
getConnection()
}
private val stdinBytes = new LinkedBlockingQueue[Integer]
@ -213,29 +242,7 @@ class NetworkClient(
waitForServer(portfile, log = true, startServer = true)
}
}
@tailrec def connect(attempt: Int): (Socket, Option[String]) = {
val res =
try Some(mkSocket(portfile))
catch {
// This catches a pipe busy exception which can happen if two windows clients
// attempt to connect in rapid succession
case e: IOException if e.getMessage.contains("Couldn't open") && attempt < 10 =>
if (e.getMessage.contains("Access is denied") || e.getMessage.contains("(5)")) {
errorStream.println(s"Access denied for portfile $portfile")
throw new NetworkClient.AccessDeniedException
}
None
case e: IOException => throw new ConnectionRefusedException(e)
}
res match {
case Some(r) => r
case None =>
// Use a random sleep to spread out the competing processes
Thread.sleep(new java.util.Random().nextInt(20).toLong)
connect(attempt + 1)
}
}
connect(0)
connect(portfile)
} catch {
case e: ConnectionRefusedException if retry =>
if (Files.deleteIfExists(portfile.toPath))
@ -243,9 +250,35 @@ class NetworkClient(
else throw e
}
private def connectToRebootedServer(): (Socket, Option[String]) = {
val deadline = 10.seconds.fromNow
@tailrec def loop(lastError: Option[Throwable]): (Socket, Option[String]) =
if (deadline.isOverdue()) {
lastError match {
case Some(e) => throw e
case _ =>
throw new TimeoutException(
s"timed out reconnecting to rebooted sbt server via $portfile"
)
}
} else if (portfile.exists) {
try connect(portfile)
catch {
case e: NetworkClient.AccessDeniedException => throw e
case NonFatal(e) =>
Thread.sleep(20)
loop(Some(e))
}
} else {
Thread.sleep(20)
loop(lastError)
}
loop(None)
}
// Open server connection based on the portfile
def init(promptCompleteUsers: Boolean, retry: Boolean): ServerSession = {
val (sk, tkn) = connectOrStartServerAndConnect(promptCompleteUsers, retry)
private def init(connectionInfo: => (Socket, Option[String])): ServerSession = {
val (sk, tkn) = connectionInfo
val conn = new ServerSessionImpl(sk, s"sbt-serverconnection-${sk.getPort}") {
override protected def onNotification(msg: JsonRpcNotificationMessage): Unit = {
msg.method match {
@ -265,7 +298,7 @@ class NetworkClient(
case c => c.close()
}
waitForServer(portfile, true, false)
init(promptCompleteUsers = false, retry = false)
init(connectToRebootedServer())
attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}"""))
rebooting.set(false)
rebootCommands match {
@ -331,10 +364,16 @@ class NetworkClient(
initializationOptions = Some(opts),
)
conn.sendCommand(initCommand)
connectionHolder.set(conn)
connectionHolder.synchronized {
connectionHolder.set(conn)
connectionHolder.notifyAll()
}
conn
}
def init(promptCompleteUsers: Boolean, retry: Boolean): ServerSession =
init(connectOrStartServerAndConnect(promptCompleteUsers, retry))
/**
* Forks another instance of sbt in the background.
* This instance must be shutdown explicitly via `sbt -client shutdown`

View File

@ -470,7 +470,7 @@ private[sbt] final class CommandExchange {
case e if e.commandLine.startsWith(CompleteExec) =>
e.commandLine.split(CompleteExec).last.trim
}
nc.shutdown(true, execId.map(_ -> remainingCommands))
nc.shutdown(true, Some(execId.getOrElse("") -> remainingCommands))
case nc: NetworkChannel => nc.shutdown(true, Some(("", "")))
case _ =>
}