diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 16d3041f0..dc69b0ff9 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -12,25 +12,38 @@ package client import java.io.{ File, IOException, InputStream, PrintStream } import java.lang.ProcessBuilder.Redirect import java.net.Socket +import java.nio.channels.ClosedChannelException import java.nio.file.Files import java.util.UUID import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } +import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue, TimeUnit } +import sbt.internal.client.NetworkClient.Arguments import sbt.internal.langserver.{ LogMessageParams, MessageType, PublishDiagnosticsParams } import sbt.internal.protocol._ -import sbt.internal.util.{ ConsoleAppender, ConsoleOut, LineReader, Terminal, Util } +import sbt.internal.util.{ ConsoleAppender, ConsoleOut, Terminal, Util } import sbt.io.IO import sbt.io.syntax._ import sbt.protocol._ import sbt.util.Level +import sjsonnew.BasicJsonProtocol._ +import sjsonnew.shaded.scalajson.ast.unsafe.{ JObject, JValue } import sjsonnew.support.scalajson.unsafe.Converter import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer import scala.collection.mutable -import scala.util.Properties +import scala.concurrent.duration._ import scala.util.control.NonFatal -import scala.util.{ Failure, Success } +import scala.util.{ Failure, Properties, Success } +import Serialization.{ + attach, + systemIn, + systemOut, + terminalCapabilities, + terminalCapabilitiesResponse, + terminalPropertiesQuery, + terminalPropertiesResponse +} import NetworkClient.Arguments trait ConsoleInterface { @@ -68,9 +81,14 @@ class NetworkClient( private val status = new AtomicReference("Ready") private val lock: AnyRef = new AnyRef {} private val running = new AtomicBoolean(true) + private val pendingResults = new ConcurrentHashMap[String, (LinkedBlockingQueue[Integer], Long)] + private val pendingCompletions = new ConcurrentHashMap[String, CompletionResponse => Unit] + private val attached = new AtomicBoolean(false) + private val attachUUID = new AtomicReference[String](null) private val connectionHolder = new AtomicReference[ServerConnection] + private val batchMode = new AtomicBoolean(false) + private val interactiveThread = new AtomicReference[Thread](null) private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI) - private val pendingExecIds = ListBuffer.empty[String] private def portfile = arguments.baseDirectory / "project" / "target" / "active.json" @@ -81,6 +99,13 @@ class NetworkClient( } } + private[this] val stdinBytes = new LinkedBlockingQueue[Int] + private[this] val stdin: InputStream = new InputStream { + override def available(): Int = stdinBytes.size + override def read: Int = stdinBytes.take + } + private[this] val inputThread = new AtomicReference(new RawInputThread) + private[this] val exitClean = new AtomicBoolean(true) private[this] val sbtProcess = new AtomicReference[Process](null) private class ConnectionRefusedException(t: Throwable) extends Throwable(t) @@ -99,7 +124,9 @@ class NetworkClient( override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) override def onShutdown(): Unit = { + if (exitClean.get != false) exitClean.set(!running.get) running.set(false) + Option(interactiveThread.get).foreach(_.interrupt()) } } // initiate handshake @@ -153,9 +180,9 @@ class NetworkClient( val byte = stderr.read errorStream.write(byte) } - while (System.in.available > 0) { - val byte = System.in.read - stdin.write(byte) + while (!stdinBytes.isEmpty) { + stdin.write(stdinBytes.take) + stdin.flush() } false } catch { @@ -197,15 +224,55 @@ class NetworkClient( printResponse() } - def onResponse(msg: JsonRpcResponseMessage): Unit = { - msg.id match { - case execId if pendingExecIds contains execId => - onReturningReponse(msg) - lock.synchronized { - pendingExecIds -= execId + private def getExitCode(jvalue: Option[JValue]): Integer = jvalue match { + case Some(o: JObject) => + o.value + .collectFirst { + case v if v.field == "exitCode" => + Converter.fromJson[Integer](v.value).getOrElse(Integer.valueOf(1)) + } + .getOrElse(1) + case _ => 1 + } + def onResponse(msg: JsonRpcResponseMessage): Unit = { + pendingResults.remove(msg.id) match { + case null => + case (q, startTime) => + val now = System.currentTimeMillis + val message = timing(startTime, now) + val exitCode = getExitCode(msg.result) + if (batchMode.get || !attached.get) { + if (exitCode == 0) console.success(message) + else if (!attached.get) console.appendLog(Level.Error, message) + } + q.offer(exitCode) + } + msg.id match { + case execId => + if (attachUUID.get == msg.id) { + attachUUID.set(null) + attached.set(true) + Option(inputThread.get).foreach(_.drain()) + } + pendingCompletions.remove(execId) match { + case null => + case completions => + completions(msg.result match { + case Some(o: JObject) => + o.value + .foldLeft(CompletionResponse(Vector.empty[String])) { + case (resp, i) => + if (i.field == "items") + resp.withItems( + Converter + .fromJson[Vector[String]](i.value) + .getOrElse(Vector.empty[String]) + ) + else resp + } + case _ => CompletionResponse(Vector.empty[String]) + }) } - () - case _ => } } @@ -213,17 +280,33 @@ class NetworkClient( def splitToMessage: Vector[(Level.Value, String)] = (msg.method, msg.params) match { case ("build/logMessage", Some(json)) => - import sbt.internal.langserver.codec.JsonProtocol._ - Converter.fromJson[LogMessageParams](json) match { - case Success(params) => splitLogMessage(params) - case Failure(e) => Vector() + if (!attached.get) { + import sbt.internal.langserver.codec.JsonProtocol._ + Converter.fromJson[LogMessageParams](json) match { + case Success(params) => splitLogMessage(params) + case Failure(_) => Vector() + } + } else Vector() + case (`systemOut`, Some(json)) => + Converter.fromJson[Seq[Byte]](json) match { + case Success(params) => + if (params.nonEmpty) { + if (attached.get) { + printStream.write(params.toArray) + printStream.flush() + } + } + case Failure(_) => } + Vector.empty case ("textDocument/publishDiagnostics", Some(json)) => import sbt.internal.langserver.codec.JsonProtocol._ Converter.fromJson[PublishDiagnosticsParams](json) match { - case Success(params) => splitDiagnostics(params) - case Failure(e) => Vector() + case Success(params) => splitDiagnostics(params); Vector() + case Failure(_) => Vector() } + case ("shutdown", Some(_)) => Vector.empty + case (msg, _) if msg.startsWith("build/") => Vector.empty case _ => Vector( ( @@ -269,73 +352,191 @@ class NetworkClient( } def onRequest(msg: JsonRpcRequestMessage): Unit = { - // ignore - } - - def start(): Unit = { - console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") - val _ = connection - val userCommands = arguments.commandArguments.toList - if (userCommands.isEmpty) shell() - else batchExecute(userCommands) - } - - def batchExecute(userCommands: List[String]): Unit = { - userCommands foreach { cmd => - println("> " + cmd) - val execId = - if (cmd == "shutdown") sendExecCommand("exit") - else sendExecCommand(cmd) - while (pendingExecIds contains execId) { - Thread.sleep(100) - } + (msg.method, msg.params) match { + case (`terminalCapabilities`, Some(json)) => + import sbt.protocol.codec.JsonProtocol._ + Converter.fromJson[TerminalCapabilitiesQuery](json) match { + case Success(terminalCapabilitiesQuery) => + val response = TerminalCapabilitiesResponse( + terminalCapabilitiesQuery.boolean.map(Terminal.console.getBooleanCapability), + terminalCapabilitiesQuery.numeric.map(Terminal.console.getNumericCapability), + terminalCapabilitiesQuery.string + .map(s => Option(Terminal.console.getStringCapability(s)).getOrElse("null")), + ) + sendCommandResponse( + terminalCapabilitiesResponse, + response, + msg.id, + ) + case Failure(_) => + } + case (`terminalPropertiesQuery`, _) => + val response = TerminalPropertiesResponse.apply( + width = Terminal.console.getWidth, + height = Terminal.console.getHeight, + isAnsiSupported = Terminal.console.isAnsiSupported, + isColorEnabled = Terminal.console.isColorEnabled, + isSupershellEnabled = Terminal.console.isSupershellEnabled, + isEchoEnabled = Terminal.console.isEchoEnabled + ) + sendCommandResponse(terminalPropertiesResponse, response, msg.id) + case _ => } } - def shell(): Unit = { - val reader = LineReader.simple(None, LineReader.HandleCONT, injectThreadSleep = true) - while (running.get) { - reader.readLine("> ", None) match { - case Some("shutdown") => - // `sbt -client shutdown` shuts down the server - sendExecCommand("exit") - Thread.sleep(100) - running.set(false) - case Some("exit") => - running.set(false) - case Some(s) if s.trim.nonEmpty => - val execId = sendExecCommand(s) - while (pendingExecIds contains execId) { - Thread.sleep(100) - } - case _ => // - } + def connect(log: Boolean): Unit = { + if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") + init(retry = true) + () + } + + def run(): Int = { + interactiveThread.set(Thread.currentThread) + val cleaned = arguments.commandArguments + val userCommands = cleaned.takeWhile(_ != "exit") + val interactive = cleaned.isEmpty + val exit = cleaned.nonEmpty && userCommands.isEmpty + attachUUID.set(sendJson(attach, s"""{"interactive": $interactive}""")) + if (interactive) { + try this.synchronized(this.wait) + catch { case _: InterruptedException => } + if (exitClean.get) 0 else 1 + } else if (exit) { + 0 + } else { + batchMode.set(true) + batchExecute(userCommands.toList) } } - def sendExecCommand(commandLine: String): String = { + def batchExecute(userCommands: List[String]): Int = { + val cmd = userCommands mkString " " + printStream.println("> " + cmd) + sendAndWait(cmd, None) + } + + private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = { + val queue = sendExecCommand(cmd) + var result: Integer = null + while (running.get && result == null && limit.fold(true)(!_.isOverdue())) { + try { + result = limit match { + case Some(l) => queue.poll((l - Deadline.now).toMillis, TimeUnit.MILLISECONDS) + case _ => queue.take + } + } catch { + case _: InterruptedException if cmd == "shutdown" => result = 0 + case _: InterruptedException => result = if (exitClean.get) 0 else 1 + } + } + if (result == null) 1 else result + } + + def sendExecCommand(commandLine: String): LinkedBlockingQueue[Integer] = { val execId = UUID.randomUUID.toString + val queue = new LinkedBlockingQueue[Integer] sendCommand(ExecCommand(commandLine, execId)) - lock.synchronized { - pendingExecIds += execId - } - execId + pendingResults.put(execId, (queue, System.currentTimeMillis)) + queue } def sendCommand(command: CommandMessage): Unit = { try { val s = Serialization.serializeCommandAsJsonMessage(command) connection.sendString(s) + lock.synchronized { + status.set("Processing") + } } catch { - case _: IOException => - // log.debug(e.getMessage) - // toDel += client - } - lock.synchronized { - status.set("Processing") + case e: IOException => + errorStream.println(s"Caught exception writing command to server: $e") + running.set(false) } } - override def close(): Unit = {} + def sendCommandResponse(method: String, command: EventMessage, id: String): Unit = { + try { + val s = new String(Serialization.serializeEventMessage(command)) + val msg = s"""{ "jsonrpc": "2.0", "id": "$id", "result": $s }""" + connection.sendString(msg) + } catch { + case e: IOException => + errorStream.println(s"Caught exception writing command to server: $e") + running.set(false) + } + } + def sendJson(method: String, params: String): String = { + val uuid = UUID.randomUUID.toString + val msg = s"""{ "jsonrpc": "2.0", "id": "$uuid", "method": "$method", "params": $params }""" + connection.sendString(msg) + uuid + } + + def sendNotification(method: String, params: String): Unit = { + connection.sendString(s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params }""") + } + + override def close(): Unit = + try { + running.set(false) + stdinBytes.offer(-1) + val mainThread = interactiveThread.getAndSet(null) + if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt + connection.shutdown() + Option(inputThread.get).foreach(_.interrupt()) + } catch { + case t: Throwable => t.printStackTrace(); throw t + } + + private[this] class RawInputThread extends Thread("sbt-read-input-thread") with AutoCloseable { + setDaemon(true) + start() + val stopped = new AtomicBoolean(false) + val lock = new Object + override final def run(): Unit = { + @tailrec def read(): Unit = { + inputStream.read match { + case -1 => + case b => + lock.synchronized(stdinBytes.offer(b)) + if (attached.get()) drain() + if (!stopped.get()) read() + } + } + try Terminal.console.withRawSystemIn(read()) + catch { case _: InterruptedException | _: ClosedChannelException => stopped.set(true) } + } + + def drain(): Unit = lock.synchronized { + while (!stdinBytes.isEmpty) { + val byte = stdinBytes.poll() + sendNotification(systemIn, byte.toString) + } + } + + override def close(): Unit = { + RawInputThread.this.interrupt() + } + } + + // copied from Aggregation + private def timing(startTime: Long, endTime: Long): String = { + import java.text.DateFormat + val format = DateFormat.getDateTimeInstance(DateFormat.MEDIUM, DateFormat.MEDIUM) + val nowString = format.format(new java.util.Date(endTime)) + val total = math.max(0, (endTime - startTime + 500) / 1000) + val totalString = s"$total s" + + (if (total <= 60) "" + else { + val maybeHours = total / 3600 match { + case 0 => "" + case h => f"$h%02d:" + } + val mins = f"${total % 3600 / 60}%02d" + val secs = f"${total % 60}%02d" + s" ($maybeHours$mins:$secs)" + }) + s"Total time: $totalString, completed $nowString" + } } object NetworkClient { private def consoleAppenderInterface(printStream: PrintStream): ConsoleInterface = { @@ -400,8 +601,8 @@ object NetworkClient { try { val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) try { - client.start() - 0 + client.connect(log = true) + client.run() } catch { case _: Throwable => 1 } finally client.close() } catch { case NonFatal(e) =>