From 734a1e76417ad594514fc0380a19683fee8afe17 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 24 Jun 2020 16:57:08 -0700 Subject: [PATCH] Add virtual terminal support for network clients This commit adds support for remote clients to connect to the sbt server and attach themselves as a virtual terminal. In order to make this work, each connection must send a json rpc request to attach to the server. When this is received, the server will periodically query the remote client to get the terminal properties and capabilities that allow the remote client to act as a jline terminal proxy. There is also support for json messages with ids sbt/systemIn and sbt/systemOut that allow io to be relayed from the remote terminal to the sbt server and back. Certain commands such as `exit` should be evaluated immediately. To make this work, we add the concept of a MaintenanceTask. The CommandExchange has a background thread that reads MaintenanceTasks and evaluates them on demand. This allows maintenance tasks to be evaluated even when sbt is evaluating an exec. If it weren't done this way, when the user typed exit while a different remote connection was running a command, they wouldn't be able to exit until the command completed. The ServerIntents in ServerHandler did not handle JsonRpcResponseMessage because prior to this commit, sbt clients were primarily making requests to the server. But now the server sends requests to the client for the terminal properties and terminal capabilities so it was necessary to add an onResponse handler to ServerIntent. I had to move the network channel publishBytes method to run on a background thread because there were scenarios in which the client socket would get blocked because the server was trying to write on the same thread that the read the bytes from the client. To make the console command work, it is necessary to hijack the classloader for JLine. In MetaBuildLoader, we put a custom forked JLine that has a setter for the TerminalFactory singleton. This allows us to change the terminal that is used by JLine in ConsoleReader. Without this hack, the scala console would not work for remote clients. --- .../scala/sbt/internal/CommandChannel.scala | 67 +++-- .../sbt/internal/server/ServerHandler.scala | 17 +- .../java/sbt/internal/MetaBuildLoader.java | 38 ++- main/src/main/scala/sbt/Defaults.scala | 13 +- .../scala/sbt/internal/CommandExchange.scala | 49 +++- .../internal/server/BuildServerProtocol.scala | 5 +- .../server/LanguageServerProtocol.scala | 6 +- .../sbt/internal/server/NetworkChannel.scala | 231 +++++++++++++++++- .../sbt/internal/server/VirtualTerminal.scala | 122 +++++++++ .../sbt/protocol/Attach.scala | 32 +++ .../protocol/TerminalCapabilitiesQuery.scala | 50 ++++ .../TerminalCapabilitiesResponse.scala | 50 ++++ .../protocol/TerminalPropertiesResponse.scala | 52 ++++ .../sbt/protocol/codec/AttachFormats.scala | 27 ++ .../codec/CommandMessageFormats.scala | 4 +- .../protocol/codec/EventMessageFormats.scala | 4 +- .../sbt/protocol/codec/JsonProtocol.scala | 4 + .../TerminalCapabilitiesQueryFormats.scala | 31 +++ .../TerminalCapabilitiesResponseFormats.scala | 31 +++ .../TerminalPropertiesResponseFormats.scala | 37 +++ protocol/src/main/contraband/server.contra | 25 ++ .../scala/sbt/protocol/Serialization.scala | 22 ++ .../src/server-test/response/build.sbt | 5 +- 23 files changed, 863 insertions(+), 59 deletions(-) create mode 100644 main/src/main/scala/sbt/internal/server/VirtualTerminal.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/Attach.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesQuery.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesResponse.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/TerminalPropertiesResponse.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/AttachFormats.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesQueryFormats.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesResponseFormats.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalPropertiesResponseFormats.scala diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index d8ccca5b7..9096c8a20 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -12,6 +12,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import sbt.internal.util.Terminal import sbt.protocol.EventMessage +import scala.collection.JavaConverters._ /** * A command channel represents an IO device such as network socket or human @@ -20,31 +21,59 @@ import sbt.protocol.EventMessage */ abstract class CommandChannel { private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue() - private val registered: java.util.Set[java.util.Queue[CommandChannel]] = new java.util.HashSet - private[sbt] final def register(queue: java.util.Queue[CommandChannel]): Unit = { - registered.add(queue) - () + private val registered: java.util.Set[java.util.Queue[Exec]] = new java.util.HashSet + private val maintenance: java.util.Set[java.util.Queue[MaintenanceTask]] = new java.util.HashSet + private[sbt] final def register( + queue: java.util.Queue[Exec], + maintenanceQueue: java.util.Queue[MaintenanceTask] + ): Unit = + registered.synchronized { + registered.add(queue) + if (!commandQueue.isEmpty) { + queue.addAll(commandQueue) + commandQueue.clear() + } + maintenance.add(maintenanceQueue) + () + } + private[sbt] final def unregister( + queue: java.util.Queue[CommandChannel], + maintenanceQueue: java.util.Queue[MaintenanceTask] + ): Unit = + registered.synchronized { + registered.remove(queue) + maintenance.remove(maintenanceQueue) + () + } + private[sbt] final def initiateMaintenance(task: String): Unit = { + maintenance.forEach(q => q.synchronized { q.add(new MaintenanceTask(this, task)); () }) } - private[sbt] final def unregister(queue: java.util.Queue[CommandChannel]): Unit = { - registered.remove(queue) - () - } - def append(exec: Exec): Boolean = { - registered.forEach( - q => - q.synchronized { - if (!q.contains(this)) { - q.add(this); () - } - } - ) - commandQueue.add(exec) + final def append(exec: Exec): Boolean = { + registered.synchronized { + exec.commandLine.nonEmpty && { + if (registered.isEmpty) commandQueue.add(exec) + else registered.asScala.forall(_.add(exec)) + } + } } def poll: Option[Exec] = Option(commandQueue.poll) def publishBytes(bytes: Array[Byte]): Unit def shutdown(): Unit def name: String + private[sbt] def onCommand: String => Boolean = { + case cmd => + if (cmd.nonEmpty) append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) + else false + } + private[sbt] def onMaintenance: String => Boolean = { s: String => + maintenance.synchronized(maintenance.forEach { q => + q.add(new MaintenanceTask(this, s)) + () + }) + true + } + private[sbt] def terminal: Terminal } @@ -62,3 +91,5 @@ case class ConsolePromptEvent(state: State) extends EventMessage */ @deprecated("No longer used", "1.4.0") case class ConsoleUnpromptEvent(lastSource: Option[CommandSource]) extends EventMessage + +private[internal] class MaintenanceTask(val channel: CommandChannel, val task: String) diff --git a/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala b/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala index 78e7d849f..4cfbd9e3f 100644 --- a/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala +++ b/main-command/src/main/scala/sbt/internal/server/ServerHandler.scala @@ -29,14 +29,18 @@ object ServerHandler { lazy val fallback: ServerHandler = ServerHandler({ handler => ServerIntent( - { case x => handler.log.debug(s"Unhandled notification received: ${x.method}: $x") }, - { case x => handler.log.debug(s"Unhandled request received: ${x.method}: $x") } + onRequest = { case x => handler.log.debug(s"Unhandled request received: ${x.method}: $x") }, + onResponse = { case x => handler.log.debug(s"Unhandled responce received") }, + onNotification = { + case x => handler.log.debug(s"Unhandled notification received: ${x.method}: $x") + }, ) }) } final class ServerIntent( val onRequest: PartialFunction[JsonRpcRequestMessage, Unit], + val onResponse: PartialFunction[JsonRpcResponseMessage, Unit], val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] ) { override def toString: String = s"ServerIntent(...)" @@ -45,15 +49,18 @@ final class ServerIntent( object ServerIntent { def apply( onRequest: PartialFunction[JsonRpcRequestMessage, Unit], + onResponse: PartialFunction[JsonRpcResponseMessage, Unit], onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] ): ServerIntent = - new ServerIntent(onRequest, onNotification) + new ServerIntent(onRequest, onResponse, onNotification) def request(onRequest: PartialFunction[JsonRpcRequestMessage, Unit]): ServerIntent = - new ServerIntent(onRequest, PartialFunction.empty) + new ServerIntent(onRequest, PartialFunction.empty, PartialFunction.empty) + def response(onResponse: PartialFunction[JsonRpcResponseMessage, Unit]): ServerIntent = + new ServerIntent(PartialFunction.empty, onResponse, PartialFunction.empty) def notify(onNotification: PartialFunction[JsonRpcNotificationMessage, Unit]): ServerIntent = - new ServerIntent(PartialFunction.empty, onNotification) + new ServerIntent(PartialFunction.empty, PartialFunction.empty, onNotification) } /** diff --git a/main/src/main/java/sbt/internal/MetaBuildLoader.java b/main/src/main/java/sbt/internal/MetaBuildLoader.java index 253165605..ca67b1286 100644 --- a/main/src/main/java/sbt/internal/MetaBuildLoader.java +++ b/main/src/main/java/sbt/internal/MetaBuildLoader.java @@ -60,17 +60,21 @@ public final class MetaBuildLoader extends URLClassLoader { * library. */ public static MetaBuildLoader makeLoader(final AppProvider appProvider) throws IOException { - final Pattern pattern = Pattern.compile("test-interface-[0-9.]+\\.jar"); + final Pattern pattern = + Pattern.compile("(test-interface-[0-9.]+|jline-[0-9.]+-sbt-.*|jansi-[0-9.]+)\\.jar"); final File[] cp = appProvider.mainClasspath(); - final URL[] interfaceURL = new URL[1]; + final URL[] interfaceURLs = new URL[3]; final File[] extra = appProvider.id().classpathExtra() == null ? new File[0] : appProvider.id().classpathExtra(); final Set bottomClasspath = new LinkedHashSet<>(); { + int interfaceIndex = 0; for (final File file : cp) { - if (pattern.matcher(file.getName()).find()) { - interfaceURL[0] = file.toURI().toURL(); + final String name = file.getName(); + if (pattern.matcher(name).find()) { + interfaceURLs[interfaceIndex] = file.toURI().toURL(); + interfaceIndex += 1; } else { bottomClasspath.add(file); } @@ -88,11 +92,29 @@ public final class MetaBuildLoader extends URLClassLoader { } } final ScalaProvider scalaProvider = appProvider.scalaProvider(); - final ClassLoader topLoader = scalaProvider.launcher().topLoader(); - final TestInterfaceLoader interfaceLoader = new TestInterfaceLoader(interfaceURL, topLoader); + ClassLoader topLoader = scalaProvider.launcher().topLoader(); + boolean foundSBTLoader = false; + while (!foundSBTLoader && topLoader != null) { + if (topLoader instanceof URLClassLoader) { + for (final URL u : ((URLClassLoader) topLoader).getURLs()) { + if (u.toString().contains("test-interface")) { + topLoader = topLoader.getParent(); + foundSBTLoader = true; + } + } + } + if (!foundSBTLoader) topLoader = topLoader.getParent(); + } + if (topLoader == null) topLoader = scalaProvider.launcher().topLoader(); + + final TestInterfaceLoader interfaceLoader = new TestInterfaceLoader(interfaceURLs, topLoader); final File[] siJars = scalaProvider.jars(); final URL[] lib = new URL[1]; - final URL[] scalaRest = new URL[Math.max(0, siJars.length - 1)]; + int scalaRestCount = siJars.length - 1; + for (final File file : siJars) { + if (pattern.matcher(file.getName()).find()) scalaRestCount -= 1; + } + final URL[] scalaRest = new URL[Math.max(0, scalaRestCount)]; { int i = 0; @@ -101,7 +123,7 @@ public final class MetaBuildLoader extends URLClassLoader { final File file = siJars[i]; if (file.getName().equals("scala-library.jar")) { lib[0] = file.toURI().toURL(); - } else { + } else if (!pattern.matcher(file.getName()).find()) { scalaRest[j] = file.toURI().toURL(); j += 1; } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 5755451a5..b3d6f3e92 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -48,7 +48,8 @@ import sbt.internal.server.{ BuildServerReporter, Definition, LanguageServerProtocol, - ServerHandler + ServerHandler, + VirtualTerminal, } import sbt.internal.testing.TestLogger import sbt.internal.util.Attributed.data @@ -208,7 +209,8 @@ object Defaults extends BuildCommon { Seq( LanguageServerProtocol.handler(fileConverter.value), BuildServerProtocol - .handler(sbtVersion.value, semanticdbEnabled.value, semanticdbVersion.value) + .handler(sbtVersion.value, semanticdbEnabled.value, semanticdbVersion.value), + VirtualTerminal.handler, ) ++ serverHandlers.value :+ ServerHandler.fallback }, uncachedStamper := Stamps.uncachedStamps(fileConverter.value), @@ -342,15 +344,12 @@ object Defaults extends BuildCommon { () => Clean.deleteContents(tempDirectory, _ => false) }, turbo :== SysProp.turbo, - useSuperShell := { if (insideCI.value) false else SysProp.supershell }, + useSuperShell := { if (insideCI.value) false else Terminal.console.isSupershellEnabled }, progressReports := { val rs = EvaluateTask.taskTimingProgress.toVector ++ EvaluateTask.taskTraceEvent.toVector rs map { Keys.TaskProgress(_) } }, - progressState := { - if ((ThisBuild / useSuperShell).value) Some(new ProgressState(SysProp.supershellBlankZone)) - else None - }, + progressState := Some(new ProgressState(SysProp.supershellBlankZone)), Previous.cache := new Previous( Def.streamsManagerKey.value, Previous.references.value.getReferences diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index af2cb478a..6b3178062 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -46,8 +46,10 @@ private[sbt] final class CommandExchange { private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer() private val channelBufferLock = new AnyRef {} private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] + private val maintenanceChannelQueue = new LinkedBlockingQueue[MaintenanceTask] private val nextChannelId: AtomicInteger = new AtomicInteger(0) private[this] val activePrompt = new AtomicBoolean(false) + private[this] val lastState = new AtomicReference[State] private[this] val currentExecRef = new AtomicReference[Exec] def channels: List[CommandChannel] = channelBuffer.toList @@ -60,9 +62,10 @@ private[sbt] final class CommandExchange { def subscribe(c: CommandChannel): Unit = channelBufferLock.synchronized { channelBuffer.append(c) - c.register(commandChannelQueue) + c.register(commandQueue, maintenanceChannelQueue) } + private[sbt] def withState[T](f: State => T): T = f(lastState.get) def blockUntilNextExec: Exec = blockUntilNextExec(Duration.Inf, NullLogger) // periodically move all messages from all the channels private[sbt] def blockUntilNextExec(interval: Duration, logger: Logger): Exec = { @@ -110,6 +113,7 @@ private[sbt] final class CommandExchange { if (autoStartServerSysProp && autoStartServerAttr) runServer(s) else s } + private[sbt] def setState(s: State): Unit = lastState.set(s) private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" @@ -191,6 +195,7 @@ private[sbt] final class CommandExchange { } def shutdown(): Unit = { + maintenanceThread.close() channels foreach (_.shutdown()) // interrupt and kill the thread server.foreach(_.shutdown()) @@ -311,4 +316,46 @@ private[sbt] final class CommandExchange { } channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal)) } + + private[sbt] def shutdown(name: String): Unit = { + commandQueue.clear() + val exit = + Exec("shutdown", Some(Exec.newExecId), Some(CommandSource(name))) + commandQueue.add(exit) + () + } + + private[this] class MaintenanceThread + extends Thread("sbt-command-exchange-maintenance") + with AutoCloseable { + setDaemon(true) + start() + private[this] val isStopped = new AtomicBoolean(false) + override def run(): Unit = { + def exit(mt: MaintenanceTask): Unit = { + mt.channel.shutdown() + if (mt.channel.name.contains("console")) shutdown(mt.channel.name) + } + @tailrec def impl(): Unit = { + maintenanceChannelQueue.take match { + case null => + case mt: MaintenanceTask => + mt.task match { + case "exit" => exit(mt) + case "shutdown" => shutdown(mt.channel.name) + case _ => + } + } + if (!isStopped.get) impl() + } + try impl() + catch { case _: InterruptedException => } + } + override def close(): Unit = if (isStopped.compareAndSet(false, true)) { + interrupt() + } + } + private[sbt] def channelForName(channelName: String): Option[CommandChannel] = + channels.find(_.name == channelName) + private[this] val maintenanceThread = new MaintenanceThread } diff --git a/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala b/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala index 5af24d4a6..d4d8af20a 100644 --- a/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala @@ -132,7 +132,7 @@ object BuildServerProtocol { semanticdbVersion: String ): ServerHandler = ServerHandler { callback => ServerIntent( - { + onRequest = { case r: JsonRpcRequestMessage if r.method == "build/initialize" => val params = Converter.fromJson[InitializeBuildParams](json(r)).get checkMetalsCompatibility(semanticdbEnabled, semanticdbVersion, params, callback.log) @@ -180,7 +180,8 @@ object BuildServerProtocol { val command = Keys.bspBuildTargetScalacOptions.key val _ = callback.appendExec(s"$command $targets", Some(r.id)) }, - PartialFunction.empty + onResponse = PartialFunction.empty, + onNotification = PartialFunction.empty, ) } diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 5ab0e491b..57c1c9c13 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -45,7 +45,7 @@ private[sbt] object LanguageServerProtocol { def handler(converter: FileConverter): ServerHandler = ServerHandler { callback => import callback._ ServerIntent( - { + onRequest = { case r: JsonRpcRequestMessage if r.method == "initialize" => val param = Converter.fromJson[InitializeParams](json(r)).get val optionJson = param.initializationOptions.getOrElse( @@ -86,7 +86,9 @@ private[sbt] object LanguageServerProtocol { val param = Converter.fromJson[CP](json(r)).get onCompletionRequest(Option(r.id), param) - }, { + }, + onResponse = PartialFunction.empty, + onNotification = { case n: JsonRpcNotificationMessage if n.method == "textDocument/didSave" => val _ = appendExec(";Test/compile; collectAnalyses", None) } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index c82476d92..edfc488cc 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -9,9 +9,11 @@ package sbt package internal package server -import java.io.IOException +import java.io.{ IOException, InputStream, OutputStream } import java.net.{ Socket, SocketTimeoutException } -import java.util.concurrent.atomic.AtomicBoolean +import java.nio.channels.ClosedChannelException +import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue } +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes, LogMessageParams, MessageType } import sbt.internal.protocol.{ @@ -20,16 +22,20 @@ import sbt.internal.protocol.{ JsonRpcResponseError, JsonRpcResponseMessage } -import sbt.internal.util.{ ReadJsonFromInputStream, Terminal } +import sbt.internal.util.{ ReadJsonFromInputStream, Prompt, Terminal, Util } +import sbt.internal.util.Terminal.TerminalImpl import sbt.internal.util.complete.Parser import sbt.protocol._ import sbt.util.Logger import sjsonnew._ import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } +import scala.annotation.tailrec import scala.collection.mutable +import scala.concurrent.duration._ import scala.util.Try import scala.util.control.NonFatal +import Serialization.attach final class NetworkChannel( val name: String, @@ -47,7 +53,29 @@ final class NetworkChannel( private var initialized = false private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map() - override private[sbt] def terminal: Terminal = Terminal.NullTerminal + private[this] val inputBuffer = new LinkedBlockingQueue[Byte]() + private[this] val pendingWrites = new LinkedBlockingQueue[(Array[Byte], Boolean)]() + private[this] val attached = new AtomicBoolean(false) + private[this] val alive = new AtomicBoolean(true) + private[sbt] def isInteractive = interactive.get + private[this] val interactive = new AtomicBoolean(false) + private[sbt] def setInteractive(id: String, value: Boolean) = { + terminalHolder.getAndSet(new NetworkTerminal) match { + case null => + case t => t.close() + } + interactive.set(value) + if (!isInteractive) terminal.setPrompt(Prompt.Batch) + attached.set(true) + pendingRequests.remove(id) + import sjsonnew.BasicJsonProtocol._ + jsonRpcRespond("", id) + initiateMaintenance(attach) + } + private[sbt] def write(byte: Byte) = inputBuffer.add(byte) + + private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal) + override private[sbt] def terminal: Terminal = terminalHolder.get private lazy val callback: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = @@ -121,6 +149,10 @@ final class NetworkChannel( intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) { case (f, i) => f orElse i.onRequest } + lazy val onResponseMessage: PartialFunction[JsonRpcResponseMessage, Unit] = + intents.foldLeft(PartialFunction.empty[JsonRpcResponseMessage, Unit]) { + case (f, i) => f orElse i.onResponse + } lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] = intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) { @@ -138,6 +170,8 @@ final class NetworkChannel( log.debug(s"sending error: $code: $message") respondError(code, message, Some(req.id)) } + case Right(res: JsonRpcResponseMessage) => + onResponseMessage(res) case Right(ntf: JsonRpcNotificationMessage) => try { onNotification(ntf) @@ -222,13 +256,43 @@ final class NetworkChannel( def publishBytes(event: Array[Byte]): Unit = publishBytes(event, false) - def publishBytes(event: Array[Byte], delimit: Boolean): Unit = { - out.write(event) - if (delimit) { - out.write(delimiter.toInt) + /* + * Do writes on a background thread because otherwise the client socket can get blocked. + */ + private[this] val writeThread = new Thread(() => { + @tailrec def impl(): Unit = { + val (event, delimit) = + try pendingWrites.take + catch { + case _: InterruptedException => + alive.set(false) + (Array.empty[Byte], false) + } + if (alive.get) { + try { + out.write(event) + if (delimit) { + out.write(delimiter.toInt) + } + out.flush() + } catch { + case _: IOException => + alive.set(false) + shutdown() + case _: InterruptedException => + alive.set(false) + } + impl() + } } - out.flush() - } + impl() + }, s"sbt-$name-write-thread") + writeThread.setDaemon(true) + writeThread.start() + + def publishBytes(event: Array[Byte], delimit: Boolean): Unit = + try pendingWrites.put(event -> delimit) + catch { case _: InterruptedException => } def onCommand(command: CommandMessage): Unit = command match { case x: InitCommand => onInitCommand(x) @@ -418,6 +482,15 @@ final class NetworkChannel( publishBytes(bytes) } + /** Notify to Language Server's client. */ + private[sbt] def jsonRpcRequest[A: JsonFormat](id: String, method: String, params: A): Unit = { + val m = + JsonRpcRequestMessage("2.0", id, method, Option(Converter.toJson[A](params).get)) + log.debug(s"jsonRpcRequest: $m") + val bytes = Serialization.serializeRequestMessage(m) + publishBytes(bytes) + } + def logMessage(level: String, message: String): Unit = { import sbt.internal.langserver.codec.JsonProtocol._ jsonRpcNotify( @@ -425,6 +498,144 @@ final class NetworkChannel( LogMessageParams(MessageType.fromLevelString(level), message) ) } + private[this] lazy val inputStream: InputStream = new InputStream { + override def read(): Int = { + try { + inputBuffer.take & 0xFF match { + case -1 => throw new ClosedChannelException() + case b => b + } + } catch { case _: IOException => -1 } + } + override def available(): Int = inputBuffer.size + } + import sjsonnew.BasicJsonProtocol._ + + import scala.collection.JavaConverters._ + private[this] lazy val outputStream: OutputStream = new OutputStream { + private[this] val buffer = new LinkedBlockingQueue[Byte]() + override def write(b: Int): Unit = buffer.put(b.toByte) + override def flush(): Unit = { + jsonRpcNotify(Serialization.systemOut, buffer.asScala) + buffer.clear() + } + override def write(b: Array[Byte]): Unit = write(b, 0, b.length) + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + var i = off + while (i < len) { + buffer.put(b(i)) + i += 1 + } + } + } + private class NetworkTerminal extends TerminalImpl(inputStream, outputStream, name) { + private[this] val pending = new AtomicBoolean(false) + private[this] val closed = new AtomicBoolean(false) + private[this] val properties = new AtomicReference[TerminalPropertiesResponse] + private[this] val lastUpdate = new AtomicReference[Deadline] + private def empty = TerminalPropertiesResponse(0, 0, false, false, false, false) + def getProperties(block: Boolean): Unit = { + if (alive.get) { + if (!pending.get && Option(lastUpdate.get).fold(true)(d => (d + 1.second).isOverdue)) { + pending.set(true) + val queue = VirtualTerminal.sendTerminalPropertiesQuery(name, jsonRpcRequest) + val update: Runnable = () => { + queue.poll(5, java.util.concurrent.TimeUnit.SECONDS) match { + case null => + case t => properties.set(t) + } + pending.synchronized { + lastUpdate.set(Deadline.now) + pending.set(false) + pending.notifyAll() + } + } + new Thread(update, s"network-terminal-$name-update") { + setDaemon(true) + }.start() + } + while (block && properties.get == null) pending.synchronized(pending.wait()) + () + } else throw new InterruptedException + } + private def withThread[R](f: => R, default: R): R = { + val t = Thread.currentThread + try { + blockedThreads.synchronized(blockedThreads.add(t)) + f + } catch { case _: InterruptedException => default } finally { + Util.ignoreResult(blockedThreads.synchronized(blockedThreads.remove(t))) + } + } + def getProperty[T](f: TerminalPropertiesResponse => T, default: T): Option[T] = { + if (closed.get || !isAttached) None + else + withThread({ + getProperties(true); + Some(f(Option(properties.get).getOrElse(empty))) + }, None) + } + private[this] def waitForPending(f: TerminalPropertiesResponse => Boolean): Boolean = { + if (closed.get || !isAttached) false + withThread( + { + if (pending.get) pending.synchronized(pending.wait()) + Option(properties.get).map(f).getOrElse(false) + }, + false + ) + } + private[this] val blockedThreads = ConcurrentHashMap.newKeySet[Thread] + override def getWidth: Int = getProperty(_.width, 0).getOrElse(0) + override def getHeight: Int = getProperty(_.height, 0).getOrElse(0) + override def isAnsiSupported: Boolean = getProperty(_.isAnsiSupported, false).getOrElse(false) + override def isEchoEnabled: Boolean = waitForPending(_.isEchoEnabled) + override def isSuccessEnabled: Boolean = prompt != Prompt.Batch + override lazy val isColorEnabled: Boolean = waitForPending(_.isColorEnabled) + override lazy val isSupershellEnabled: Boolean = waitForPending(_.isSupershellEnabled) + getProperties(false) + private def getCapability[T]( + query: TerminalCapabilitiesQuery, + result: TerminalCapabilitiesResponse => T + ): Option[T] = { + if (closed.get) None + else { + import sbt.protocol.codec.JsonProtocol._ + val queue = VirtualTerminal.sendTerminalCapabilitiesQuery(name, jsonRpcRequest, query) + Some(result(queue.take)) + } + } + override def getBooleanCapability(capability: String): Boolean = + getCapability( + TerminalCapabilitiesQuery(boolean = Some(capability), numeric = None, string = None), + _.boolean.getOrElse(false) + ).getOrElse(false) + override def getNumericCapability(capability: String): Int = + getCapability( + TerminalCapabilitiesQuery(boolean = None, numeric = Some(capability), string = None), + _.numeric.getOrElse(-1) + ).getOrElse(-1) + override def getStringCapability(capability: String): String = + getCapability( + TerminalCapabilitiesQuery(boolean = None, numeric = None, string = Some(capability)), + _.string.flatMap { + case "null" => None + case s => Some(s) + }.orNull + ).getOrElse("") + + override def toString: String = s"NetworkTerminal($name)" + override def close(): Unit = if (closed.compareAndSet(false, true)) { + val threads = blockedThreads.synchronized { + val t = blockedThreads.asScala.toVector + blockedThreads.clear() + t + } + threads.foreach(_.interrupt()) + super.close() + } + } + private[sbt] def isAttached: Boolean = attached.get } object NetworkChannel { diff --git a/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala b/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala new file mode 100644 index 000000000..55659ef61 --- /dev/null +++ b/main/src/main/scala/sbt/internal/server/VirtualTerminal.scala @@ -0,0 +1,122 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt +package internal +package server + +import java.util.concurrent.{ ArrayBlockingQueue, ConcurrentHashMap } +import java.util.UUID +import sbt.internal.protocol.{ + JsonRpcNotificationMessage, + JsonRpcRequestMessage, + JsonRpcResponseMessage +} +import sbt.protocol.Serialization.{ + attach, + systemIn, + terminalCapabilities, + terminalPropertiesQuery, +} +import sjsonnew.support.scalajson.unsafe.Converter +import sbt.protocol.{ + Attach, + TerminalCapabilitiesQuery, + TerminalCapabilitiesResponse, + TerminalPropertiesResponse +} + +object VirtualTerminal { + private[this] val pendingTerminalProperties = + new ConcurrentHashMap[(String, String), ArrayBlockingQueue[TerminalPropertiesResponse]]() + private[this] val pendingTerminalCapabilities = + new ConcurrentHashMap[(String, String), ArrayBlockingQueue[TerminalCapabilitiesResponse]] + private[sbt] def sendTerminalPropertiesQuery( + channelName: String, + jsonRpcRequest: (String, String, String) => Unit + ): ArrayBlockingQueue[TerminalPropertiesResponse] = { + val id = UUID.randomUUID.toString + val queue = new ArrayBlockingQueue[TerminalPropertiesResponse](1) + pendingTerminalProperties.put((channelName, id), queue) + jsonRpcRequest(id, terminalPropertiesQuery, "") + queue + } + private[sbt] def sendTerminalCapabilitiesQuery( + channelName: String, + jsonRpcRequest: (String, String, TerminalCapabilitiesQuery) => Unit, + query: TerminalCapabilitiesQuery, + ): ArrayBlockingQueue[TerminalCapabilitiesResponse] = { + val id = UUID.randomUUID.toString + val queue = new ArrayBlockingQueue[TerminalCapabilitiesResponse](1) + pendingTerminalCapabilities.put((channelName, id), queue) + jsonRpcRequest(id, terminalCapabilities, query) + queue + } + private[sbt] def cancelRequests(name: String): Unit = { + pendingTerminalCapabilities.forEach { + case (k @ (`name`, _), q) => + pendingTerminalCapabilities.remove(k) + q.put(TerminalCapabilitiesResponse(None, None, None)) + case _ => + } + pendingTerminalProperties.forEach { + case (k @ (`name`, _), q) => + pendingTerminalProperties.remove(k) + q.put(TerminalPropertiesResponse(0, 0, false, false, false, false)) + case _ => + } + } + val handler = ServerHandler { cb => + ServerIntent(requestHandler(cb), responseHandler(cb), notificationHandler(cb)) + } + type Handler[R] = ServerCallback => PartialFunction[R, Unit] + private val requestHandler: Handler[JsonRpcRequestMessage] = + callback => { + case r if r.method == attach => + import sbt.protocol.codec.JsonProtocol.AttachFormat + val isInteractive = r.params + .flatMap(Converter.fromJson[Attach](_).toOption.map(_.interactive)) + .exists(identity) + StandardMain.exchange.channelForName(callback.name) match { + case Some(nc: NetworkChannel) => nc.setInteractive(r.id, isInteractive) + case _ => + } + } + private val responseHandler: Handler[JsonRpcResponseMessage] = + callback => { + case r if pendingTerminalProperties.get((callback.name, r.id)) != null => + import sbt.protocol.codec.JsonProtocol._ + val response = + r.result.flatMap(Converter.fromJson[TerminalPropertiesResponse](_).toOption) + pendingTerminalProperties.remove((callback.name, r.id)) match { + case null => + case buffer => response.foreach(buffer.put) + } + case r if pendingTerminalCapabilities.get((callback.name, r.id)) != null => + import sbt.protocol.codec.JsonProtocol._ + val response = + r.result.flatMap( + Converter.fromJson[TerminalCapabilitiesResponse](_).toOption + ) + pendingTerminalCapabilities.remove((callback.name, r.id)) match { + case null => + case buffer => + buffer.put(response.getOrElse(TerminalCapabilitiesResponse(None, None, None))) + } + } + private val notificationHandler: Handler[JsonRpcNotificationMessage] = + callback => { + case n if n.method == systemIn => + import sjsonnew.BasicJsonProtocol._ + n.params.flatMap(Converter.fromJson[Byte](_).toOption).foreach { byte => + StandardMain.exchange.channelForName(callback.name) match { + case Some(nc: NetworkChannel) => nc.write(byte) + case _ => + } + } + } +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/Attach.scala b/protocol/src/main/contraband-scala/sbt/protocol/Attach.scala new file mode 100644 index 000000000..e06846ae8 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/Attach.scala @@ -0,0 +1,32 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class Attach private ( + val interactive: Boolean) extends sbt.protocol.CommandMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: Attach => (this.interactive == x.interactive) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (17 + "sbt.protocol.Attach".##) + interactive.##) + } + override def toString: String = { + "Attach(" + interactive + ")" + } + private[this] def copy(interactive: Boolean = interactive): Attach = { + new Attach(interactive) + } + def withInteractive(interactive: Boolean): Attach = { + copy(interactive = interactive) + } +} +object Attach { + + def apply(interactive: Boolean): Attach = new Attach(interactive) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesQuery.scala b/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesQuery.scala new file mode 100644 index 000000000..2e270924c --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesQuery.scala @@ -0,0 +1,50 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class TerminalCapabilitiesQuery private ( + val boolean: Option[String], + val numeric: Option[String], + val string: Option[String]) extends sbt.protocol.CommandMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TerminalCapabilitiesQuery => (this.boolean == x.boolean) && (this.numeric == x.numeric) && (this.string == x.string) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (37 * (17 + "sbt.protocol.TerminalCapabilitiesQuery".##) + boolean.##) + numeric.##) + string.##) + } + override def toString: String = { + "TerminalCapabilitiesQuery(" + boolean + ", " + numeric + ", " + string + ")" + } + private[this] def copy(boolean: Option[String] = boolean, numeric: Option[String] = numeric, string: Option[String] = string): TerminalCapabilitiesQuery = { + new TerminalCapabilitiesQuery(boolean, numeric, string) + } + def withBoolean(boolean: Option[String]): TerminalCapabilitiesQuery = { + copy(boolean = boolean) + } + def withBoolean(boolean: String): TerminalCapabilitiesQuery = { + copy(boolean = Option(boolean)) + } + def withNumeric(numeric: Option[String]): TerminalCapabilitiesQuery = { + copy(numeric = numeric) + } + def withNumeric(numeric: String): TerminalCapabilitiesQuery = { + copy(numeric = Option(numeric)) + } + def withString(string: Option[String]): TerminalCapabilitiesQuery = { + copy(string = string) + } + def withString(string: String): TerminalCapabilitiesQuery = { + copy(string = Option(string)) + } +} +object TerminalCapabilitiesQuery { + + def apply(boolean: Option[String], numeric: Option[String], string: Option[String]): TerminalCapabilitiesQuery = new TerminalCapabilitiesQuery(boolean, numeric, string) + def apply(boolean: String, numeric: String, string: String): TerminalCapabilitiesQuery = new TerminalCapabilitiesQuery(Option(boolean), Option(numeric), Option(string)) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesResponse.scala b/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesResponse.scala new file mode 100644 index 000000000..14eda0982 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/TerminalCapabilitiesResponse.scala @@ -0,0 +1,50 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class TerminalCapabilitiesResponse private ( + val boolean: Option[Boolean], + val numeric: Option[Int], + val string: Option[String]) extends sbt.protocol.EventMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TerminalCapabilitiesResponse => (this.boolean == x.boolean) && (this.numeric == x.numeric) && (this.string == x.string) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (37 * (17 + "sbt.protocol.TerminalCapabilitiesResponse".##) + boolean.##) + numeric.##) + string.##) + } + override def toString: String = { + "TerminalCapabilitiesResponse(" + boolean + ", " + numeric + ", " + string + ")" + } + private[this] def copy(boolean: Option[Boolean] = boolean, numeric: Option[Int] = numeric, string: Option[String] = string): TerminalCapabilitiesResponse = { + new TerminalCapabilitiesResponse(boolean, numeric, string) + } + def withBoolean(boolean: Option[Boolean]): TerminalCapabilitiesResponse = { + copy(boolean = boolean) + } + def withBoolean(boolean: Boolean): TerminalCapabilitiesResponse = { + copy(boolean = Option(boolean)) + } + def withNumeric(numeric: Option[Int]): TerminalCapabilitiesResponse = { + copy(numeric = numeric) + } + def withNumeric(numeric: Int): TerminalCapabilitiesResponse = { + copy(numeric = Option(numeric)) + } + def withString(string: Option[String]): TerminalCapabilitiesResponse = { + copy(string = string) + } + def withString(string: String): TerminalCapabilitiesResponse = { + copy(string = Option(string)) + } +} +object TerminalCapabilitiesResponse { + + def apply(boolean: Option[Boolean], numeric: Option[Int], string: Option[String]): TerminalCapabilitiesResponse = new TerminalCapabilitiesResponse(boolean, numeric, string) + def apply(boolean: Boolean, numeric: Int, string: String): TerminalCapabilitiesResponse = new TerminalCapabilitiesResponse(Option(boolean), Option(numeric), Option(string)) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/TerminalPropertiesResponse.scala b/protocol/src/main/contraband-scala/sbt/protocol/TerminalPropertiesResponse.scala new file mode 100644 index 000000000..c0552ab8c --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/TerminalPropertiesResponse.scala @@ -0,0 +1,52 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class TerminalPropertiesResponse private ( + val width: Int, + val height: Int, + val isAnsiSupported: Boolean, + val isColorEnabled: Boolean, + val isSupershellEnabled: Boolean, + val isEchoEnabled: Boolean) extends sbt.protocol.EventMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TerminalPropertiesResponse => (this.width == x.width) && (this.height == x.height) && (this.isAnsiSupported == x.isAnsiSupported) && (this.isColorEnabled == x.isColorEnabled) && (this.isSupershellEnabled == x.isSupershellEnabled) && (this.isEchoEnabled == x.isEchoEnabled) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.protocol.TerminalPropertiesResponse".##) + width.##) + height.##) + isAnsiSupported.##) + isColorEnabled.##) + isSupershellEnabled.##) + isEchoEnabled.##) + } + override def toString: String = { + "TerminalPropertiesResponse(" + width + ", " + height + ", " + isAnsiSupported + ", " + isColorEnabled + ", " + isSupershellEnabled + ", " + isEchoEnabled + ")" + } + private[this] def copy(width: Int = width, height: Int = height, isAnsiSupported: Boolean = isAnsiSupported, isColorEnabled: Boolean = isColorEnabled, isSupershellEnabled: Boolean = isSupershellEnabled, isEchoEnabled: Boolean = isEchoEnabled): TerminalPropertiesResponse = { + new TerminalPropertiesResponse(width, height, isAnsiSupported, isColorEnabled, isSupershellEnabled, isEchoEnabled) + } + def withWidth(width: Int): TerminalPropertiesResponse = { + copy(width = width) + } + def withHeight(height: Int): TerminalPropertiesResponse = { + copy(height = height) + } + def withIsAnsiSupported(isAnsiSupported: Boolean): TerminalPropertiesResponse = { + copy(isAnsiSupported = isAnsiSupported) + } + def withIsColorEnabled(isColorEnabled: Boolean): TerminalPropertiesResponse = { + copy(isColorEnabled = isColorEnabled) + } + def withIsSupershellEnabled(isSupershellEnabled: Boolean): TerminalPropertiesResponse = { + copy(isSupershellEnabled = isSupershellEnabled) + } + def withIsEchoEnabled(isEchoEnabled: Boolean): TerminalPropertiesResponse = { + copy(isEchoEnabled = isEchoEnabled) + } +} +object TerminalPropertiesResponse { + + def apply(width: Int, height: Int, isAnsiSupported: Boolean, isColorEnabled: Boolean, isSupershellEnabled: Boolean, isEchoEnabled: Boolean): TerminalPropertiesResponse = new TerminalPropertiesResponse(width, height, isAnsiSupported, isColorEnabled, isSupershellEnabled, isEchoEnabled) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/AttachFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/AttachFormats.scala new file mode 100644 index 000000000..a8be3f15e --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/AttachFormats.scala @@ -0,0 +1,27 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait AttachFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val AttachFormat: JsonFormat[sbt.protocol.Attach] = new JsonFormat[sbt.protocol.Attach] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.Attach = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val interactive = unbuilder.readField[Boolean]("interactive") + unbuilder.endObject() + sbt.protocol.Attach(interactive) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.Attach, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("interactive", obj.interactive) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala index 8c6ae04ac..ee79ca457 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => -implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat3[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") +trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.AttachFormats with sbt.protocol.codec.TerminalCapabilitiesQueryFormats => +implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat5[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery, sbt.protocol.Attach, sbt.protocol.TerminalCapabilitiesQuery]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala index 4d17fb72c..2694b0078 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/EventMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait EventMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ChannelAcceptedEventFormats with sbt.protocol.codec.LogEventFormats with sbt.protocol.codec.ExecStatusEventFormats with sbt.internal.util.codec.JValueFormats with sbt.protocol.codec.SettingQuerySuccessFormats with sbt.protocol.codec.SettingQueryFailureFormats => -implicit lazy val EventMessageFormat: JsonFormat[sbt.protocol.EventMessage] = flatUnionFormat5[sbt.protocol.EventMessage, sbt.protocol.ChannelAcceptedEvent, sbt.protocol.LogEvent, sbt.protocol.ExecStatusEvent, sbt.protocol.SettingQuerySuccess, sbt.protocol.SettingQueryFailure]("type") +trait EventMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ChannelAcceptedEventFormats with sbt.protocol.codec.LogEventFormats with sbt.protocol.codec.ExecStatusEventFormats with sbt.internal.util.codec.JValueFormats with sbt.protocol.codec.SettingQuerySuccessFormats with sbt.protocol.codec.SettingQueryFailureFormats with sbt.protocol.codec.TerminalPropertiesResponseFormats with sbt.protocol.codec.TerminalCapabilitiesResponseFormats => +implicit lazy val EventMessageFormat: JsonFormat[sbt.protocol.EventMessage] = flatUnionFormat7[sbt.protocol.EventMessage, sbt.protocol.ChannelAcceptedEvent, sbt.protocol.LogEvent, sbt.protocol.ExecStatusEvent, sbt.protocol.SettingQuerySuccess, sbt.protocol.SettingQueryFailure, sbt.protocol.TerminalPropertiesResponse, sbt.protocol.TerminalCapabilitiesResponse]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala index 67584f980..de4aba238 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala @@ -8,6 +8,8 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats + with sbt.protocol.codec.AttachFormats + with sbt.protocol.codec.TerminalCapabilitiesQueryFormats with sbt.protocol.codec.CommandMessageFormats with sbt.protocol.codec.CompletionParamsFormats with sbt.protocol.codec.ChannelAcceptedEventFormats @@ -16,6 +18,8 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.internal.util.codec.JValueFormats with sbt.protocol.codec.SettingQuerySuccessFormats with sbt.protocol.codec.SettingQueryFailureFormats + with sbt.protocol.codec.TerminalPropertiesResponseFormats + with sbt.protocol.codec.TerminalCapabilitiesResponseFormats with sbt.protocol.codec.EventMessageFormats with sbt.protocol.codec.SettingQueryResponseFormats with sbt.protocol.codec.CompletionResponseFormats diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesQueryFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesQueryFormats.scala new file mode 100644 index 000000000..a26886a46 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesQueryFormats.scala @@ -0,0 +1,31 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TerminalCapabilitiesQueryFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TerminalCapabilitiesQueryFormat: JsonFormat[sbt.protocol.TerminalCapabilitiesQuery] = new JsonFormat[sbt.protocol.TerminalCapabilitiesQuery] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.TerminalCapabilitiesQuery = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val boolean = unbuilder.readField[Option[String]]("boolean") + val numeric = unbuilder.readField[Option[String]]("numeric") + val string = unbuilder.readField[Option[String]]("string") + unbuilder.endObject() + sbt.protocol.TerminalCapabilitiesQuery(boolean, numeric, string) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.TerminalCapabilitiesQuery, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("boolean", obj.boolean) + builder.addField("numeric", obj.numeric) + builder.addField("string", obj.string) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesResponseFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesResponseFormats.scala new file mode 100644 index 000000000..eff1f27f4 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalCapabilitiesResponseFormats.scala @@ -0,0 +1,31 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TerminalCapabilitiesResponseFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TerminalCapabilitiesResponseFormat: JsonFormat[sbt.protocol.TerminalCapabilitiesResponse] = new JsonFormat[sbt.protocol.TerminalCapabilitiesResponse] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.TerminalCapabilitiesResponse = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val boolean = unbuilder.readField[Option[Boolean]]("boolean") + val numeric = unbuilder.readField[Option[Int]]("numeric") + val string = unbuilder.readField[Option[String]]("string") + unbuilder.endObject() + sbt.protocol.TerminalCapabilitiesResponse(boolean, numeric, string) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.TerminalCapabilitiesResponse, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("boolean", obj.boolean) + builder.addField("numeric", obj.numeric) + builder.addField("string", obj.string) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalPropertiesResponseFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalPropertiesResponseFormats.scala new file mode 100644 index 000000000..ee4712e0d --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/TerminalPropertiesResponseFormats.scala @@ -0,0 +1,37 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TerminalPropertiesResponseFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TerminalPropertiesResponseFormat: JsonFormat[sbt.protocol.TerminalPropertiesResponse] = new JsonFormat[sbt.protocol.TerminalPropertiesResponse] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.TerminalPropertiesResponse = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val width = unbuilder.readField[Int]("width") + val height = unbuilder.readField[Int]("height") + val isAnsiSupported = unbuilder.readField[Boolean]("isAnsiSupported") + val isColorEnabled = unbuilder.readField[Boolean]("isColorEnabled") + val isSupershellEnabled = unbuilder.readField[Boolean]("isSupershellEnabled") + val isEchoEnabled = unbuilder.readField[Boolean]("isEchoEnabled") + unbuilder.endObject() + sbt.protocol.TerminalPropertiesResponse(width, height, isAnsiSupported, isColorEnabled, isSupershellEnabled, isEchoEnabled) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.TerminalPropertiesResponse, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("width", obj.width) + builder.addField("height", obj.height) + builder.addField("isAnsiSupported", obj.isAnsiSupported) + builder.addField("isColorEnabled", obj.isColorEnabled) + builder.addField("isSupershellEnabled", obj.isSupershellEnabled) + builder.addField("isEchoEnabled", obj.isEchoEnabled) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index aec8c9a70..08fe78032 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -23,6 +23,10 @@ type SettingQuery implements CommandMessage { setting: String! } +type Attach implements CommandMessage { + interactive: Boolean! +} + type CompletionParams { query: String! } @@ -76,3 +80,24 @@ type ExecutionEvent { success: String! commandLine: String! } + +type TerminalPropertiesResponse implements EventMessage { + width: Int! + height: Int! + isAnsiSupported: Boolean! + isColorEnabled: Boolean! + isSupershellEnabled: Boolean! + isEchoEnabled: Boolean! +} + +type TerminalCapabilitiesQuery implements CommandMessage { + boolean: String + numeric: String + string: String +} + +type TerminalCapabilitiesResponse implements EventMessage { + boolean: Boolean + numeric: Int + string: String +} diff --git a/protocol/src/main/scala/sbt/protocol/Serialization.scala b/protocol/src/main/scala/sbt/protocol/Serialization.scala index 02927f067..cb99d0f3b 100644 --- a/protocol/src/main/scala/sbt/protocol/Serialization.scala +++ b/protocol/src/main/scala/sbt/protocol/Serialization.scala @@ -24,6 +24,15 @@ import sbt.internal.protocol.{ object Serialization { private[sbt] val VsCode = "application/vscode-jsonrpc; charset=utf-8" + val systemIn = "sbt/systemIn" + val systemOut = "sbt/systemOut" + val terminalPropertiesQuery = "sbt/terminalPropertiesQuery" + val terminalPropertiesResponse = "sbt/terminalPropertiesResponse" + val terminalCapabilities = "sbt/terminalCapabilities" + val terminalCapabilitiesResponse = "sbt/terminalCapabilitiesResponse" + val attach = "sbt/attach" + val attachResponse = "sbt/attachResponse" + val cancelRequest = "sbt/cancelRequest" @deprecated("unused", since = "1.4.0") def serializeEvent[A: JsonFormat](event: A): Array[Byte] = { @@ -63,6 +72,13 @@ object Serialization { val json: JValue = Converter.toJson[String](x.setting).get val v = CompactPrinter(json) s"""{ "jsonrpc": "2.0", "id": "$execId", "method": "sbt/setting", "params": { "setting": $v } }""" + + case x: Attach => + val execId = UUID.randomUUID.toString + val json: JValue = Converter.toJson[Boolean](x.interactive).get + val v = CompactPrinter(json) + s"""{ "jsonrpc": "2.0", "id": "$execId", "method": "$attach", "params": { "interactive": $v } }""" + } } @@ -78,6 +94,12 @@ object Serialization { serializeResponse(message) } + /** This formats the message according to JSON-RPC. https://www.jsonrpc.org/specification */ + private[sbt] def serializeRequestMessage(message: JsonRpcRequestMessage): Array[Byte] = { + import sbt.internal.protocol.codec.JsonRPCProtocol._ + serializeResponse(message) + } + /** This formats the message according to JSON-RPC. https://www.jsonrpc.org/specification */ private[sbt] def serializeNotificationMessage( message: JsonRpcNotificationMessage, diff --git a/server-test/src/server-test/response/build.sbt b/server-test/src/server-test/response/build.sbt index ee95370c3..499a85809 100644 --- a/server-test/src/server-test/response/build.sbt +++ b/server-test/src/server-test/response/build.sbt @@ -9,7 +9,7 @@ Global / serverHandlers += ServerHandler({ callback => import sjsonnew.BasicJsonProtocol._ import sbt.internal.protocol.JsonRpcRequestMessage ServerIntent( - { + onRequest = { case r: JsonRpcRequestMessage if r.method == "foo/export" => appendExec(Exec("fooExport", Some(r.id), Some(CommandSource(callback.name)))) () @@ -34,7 +34,8 @@ Global / serverHandlers += ServerHandler({ callback => jsonRpcRespond("concurrent response", Some(r.id)) () }, - { + onResponse = PartialFunction.empty, + onNotification = { case r if r.method == "foo/customNotification" => jsonRpcRespond("notification result", None) ()