diff --git a/main-command/src/main/scala/sbt/internal/client/BspClient.scala b/main-command/src/main/scala/sbt/internal/client/BspClient.scala index d9568304c..613e5de96 100644 --- a/main-command/src/main/scala/sbt/internal/client/BspClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/BspClient.scala @@ -9,6 +9,7 @@ package sbt.internal.client import java.io.{ File, InputStream, OutputStream } import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean import sbt.Exit import sbt.io.syntax._ @@ -18,18 +19,37 @@ import scala.sys.process.Process import scala.util.control.NonFatal class BspClient private (sbtServer: Socket) { - private val lock = new AnyRef - private var terminated = false + private def run(): Exit = Exit(BspClient.bspRun(sbtServer)) +} - private def transferTo(input: InputStream, output: OutputStream): Thread = { +object BspClient { + private[sbt] def bspRun(sbtServer: Socket): Int = { + val lock = new AnyRef + val terminated = new AtomicBoolean(false) + transferTo(terminated, lock, sbtServer.getInputStream, System.out).start() + transferTo(terminated, lock, System.in, sbtServer.getOutputStream).start() + try { + lock.synchronized { + while (!terminated.get) lock.wait() + } + 0 + } catch { case _: Throwable => 1 } finally sbtServer.close() + } + + private[sbt] def transferTo( + terminated: AtomicBoolean, + lock: AnyRef, + input: InputStream, + output: OutputStream + ): Thread = { val thread = new Thread { override def run(): Unit = { val buffer = Array.ofDim[Byte](1024) try { - while (!terminated) { + while (!terminated.get) { val size = input.read(buffer) if (size == -1) { - terminated = true + terminated.set(true) } else { output.write(buffer, 0, size) output.flush() @@ -38,10 +58,11 @@ class BspClient private (sbtServer: Socket) { input.close() output.close() } catch { - case NonFatal(_) => () + case _: InterruptedException => terminated.set(true) + case NonFatal(_) => () } finally { lock.synchronized { - terminated = true + terminated.set(true) lock.notify() } } @@ -50,24 +71,6 @@ class BspClient private (sbtServer: Socket) { thread.setDaemon(true) thread } - - private def run(): Exit = { - try { - transferTo(sbtServer.getInputStream, System.out).start() - transferTo(System.in, sbtServer.getOutputStream).start() - - lock.synchronized { - while (!terminated) lock.wait() - } - - Exit(0) - } catch { - case NonFatal(_) => Exit(1) - } - } -} - -object BspClient { def run(configuration: xsbti.AppConfiguration): Exit = { val baseDirectory = configuration.baseDirectory val portFile = baseDirectory / "project" / "target" / "active.json" diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 17d5e87fb..5774f9e3a 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -164,8 +164,10 @@ class NetworkClient( case _ => } - // Open server connection based on the portfile - def init(promptCompleteUsers: Boolean, retry: Boolean): ServerConnection = + private[sbt] def connectOrStartServerAndConnect( + promptCompleteUsers: Boolean, + retry: Boolean + ): (Socket, Option[String]) = try { if (!portfile.exists) { if (promptCompleteUsers) { @@ -208,88 +210,94 @@ class NetworkClient( connect(attempt + 1) } } - val (sk, tkn) = connect(0) - val conn = new ServerConnection(sk) { - override def onNotification(msg: JsonRpcNotificationMessage): Unit = { - msg.method match { - case `Shutdown` => - val (log, rebootCommands) = msg.params match { - case Some(jvalue) => - Converter - .fromJson[(Boolean, Option[(String, String)])](jvalue) - .getOrElse((true, None)) - case _ => (false, None) - } - if (rebootCommands.nonEmpty) { - rebooting.set(true) - attached.set(false) - connectionHolder.getAndSet(null) match { - case null => - case c => c.shutdown() - } - waitForServer(portfile, true, false) - init(promptCompleteUsers = false, retry = false) - attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}""")) - rebooting.set(false) - rebootCommands match { - case Some((execId, cmd)) if execId.nonEmpty => - if (batchMode.get && !pendingResults.containsKey(execId) && cmd.nonEmpty) { - console.appendLog( - Level.Error, - s"received request to re-run unknown command '$cmd' after reboot" - ) - } else if (cmd.nonEmpty) { - if (batchMode.get) sendCommand(ExecCommand(cmd, execId)) - else - inLock.synchronized { - val toSend = cmd.getBytes :+ '\r'.toByte - toSend.foreach(b => sendNotification(systemIn, b.toString)) - } - } else completeExec(execId, 0) - case _ => - } - } else { - if (!rebooting.get() && running.compareAndSet(true, false) && log) { - if (!arguments.commandArguments.contains(Shutdown)) { - console.appendLog(Level.Error, "sbt server disconnected") - exitClean.set(false) - } - } else { - console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected") - } - stdinBytes.offer(-1) - Option(inputThread.get).foreach(_.close()) - Option(interactiveThread.get).foreach(_.interrupt) - } - case `readSystemIn` => startInputThread() - case `cancelReadSystemIn` => - inputThread.get match { - case null => - case t => t.close() - } - case _ => self.onNotification(msg) - } - } - override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) - override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) - override def onShutdown(): Unit = if (!rebooting.get) { - if (exitClean.get != false) exitClean.set(!running.get) - running.set(false) - Option(interactiveThread.get).foreach(_.interrupt()) - } - } - // initiate handshake - val execId = UUID.randomUUID.toString - val initCommand = InitCommand(tkn, Option(execId), Some(true)) - conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand)) - connectionHolder.set(conn) - conn + connect(0) } catch { case e: ConnectionRefusedException if retry => - if (Files.deleteIfExists(portfile.toPath)) init(promptCompleteUsers, retry = false) + if (Files.deleteIfExists(portfile.toPath)) + connectOrStartServerAndConnect(promptCompleteUsers, retry = false) else throw e } + // Open server connection based on the portfile + def init(promptCompleteUsers: Boolean, retry: Boolean): ServerConnection = { + val (sk, tkn) = connectOrStartServerAndConnect(promptCompleteUsers, retry) + val conn = new ServerConnection(sk) { + override def onNotification(msg: JsonRpcNotificationMessage): Unit = { + msg.method match { + case `Shutdown` => + val (log, rebootCommands) = msg.params match { + case Some(jvalue) => + Converter + .fromJson[(Boolean, Option[(String, String)])](jvalue) + .getOrElse((true, None)) + case _ => (false, None) + } + if (rebootCommands.nonEmpty) { + rebooting.set(true) + attached.set(false) + connectionHolder.getAndSet(null) match { + case null => + case c => c.shutdown() + } + waitForServer(portfile, true, false) + init(promptCompleteUsers = false, retry = false) + attachUUID.set(sendJson(attach, s"""{"interactive": ${!batchMode.get}}""")) + rebooting.set(false) + rebootCommands match { + case Some((execId, cmd)) if execId.nonEmpty => + if (batchMode.get && !pendingResults.containsKey(execId) && cmd.nonEmpty) { + console.appendLog( + Level.Error, + s"received request to re-run unknown command '$cmd' after reboot" + ) + } else if (cmd.nonEmpty) { + if (batchMode.get) sendCommand(ExecCommand(cmd, execId)) + else + inLock.synchronized { + val toSend = cmd.getBytes :+ '\r'.toByte + toSend.foreach(b => sendNotification(systemIn, b.toString)) + } + } else completeExec(execId, 0) + case _ => + } + } else { + if (!rebooting.get() && running.compareAndSet(true, false) && log) { + if (!arguments.commandArguments.contains(Shutdown)) { + console.appendLog(Level.Error, "sbt server disconnected") + exitClean.set(false) + } + } else { + console.appendLog(Level.Info, s"${if (log) "sbt server " else ""}disconnected") + } + stdinBytes.offer(-1) + Option(inputThread.get).foreach(_.close()) + Option(interactiveThread.get).foreach(_.interrupt) + } + case `readSystemIn` => startInputThread() + case `cancelReadSystemIn` => + inputThread.get match { + case null => + case t => t.close() + } + case _ => self.onNotification(msg) + } + } + override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) + override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) + override def onShutdown(): Unit = if (!rebooting.get) { + if (exitClean.get != false) exitClean.set(!running.get) + running.set(false) + Option(interactiveThread.get).foreach(_.interrupt()) + } + } + // initiate handshake + val execId = UUID.randomUUID.toString + val initCommand = InitCommand(tkn, Option(execId), Some(true)) + conn.sendString(Serialization.serializeCommandAsJsonMessage(initCommand)) + connectionHolder.set(conn) + conn + } + /** * Forks another instance of sbt in the background. * This instance must be shutdown explicitly via `sbt -client shutdown` @@ -1006,9 +1014,10 @@ object NetworkClient { val commandArguments: Seq[String], val completionArguments: Seq[String], val sbtScript: String, + val bsp: Boolean, ) { def withBaseDirectory(file: File): Arguments = - new Arguments(file, sbtArguments, commandArguments, completionArguments, sbtScript) + new Arguments(file, sbtArguments, commandArguments, completionArguments, sbtScript, bsp) } private[client] val completions = "--completions" private[client] val noTab = "--no-tab" @@ -1016,6 +1025,7 @@ object NetworkClient { private[client] val sbtBase = "--sbt-base-directory" private[client] def parseArgs(args: Array[String]): Arguments = { var sbtScript = if (Properties.isWin) "sbt.bat" else "sbt" + var bsp = false val commandArgs = new mutable.ArrayBuffer[String] val sbtArguments = new mutable.ArrayBuffer[String] val completionArguments = new mutable.ArrayBuffer[String] @@ -1037,6 +1047,7 @@ object NetworkClient { .lastOption .map(_.replaceAllLiterally("%20", " ")) .getOrElse(sbtScript) + case "-bsp" | "--bsp" => bsp = true case "--sbt-script" if i + 1 < sanitized.length => i += 1 sbtScript = sanitized(i).replaceAllLiterally("%20", " ") @@ -1050,7 +1061,7 @@ object NetworkClient { } val base = new File("").getCanonicalFile if (!sbtArguments.contains("-Dsbt.io.virtual=true")) sbtArguments += "-Dsbt.io.virtual=true" - new Arguments(base, sbtArguments, commandArgs, completionArguments, sbtScript) + new Arguments(base, sbtArguments, commandArgs, completionArguments, sbtScript, bsp) } def client( @@ -1091,8 +1102,14 @@ object NetworkClient { terminal ) try { - if (client.connect(log = true, promptCompleteUsers = false)) client.run() - else 1 + if (args.bsp) { + val (socket, _) = + client.connectOrStartServerAndConnect(promptCompleteUsers = false, retry = true) + BspClient.bspRun(socket) + } else { + if (client.connect(log = true, promptCompleteUsers = false)) client.run() + else 1 + } } catch { case _: Exception => 1 } finally client.close() } def client(