mirror of https://github.com/sbt/sbt.git
Refactor to make NetworkChannel per client connection
This commit is contained in:
parent
d618f91c6d
commit
46d8f952e4
|
|
@ -14,9 +14,8 @@ abstract class CommandChannel {
|
|||
commandQueue.add(exec)
|
||||
def poll: Option[Exec] = Option(commandQueue.poll)
|
||||
|
||||
/** start listening for a command exec. */
|
||||
def run(s: State): State
|
||||
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit
|
||||
def publishBytes(bytes: Array[Byte]): Unit
|
||||
def shutdown(): Unit
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
package sbt
|
||||
package internal
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import java.net.SocketException
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import sbt.internal.server._
|
||||
import sbt.protocol.Serialization
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.annotation.tailrec
|
||||
import BasicKeys.serverPort
|
||||
import sbt.protocol.StatusEvent
|
||||
import java.net.Socket
|
||||
|
||||
/**
|
||||
* The command exchange merges multiple command channels (e.g. network and console),
|
||||
|
|
@ -12,14 +19,18 @@ import java.util.concurrent.ConcurrentLinkedQueue
|
|||
* this exchange, which could serve command request from either of the channel.
|
||||
*/
|
||||
private[sbt] final class CommandExchange {
|
||||
private val lock = new AnyRef {}
|
||||
private var server: Option[ServerInstance] = None
|
||||
private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue()
|
||||
private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer()
|
||||
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
|
||||
def channels: List[CommandChannel] = channelBuffer.toList
|
||||
def subscribe(c: CommandChannel): Unit =
|
||||
channelBuffer.append(c)
|
||||
lock.synchronized {
|
||||
channelBuffer.append(c)
|
||||
}
|
||||
|
||||
subscribe(new ConsoleChannel())
|
||||
subscribe(new NetworkChannel())
|
||||
|
||||
// periodically move all messages from all the channels
|
||||
@tailrec def blockUntilNextExec: Exec =
|
||||
|
|
@ -40,13 +51,69 @@ private[sbt] final class CommandExchange {
|
|||
}
|
||||
}
|
||||
|
||||
// fanout run to all channels
|
||||
def run(s: State): State =
|
||||
(s /: channels) { (acc, c) => c.run(acc) }
|
||||
def run(s: State): State = runServer(s)
|
||||
|
||||
private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}"
|
||||
|
||||
private def runServer(s: State): State =
|
||||
{
|
||||
val port = (s get serverPort) match {
|
||||
case Some(x) => x
|
||||
case None => 5001
|
||||
}
|
||||
def onIncomingSocket(socket: Socket): Unit =
|
||||
{
|
||||
s.log.info(s"new client connected from: ${socket.getPort}")
|
||||
val channel = new NetworkChannel(newChannelName, socket)
|
||||
subscribe(channel)
|
||||
}
|
||||
server match {
|
||||
case Some(x) => // do nothing
|
||||
case _ =>
|
||||
server = Some(Server.start("127.0.0.1", port, onIncomingSocket, s.log))
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
def shutdown(): Unit =
|
||||
{
|
||||
channels foreach { c =>
|
||||
c.shutdown()
|
||||
}
|
||||
// interrupt and kill the thread
|
||||
server.foreach(_.shutdown())
|
||||
server = None
|
||||
}
|
||||
|
||||
// fanout publishStatus to all channels
|
||||
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
|
||||
channels foreach { c =>
|
||||
c.publishStatus(status, lastSource)
|
||||
{
|
||||
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
|
||||
|
||||
val event =
|
||||
if (status.canEnter) StatusEvent("Ready", Vector())
|
||||
else StatusEvent("Processing", status.state.remainingCommands.toVector)
|
||||
|
||||
// TODO do not do this on the calling thread
|
||||
val bytes = Serialization.serializeEvent(event)
|
||||
channels.foreach {
|
||||
case c: ConsoleChannel =>
|
||||
c.publishStatus(status, lastSource)
|
||||
case c: NetworkChannel =>
|
||||
try {
|
||||
c.publishBytes(bytes)
|
||||
} catch {
|
||||
case e: SocketException =>
|
||||
// log.debug(e.getMessage)
|
||||
toDel += c
|
||||
}
|
||||
}
|
||||
toDel.toList match {
|
||||
case Nil => // do nothing
|
||||
case xs =>
|
||||
lock.synchronized {
|
||||
channelBuffer --= xs
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ private[sbt] final class ConsoleChannel extends CommandChannel {
|
|||
|
||||
def run(s: State): State = s
|
||||
|
||||
def publishBytes(bytes: Array[Byte]): Unit = ()
|
||||
|
||||
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
|
||||
if (status.canEnter) {
|
||||
askUserThread match {
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
package sbt
|
||||
package internal
|
||||
|
||||
import sbt.internal.server._
|
||||
import sbt.protocol._
|
||||
import BasicKeys._
|
||||
|
||||
private[sbt] final class NetworkChannel extends CommandChannel {
|
||||
private var server: Option[ServerInstance] = None
|
||||
|
||||
def run(s: State): State =
|
||||
{
|
||||
val port = (s get serverPort) match {
|
||||
case Some(x) => x
|
||||
case None => 5001
|
||||
}
|
||||
def onCommand(command: CommandMessage): Unit =
|
||||
command match {
|
||||
case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine))
|
||||
}
|
||||
server match {
|
||||
case Some(x) => // do nothing
|
||||
case _ =>
|
||||
server = Some(Server.start("127.0.0.1", port, onCommand, s.log))
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
def shutdown(): Unit =
|
||||
{
|
||||
// interrupt and kill the thread
|
||||
server.foreach(_.shutdown())
|
||||
server = None
|
||||
}
|
||||
|
||||
def publishStatus(cmdStatus: CommandStatus, lastSource: Option[CommandSource]): Unit = {
|
||||
server.foreach(server =>
|
||||
server.publish(
|
||||
if (cmdStatus.canEnter) StatusEvent("Ready", Vector())
|
||||
else StatusEvent("Processing", cmdStatus.state.remainingCommands.toVector)
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
@ -34,7 +34,8 @@ abstract class ServerConnection(connection: Socket) {
|
|||
val chunk = buffer.take(delimPos)
|
||||
buffer = buffer.drop(delimPos + 1)
|
||||
|
||||
Serialization.deserializeEvent(chunk).fold({ errorDesc =>
|
||||
Serialization.deserializeEvent(chunk).fold(
|
||||
{ errorDesc =>
|
||||
val s = new String(chunk.toArray, "UTF-8")
|
||||
println(s"Got invalid chunk from server: $s \n" + errorDesc)
|
||||
},
|
||||
|
|
|
|||
|
|
@ -5,18 +5,16 @@ package sbt
|
|||
package internal
|
||||
package server
|
||||
|
||||
import java.net.{ SocketTimeoutException, Socket }
|
||||
import java.net.{ Socket, SocketTimeoutException }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import sbt.protocol._
|
||||
|
||||
abstract class ClientConnection(connection: Socket) {
|
||||
import sbt.protocol.{ Serialization, CommandMessage, ExecCommand }
|
||||
|
||||
final class NetworkChannel(name: String, connection: Socket) extends CommandChannel {
|
||||
private val running = new AtomicBoolean(true)
|
||||
private val delimiter: Byte = '\n'.toByte
|
||||
|
||||
private val out = connection.getOutputStream
|
||||
|
||||
val thread = new Thread(s"sbt-clientconnection-${connection.getPort}") {
|
||||
val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
val readBuffer = new Array[Byte](4096)
|
||||
|
|
@ -52,18 +50,25 @@ abstract class ClientConnection(connection: Socket) {
|
|||
}
|
||||
thread.start()
|
||||
|
||||
def publish(event: Array[Byte]): Unit = {
|
||||
out.write(event)
|
||||
out.write(delimiter.toInt)
|
||||
out.flush()
|
||||
}
|
||||
def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit =
|
||||
{
|
||||
()
|
||||
}
|
||||
def publishBytes(event: Array[Byte]): Unit =
|
||||
{
|
||||
out.write(event)
|
||||
out.write(delimiter.toInt)
|
||||
out.flush()
|
||||
}
|
||||
|
||||
def onCommand(command: CommandMessage): Unit
|
||||
def onCommand(command: CommandMessage): Unit =
|
||||
command match {
|
||||
case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine))
|
||||
}
|
||||
|
||||
def shutdown(): Unit = {
|
||||
println("Shutting down client connection")
|
||||
running.set(false)
|
||||
out.close()
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -5,23 +5,21 @@ package sbt
|
|||
package internal
|
||||
package server
|
||||
|
||||
import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, SocketException }
|
||||
import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import sbt.util.Logger
|
||||
import sbt.protocol._
|
||||
import scala.collection.mutable
|
||||
|
||||
private[sbt] sealed trait ServerInstance {
|
||||
def shutdown(): Unit
|
||||
def publish(event: EventMessage): Unit
|
||||
}
|
||||
|
||||
private[sbt] object Server {
|
||||
def start(host: String, port: Int, onIncommingCommand: CommandMessage => Unit, log: Logger): ServerInstance =
|
||||
def start(host: String, port: Int, onIncomingSocket: Socket => Unit,
|
||||
/*onIncommingCommand: CommandMessage => Unit,*/ log: Logger): ServerInstance =
|
||||
new ServerInstance {
|
||||
|
||||
val lock = new AnyRef {}
|
||||
val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
|
||||
// val lock = new AnyRef {}
|
||||
// val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
|
||||
val running = new AtomicBoolean(true)
|
||||
|
||||
val serverThread = new Thread("sbt-socket-server") {
|
||||
|
|
@ -34,18 +32,7 @@ private[sbt] object Server {
|
|||
while (running.get()) {
|
||||
try {
|
||||
val socket = serverSocket.accept()
|
||||
log.info(s"new client connected from: ${socket.getPort}")
|
||||
|
||||
val connection = new ClientConnection(socket) {
|
||||
override def onCommand(command: CommandMessage): Unit = {
|
||||
onIncommingCommand(command)
|
||||
}
|
||||
}
|
||||
|
||||
lock.synchronized {
|
||||
clients += connection
|
||||
}
|
||||
|
||||
onIncomingSocket(socket)
|
||||
} catch {
|
||||
case _: SocketTimeoutException => // its ok
|
||||
}
|
||||
|
|
@ -55,25 +42,6 @@ private[sbt] object Server {
|
|||
}
|
||||
serverThread.start()
|
||||
|
||||
/** Publish an event to all connected clients */
|
||||
def publish(event: EventMessage): Unit = {
|
||||
// TODO do not do this on the calling thread
|
||||
val bytes = Serialization.serializeEvent(event)
|
||||
lock.synchronized {
|
||||
val toDel: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty
|
||||
clients.foreach { client =>
|
||||
try {
|
||||
client.publish(bytes)
|
||||
} catch {
|
||||
case e: SocketException =>
|
||||
log.debug(e.getMessage)
|
||||
toDel += client
|
||||
}
|
||||
}
|
||||
clients --= toDel.toList
|
||||
}
|
||||
}
|
||||
|
||||
override def shutdown(): Unit = {
|
||||
log.info("shutting down server")
|
||||
running.set(false)
|
||||
|
|
|
|||
Loading…
Reference in New Issue