Refactor to make NetworkChannel per client connection

This commit is contained in:
Eugene Yokota 2016-12-05 16:33:57 -05:00
parent d618f91c6d
commit 46d8f952e4
7 changed files with 105 additions and 106 deletions

View File

@ -14,9 +14,8 @@ abstract class CommandChannel {
commandQueue.add(exec)
def poll: Option[Exec] = Option(commandQueue.poll)
/** start listening for a command exec. */
def run(s: State): State
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit
def publishBytes(bytes: Array[Byte]): Unit
def shutdown(): Unit
}

View File

@ -1,9 +1,16 @@
package sbt
package internal
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import java.net.SocketException
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import sbt.internal.server._
import sbt.protocol.Serialization
import scala.collection.mutable.ListBuffer
import scala.annotation.tailrec
import BasicKeys.serverPort
import sbt.protocol.StatusEvent
import java.net.Socket
/**
* The command exchange merges multiple command channels (e.g. network and console),
@ -12,14 +19,18 @@ import java.util.concurrent.ConcurrentLinkedQueue
* this exchange, which could serve command request from either of the channel.
*/
private[sbt] final class CommandExchange {
private val lock = new AnyRef {}
private var server: Option[ServerInstance] = None
private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue()
private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer()
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
def channels: List[CommandChannel] = channelBuffer.toList
def subscribe(c: CommandChannel): Unit =
channelBuffer.append(c)
lock.synchronized {
channelBuffer.append(c)
}
subscribe(new ConsoleChannel())
subscribe(new NetworkChannel())
// periodically move all messages from all the channels
@tailrec def blockUntilNextExec: Exec =
@ -40,13 +51,69 @@ private[sbt] final class CommandExchange {
}
}
// fanout run to all channels
def run(s: State): State =
(s /: channels) { (acc, c) => c.run(acc) }
def run(s: State): State = runServer(s)
private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}"
private def runServer(s: State): State =
{
val port = (s get serverPort) match {
case Some(x) => x
case None => 5001
}
def onIncomingSocket(socket: Socket): Unit =
{
s.log.info(s"new client connected from: ${socket.getPort}")
val channel = new NetworkChannel(newChannelName, socket)
subscribe(channel)
}
server match {
case Some(x) => // do nothing
case _ =>
server = Some(Server.start("127.0.0.1", port, onIncomingSocket, s.log))
}
s
}
def shutdown(): Unit =
{
channels foreach { c =>
c.shutdown()
}
// interrupt and kill the thread
server.foreach(_.shutdown())
server = None
}
// fanout publishStatus to all channels
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
channels foreach { c =>
c.publishStatus(status, lastSource)
{
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
val event =
if (status.canEnter) StatusEvent("Ready", Vector())
else StatusEvent("Processing", status.state.remainingCommands.toVector)
// TODO do not do this on the calling thread
val bytes = Serialization.serializeEvent(event)
channels.foreach {
case c: ConsoleChannel =>
c.publishStatus(status, lastSource)
case c: NetworkChannel =>
try {
c.publishBytes(bytes)
} catch {
case e: SocketException =>
// log.debug(e.getMessage)
toDel += c
}
}
toDel.toList match {
case Nil => // do nothing
case xs =>
lock.synchronized {
channelBuffer --= xs
}
}
}
}

View File

