From fa7253ece34cb7ec196b684f2a8a749e4a72e49f Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Fri, 2 Dec 2016 18:09:16 -0500 Subject: [PATCH] Start lightweight client This is the beginning of a lightweight client, which talks to the server over Contraband-generated JSON API. Given that the server is started on port 5173: ``` $ cd /tmp/bogus $ sbt client localhost:5173 > compile StatusEvent(Processing, Vector(compile, server)) StatusEvent(Ready, Vector()) StatusEvent(Processing, Vector(, server)) StatusEvent(Ready, Vector()) ``` --- .../main/scala/sbt/BasicCommandStrings.scala | 3 + .../src/main/scala/sbt/BasicCommands.scala | 21 +++- .../sbt/internal/client/NetworkClient.scala | 96 +++++++++++++++++++ .../internal/client/ServerConnection.scala | 73 ++++++++++++++ .../internal/server/ClientConnection.scala | 12 +-- .../scala/sbt/internal/server/Server.scala | 2 +- main/src/main/scala/sbt/Main.scala | 2 +- .../scala/sbt/protocol}/Serialization.scala | 32 ++++++- 8 files changed, 224 insertions(+), 17 deletions(-) create mode 100644 main-command/src/main/scala/sbt/internal/client/NetworkClient.scala create mode 100644 main-command/src/main/scala/sbt/internal/client/ServerConnection.scala rename {main-command/src/main/scala/sbt/internal/server => protocol/src/main/scala/sbt/protocol}/Serialization.scala (51%) diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index bfd424a65..d99793259 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -152,6 +152,9 @@ object BasicCommandStrings { def Server = "server" def ServerDetailed = "Provides a network server and an interactive prompt from which commands can be run." + def Client = "client" + def ClientDetailed = "Provides an interactive prompt from which commands can be run on a server." + def StashOnFailure = "sbtStashOnFailure" def PopOnFailure = "sbtPopOnFailure" diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index 12e0ef49b..40f75d2eb 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -7,6 +7,7 @@ import sbt.internal.util.Types.{ const, idFun } import sbt.internal.inc.classpath.ClasspathUtilities.toLoader import sbt.internal.inc.ModuleUtilities import sbt.internal.{ Exec, CommandSource, CommandStatus } +import sbt.internal.client.NetworkClient import DefaultParsers._ import Function.tupled import Command.applyEffect @@ -16,12 +17,11 @@ import BasicKeys._ import java.io.File import sbt.io.IO -import java.util.concurrent.atomic.AtomicBoolean - import scala.util.control.NonFatal object BasicCommands { - lazy val allBasicCommands = Seq(nop, ignore, help, completionsCommand, multi, ifLast, append, setOnFailure, clearOnFailure, stashOnFailure, popOnFailure, reboot, call, early, exit, continuous, history, shell, server, read, alias) ++ compatCommands + lazy val allBasicCommands = Seq(nop, ignore, help, completionsCommand, multi, ifLast, append, setOnFailure, clearOnFailure, + stashOnFailure, popOnFailure, reboot, call, early, exit, continuous, history, shell, server, client, read, alias) ++ compatCommands def nop = Command.custom(s => success(() => s)) def ignore = Command.command(FailureWall)(idFun) @@ -204,6 +204,21 @@ object BasicCommands { else newState.clearGlobalLog } + def client = Command.make(Client, Help.more(Client, ClientDetailed))(clientParser) + def clientParser(s0: State) = + { + val p = (token(Space) ~> repsep(StringBasic, token(Space))) | (token(EOF) map { case _ => Nil }) + applyEffect(p)({ inputArg => + val arguments = inputArg.toList ++ + (s0.remainingCommands.toList match { + case "shell" :: Nil => Nil + case xs => xs + }) + NetworkClient.run(arguments) + "exit" :: s0.copy(remainingCommands = Nil) + }) + } + def read = Command.make(ReadCommand, Help.more(ReadCommand, ReadDetailed))(s => applyEffect(readParser(s))(doRead(s))) def readParser(s: State) = { diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala new file mode 100644 index 000000000..1ee665764 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -0,0 +1,96 @@ +/* + * Copyright (C) 2016 Lightbend Inc. + */ +package sbt +package internal +package client + +import java.net.{ URI, Socket, InetAddress, SocketException } +import sbt.protocol._ +import sbt.internal.util.JLine +import java.util.concurrent.atomic.AtomicBoolean + +class NetworkClient(arguments: List[String]) { + private var status: String = "Ready" + private val lock: AnyRef = new AnyRef {} + private val running = new AtomicBoolean(true) + def usageError = sys.error("Expecting: sbt client 127.0.0.1:port") + val connection = init() + start() + + def init(): ServerConnection = { + val u = arguments match { + case List(x) => + if (x contains "://") new URI(x) + else new URI("tcp://" + x) + case _ => usageError + } + val host = Option(u.getHost) match { + case None => usageError + case Some(x) => x + } + val port = Option(u.getPort) match { + case None => usageError + case Some(x) if x == -1 => usageError + case Some(x) => x + } + println(s"client on port $port") + val socket = new Socket(InetAddress.getByName(host), port) + new ServerConnection(socket) { + override def onEvent(event: EventMessage): Unit = + event match { + case e: StatusEvent => + lock.synchronized { + status = e.status + } + println(event) + case e => println(e.toString) + } + override def onShutdown: Unit = + { + running.set(false) + } + } + } + + def start(): Unit = + { + val reader = JLine.simple(None, JLine.HandleCONT, injectThreadSleep = true) + while (running.get) { + reader.readLine("> ", None) match { + case Some("exit") => + running.set(false) + case Some(s) => + publishCommand(ExecCommand(s)) + case _ => // + } + while (status != "Ready") { + Thread.sleep(100) + } + } + } + + def publishCommand(command: CommandMessage): Unit = + { + val bytes = Serialization.serializeCommand(command) + try { + connection.publish(bytes) + } catch { + case e: SocketException => + // log.debug(e.getMessage) + // toDel += client + } + lock.synchronized { + status = "Processing" + } + } +} + +object NetworkClient { + def run(arguments: List[String]): Unit = + try { + new NetworkClient(arguments) + } catch { + case e: Exception => println(e.getMessage) + } +} diff --git a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala new file mode 100644 index 000000000..21fc72a12 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2016 Lightbend Inc. + */ +package sbt +package internal +package client + +import java.net.{ SocketTimeoutException, Socket } +import java.util.concurrent.atomic.AtomicBoolean +import sbt.protocol._ + +abstract class ServerConnection(connection: Socket) { + + private val running = new AtomicBoolean(true) + private val delimiter: Byte = '\n'.toByte + + private val out = connection.getOutputStream + + val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") { + override def run(): Unit = { + try { + val readBuffer = new Array[Byte](4096) + val in = connection.getInputStream + connection.setSoTimeout(5000) + var buffer: Vector[Byte] = Vector.empty + var bytesRead = 0 + while (bytesRead != -1 && running.get) { + try { + bytesRead = in.read(readBuffer) + buffer = buffer ++ readBuffer.toVector.take(bytesRead) + // handle un-framing + val delimPos = buffer.indexOf(delimiter) + if (delimPos > 0) { + val chunk = buffer.take(delimPos) + buffer = buffer.drop(delimPos + 1) + + Serialization.deserializeEvent(chunk).fold({ errorDesc => + val s = new String(chunk.toArray, "UTF-8") + println(s"Got invalid chunk from server: $s \n" + errorDesc) + }, + onEvent + ) + } + + } catch { + case _: SocketTimeoutException => // its ok + } + } + } finally { + shutdown() + } + } + } + thread.start() + + def publish(command: Array[Byte]): Unit = { + out.write(command) + out.write(delimiter.toInt) + out.flush() + } + + def onEvent(event: EventMessage): Unit + + def onShutdown: Unit + + def shutdown(): Unit = { + println("Shutting down client connection") + running.set(false) + out.close() + onShutdown + } + +} diff --git a/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala b/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala index 37380b72c..836f1dd21 100644 --- a/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala +++ b/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala @@ -16,7 +16,7 @@ abstract class ClientConnection(connection: Socket) { private val out = connection.getOutputStream - val thread = new Thread(s"sbt-client-${connection.getPort}") { + val thread = new Thread(s"sbt-clientconnection-${connection.getPort}") { override def run(): Unit = { try { val readBuffer = new Array[Byte](4096) @@ -27,16 +27,14 @@ abstract class ClientConnection(connection: Socket) { while (bytesRead != -1 && running.get) { try { bytesRead = in.read(readBuffer) - val bytes = readBuffer.toVector.take(bytesRead) - buffer = buffer ++ bytes - + buffer = buffer ++ readBuffer.toVector.take(bytesRead) // handle un-framing - val delimPos = bytes.indexOf(delimiter) + val delimPos = buffer.indexOf(delimiter) if (delimPos > 0) { val chunk = buffer.take(delimPos) buffer = buffer.drop(delimPos + 1) - Serialization.deserialize(chunk).fold( + Serialization.deserializeCommand(chunk).fold( errorDesc => println("Got invalid chunk from client: " + errorDesc), onCommand ) @@ -56,7 +54,7 @@ abstract class ClientConnection(connection: Socket) { def publish(event: Array[Byte]): Unit = { out.write(event) - out.write(delimiter) + out.write(delimiter.toInt) out.flush() } diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 55c38768f..078f07f5f 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -58,7 +58,7 @@ private[sbt] object Server { /** Publish an event to all connected clients */ def publish(event: EventMessage): Unit = { // TODO do not do this on the calling thread - val bytes = Serialization.serialize(event) + val bytes = Serialization.serializeEvent(event) lock.synchronized { val toDel: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty clients.foreach { client => diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 87bd3fd4e..fb3e99c9a 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -115,7 +115,7 @@ object BuiltinCommands { def DefaultCommands: Seq[Command] = Seq(ignore, help, completionsCommand, about, tasks, settingsCommand, loadProject, projects, project, reboot, read, history, set, sessionCommand, inspect, loadProjectImpl, loadFailed, Cross.crossBuild, Cross.switchVersion, Cross.crossRestoreSession, setOnFailure, clearOnFailure, stashOnFailure, popOnFailure, setLogLevel, plugin, plugins, - ifLast, multi, shell, BasicCommands.server, continuous, eval, alias, append, last, lastGrep, export, boot, nop, call, exit, early, initialize, act) ++ + ifLast, multi, shell, BasicCommands.server, BasicCommands.client, continuous, eval, alias, append, last, lastGrep, export, boot, nop, call, exit, early, initialize, act) ++ compatCommands def DefaultBootCommands: Seq[String] = LoadProject :: (IfLast + " " + Shell) :: Nil diff --git a/main-command/src/main/scala/sbt/internal/server/Serialization.scala b/protocol/src/main/scala/sbt/protocol/Serialization.scala similarity index 51% rename from main-command/src/main/scala/sbt/internal/server/Serialization.scala rename to protocol/src/main/scala/sbt/protocol/Serialization.scala index cc1c7e14d..0ca01cb56 100644 --- a/main-command/src/main/scala/sbt/internal/server/Serialization.scala +++ b/protocol/src/main/scala/sbt/protocol/Serialization.scala @@ -2,19 +2,23 @@ * Copyright (C) 2016 Lightbend Inc. */ package sbt -package internal -package server +package protocol import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } import scala.json.ast.unsafe.JValue import sjsonnew.support.scalajson.unsafe.Parser import java.nio.ByteBuffer import scala.util.{ Success, Failure } -import sbt.protocol._ object Serialization { + def serializeCommand(command: CommandMessage): Array[Byte] = + { + import codec.JsonProtocol._ + val json: JValue = Converter.toJson[CommandMessage](command).get + CompactPrinter(json).getBytes("UTF-8") + } - def serialize(event: EventMessage): Array[Byte] = + def serializeEvent(event: EventMessage): Array[Byte] = { import codec.JsonProtocol._ val json: JValue = Converter.toJson[EventMessage](event).get @@ -24,7 +28,7 @@ object Serialization { /** * @return A command or an invalid input description */ - def deserialize(bytes: Seq[Byte]): Either[String, CommandMessage] = + def deserializeCommand(bytes: Seq[Byte]): Either[String, CommandMessage] = { val buffer = ByteBuffer.wrap(bytes.toArray) Parser.parseFromByteBuffer(buffer) match { @@ -38,4 +42,22 @@ object Serialization { Left(s"Parse error: ${e.getMessage}") } } + + /** + * @return A command or an invalid input description + */ + def deserializeEvent(bytes: Seq[Byte]): Either[String, EventMessage] = + { + val buffer = ByteBuffer.wrap(bytes.toArray) + Parser.parseFromByteBuffer(buffer) match { + case Success(json) => + import codec.JsonProtocol._ + Converter.fromJson[EventMessage](json) match { + case Success(event) => Right(event) + case Failure(e) => Left(e.getMessage) + } + case Failure(e) => + Left(s"Parse error: ${e.getMessage}") + } + } }