From d0842711e4add4508257b985d708d8d70ecd6be8 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 24 Jun 2020 18:32:58 -0700 Subject: [PATCH] Rework NetworkClient This commit integrates the NetworkClient with the server side rendered ui. Rather than implementing its own shell method, it will now connect to the server and register itself as a virtual terminal. If there are command arguments, those will be sent to the server as execs. Otherwise it will enter a shell mode where it just acts as a relay for io. In batch mode, it will return the exit code of the last exec sent to the server. If the server disconnects, the client will exit with an error code. --- .../sbt/internal/client/NetworkClient.scala | 349 ++++++++++++++---- 1 file changed, 275 insertions(+), 74 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 16d3041f0..dc69b0ff9 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -12,25 +12,38 @@ package client import java.io.{ File, IOException, InputStream, PrintStream } import java.lang.ProcessBuilder.Redirect import java.net.Socket +import java.nio.channels.ClosedChannelException import java.nio.file.Files import java.util.UUID import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } +import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit } +import sbt.internal.client.NetworkClient.Arguments import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams } import sbt.internal.protocol._ -import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader, Terminal, Util } +import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util } import sbt.io.IO import sbt.io.syntax._ import sbt.protocol._ import sbt.util.Level +import sjsonnew.BasicJsonProtocol._ +import sjsonnew.shaded.scalajson.ast.unsafe.{ JObject, JValue } import sjsonnew.support.scalajson.unsafe.Converter import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer import scala.collection.mutable -import scala.util.Properties +import scala.concurrent.duration._ import scala.util.control.NonFatal -import scala.util.{ Failure, Success } +import scala.util.{ Failure, Properties, Success } +import Serialization.{ + attach, + systemIn, + systemOut, + terminalCapabilities, + terminalCapabilitiesResponse, + terminalPropertiesQuery, + terminalPropertiesResponse +} import NetworkClient.Arguments trait ConsoleInterface { @@ -68,9 +81,14 @@ class NetworkClient( private val status = new AtomicReference("Ready") private val lock: AnyRef = new AnyRef {} private val running = new AtomicBoolean(true) + private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)] + private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit] + private val attached = new AtomicBoolean(false) + private val attachUUID = new AtomicReference[String](null) private val connectionHolder = new AtomicReference[ServerConnection] + private val batchMode = new AtomicBoolean(false) + private val interactiveThread = new AtomicReference[Thread](null) private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI) - private val pendingExecIds = ListBuffer.empty[String] private def portfile = arguments.baseDirectory / "project" / "target" / "active.json" @@ -81,6 +99,13 @@ class NetworkClient( } } + private[this] val stdinBytes = new LinkedBlockingQueue[Int] + private[this] val stdin: InputStream = new InputStream { + override def available(): Int = stdinBytes.size + override def read: Int = stdinBytes.take + } + private[this] val inputThread = new AtomicReference(new RawInputThread) + private[this] val exitClean = new AtomicBoolean(true) private[this] val sbtProcess = new AtomicReference[Process](null) private class ConnectionRefusedException(t: Throwable) extends Throwable(t) @@ -99,7 +124,9 @@ class NetworkClient( override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) override def onShutdown(): Unit = { + if (exitClean.get != false) exitClean.set(!running.get) running.set(false) + Option(interactiveThread.get).foreach(_.interrupt()) } } // initiate handshake @@ -153,9 +180,9 @@ class NetworkClient( val byte = stderr.read errorStream.write(byte) } - while (System.in.available > 0) { - val byte = System.in.read - stdin.write(byte) + while (!stdinBytes.isEmpty) { + stdin.write(stdinBytes.take) + stdin.flush() } false } catch { @@ -197,15 +224,55 @@ class NetworkClient( printResponse() } - def onResponse(msg: JsonRpcResponseMessage): Unit = { - msg.id match { - case execId if pendingExecIds contains execId => - onReturningReponse(msg) - lock.synchronized { - pendingExecIds -= execId + private def getExitCode(jvalue: Option[JValue]): Integer = jvalue match { + case Some(o: JObject) => + o.value + .collectFirst { + case v if v.field == "exitCode" => + Converter.fromJson[Integer](v.value).getOrElse(Integer.valueOf(1)) + } + .getOrElse(1) + case _ => 1 + } + def onResponse(msg: JsonRpcResponseMessage): Unit = { + pendingResults.remove(msg.id) match { + case null => + case (q, startTime) => + val now = System.currentTimeMillis + val message = timing(startTime, now) + val exitCode = getExitCode(msg.result) + if (batchMode.get || !attached.get) { + if (exitCode == 0) console.success(message) + else if (!attached.get) console.appendLog(Level.Error, message) + } + q.offer(exitCode) + } + msg.id match { + case execId => + if (attachUUID.get == msg.id) { + attachUUID.set(null) + attached.set(true) + Option(inputThread.get).foreach(_.drain()) + } + pendingCompletions.remove(execId) match { + case null => + case completions => + completions(msg.result match { + case Some(o: JObject) => + o.value + .foldLeft(CompletionResponse(Vector.empty[String])) { + case (resp, i) => + if (i.field == "items") + resp.withItems( + Converter + .fromJson[Vector[String]](i.value) + .getOrElse(Vector.empty[String]) + ) + else resp + } + case _ => CompletionResponse(Vector.empty[String]) + }) } - () - case _ => } } @@ -213,17 +280,33 @@ class NetworkClient( def splitToMessage: Vector[(Level.Value, String)] = (msg.method, msg.params) match { case ("build/logMessage", Some(json)) => - import sbt.internal.langserver.codec.JsonProtocol._ - Converter.fromJson[LogMessageParams](json) match { - case Success(params) => splitLogMessage(params) - case Failure(e) => Vector() + if (!attached.get) { + import sbt.internal.langserver.codec.JsonProtocol._ + Converter.fromJson[LogMessageParams](json) match { + case Success(params) => splitLogMessage(params) + case Failure(_) => Vector() + } + } else Vector() + case (`systemOut`, Some(json)) => + Converter.fromJson[Seq[Byte]](json) match { + case Success(params) => + if (params.nonEmpty) { + if (attached.get) { + printStream.write(params.toArray) + printStream.flush() + } + } + case Failure(_) => } + Vector.empty case ("textDocument/publishDiagnostics", Some(json)) => import sbt.internal.langserver.codec.JsonProtocol._ Converter.fromJson[PublishDiagnosticsParams](json) match { - case Success(params) => splitDiagnostics(params) - case Failure(e) => Vector() + case Success(params) => splitDiagnostics(params); Vector() + case Failure(_) => Vector() } + case ("shutdown", Some(_)) => Vector.empty + case (msg, _) if msg.startsWith("build/") => Vector.empty case _ => Vector( ( @@ -269,73 +352,191 @@ class NetworkClient( } def onRequest(msg: JsonRpcRequestMessage): Unit = { - // ignore - } - - def start(): Unit = { - console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") - val _ = connection - val userCommands = arguments.commandArguments.toList - if (userCommands.isEmpty) shell() - else batchExecute(userCommands) - } - - def batchExecute(userCommands: List[String]): Unit = { - userCommands foreach { cmd => - println("> " + cmd) - val execId = - if (cmd == "shutdown") sendExecCommand("exit") - else sendExecCommand(cmd) - while (pendingExecIds contains execId) { - Thread.sleep(100) - } + (msg.method, msg.params) match { + case (`terminalCapabilities`, Some(json)) => + import sbt.protocol.codec.JsonProtocol._ + Converter.fromJson[TerminalCapabilitiesQuery](json) match { + case Success(terminalCapabilitiesQuery) => + val response = TerminalCapabilitiesResponse( + terminalCapabilitiesQuery.boolean.map(Terminal.console.getBooleanCapability), + terminalCapabilitiesQuery.numeric.map(Terminal.console.getNumericCapability), + terminalCapabilitiesQuery.string + .map(s => Option(Terminal.console.getStringCapability(s)).getOrElse("null")), + ) + sendCommandResponse( + terminalCapabilitiesResponse, + response, + msg.id, + ) + case Failure(_) => + } + case (`terminalPropertiesQuery`, _) => + val response = TerminalPropertiesResponse.apply( + width = Terminal.console.getWidth, + height = Terminal.console.getHeight, + isAnsiSupported = Terminal.console.isAnsiSupported, + isColorEnabled = Terminal.console.isColorEnabled, + isSupershellEnabled = Terminal.console.isSupershellEnabled, + isEchoEnabled = Terminal.console.isEchoEnabled + ) + sendCommandResponse(terminalPropertiesResponse, response, msg.id) + case _ => } } - def shell(): Unit = { - val reader = LineReader.simple(None, LineReader.HandleCONT, injectThreadSleep = true) - while (running.get) { - reader.readLine("> ", None) match { - case Some("shutdown") => - // `sbt -client shutdown` shuts down the server - sendExecCommand("exit") - Thread.sleep(100) - running.set(false) - case Some("exit") => - running.set(false) - case Some(s) if s.trim.nonEmpty => - val execId = sendExecCommand(s) - while (pendingExecIds contains execId) { - Thread.sleep(100) - } - case _ => // - } + def connect(log: Boolean): Unit = { + if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") + init(retry = true) + () + } + + def run(): Int = { + interactiveThread.set(Thread.currentThread) + val cleaned = arguments.commandArguments + val userCommands = cleaned.takeWhile(_ != "exit") + val interactive = cleaned.isEmpty + val exit = cleaned.nonEmpty && userCommands.isEmpty + attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}""")) + if (interactive) { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + if (exitClean.get) 0 else 1 + } else if (exit) { + 0 + } else { + batchMode.set(true) + batchExecute(userCommands.toList) } } - def sendExecCommand(commandLine: String): String = { + def batchExecute(userCommands: List[String]): Int = { + val cmd = userCommands mkString " " + printStream.println("> " + cmd) + sendAndWait(cmd, None) + } + + private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = { + val queue = sendExecCommand(cmd) + var result: Integer = null + while (running.get && result == null && limit.fold(true)(!_.isOverdue())) { + try { + result = limit match { + case Some(l) => queue.poll((l - Deadline.now).toMillis, TimeUnit.MILLISECONDS) + case _ => queue.take + } + } catch { + case _: InterruptedException if cmd == "shutdown" => result = 0 + case _: InterruptedException => result = if (exitClean.get) 0 else 1 + } + } + if (result == null) 1 else result + } + + def sendExecCommand(commandLine: String): LinkedBlockingQueue[Integer] = { val execId = UUID.randomUUID.toString + val queue = new LinkedBlockingQueue[Integer] sendCommand(ExecCommand(commandLine, execId)) - lock.synchronized { - pendingExecIds += execId - } - execId + pendingResults.put(execId, (queue, System.currentTimeMillis)) + queue } def sendCommand(command: CommandMessage): Unit = { try { val s = Serialization.serializeCommandAsJsonMessage(command) connection.sendString(s) + lock.synchronized { + status.set("Processing") + } } catch { - case _: IOException => - // log.debug(e.getMessage) - // toDel += client - } - lock.synchronized { - status.set("Processing") + case e: IOException => + errorStream.println(s"Caught exception writing command to server: $e") + running.set(false) } } - override def close(): Unit = {} + def sendCommandResponse(method: String, command: EventMessage, id: String): Unit = { + try { + val s = new String(Serialization.serializeEventMessage(command)) + val msg = s"""{ "jsonrpc": "2.0", "id": "$id", "result": $s }""" + connection.sendString(msg) + } catch { + case e: IOException => + errorStream.println(s"Caught exception writing command to server: $e") + running.set(false) + } + } + def sendJson(method: String, params: String): String = { + val uuid = UUID.randomUUID.toString + val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }""" + connection.sendString(msg) + uuid + } + + def sendNotification(method: String, params: String): Unit = { + connection.sendString(s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params }""") + } + + override def close(): Unit = + try { + running.set(false) + stdinBytes.offer(-1) + val mainThread = interactiveThread.getAndSet(null) + if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt + connection.shutdown() + Option(inputThread.get).foreach(_.interrupt()) + } catch { + case t: Throwable => t.printStackTrace(); throw t + } + + private[this] class RawInputThread extends Thread("sbt-read-input-thread") with AutoCloseable { + setDaemon(true) + start() + val stopped = new AtomicBoolean(false) + val lock = new Object + override final def run(): Unit = { + @tailrec def read(): Unit = { + inputStream.read match { + case -1 => + case b => + lock.synchronized(stdinBytes.offer(b)) + if (attached.get()) drain() + if (!stopped.get()) read() + } + } + try Terminal.console.withRawSystemIn(read()) + catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) } + } + + def drain(): Unit = lock.synchronized { + while (!stdinBytes.isEmpty) { + val byte = stdinBytes.poll() + sendNotification(systemIn, byte.toString) + } + } + + override def close(): Unit = { + RawInputThread.this.interrupt() + } + } + + // copied from Aggregation + private def timing(startTime: Long, endTime: Long): String = { + import java.text.DateFormat + val format = DateFormat.getDateTimeInstance(DateFormat.MEDIUM, DateFormat.MEDIUM) + val nowString = format.format(new java.util.Date(endTime)) + val total = math.max(0, (endTime - startTime + 500) / 1000) + val totalString = s"$total s" + + (if (total <= 60) "" + else { + val maybeHours = total / 3600 match { + case 0 => "" + case h => f"$h%02d:" + } + val mins = f"${total % 3600 / 60}%02d" + val secs = f"${total % 60}%02d" + s" ($maybeHours$mins:$secs)" + }) + s"Total time: $totalString, completed $nowString" + } } object NetworkClient { private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = { @@ -400,8 +601,8 @@ object NetworkClient { try { val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) try { - client.start() - 0 + client.connect(log = true) + client.run() } catch { case _: Throwable => 1 } finally client.close() } catch { case NonFatal(e) =>