IPC Unix domain socket for sbt server

In addition to TCP, this adds sbt server support for IPC (interprocess communication) using Unix domain socket and Windows named pipe.

The use of Unix domain socket has performance and security benefits.
This commit is contained in:
Eugene Yokota 2017-11-27 21:37:31 -05:00
parent 0c803214aa
commit f785750fc4
14 changed files with 169 additions and 31 deletions

View File

@ -55,7 +55,7 @@ def commonSettings: Seq[Setting[_]] =
concurrentRestrictions in Global += Util.testExclusiveRestriction,
testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-w", "1"),
testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-verbosity", "2"),
javacOptions in compile ++= Seq("-target", "6", "-source", "6", "-Xlint", "-Xlint:-serial"),
javacOptions in compile ++= Seq("-Xlint", "-Xlint:-serial"),
crossScalaVersions := Seq(baseScalaVersion),
bintrayPackage := (bintrayPackage in ThisBuild).value,
bintrayRepository := (bintrayRepository in ThisBuild).value,
@ -309,7 +309,8 @@ lazy val commandProj = (project in file("main-command"))
.settings(
testedBaseSettings,
name := "Command",
libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi),
libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi,
jna, jnaPlatform),
managedSourceDirectories in Compile +=
baseDirectory.value / "src" / "main" / "contraband-scala",
sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala",
@ -324,7 +325,11 @@ lazy val commandProj = (project in file("main-command"))
exclude[ReversedMissingMethodProblem]("sbt.internal.CommandChannel.*"),
// Added an overload to reboot. The overload is private[sbt].
exclude[ReversedMissingMethodProblem]("sbt.StateOps.reboot"),
)
),
unmanagedSources in (Compile, headerCreate) := {
val old = (unmanagedSources in (Compile, headerCreate)).value
old filterNot { x => (x.getName startsWith "NG") || (x.getName == "ReferenceCountedFileDescriptor.java") }
},
)
.configure(
addSbtIO,

View File

@ -0,0 +1,28 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait ConnectionTypeFormats { self: sjsonnew.BasicJsonProtocol =>
implicit lazy val ConnectionTypeFormat: JsonFormat[sbt.ConnectionType] = new JsonFormat[sbt.ConnectionType] {
override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ConnectionType = {
jsOpt match {
case Some(js) =>
unbuilder.readString(js) match {
case "Local" => sbt.ConnectionType.Local
case "Tcp" => sbt.ConnectionType.Tcp
}
case None =>
deserializationError("Expected JsString but found None")
}
}
override def write[J](obj: sbt.ConnectionType, builder: Builder[J]): Unit = {
val str = obj match {
case sbt.ConnectionType.Local => "Local"
case sbt.ConnectionType.Tcp => "Tcp"
}
builder.writeString(str)
}
}
}

View File

@ -0,0 +1,13 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt
sealed abstract class ConnectionType extends Serializable
object ConnectionType {
/** This uses Unix domain socket on POSIX, and named pipe on Windows. */
case object Local extends ConnectionType
case object Tcp extends ConnectionType
}

View File

@ -16,3 +16,10 @@ type CommandSource {
enum ServerAuthentication {
Token
}
enum ConnectionType {
## This uses Unix domain socket on POSIX, and named pipe on Windows.
Local
Tcp
# Ssh
}

View File

@ -33,6 +33,11 @@ object BasicKeys {
"Method of authenticating server command.",
10000)
val serverConnectionType =
AttributeKey[ConnectionType]("serverConnectionType",
"The wire protocol for the server command.",
10000)
private[sbt] val interactive = AttributeKey[Boolean](
"interactive",
"True if commands are currently being entered from an interactive environment.",

View File

@ -17,13 +17,14 @@ import java.security.SecureRandom
import java.math.BigInteger
import scala.concurrent.{ Future, Promise }
import scala.util.{ Try, Success, Failure }
import sbt.internal.util.ErrorHandling
import sbt.internal.protocol.{ PortFile, TokenFile }
import sbt.util.Logger
import sbt.io.IO
import sbt.io.syntax._
import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter }
import sbt.internal.protocol.codec._
import sbt.internal.util.ErrorHandling
import sbt.internal.util.Util.isWindows
private[sbt] sealed trait ServerInstance {
def shutdown(): Unit
@ -38,31 +39,37 @@ private[sbt] object Server {
with TokenFileFormats
object JsonProtocol extends JsonProtocol
def start(host: String,
port: Int,
def start(connection: ServerConnection,
onIncomingSocket: (Socket, ServerInstance) => Unit,
auth: Set[ServerAuthentication],
portfile: File,
tokenfile: File,
log: Logger): ServerInstance =
new ServerInstance { self =>
import connection._
val running = new AtomicBoolean(false)
val p: Promise[Unit] = Promise[Unit]()
val ready: Future[Unit] = p.future
private[this] val rand = new SecureRandom
private[this] var token: String = nextToken
private[this] var serverSocketOpt: Option[ServerSocket] = None
val serverThread = new Thread("sbt-socket-server") {
override def run(): Unit = {
Try {
ErrorHandling.translate(s"server failed to start on $host:$port. ") {
new ServerSocket(port, 50, InetAddress.getByName(host))
ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") {
connection.connectionType match {
case ConnectionType.Local if isWindows =>
new NGWin32NamedPipeServerSocket(pipeName)
case ConnectionType.Local =>
prepareSocketfile()
new NGUnixDomainServerSocket(socketfile.getAbsolutePath)
case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host))
}
}
} match {
case Failure(e) => p.failure(e)
case Success(serverSocket) =>
serverSocket.setSoTimeout(5000)
log.info(s"sbt server started at $host:$port")
serverSocketOpt = Option(serverSocket)
log.info(s"sbt server started at ${connection.shortName}")
writePortfile()
running.set(true)
p.success(())
@ -74,6 +81,7 @@ private[sbt] object Server {
case _: SocketTimeoutException => // its ok
}
}
serverSocket.close()
}
}
}
@ -106,7 +114,7 @@ private[sbt] object Server {
private[this] def writeTokenfile(): Unit = {
import JsonProtocol._
val uri = s"tcp://$host:$port"
val uri = connection.shortName
val t = TokenFile(uri, token)
val jsonToken = Converter.toJson(t).get
@ -141,7 +149,7 @@ private[sbt] object Server {
private[this] def writePortfile(): Unit = {
import JsonProtocol._
val uri = s"tcp://$host:$port"
val uri = connection.shortName
val p =
auth match {
case _ if auth(ServerAuthentication.Token) =>
@ -153,5 +161,32 @@ private[sbt] object Server {
val json = Converter.toJson(p).get
IO.write(portfile, CompactPrinter(json))
}
private[sbt] def prepareSocketfile(): Unit = {
if (socketfile.exists) {
IO.delete(socketfile)
}
IO.createDirectory(socketfile.getParentFile)
}
}
}
private[sbt] case class ServerConnection(
connectionType: ConnectionType,
host: String,
port: Int,
auth: Set[ServerAuthentication],
portfile: File,
tokenfile: File,
socketfile: File,
pipeName: String
) {
def shortName: String = {
connectionType match {
case ConnectionType.Local if isWindows => s"local:$pipeName"
case ConnectionType.Local => s"local://$socketfile"
case ConnectionType.Tcp => s"tcp://$host:$port"
// case ConnectionType.Ssh => s"ssh://$host:$port"
}
}
}

View File

@ -272,7 +272,11 @@ object Defaults extends BuildCommon {
serverPort := 5000 + (Hash
.toHex(Hash(appConfiguration.value.baseDirectory.toString))
.## % 1000),
serverAuthentication := Set(ServerAuthentication.Token),
serverConnectionType := ConnectionType.Local,
serverAuthentication := {
if (serverConnectionType.value == ConnectionType.Tcp) Set(ServerAuthentication.Token)
else Set()
},
insideCI :== sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI"),
))

View File

@ -133,6 +133,8 @@ object Keys {
val serverPort = SettingKey(BasicKeys.serverPort)
val serverHost = SettingKey(BasicKeys.serverHost)
val serverAuthentication = SettingKey(BasicKeys.serverAuthentication)
val serverConnectionType = SettingKey(BasicKeys.serverConnectionType)
val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
val watch = SettingKey(BasicKeys.watch)
val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting)

View File

@ -23,6 +23,7 @@ import Keys.{
serverHost,
serverPort,
serverAuthentication,
serverConnectionType,
watch
}
import Scope.{ Global, ThisScope }
@ -461,6 +462,7 @@ object Project extends ProjectExtra {
val host: Option[String] = get(serverHost)
val port: Option[Int] = get(serverPort)
val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication)
val connectionType: Option[ConnectionType] = get(serverConnectionType)
val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true))
val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands,
projectCommand)
@ -471,6 +473,7 @@ object Project extends ProjectExtra {
.setCond(serverPort.key, port)
.setCond(serverHost.key, host)
.setCond(serverAuthentication.key, authentication)
.setCond(serverConnectionType.key, connectionType)
.put(historyPath.key, history)
.put(templateResolverInfos.key, trs)
.setCond(shellPrompt.key, prompt)

View File

@ -13,7 +13,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ListBuffer
import scala.annotation.tailrec
import BasicKeys.{ serverHost, serverPort, serverAuthentication }
import BasicKeys.{ serverHost, serverPort, serverAuthentication, serverConnectionType }
import java.net.Socket
import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe._
@ -83,6 +83,7 @@ private[sbt] final class CommandExchange {
}
private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}"
private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}"
/**
* Check if a server instance is running already, and start one if it isn't.
@ -100,19 +101,23 @@ private[sbt] final class CommandExchange {
case Some(xs) => xs
case None => Set(ServerAuthentication.Token)
}
lazy val connectionType = (s get serverConnectionType) match {
case Some(x) => x
case None => ConnectionType.Tcp
}
val serverLogLevel: Level.Value = Level.Debug
def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = {
s.log.info(s"new client connected from: ${socket.getPort}")
val name = newNetworkName
s.log.info(s"new client connected: $name")
val logger: Logger = {
val loggerName = s"network-${socket.getPort}"
val log = LogExchange.logger(loggerName, None, None)
LogExchange.unbindLoggerAppenders(loggerName)
val log = LogExchange.logger(name, None, None)
LogExchange.unbindLoggerAppenders(name)
val appender = MainAppender.defaultScreen(s.globalLogging.console)
LogExchange.bindLoggerAppenders(loggerName, List(appender -> serverLogLevel))
LogExchange.bindLoggerAppenders(name, List(appender -> serverLogLevel))
log
}
val channel =
new NetworkChannel(newChannelName, socket, Project structure s, auth, instance, logger)
new NetworkChannel(name, socket, Project structure s, auth, instance, logger)
subscribe(channel)
}
server match {
@ -121,7 +126,18 @@ private[sbt] final class CommandExchange {
val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json"
val h = Hash.halfHashString(portfile.toURI.toString)
val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json"
val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log)
val socketfile = BuildPaths.getGlobalBase(s) / "server" / h / "sock"
val pipeName = "sbt-server-" + h
val connection =
ServerConnection(connectionType,
host,
port,
auth,
portfile,
tokenfile,
socketfile,
pipeName)
val x = Server.start(connection, onIncomingSocket, s.log)
Await.ready(x.ready, Duration("10s"))
x.ready.value match {
case Some(Success(_)) =>

View File

@ -106,6 +106,8 @@ object Dependencies {
val specs2 = "org.specs2" %% "specs2" % "2.4.17"
val junit = "junit" % "junit" % "4.11"
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
val jna = "net.java.dev.jna" % "jna" % "4.1.0"
val jnaPlatform = "net.java.dev.jna" % "jna-platform" % "4.1.0"
private def scala211Module(name: String, moduleVersion: String) = Def setting (
scalaBinaryVersion.value match {

View File

@ -2,6 +2,7 @@ lazy val runClient = taskKey[Unit]("")
lazy val root = (project in file("."))
.settings(
serverConnectionType in Global := ConnectionType.Tcp,
scalaVersion := "2.12.3",
serverPort in Global := 5123,
libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1",

View File

@ -23,19 +23,25 @@ export function activate(context: ExtensionContext) {
let clientOptions: LanguageClientOptions = {
documentSelector: [{ language: 'scala', scheme: 'file' }, { language: 'java', scheme: 'file' }],
initializationOptions: () => {
return {
token: discoverToken()
};
return discoverToken();
}
}
// the port file is hardcoded to a particular location relative to the build.
function discoverToken(): String {
function discoverToken(): any {
let pf = path.join(workspace.rootPath, 'project', 'target', 'active.json');
let portfile = JSON.parse(fs.readFileSync(pf));
let tf = portfile.tokenfilePath;
let tokenfile = JSON.parse(fs.readFileSync(tf));
return tokenfile.token;
// if tokenfilepath exists, return the token.
if (portfile.hasOwnProperty('tokenfilePath')) {
let tf = portfile.tokenfilePath;
let tokenfile = JSON.parse(fs.readFileSync(tf));
return {
token: tokenfile.token
};
} else {
return {};
}
}
// Create the language client and start the client.

View File

@ -4,6 +4,7 @@ import * as path from 'path';
import * as url from 'url';
let net = require('net'),
fs = require('fs'),
os = require('os'),
stdin = process.stdin,
stdout = process.stdout;
@ -16,7 +17,17 @@ socket.on('data', (chunk: any) => {
}).on('end', () => {
stdin.pause();
});
socket.connect(u.port, '127.0.0.1');
if (u.protocol == 'tcp:') {
socket.connect(u.port, '127.0.0.1');
} else if (u.protocol == 'local:' && os.platform() == 'win32') {
let pipePath = '\\\\.\\pipe\\' + u.hostname;
socket.connect(pipePath);
} else if (u.protocol == 'local:') {
socket.connect(u.path);
} else {
throw 'Unknown protocol ' + u.protocol;
}
stdin.resume();
stdin.on('data', (chunk: any) => {