diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index b7cd0ad92..16d3041f0 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -10,21 +10,24 @@ package internal package client import java.io.{ File, IOException, InputStream, PrintStream } +import java.lang.ProcessBuilder.Redirect +import java.net.Socket +import java.nio.file.Files import java.util.UUID import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams } import sbt.internal.protocol._ -import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader } +import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader, Terminal, Util } import sbt.io.IO import sbt.io.syntax._ import sbt.protocol._ import sbt.util.Level import sjsonnew.support.scalajson.unsafe.Converter +import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import scala.collection.mutable -import scala.sys.process.{ BasicIO, Process, ProcessLogger } import scala.util.Properties import scala.util.control.NonFatal import scala.util.{ Failure, Success } @@ -65,70 +68,114 @@ class NetworkClient( private val status = new AtomicReference("Ready") private val lock: AnyRef = new AnyRef {} private val running = new AtomicBoolean(true) + private val connectionHolder = new AtomicReference[ServerConnection] + private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI) private val pendingExecIds = ListBuffer.empty[String] - private def baseDirectory: File = arguments.baseDirectory + private def portfile = arguments.baseDirectory / "project" / "target" / "active.json" - lazy val connection = init() + def connection: ServerConnection = connectionHolder.synchronized { + connectionHolder.get match { + case null => init(true) + case c => c + } + } - start() + private[this] val sbtProcess = new AtomicReference[Process](null) + private class ConnectionRefusedException(t: Throwable) extends Throwable(t) // Open server connection based on the portfile - def init(): ServerConnection = { - val portfile = baseDirectory / "project" / "target" / "active.json" - if (!portfile.exists) { - forkServer(portfile) - } - val (sk, tkn) = ClientSocket.socket(portfile) - val conn = new ServerConnection(sk) { - override def onNotification(msg: JsonRpcNotificationMessage): Unit = self.onNotification(msg) - override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) - override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) - override def onShutdown(): Unit = { - running.set(false) + def init(retry: Boolean): ServerConnection = + try { + if (!portfile.exists) { + forkServer(portfile, log = true) } + val (sk, tkn) = + try mkSocket(portfile) + catch { case e: IOException => throw new ConnectionRefusedException(e) } + val conn = new ServerConnection(sk) { + override def onNotification(msg: JsonRpcNotificationMessage): Unit = + self.onNotification(msg) + override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) + override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) + override def onShutdown(): Unit = { + running.set(false) + } + } + // initiate handshake + val execId = UUID.randomUUID.toString + val initCommand = InitCommand(tkn, Option(execId), Some(true)) + conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand)) + connectionHolder.set(conn) + conn + } catch { + case e: ConnectionRefusedException if retry => + if (Files.deleteIfExists(portfile.toPath)) init(retry = false) + else throw e } - // initiate handshake - val execId = UUID.randomUUID.toString - val initCommand = InitCommand(tkn, Option(execId), Some(true)) - conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand)) - conn - } /** * Forks another instance of sbt in the background. * This instance must be shutdown explicitly via `sbt -client shutdown` */ - def forkServer(portfile: File): Unit = { - console.appendLog(Level.Info, "server was not detected. starting an instance") - val args = List[String]() - val launchOpts = List("-Xms2048M", "-Xmx2048M", "-Xss2M") - val launcherJarString = sys.props.get("java.class.path") match { - case Some(cp) => - cp.split(File.pathSeparator) - .toList - .headOption - .getOrElse(sys.error("launcher JAR classpath not found")) - case _ => sys.error("property java.class.path expected") - } - val cmd = "java" :: launchOpts ::: "-jar" :: launcherJarString :: args - // val cmd = "sbt" - val io = BasicIO(false, ProcessLogger(_ => ())) - val _ = Process(cmd, baseDirectory).run(io) - def waitForPortfile(n: Int): Unit = - if (portfile.exists) { - console.appendLog(Level.Info, "server found") - } else { - if (n <= 0) sys.error(s"timeout. $portfile is not found.") - else { - Thread.sleep(1000) - if ((n - 1) % 10 == 0) { - console.appendLog(Level.Info, "waiting for the server...") - } - waitForPortfile(n - 1) + def forkServer(portfile: File, log: Boolean): Unit = { + if (log) console.appendLog(Level.Info, "server was not detected. starting an instance") + val color = + if (!arguments.sbtArguments.exists(_.startsWith("-Dsbt.color="))) + s"-Dsbt.color=${Terminal.console.isColorEnabled}" :: Nil + else Nil + val superShell = + if (!arguments.sbtArguments.exists(_.startsWith("-Dsbt.supershell="))) + s"-Dsbt.supershell=${Terminal.console.isColorEnabled}" :: Nil + else Nil + + val args = color ++ superShell ++ arguments.sbtArguments + val cmd = arguments.sbtScript +: args + val process = + new ProcessBuilder(cmd: _*) + .directory(arguments.baseDirectory) + .redirectInput(Redirect.PIPE) + .start() + sbtProcess.set(process) + val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly())) + Runtime.getRuntime.addShutdownHook(hook) + val stdout = process.getInputStream + val stderr = process.getErrorStream + val stdin = process.getOutputStream + @tailrec + def blockUntilStart(): Unit = { + val stop = try { + while (stdout.available > 0) { + val byte = stdout.read + printStream.write(byte) } + while (stderr.available > 0) { + val byte = stderr.read + errorStream.write(byte) + } + while (System.in.available > 0) { + val byte = System.in.read + stdin.write(byte) + } + false + } catch { + case _: IOException => true } - waitForPortfile(90) + Thread.sleep(10) + if (!portfile.exists && !stop) blockUntilStart() + else { + stdin.close() + stdout.close() + stderr.close() + process.getOutputStream.close() + } + } + + try blockUntilStart() + catch { case t: Throwable => t.printStackTrace() } finally { + sbtProcess.set(null) + Util.ignoreResult(Runtime.getRuntime.removeShutdownHook(hook)) + } } /** Called on the response for a returning message. */