diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index f01d9ffbd..3de22c90f 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -14,9 +14,8 @@ abstract class CommandChannel { commandQueue.add(exec) def poll: Option[Exec] = Option(commandQueue.poll) - /** start listening for a command exec. */ - def run(s: State): State def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit + def publishBytes(bytes: Array[Byte]): Unit def shutdown(): Unit } diff --git a/main-command/src/main/scala/sbt/internal/CommandExchange.scala b/main-command/src/main/scala/sbt/internal/CommandExchange.scala index 6bce31b0f..63d67ce42 100644 --- a/main-command/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main-command/src/main/scala/sbt/internal/CommandExchange.scala @@ -1,9 +1,16 @@ package sbt package internal -import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer +import java.net.SocketException import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger +import sbt.internal.server._ +import sbt.protocol.Serialization +import scala.collection.mutable.ListBuffer +import scala.annotation.tailrec +import BasicKeys.serverPort +import sbt.protocol.StatusEvent +import java.net.Socket /** * The command exchange merges multiple command channels (e.g. network and console), @@ -12,14 +19,18 @@ import java.util.concurrent.ConcurrentLinkedQueue * this exchange, which could serve command request from either of the channel. */ private[sbt] final class CommandExchange { + private val lock = new AnyRef {} + private var server: Option[ServerInstance] = None private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue() private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer() + private val nextChannelId: AtomicInteger = new AtomicInteger(0) def channels: List[CommandChannel] = channelBuffer.toList def subscribe(c: CommandChannel): Unit = - channelBuffer.append(c) + lock.synchronized { + channelBuffer.append(c) + } subscribe(new ConsoleChannel()) - subscribe(new NetworkChannel()) // periodically move all messages from all the channels @tailrec def blockUntilNextExec: Exec = @@ -40,13 +51,69 @@ private[sbt] final class CommandExchange { } } - // fanout run to all channels - def run(s: State): State = - (s /: channels) { (acc, c) => c.run(acc) } + def run(s: State): State = runServer(s) + + private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}" + + private def runServer(s: State): State = + { + val port = (s get serverPort) match { + case Some(x) => x + case None => 5001 + } + def onIncomingSocket(socket: Socket): Unit = + { + s.log.info(s"new client connected from: ${socket.getPort}") + val channel = new NetworkChannel(newChannelName, socket) + subscribe(channel) + } + server match { + case Some(x) => // do nothing + case _ => + server = Some(Server.start("127.0.0.1", port, onIncomingSocket, s.log)) + } + s + } + + def shutdown(): Unit = + { + channels foreach { c => + c.shutdown() + } + // interrupt and kill the thread + server.foreach(_.shutdown()) + server = None + } // fanout publishStatus to all channels def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit = - channels foreach { c => - c.publishStatus(status, lastSource) + { + val toDel: ListBuffer[CommandChannel] = ListBuffer.empty + + val event = + if (status.canEnter) StatusEvent("Ready", Vector()) + else StatusEvent("Processing", status.state.remainingCommands.toVector) + + // TODO do not do this on the calling thread + val bytes = Serialization.serializeEvent(event) + channels.foreach { + case c: ConsoleChannel => + c.publishStatus(status, lastSource) + case c: NetworkChannel => + try { + c.publishBytes(bytes) + } catch { + case e: SocketException => + // log.debug(e.getMessage) + toDel += c + } + } + toDel.toList match { + case Nil => // do nothing + case xs => + lock.synchronized { + channelBuffer --= xs + } + } } } diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index 5b0c99497..825198d46 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -28,6 +28,8 @@ private[sbt] final class ConsoleChannel extends CommandChannel { def run(s: State): State = s + def publishBytes(bytes: Array[Byte]): Unit = () + def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit = if (status.canEnter) { askUserThread match { diff --git a/main-command/src/main/scala/sbt/internal/NetworkChannel.scala b/main-command/src/main/scala/sbt/internal/NetworkChannel.scala deleted file mode 100644 index e2e475086..000000000 --- a/main-command/src/main/scala/sbt/internal/NetworkChannel.scala +++ /dev/null @@ -1,43 +0,0 @@ -package sbt -package internal - -import sbt.internal.server._ -import sbt.protocol._ -import BasicKeys._ - -private[sbt] final class NetworkChannel extends CommandChannel { - private var server: Option[ServerInstance] = None - - def run(s: State): State = - { - val port = (s get serverPort) match { - case Some(x) => x - case None => 5001 - } - def onCommand(command: CommandMessage): Unit = - command match { - case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine)) - } - server match { - case Some(x) => // do nothing - case _ => - server = Some(Server.start("127.0.0.1", port, onCommand, s.log)) - } - s - } - - def shutdown(): Unit = - { - // interrupt and kill the thread - server.foreach(_.shutdown()) - server = None - } - - def publishStatus(cmdStatus: CommandStatus, lastSource: Option[CommandSource]): Unit = { - server.foreach(server => - server.publish( - if (cmdStatus.canEnter) StatusEvent("Ready", Vector()) - else StatusEvent("Processing", cmdStatus.state.remainingCommands.toVector) - )) - } -} diff --git a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala index 348761235..b7c5368ab 100644 --- a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala +++ b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala @@ -34,7 +34,8 @@ abstract class ServerConnection(connection: Socket) { val chunk = buffer.take(delimPos) buffer = buffer.drop(delimPos + 1) - Serialization.deserializeEvent(chunk).fold({ errorDesc => + Serialization.deserializeEvent(chunk).fold( + { errorDesc => val s = new String(chunk.toArray, "UTF-8") println(s"Got invalid chunk from server: $s \n" + errorDesc) }, diff --git a/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala b/main-command/src/main/scala/sbt/internal/server/NetworkChannel.scala similarity index 68% rename from main-command/src/main/scala/sbt/internal/server/ClientConnection.scala rename to main-command/src/main/scala/sbt/internal/server/NetworkChannel.scala index 836f1dd21..4de5fe5ef 100644 --- a/main-command/src/main/scala/sbt/internal/server/ClientConnection.scala +++ b/main-command/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -5,18 +5,16 @@ package sbt package internal package server -import java.net.{ SocketTimeoutException, Socket } +import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean -import sbt.protocol._ - -abstract class ClientConnection(connection: Socket) { +import sbt.protocol.{ Serialization, CommandMessage, ExecCommand } +final class NetworkChannel(name: String, connection: Socket) extends CommandChannel { private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte - private val out = connection.getOutputStream - val thread = new Thread(s"sbt-clientconnection-${connection.getPort}") { + val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { override def run(): Unit = { try { val readBuffer = new Array[Byte](4096) @@ -52,18 +50,25 @@ abstract class ClientConnection(connection: Socket) { } thread.start() - def publish(event: Array[Byte]): Unit = { - out.write(event) - out.write(delimiter.toInt) - out.flush() - } + def publishStatus(status: CommandStatus, lastSource: Option[CommandSource]): Unit = + { + () + } + def publishBytes(event: Array[Byte]): Unit = + { + out.write(event) + out.write(delimiter.toInt) + out.flush() + } - def onCommand(command: CommandMessage): Unit + def onCommand(command: CommandMessage): Unit = + command match { + case x: ExecCommand => append(Exec(CommandSource.Network, x.commandLine)) + } def shutdown(): Unit = { println("Shutting down client connection") running.set(false) out.close() } - } diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 078f07f5f..f00ffcf53 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -5,23 +5,21 @@ package sbt package internal package server -import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, SocketException } +import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } import java.util.concurrent.atomic.AtomicBoolean import sbt.util.Logger -import sbt.protocol._ -import scala.collection.mutable private[sbt] sealed trait ServerInstance { def shutdown(): Unit - def publish(event: EventMessage): Unit } private[sbt] object Server { - def start(host: String, port: Int, onIncommingCommand: CommandMessage => Unit, log: Logger): ServerInstance = + def start(host: String, port: Int, onIncomingSocket: Socket => Unit, + /*onIncommingCommand: CommandMessage => Unit,*/ log: Logger): ServerInstance = new ServerInstance { - val lock = new AnyRef {} - val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty + // val lock = new AnyRef {} + // val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty val running = new AtomicBoolean(true) val serverThread = new Thread("sbt-socket-server") { @@ -34,18 +32,7 @@ private[sbt] object Server { while (running.get()) { try { val socket = serverSocket.accept() - log.info(s"new client connected from: ${socket.getPort}") - - val connection = new ClientConnection(socket) { - override def onCommand(command: CommandMessage): Unit = { - onIncommingCommand(command) - } - } - - lock.synchronized { - clients += connection - } - + onIncomingSocket(socket) } catch { case _: SocketTimeoutException => // its ok } @@ -55,25 +42,6 @@ private[sbt] object Server { } serverThread.start() - /** Publish an event to all connected clients */ - def publish(event: EventMessage): Unit = { - // TODO do not do this on the calling thread - val bytes = Serialization.serializeEvent(event) - lock.synchronized { - val toDel: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty - clients.foreach { client => - try { - client.publish(bytes) - } catch { - case e: SocketException => - log.debug(e.getMessage) - toDel += client - } - } - clients --= toDel.toList - } - } - override def shutdown(): Unit = { log.info("shutting down server") running.set(false)