mirror of https://github.com/sbt/sbt.git
Rework NetworkClient
This commit integrates the NetworkClient with the server side rendered ui. Rather than implementing its own shell method, it will now connect to the server and register itself as a virtual terminal. If there are command arguments, those will be sent to the server as execs. Otherwise it will enter a shell mode where it just acts as a relay for io. In batch mode, it will return the exit code of the last exec sent to the server. If the server disconnects, the client will exit with an error code.
This commit is contained in:
parent
ab362397ba
commit
d0842711e4
|
|
@ -12,25 +12,38 @@ package client
|
|||
import java.io.{ File, IOException, InputStream, PrintStream }
|
||||
import java.lang.ProcessBuilder.Redirect
|
||||
import java.net.Socket
|
||||
import java.nio.channels.ClosedChannelException
|
||||
import java.nio.file.Files
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
|
||||
import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }
|
||||
|
||||
import sbt.internal.client.NetworkClient.Arguments
|
||||
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
|
||||
import sbt.internal.protocol._
|
||||
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader, Terminal, Util }
|
||||
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util }
|
||||
import sbt.io.IO
|
||||
import sbt.io.syntax._
|
||||
import sbt.protocol._
|
||||
import sbt.util.Level
|
||||
import sjsonnew.BasicJsonProtocol._
|
||||
import sjsonnew.shaded.scalajson.ast.unsafe.{ JObject, JValue }
|
||||
import sjsonnew.support.scalajson.unsafe.Converter
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.collection.mutable
|
||||
import scala.util.Properties
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.control.NonFatal
|
||||
import scala.util.{ Failure, Success }
|
||||
import scala.util.{ Failure, Properties, Success }
|
||||
import Serialization.{
|
||||
attach,
|
||||
systemIn,
|
||||
systemOut,
|
||||
terminalCapabilities,
|
||||
terminalCapabilitiesResponse,
|
||||
terminalPropertiesQuery,
|
||||
terminalPropertiesResponse
|
||||
}
|
||||
import NetworkClient.Arguments
|
||||
|
||||
trait ConsoleInterface {
|
||||
|
|
@ -68,9 +81,14 @@ class NetworkClient(
|
|||
private val status = new AtomicReference("Ready")
|
||||
private val lock: AnyRef = new AnyRef {}
|
||||
private val running = new AtomicBoolean(true)
|
||||
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
|
||||
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
|
||||
private val attached = new AtomicBoolean(false)
|
||||
private val attachUUID = new AtomicReference[String](null)
|
||||
private val connectionHolder = new AtomicReference[ServerConnection]
|
||||
private val batchMode = new AtomicBoolean(false)
|
||||
private val interactiveThread = new AtomicReference[Thread](null)
|
||||
private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI)
|
||||
private val pendingExecIds = ListBuffer.empty[String]
|
||||
|
||||
private def portfile = arguments.baseDirectory / "project" / "target" / "active.json"
|
||||
|
||||
|
|
@ -81,6 +99,13 @@ class NetworkClient(
|
|||
}
|
||||
}
|
||||
|
||||
private[this] val stdinBytes = new LinkedBlockingQueue[Int]
|
||||
private[this] val stdin: InputStream = new InputStream {
|
||||
override def available(): Int = stdinBytes.size
|
||||
override def read: Int = stdinBytes.take
|
||||
}
|
||||
private[this] val inputThread = new AtomicReference(new RawInputThread)
|
||||
private[this] val exitClean = new AtomicBoolean(true)
|
||||
private[this] val sbtProcess = new AtomicReference[Process](null)
|
||||
private class ConnectionRefusedException(t: Throwable) extends Throwable(t)
|
||||
|
||||
|
|
@ -99,7 +124,9 @@ class NetworkClient(
|
|||
override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg)
|
||||
override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg)
|
||||
override def onShutdown(): Unit = {
|
||||
if (exitClean.get != false) exitClean.set(!running.get)
|
||||
running.set(false)
|
||||
Option(interactiveThread.get).foreach(_.interrupt())
|
||||
}
|
||||
}
|
||||
// initiate handshake
|
||||
|
|
@ -153,9 +180,9 @@ class NetworkClient(
|
|||
val byte = stderr.read
|
||||
errorStream.write(byte)
|
||||
}
|
||||
while (System.in.available > 0) {
|
||||
val byte = System.in.read
|
||||
stdin.write(byte)
|
||||
while (!stdinBytes.isEmpty) {
|
||||
stdin.write(stdinBytes.take)
|
||||
stdin.flush()
|
||||
}
|
||||
false
|
||||
} catch {
|
||||
|
|
@ -197,15 +224,55 @@ class NetworkClient(
|
|||
printResponse()
|
||||
}
|
||||
|
||||
def onResponse(msg: JsonRpcResponseMessage): Unit = {
|
||||
msg.id match {
|
||||
case execId if pendingExecIds contains execId =>
|
||||
onReturningReponse(msg)
|
||||
lock.synchronized {
|
||||
pendingExecIds -= execId
|
||||
private def getExitCode(jvalue: Option[JValue]): Integer = jvalue match {
|
||||
case Some(o: JObject) =>
|
||||
o.value
|
||||
.collectFirst {
|
||||
case v if v.field == "exitCode" =>
|
||||
Converter.fromJson[Integer](v.value).getOrElse(Integer.valueOf(1))
|
||||
}
|
||||
.getOrElse(1)
|
||||
case _ => 1
|
||||
}
|
||||
def onResponse(msg: JsonRpcResponseMessage): Unit = {
|
||||
pendingResults.remove(msg.id) match {
|
||||
case null =>
|
||||
case (q, startTime) =>
|
||||
val now = System.currentTimeMillis
|
||||
val message = timing(startTime, now)
|
||||
val exitCode = getExitCode(msg.result)
|
||||
if (batchMode.get || !attached.get) {
|
||||
if (exitCode == 0) console.success(message)
|
||||
else if (!attached.get) console.appendLog(Level.Error, message)
|
||||
}
|
||||
q.offer(exitCode)
|
||||
}
|
||||
msg.id match {
|
||||
case execId =>
|
||||
if (attachUUID.get == msg.id) {
|
||||
attachUUID.set(null)
|
||||
attached.set(true)
|
||||
Option(inputThread.get).foreach(_.drain())
|
||||
}
|
||||
pendingCompletions.remove(execId) match {
|
||||
case null =>
|
||||
case completions =>
|
||||
completions(msg.result match {
|
||||
case Some(o: JObject) =>
|
||||
o.value
|
||||
.foldLeft(CompletionResponse(Vector.empty[String])) {
|
||||
case (resp, i) =>
|
||||
if (i.field == "items")
|
||||
resp.withItems(
|
||||
Converter
|
||||
.fromJson[Vector[String]](i.value)
|
||||
.getOrElse(Vector.empty[String])
|
||||
)
|
||||
else resp
|
||||
}
|
||||
case _ => CompletionResponse(Vector.empty[String])
|
||||
})
|
||||
}
|
||||
()
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -213,17 +280,33 @@ class NetworkClient(
|
|||
def splitToMessage: Vector[(Level.Value, String)] =
|
||||
(msg.method, msg.params) match {
|
||||
case ("build/logMessage", Some(json)) =>
|
||||
import sbt.internal.langserver.codec.JsonProtocol._
|
||||
Converter.fromJson[LogMessageParams](json) match {
|
||||
case Success(params) => splitLogMessage(params)
|
||||
case Failure(e) => Vector()
|
||||
if (!attached.get) {
|
||||
import sbt.internal.langserver.codec.JsonProtocol._
|
||||
Converter.fromJson[LogMessageParams](json) match {
|
||||
case Success(params) => splitLogMessage(params)
|
||||
case Failure(_) => Vector()
|
||||
}
|
||||
} else Vector()
|
||||
case (`systemOut`, Some(json)) =>
|
||||
Converter.fromJson[Seq[Byte]](json) match {
|
||||
case Success(params) =>
|
||||
if (params.nonEmpty) {
|
||||
if (attached.get) {
|
||||
printStream.write(params.toArray)
|
||||
printStream.flush()
|
||||
}
|
||||
}
|
||||
case Failure(_) =>
|
||||
}
|
||||
Vector.empty
|
||||
case ("textDocument/publishDiagnostics", Some(json)) =>
|
||||
import sbt.internal.langserver.codec.JsonProtocol._
|
||||
Converter.fromJson[PublishDiagnosticsParams](json) match {
|
||||
case Success(params) => splitDiagnostics(params)
|
||||
case Failure(e) => Vector()
|
||||
case Success(params) => splitDiagnostics(params); Vector()
|
||||
case Failure(_) => Vector()
|
||||
}
|
||||
case ("shutdown", Some(_)) => Vector.empty
|
||||
case (msg, _) if msg.startsWith("build/") => Vector.empty
|
||||
case _ =>
|
||||
Vector(
|
||||
(
|
||||
|
|
@ -269,73 +352,191 @@ class NetworkClient(
|
|||
}
|
||||
|
||||
def onRequest(msg: JsonRpcRequestMessage): Unit = {
|
||||
// ignore
|
||||
}
|
||||
|
||||
def start(): Unit = {
|
||||
console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
|
||||
val _ = connection
|
||||
val userCommands = arguments.commandArguments.toList
|
||||
if (userCommands.isEmpty) shell()
|
||||
else batchExecute(userCommands)
|
||||
}
|
||||
|
||||
def batchExecute(userCommands: List[String]): Unit = {
|
||||
userCommands foreach { cmd =>
|
||||
println("> " + cmd)
|
||||
val execId =
|
||||
if (cmd == "shutdown") sendExecCommand("exit")
|
||||
else sendExecCommand(cmd)
|
||||
while (pendingExecIds contains execId) {
|
||||
Thread.sleep(100)
|
||||
}
|
||||
(msg.method, msg.params) match {
|
||||
case (`terminalCapabilities`, Some(json)) =>
|
||||
import sbt.protocol.codec.JsonProtocol._
|
||||
Converter.fromJson[TerminalCapabilitiesQuery](json) match {
|
||||
case Success(terminalCapabilitiesQuery) =>
|
||||
val response = TerminalCapabilitiesResponse(
|
||||
terminalCapabilitiesQuery.boolean.map(Terminal.console.getBooleanCapability),
|
||||
terminalCapabilitiesQuery.numeric.map(Terminal.console.getNumericCapability),
|
||||
terminalCapabilitiesQuery.string
|
||||
.map(s => Option(Terminal.console.getStringCapability(s)).getOrElse("null")),
|
||||
)
|
||||
sendCommandResponse(
|
||||
terminalCapabilitiesResponse,
|
||||
response,
|
||||
msg.id,
|
||||
)
|
||||
case Failure(_) =>
|
||||
}
|
||||
case (`terminalPropertiesQuery`, _) =>
|
||||
val response = TerminalPropertiesResponse.apply(
|
||||
width = Terminal.console.getWidth,
|
||||
height = Terminal.console.getHeight,
|
||||
isAnsiSupported = Terminal.console.isAnsiSupported,
|
||||
isColorEnabled = Terminal.console.isColorEnabled,
|
||||
isSupershellEnabled = Terminal.console.isSupershellEnabled,
|
||||
isEchoEnabled = Terminal.console.isEchoEnabled
|
||||
)
|
||||
sendCommandResponse(terminalPropertiesResponse, response, msg.id)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
def shell(): Unit = {
|
||||
val reader = LineReader.simple(None, LineReader.HandleCONT, injectThreadSleep = true)
|
||||
while (running.get) {
|
||||
reader.readLine("> ", None) match {
|
||||
case Some("shutdown") =>
|
||||
// `sbt -client shutdown` shuts down the server
|
||||
sendExecCommand("exit")
|
||||
Thread.sleep(100)
|
||||
running.set(false)
|
||||
case Some("exit") =>
|
||||
running.set(false)
|
||||
case Some(s) if s.trim.nonEmpty =>
|
||||
val execId = sendExecCommand(s)
|
||||
while (pendingExecIds contains execId) {
|
||||
Thread.sleep(100)
|
||||
}
|
||||
case _ => //
|
||||
}
|
||||
def connect(log: Boolean): Unit = {
|
||||
if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
|
||||
init(retry = true)
|
||||
()
|
||||
}
|
||||
|
||||
def run(): Int = {
|
||||
interactiveThread.set(Thread.currentThread)
|
||||
val cleaned = arguments.commandArguments
|
||||
val userCommands = cleaned.takeWhile(_ != "exit")
|
||||
val interactive = cleaned.isEmpty
|
||||
val exit = cleaned.nonEmpty && userCommands.isEmpty
|
||||
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
|
||||
if (interactive) {
|
||||
try this.synchronized(this.wait)
|
||||
catch { case _: InterruptedException => }
|
||||
if (exitClean.get) 0 else 1
|
||||
} else if (exit) {
|
||||
0
|
||||
} else {
|
||||
batchMode.set(true)
|
||||
batchExecute(userCommands.toList)
|
||||
}
|
||||
}
|
||||
|
||||
def sendExecCommand(commandLine: String): String = {
|
||||
def batchExecute(userCommands: List[String]): Int = {
|
||||
val cmd = userCommands mkString " "
|
||||
printStream.println("> " + cmd)
|
||||
sendAndWait(cmd, None)
|
||||
}
|
||||
|
||||
private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = {
|
||||
val queue = sendExecCommand(cmd)
|
||||
var result: Integer = null
|
||||
while (running.get && result == null && limit.fold(true)(!_.isOverdue())) {
|
||||
try {
|
||||
result = limit match {
|
||||
case Some(l) => queue.poll((l - Deadline.now).toMillis, TimeUnit.MILLISECONDS)
|
||||
case _ => queue.take
|
||||
}
|
||||
} catch {
|
||||
case _: InterruptedException if cmd == "shutdown" => result = 0
|
||||
case _: InterruptedException => result = if (exitClean.get) 0 else 1
|
||||
}
|
||||
}
|
||||
if (result == null) 1 else result
|
||||
}
|
||||
|
||||
def sendExecCommand(commandLine: String): LinkedBlockingQueue[Integer] = {
|
||||
val execId = UUID.randomUUID.toString
|
||||
val queue = new LinkedBlockingQueue[Integer]
|
||||
sendCommand(ExecCommand(commandLine, execId))
|
||||
lock.synchronized {
|
||||
pendingExecIds += execId
|
||||
}
|
||||
execId
|
||||
pendingResults.put(execId, (queue, System.currentTimeMillis))
|
||||
queue
|
||||
}
|
||||
|
||||
def sendCommand(command: CommandMessage): Unit = {
|
||||
try {
|
||||
val s = Serialization.serializeCommandAsJsonMessage(command)
|
||||
connection.sendString(s)
|
||||
lock.synchronized {
|
||||
status.set("Processing")
|
||||
}
|
||||
} catch {
|
||||
case _: IOException =>
|
||||
// log.debug(e.getMessage)
|
||||
// toDel += client
|
||||
}
|
||||
lock.synchronized {
|
||||
status.set("Processing")
|
||||
case e: IOException =>
|
||||
errorStream.println(s"Caught exception writing command to server: $e")
|
||||
running.set(false)
|
||||
}
|
||||
}
|
||||
override def close(): Unit = {}
|
||||
def sendCommandResponse(method: String, command: EventMessage, id: String): Unit = {
|
||||
try {
|
||||
val s = new String(Serialization.serializeEventMessage(command))
|
||||
val msg = s"""{ "jsonrpc": "2.0", "id": "$id", "result": $s }"""
|
||||
connection.sendString(msg)
|
||||
} catch {
|
||||
case e: IOException =>
|
||||
errorStream.println(s"Caught exception writing command to server: $e")
|
||||
running.set(false)
|
||||
}
|
||||
}
|
||||
def sendJson(method: String, params: String): String = {
|
||||
val uuid = UUID.randomUUID.toString
|
||||
val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }"""
|
||||
connection.sendString(msg)
|
||||
uuid
|
||||
}
|
||||
|
||||
def sendNotification(method: String, params: String): Unit = {
|
||||
connection.sendString(s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params }""")
|
||||
}
|
||||
|
||||
override def close(): Unit =
|
||||
try {
|
||||
running.set(false)
|
||||
stdinBytes.offer(-1)
|
||||
val mainThread = interactiveThread.getAndSet(null)
|
||||
if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt
|
||||
connection.shutdown()
|
||||
Option(inputThread.get).foreach(_.interrupt())
|
||||
} catch {
|
||||
case t: Throwable => t.printStackTrace(); throw t
|
||||
}
|
||||
|
||||
private[this] class RawInputThread extends Thread("sbt-read-input-thread") with AutoCloseable {
|
||||
setDaemon(true)
|
||||
start()
|
||||
val stopped = new AtomicBoolean(false)
|
||||
val lock = new Object
|
||||
override final def run(): Unit = {
|
||||
@tailrec def read(): Unit = {
|
||||
inputStream.read match {
|
||||
case -1 =>
|
||||
case b =>
|
||||
lock.synchronized(stdinBytes.offer(b))
|
||||
if (attached.get()) drain()
|
||||
if (!stopped.get()) read()
|
||||
}
|
||||
}
|
||||
try Terminal.console.withRawSystemIn(read())
|
||||
catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) }
|
||||
}
|
||||
|
||||
def drain(): Unit = lock.synchronized {
|
||||
while (!stdinBytes.isEmpty) {
|
||||
val byte = stdinBytes.poll()
|
||||
sendNotification(systemIn, byte.toString)
|
||||
}
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
RawInputThread.this.interrupt()
|
||||
}
|
||||
}
|
||||
|
||||
// copied from Aggregation
|
||||
private def timing(startTime: Long, endTime: Long): String = {
|
||||
import java.text.DateFormat
|
||||
val format = DateFormat.getDateTimeInstance(DateFormat.MEDIUM, DateFormat.MEDIUM)
|
||||
val nowString = format.format(new java.util.Date(endTime))
|
||||
val total = math.max(0, (endTime - startTime + 500) / 1000)
|
||||
val totalString = s"$total s" +
|
||||
(if (total <= 60) ""
|
||||
else {
|
||||
val maybeHours = total / 3600 match {
|
||||
case 0 => ""
|
||||
case h => f"$h%02d:"
|
||||
}
|
||||
val mins = f"${total % 3600 / 60}%02d"
|
||||
val secs = f"${total % 60}%02d"
|
||||
s" ($maybeHours$mins:$secs)"
|
||||
})
|
||||
s"Total time: $totalString, completed $nowString"
|
||||
}
|
||||
}
|
||||
object NetworkClient {
|
||||
private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
|
||||
|
|
@ -400,8 +601,8 @@ object NetworkClient {
|
|||
try {
|
||||
val client = new NetworkClient(configuration, parseArgs(arguments.toArray))
|
||||
try {
|
||||
client.start()
|
||||
0
|
||||
client.connect(log = true)
|
||||
client.run()
|
||||
} catch { case _: Throwable => 1 } finally client.close()
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
|
|
|
|||
Loading…
Reference in New Issue