@ -28,6 +28,8 @@ private[sbt] final class ConsoleChannel extends CommandChannel {
def run(s: State): State = s
def publishBytes(bytes: Array[Byte]): Unit = ()
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
if (status.canEnter) {
askUserThread match {

View File

@ -1,43 +0,0 @@
package sbt
package internal
import sbt.internal.server._
import sbt.protocol._
import BasicKeys._
private[sbt] final class NetworkChannel extends CommandChannel {
private var server: Option[ServerInstance] = None
def run(s: State): State =
{
val port = (s get serverPort) match {
case Some(x) => x
case None => 5001
}
def onCommand(command: CommandMessage): Unit =
command match {
case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine))
}
server match {
case Some(x) => // do nothing
case _ =>
server = Some(Server.start("127.0.0.1", port, onCommand, s.log))
}
s
}
def shutdown(): Unit =
{
// interrupt and kill the thread
server.foreach(_.shutdown())
server = None
}
def publishStatus(cmdStatus: CommandStatus, lastSource: Option[CommandSource]): Unit = {
server.foreach(server =>
server.publish(
if (cmdStatus.canEnter) StatusEvent("Ready", Vector())
else StatusEvent("Processing", cmdStatus.state.remainingCommands.toVector)
))
}
}

View File

@ -34,7 +34,8 @@ abstract class ServerConnection(connection: Socket) {
val chunk = buffer.take(delimPos)
buffer = buffer.drop(delimPos + 1)
Serialization.deserializeEvent(chunk).fold({ errorDesc =>
Serialization.deserializeEvent(chunk).fold(
{ errorDesc =>
val s = new String(chunk.toArray, "UTF-8")
println(s"Got invalid chunk from server: $s \n" + errorDesc)
},

View File

@ -5,18 +5,16 @@ package sbt
package internal
package server
import java.net.{ SocketTimeoutException, Socket }
import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean
import sbt.protocol._
abstract class ClientConnection(connection: Socket) {
import sbt.protocol.{ Serialization, CommandMessage, ExecCommand }
final class NetworkChannel(name: String, connection: Socket) extends CommandChannel {
private val running = new AtomicBoolean(true)
private val delimiter: Byte = '\n'.toByte
private val out = connection.getOutputStream
val thread = new Thread(s"sbt-clientconnection-${connection.getPort}") {
val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") {
override def run(): Unit = {
try {
val readBuffer = new Array[Byte](4096)
@ -52,18 +50,25 @@ abstract class ClientConnection(connection: Socket) {
}
thread.start()
def publish(event: Array[Byte]): Unit = {
out.write(event)
out.write(delimiter.toInt)
out.flush()
}
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
{
()
}
def publishBytes(event: Array[Byte]): Unit =
{
out.write(event)
out.write(delimiter.toInt)
out.flush()
}
def onCommand(command: CommandMessage): Unit
def onCommand(command: CommandMessage): Unit =
command match {
case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine))
}
def shutdown(): Unit = {
println("Shutting down client connection")
running.set(false)
out.close()
}
}

View File

@ -5,23 +5,21 @@ package sbt
package internal
package server
import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, SocketException }
import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket }
import java.util.concurrent.atomic.AtomicBoolean
import sbt.util.Logger
import sbt.protocol._
import scala.collection.mutable
private[sbt] sealed trait ServerInstance {
def shutdown(): Unit
def publish(event: EventMessage): Unit
}
private[sbt] object Server {
def start(host: String, port: Int, onIncommingCommand: CommandMessage => Unit, log: Logger): ServerInstance =
def start(host: String, port: Int, onIncomingSocket: Socket => Unit,
/*onIncommingCommand: CommandMessage => Unit,*/ log: Logger): ServerInstance =
new ServerInstance {
val lock = new AnyRef {}
val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
// val lock = new AnyRef {}
// val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
val running = new AtomicBoolean(true)
val serverThread = new Thread("sbt-socket-server") {
@ -34,18 +32,7 @@ private[sbt] object Server {
while (running.get()) {
try {
val socket = serverSocket.accept()
log.info(s"new client connected from: ${socket.getPort}")
val connection = new ClientConnection(socket) {
override def onCommand(command: CommandMessage): Unit = {
onIncommingCommand(command)
}
}
lock.synchronized {
clients += connection
}
onIncomingSocket(socket)
} catch {
case _: SocketTimeoutException => // its ok
}
@ -55,25 +42,6 @@ private[sbt] object Server {
}
serverThread.start()
/** Publish an event to all connected clients */
def publish(event: EventMessage): Unit = {
// TODO do not do this on the calling thread
val bytes = Serialization.serializeEvent(event)
lock.synchronized {
val toDel: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
clients.foreach { client =>
try {
client.publish(bytes)
} catch {
case e: SocketException =>
log.debug(e.getMessage)
toDel += client
}
}
clients --= toDel.toList
}
}
override def shutdown(): Unit = {
log.info("shutting down server")
running.set(false)