Rework NetworkClient

This commit integrates the NetworkClient with the server side rendered
ui. Rather than implementing its own shell method, it will now connect
to the server and register itself as a virtual terminal. If there are
command arguments, those will be sent to the server as execs. Otherwise
it will enter a shell mode where it just acts as a relay for io.

In batch mode, it will return the exit code of the last exec sent to the
server. If the server disconnects, the client will exit with an error code.
This commit is contained in:
Ethan Atkins 2020-06-24 18:32:58 -07:00
parent ab362397ba
commit d0842711e4
1 changed files with 275 additions and 74 deletions

View File

@ -12,25 +12,38 @@ package client
import java.io.{ File, IOException, InputStream, PrintStream }
import java.lang.ProcessBuilder.Redirect
import java.net.Socket
import java.nio.channels.ClosedChannelException
import java.nio.file.Files
import java.util.UUID
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit }
import sbt.internal.client.NetworkClient.Arguments
import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams }
import sbt.internal.protocol._
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader, Terminal, Util }
import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util }
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol._
import sbt.util.Level
import sjsonnew.BasicJsonProtocol._
import sjsonnew.shaded.scalajson.ast.unsafe.{ JObject, JValue }
import sjsonnew.support.scalajson.unsafe.Converter
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.collection.mutable
import scala.util.Properties
import scala.concurrent.duration._
import scala.util.control.NonFatal
import scala.util.{ Failure, Success }
import scala.util.{ Failure, Properties, Success }
import Serialization.{
attach,
systemIn,
systemOut,
terminalCapabilities,
terminalCapabilitiesResponse,
terminalPropertiesQuery,
terminalPropertiesResponse
}
import NetworkClient.Arguments
trait ConsoleInterface {
@ -68,9 +81,14 @@ class NetworkClient(
private val status = new AtomicReference("Ready")
private val lock: AnyRef = new AnyRef {}
private val running = new AtomicBoolean(true)
private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)]
private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit]
private val attached = new AtomicBoolean(false)
private val attachUUID = new AtomicReference[String](null)
private val connectionHolder = new AtomicReference[ServerConnection]
private val batchMode = new AtomicBoolean(false)
private val interactiveThread = new AtomicReference[Thread](null)
private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI)
private val pendingExecIds = ListBuffer.empty[String]
private def portfile = arguments.baseDirectory / "project" / "target" / "active.json"
@ -81,6 +99,13 @@ class NetworkClient(
}
}
private[this] val stdinBytes = new LinkedBlockingQueue[Int]
private[this] val stdin: InputStream = new InputStream {
override def available(): Int = stdinBytes.size
override def read: Int = stdinBytes.take
}
private[this] val inputThread = new AtomicReference(new RawInputThread)
private[this] val exitClean = new AtomicBoolean(true)
private[this] val sbtProcess = new AtomicReference[Process](null)
private class ConnectionRefusedException(t: Throwable) extends Throwable(t)
@ -99,7 +124,9 @@ class NetworkClient(
override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg)
override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg)
override def onShutdown(): Unit = {
if (exitClean.get != false) exitClean.set(!running.get)
running.set(false)
Option(interactiveThread.get).foreach(_.interrupt())
}
}
// initiate handshake
@ -153,9 +180,9 @@ class NetworkClient(
val byte = stderr.read
errorStream.write(byte)
}
while (System.in.available > 0) {
val byte = System.in.read
stdin.write(byte)
while (!stdinBytes.isEmpty) {
stdin.write(stdinBytes.take)
stdin.flush()
}
false
} catch {
@ -197,15 +224,55 @@ class NetworkClient(
printResponse()
}
def onResponse(msg: JsonRpcResponseMessage): Unit = {
msg.id match {
case execId if pendingExecIds contains execId =>
onReturningReponse(msg)
lock.synchronized {
pendingExecIds -= execId
private def getExitCode(jvalue: Option[JValue]): Integer = jvalue match {
case Some(o: JObject) =>
o.value
.collectFirst {
case v if v.field == "exitCode" =>
Converter.fromJson[Integer](v.value).getOrElse(Integer.valueOf(1))
}
.getOrElse(1)
case _ => 1
}
def onResponse(msg: JsonRpcResponseMessage): Unit = {
pendingResults.remove(msg.id) match {
case null =>
case (q, startTime) =>
val now = System.currentTimeMillis
val message = timing(startTime, now)
val exitCode = getExitCode(msg.result)
if (batchMode.get || !attached.get) {
if (exitCode == 0) console.success(message)
else if (!attached.get) console.appendLog(Level.Error, message)
}
q.offer(exitCode)
}
msg.id match {
case execId =>
if (attachUUID.get == msg.id) {
attachUUID.set(null)
attached.set(true)
Option(inputThread.get).foreach(_.drain())
}
pendingCompletions.remove(execId) match {
case null =>
case completions =>
completions(msg.result match {
case Some(o: JObject) =>
o.value
.foldLeft(CompletionResponse(Vector.empty[String])) {
case (resp, i) =>
if (i.field == "items")
resp.withItems(
Converter
.fromJson[Vector[String]](i.value)
.getOrElse(Vector.empty[String])
)
else resp
}
case _ => CompletionResponse(Vector.empty[String])
})
}
()
case _ =>
}
}
@ -213,17 +280,33 @@ class NetworkClient(
def splitToMessage: Vector[(Level.Value, String)] =
(msg.method, msg.params) match {
case ("build/logMessage", Some(json)) =>
import sbt.internal.langserver.codec.JsonProtocol._
Converter.fromJson[LogMessageParams](json) match {
case Success(params) => splitLogMessage(params)
case Failure(e) => Vector()
if (!attached.get) {
import sbt.internal.langserver.codec.JsonProtocol._
Converter.fromJson[LogMessageParams](json) match {
case Success(params) => splitLogMessage(params)
case Failure(_) => Vector()
}
} else Vector()
case (`systemOut`, Some(json)) =>
Converter.fromJson[Seq[Byte]](json) match {
case Success(params) =>
if (params.nonEmpty) {
if (attached.get) {
printStream.write(params.toArray)
printStream.flush()
}
}
case Failure(_) =>
}
Vector.empty
case ("textDocument/publishDiagnostics", Some(json)) =>
import sbt.internal.langserver.codec.JsonProtocol._
Converter.fromJson[PublishDiagnosticsParams](json) match {
case Success(params) => splitDiagnostics(params)
case Failure(e) => Vector()
case Success(params) => splitDiagnostics(params); Vector()
case Failure(_) => Vector()
}
case ("shutdown", Some(_)) => Vector.empty
case (msg, _) if msg.startsWith("build/") => Vector.empty
case _ =>
Vector(
(
@ -269,73 +352,191 @@ class NetworkClient(
}
def onRequest(msg: JsonRpcRequestMessage): Unit = {
// ignore
}
def start(): Unit = {
console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
val _ = connection
val userCommands = arguments.commandArguments.toList
if (userCommands.isEmpty) shell()
else batchExecute(userCommands)
}
def batchExecute(userCommands: List[String]): Unit = {
userCommands foreach { cmd =>
println("> " + cmd)
val execId =
if (cmd == "shutdown") sendExecCommand("exit")
else sendExecCommand(cmd)
while (pendingExecIds contains execId) {
Thread.sleep(100)
}
(msg.method, msg.params) match {
case (`terminalCapabilities`, Some(json)) =>
import sbt.protocol.codec.JsonProtocol._
Converter.fromJson[TerminalCapabilitiesQuery](json) match {
case Success(terminalCapabilitiesQuery) =>
val response = TerminalCapabilitiesResponse(
terminalCapabilitiesQuery.boolean.map(Terminal.console.getBooleanCapability),
terminalCapabilitiesQuery.numeric.map(Terminal.console.getNumericCapability),
terminalCapabilitiesQuery.string
.map(s => Option(Terminal.console.getStringCapability(s)).getOrElse("null")),
)
sendCommandResponse(
terminalCapabilitiesResponse,
response,
msg.id,
)
case Failure(_) =>
}
case (`terminalPropertiesQuery`, _) =>
val response = TerminalPropertiesResponse.apply(
width = Terminal.console.getWidth,
height = Terminal.console.getHeight,
isAnsiSupported = Terminal.console.isAnsiSupported,
isColorEnabled = Terminal.console.isColorEnabled,
isSupershellEnabled = Terminal.console.isSupershellEnabled,
isEchoEnabled = Terminal.console.isEchoEnabled
)
sendCommandResponse(terminalPropertiesResponse, response, msg.id)
case _ =>
}
}
def shell(): Unit = {
val reader = LineReader.simple(None, LineReader.HandleCONT, injectThreadSleep = true)
while (running.get) {
reader.readLine("> ", None) match {
case Some("shutdown") =>
// `sbt -client shutdown` shuts down the server
sendExecCommand("exit")
Thread.sleep(100)
running.set(false)
case Some("exit") =>
running.set(false)
case Some(s) if s.trim.nonEmpty =>
val execId = sendExecCommand(s)
while (pendingExecIds contains execId) {
Thread.sleep(100)
}
case _ => //
}
def connect(log: Boolean): Unit = {
if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
init(retry = true)
()
}
def run(): Int = {
interactiveThread.set(Thread.currentThread)
val cleaned = arguments.commandArguments
val userCommands = cleaned.takeWhile(_ != "exit")
val interactive = cleaned.isEmpty
val exit = cleaned.nonEmpty && userCommands.isEmpty
attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}"""))
if (interactive) {
try this.synchronized(this.wait)
catch { case _: InterruptedException => }
if (exitClean.get) 0 else 1
} else if (exit) {
0
} else {
batchMode.set(true)
batchExecute(userCommands.toList)
}
}
def sendExecCommand(commandLine: String): String = {
def batchExecute(userCommands: List[String]): Int = {
val cmd = userCommands mkString " "
printStream.println("> " + cmd)
sendAndWait(cmd, None)
}
private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = {
val queue = sendExecCommand(cmd)
var result: Integer = null
while (running.get && result == null && limit.fold(true)(!_.isOverdue())) {
try {
result = limit match {
case Some(l) => queue.poll((l - Deadline.now).toMillis, TimeUnit.MILLISECONDS)
case _ => queue.take
}
} catch {
case _: InterruptedException if cmd == "shutdown" => result = 0
case _: InterruptedException => result = if (exitClean.get) 0 else 1
}
}
if (result == null) 1 else result
}
def sendExecCommand(commandLine: String): LinkedBlockingQueue[Integer] = {
val execId = UUID.randomUUID.toString
val queue = new LinkedBlockingQueue[Integer]
sendCommand(ExecCommand(commandLine, execId))
lock.synchronized {
pendingExecIds += execId
}
execId
pendingResults.put(execId, (queue, System.currentTimeMillis))
queue
}
def sendCommand(command: CommandMessage): Unit = {
try {
val s = Serialization.serializeCommandAsJsonMessage(command)
connection.sendString(s)
lock.synchronized {
status.set("Processing")
}
} catch {
case _: IOException =>
// log.debug(e.getMessage)
// toDel += client
}
lock.synchronized {
status.set("Processing")
case e: IOException =>
errorStream.println(s"Caught exception writing command to server: $e")
running.set(false)
}
}
override def close(): Unit = {}
def sendCommandResponse(method: String, command: EventMessage, id: String): Unit = {
try {
val s = new String(Serialization.serializeEventMessage(command))
val msg = s"""{ "jsonrpc": "2.0", "id": "$id", "result": $s }"""
connection.sendString(msg)
} catch {
case e: IOException =>
errorStream.println(s"Caught exception writing command to server: $e")
running.set(false)
}
}
def sendJson(method: String, params: String): String = {
val uuid = UUID.randomUUID.toString
val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }"""
connection.sendString(msg)
uuid
}
def sendNotification(method: String, params: String): Unit = {
connection.sendString(s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params }""")
}
override def close(): Unit =
try {
running.set(false)
stdinBytes.offer(-1)
val mainThread = interactiveThread.getAndSet(null)
if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt
connection.shutdown()
Option(inputThread.get).foreach(_.interrupt())
} catch {
case t: Throwable => t.printStackTrace(); throw t
}
private[this] class RawInputThread extends Thread("sbt-read-input-thread") with AutoCloseable {
setDaemon(true)
start()
val stopped = new AtomicBoolean(false)
val lock = new Object
override final def run(): Unit = {
@tailrec def read(): Unit = {
inputStream.read match {
case -1 =>
case b =>
lock.synchronized(stdinBytes.offer(b))
if (attached.get()) drain()
if (!stopped.get()) read()
}
}
try Terminal.console.withRawSystemIn(read())
catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) }
}
def drain(): Unit = lock.synchronized {
while (!stdinBytes.isEmpty) {
val byte = stdinBytes.poll()
sendNotification(systemIn, byte.toString)
}
}
override def close(): Unit = {
RawInputThread.this.interrupt()
}
}
// copied from Aggregation
private def timing(startTime: Long, endTime: Long): String = {
import java.text.DateFormat
val format = DateFormat.getDateTimeInstance(DateFormat.MEDIUM, DateFormat.MEDIUM)
val nowString = format.format(new java.util.Date(endTime))
val total = math.max(0, (endTime - startTime + 500) / 1000)
val totalString = s"$total s" +
(if (total <= 60) ""
else {
val maybeHours = total / 3600 match {
case 0 => ""
case h => f"$h%02d:"
}
val mins = f"${total % 3600 / 60}%02d"
val secs = f"${total % 60}%02d"
s" ($maybeHours$mins:$secs)"
})
s"Total time: $totalString, completed $nowString"
}
}
object NetworkClient {
private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = {
@ -400,8 +601,8 @@ object NetworkClient {
try {
val client = new NetworkClient(configuration, parseArgs(arguments.toArray))
try {
client.start()
0
client.connect(log = true)
client.run()
} catch { case _: Throwable => 1 } finally client.close()
} catch {
case NonFatal(e) =>