diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 4c3fc0613..170d661a9 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -156,13 +156,42 @@ class NetworkClient( private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI) + @tailrec + private def connect(file: File, attempt: Int = 0): (Socket, Option[String]) = { + val res = + try Some(mkSocket(file)) + catch { + case e: IOException + if Option(e.getMessage).exists(_.contains("Couldn't open")) && attempt < 10 => + if ( + Option(e.getMessage).exists(m => m.contains("Access is denied") || m.contains("(5)")) + ) { + errorStream.println(s"Access denied for portfile $file") + throw new NetworkClient.AccessDeniedException + } + None + case e: IOException => throw new ConnectionRefusedException(e) + } + res match { + case Some(r) => r + case None => + Thread.sleep(new java.util.Random().nextInt(20).toLong) + connect(file, attempt + 1) + } + } + private def portfile = arguments.baseDirectory / "project" / "target" / "active.json" def connection: ServerSession = connectionHolder.synchronized { - connectionHolder.get match { - case null => init(promptCompleteUsers = false, retry = true) - case c => c - } + @tailrec def getConnection(): ServerSession = + connectionHolder.get match { + case null if rebooting.get && running.get => + connectionHolder.wait(20) + getConnection() + case null => init(promptCompleteUsers = false, retry = true) + case c => c + } + getConnection() } private val stdinBytes = new LinkedBlockingQueue[Integer] @@ -215,29 +244,7 @@ class NetworkClient( waitForServer(portfile, log = true, startServer = true) } } - @tailrec def connect(attempt: Int): (Socket, Option[String]) = { - val res = - try Some(mkSocket(portfile)) - catch { - // This catches a pipe busy exception which can happen if two windows clients - // attempt to connect in rapid succession - case e: IOException if e.getMessage.contains("Couldn't open") && attempt < 10 => - if (e.getMessage.contains("Access is denied") || e.getMessage.contains("(5)")) { - errorStream.println(s"Access denied for portfile $portfile") - throw new NetworkClient.AccessDeniedException - } - None - case e: IOException => throw new ConnectionRefusedException(e) - } - res match { - case Some(r) => r - case None => - // Use a random sleep to spread out the competing processes - Thread.sleep(new java.util.Random().nextInt(20).toLong) - connect(attempt + 1) - } - } - connect(0) + connect(portfile) } catch { case e: ConnectionRefusedException if retry => if (Files.deleteIfExists(portfile.toPath)) @@ -245,9 +252,35 @@ class NetworkClient( else throw e } + private def connectToRebootedServer(): (Socket, Option[String]) = { + val deadline = 10.seconds.fromNow + @tailrec def loop(lastError: Option[Throwable]): (Socket, Option[String]) = + if (deadline.isOverdue()) { + lastError match { + case Some(e) => throw e + case _ => + throw new TimeoutException( + s"timed out reconnecting to rebooted sbt server via $portfile" + ) + } + } else if (portfile.exists) { + try connect(portfile) + catch { + case e: NetworkClient.AccessDeniedException => throw e + case NonFatal(e) => + Thread.sleep(20) + loop(Some(e)) + } + } else { + Thread.sleep(20) + loop(lastError) + } + loop(None) + } + // Open server connection based on the portfile - def init(promptCompleteUsers: Boolean, retry: Boolean): ServerSession = { - val (sk, tkn) = connectOrStartServerAndConnect(promptCompleteUsers, retry) + private def init(connectionInfo: => (Socket, Option[String])): ServerSession = { + val (sk, tkn) = connectionInfo val conn = new ServerSessionImpl(sk, s"sbt-serverconnection-${sk.getPort}") { override protected def onNotification(msg: JsonRpcNotificationMessage): Unit = { msg.method match { @@ -267,7 +300,7 @@ class NetworkClient( case c => c.close() } waitForServer(portfile, true, false) - init(promptCompleteUsers = false, retry = false) + init(connectToRebootedServer()) attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}""")) rebooting.set(false) rebootCommands match { @@ -333,10 +366,16 @@ class NetworkClient( initializationOptions = Some(opts), ) conn.sendCommand(initCommand) - connectionHolder.set(conn) + connectionHolder.synchronized { + connectionHolder.set(conn) + connectionHolder.notifyAll() + } conn } + def init(promptCompleteUsers: Boolean, retry: Boolean): ServerSession = + init(connectOrStartServerAndConnect(promptCompleteUsers, retry)) + /** * Forks another instance of sbt in the background. * This instance must be shutdown explicitly via `sbt -client shutdown` diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 8a542fb8c..9c3088eb2 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -479,7 +479,7 @@ private[sbt] final class CommandExchange { case e if e.commandLine.startsWith(CompleteExec) => e.commandLine.split(CompleteExec).last.trim } - nc.shutdown(true, execId.map(_ -> remainingCommands)) + nc.shutdown(true, Some(execId.getOrElse("") -> remainingCommands)) case nc: NetworkChannel => nc.shutdown(true, Some(("", ""))) case _ => }