mirror of https://github.com/sbt/sbt.git
IPC Unix domain socket for sbt server
In addition to TCP, this adds sbt server support for IPC (interprocess communication) using Unix domain socket and Windows named pipe. The use of Unix domain socket has performance and security benefits.
This commit is contained in:
parent
0c803214aa
commit
f785750fc4
11
build.sbt
11
build.sbt
|
|
@ -55,7 +55,7 @@ def commonSettings: Seq[Setting[_]] =
|
|||
concurrentRestrictions in Global += Util.testExclusiveRestriction,
|
||||
testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-w", "1"),
|
||||
testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-verbosity", "2"),
|
||||
javacOptions in compile ++= Seq("-target", "6", "-source", "6", "-Xlint", "-Xlint:-serial"),
|
||||
javacOptions in compile ++= Seq("-Xlint", "-Xlint:-serial"),
|
||||
crossScalaVersions := Seq(baseScalaVersion),
|
||||
bintrayPackage := (bintrayPackage in ThisBuild).value,
|
||||
bintrayRepository := (bintrayRepository in ThisBuild).value,
|
||||
|
|
@ -309,7 +309,8 @@ lazy val commandProj = (project in file("main-command"))
|
|||
.settings(
|
||||
testedBaseSettings,
|
||||
name := "Command",
|
||||
libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi),
|
||||
libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi,
|
||||
jna, jnaPlatform),
|
||||
managedSourceDirectories in Compile +=
|
||||
baseDirectory.value / "src" / "main" / "contraband-scala",
|
||||
sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala",
|
||||
|
|
@ -324,7 +325,11 @@ lazy val commandProj = (project in file("main-command"))
|
|||
exclude[ReversedMissingMethodProblem]("sbt.internal.CommandChannel.*"),
|
||||
// Added an overload to reboot. The overload is private[sbt].
|
||||
exclude[ReversedMissingMethodProblem]("sbt.StateOps.reboot"),
|
||||
)
|
||||
),
|
||||
unmanagedSources in (Compile, headerCreate) := {
|
||||
val old = (unmanagedSources in (Compile, headerCreate)).value
|
||||
old filterNot { x => (x.getName startsWith "NG") || (x.getName == "ReferenceCountedFileDescriptor.java") }
|
||||
},
|
||||
)
|
||||
.configure(
|
||||
addSbtIO,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
|
||||
*/
|
||||
|
||||
// DO NOT EDIT MANUALLY
|
||||
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
|
||||
trait ConnectionTypeFormats { self: sjsonnew.BasicJsonProtocol =>
|
||||
implicit lazy val ConnectionTypeFormat: JsonFormat[sbt.ConnectionType] = new JsonFormat[sbt.ConnectionType] {
|
||||
override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ConnectionType = {
|
||||
jsOpt match {
|
||||
case Some(js) =>
|
||||
unbuilder.readString(js) match {
|
||||
case "Local" => sbt.ConnectionType.Local
|
||||
case "Tcp" => sbt.ConnectionType.Tcp
|
||||
}
|
||||
case None =>
|
||||
deserializationError("Expected JsString but found None")
|
||||
}
|
||||
}
|
||||
override def write[J](obj: sbt.ConnectionType, builder: Builder[J]): Unit = {
|
||||
val str = obj match {
|
||||
case sbt.ConnectionType.Local => "Local"
|
||||
case sbt.ConnectionType.Tcp => "Tcp"
|
||||
}
|
||||
builder.writeString(str)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
/**
|
||||
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
|
||||
*/
|
||||
|
||||
// DO NOT EDIT MANUALLY
|
||||
package sbt
|
||||
sealed abstract class ConnectionType extends Serializable
|
||||
object ConnectionType {
|
||||
|
||||
/** This uses Unix domain socket on POSIX, and named pipe on Windows. */
|
||||
case object Local extends ConnectionType
|
||||
case object Tcp extends ConnectionType
|
||||
}
|
||||
|
|
@ -16,3 +16,10 @@ type CommandSource {
|
|||
enum ServerAuthentication {
|
||||
Token
|
||||
}
|
||||
|
||||
enum ConnectionType {
|
||||
## This uses Unix domain socket on POSIX, and named pipe on Windows.
|
||||
Local
|
||||
Tcp
|
||||
# Ssh
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,11 @@ object BasicKeys {
|
|||
"Method of authenticating server command.",
|
||||
10000)
|
||||
|
||||
val serverConnectionType =
|
||||
AttributeKey[ConnectionType]("serverConnectionType",
|
||||
"The wire protocol for the server command.",
|
||||
10000)
|
||||
|
||||
private[sbt] val interactive = AttributeKey[Boolean](
|
||||
"interactive",
|
||||
"True if commands are currently being entered from an interactive environment.",
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@ import java.security.SecureRandom
|
|||
import java.math.BigInteger
|
||||
import scala.concurrent.{ Future, Promise }
|
||||
import scala.util.{ Try, Success, Failure }
|
||||
import sbt.internal.util.ErrorHandling
|
||||
import sbt.internal.protocol.{ PortFile, TokenFile }
|
||||
import sbt.util.Logger
|
||||
import sbt.io.IO
|
||||
import sbt.io.syntax._
|
||||
import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter }
|
||||
import sbt.internal.protocol.codec._
|
||||
import sbt.internal.util.ErrorHandling
|
||||
import sbt.internal.util.Util.isWindows
|
||||
|
||||
private[sbt] sealed trait ServerInstance {
|
||||
def shutdown(): Unit
|
||||
|
|
@ -38,31 +39,37 @@ private[sbt] object Server {
|
|||
with TokenFileFormats
|
||||
object JsonProtocol extends JsonProtocol
|
||||
|
||||
def start(host: String,
|
||||
port: Int,
|
||||
def start(connection: ServerConnection,
|
||||
onIncomingSocket: (Socket, ServerInstance) => Unit,
|
||||
auth: Set[ServerAuthentication],
|
||||
portfile: File,
|
||||
tokenfile: File,
|
||||
log: Logger): ServerInstance =
|
||||
new ServerInstance { self =>
|
||||
import connection._
|
||||
val running = new AtomicBoolean(false)
|
||||
val p: Promise[Unit] = Promise[Unit]()
|
||||
val ready: Future[Unit] = p.future
|
||||
private[this] val rand = new SecureRandom
|
||||
private[this] var token: String = nextToken
|
||||
private[this] var serverSocketOpt: Option[ServerSocket] = None
|
||||
|
||||
val serverThread = new Thread("sbt-socket-server") {
|
||||
override def run(): Unit = {
|
||||
Try {
|
||||
ErrorHandling.translate(s"server failed to start on $host:$port. ") {
|
||||
new ServerSocket(port, 50, InetAddress.getByName(host))
|
||||
ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") {
|
||||
connection.connectionType match {
|
||||
case ConnectionType.Local if isWindows =>
|
||||
new NGWin32NamedPipeServerSocket(pipeName)
|
||||
case ConnectionType.Local =>
|
||||
prepareSocketfile()
|
||||
new NGUnixDomainServerSocket(socketfile.getAbsolutePath)
|
||||
case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host))
|
||||
}
|
||||
}
|
||||
} match {
|
||||
case Failure(e) => p.failure(e)
|
||||
case Success(serverSocket) =>
|
||||
serverSocket.setSoTimeout(5000)
|
||||
log.info(s"sbt server started at $host:$port")
|
||||
serverSocketOpt = Option(serverSocket)
|
||||
log.info(s"sbt server started at ${connection.shortName}")
|
||||
writePortfile()
|
||||
running.set(true)
|
||||
p.success(())
|
||||
|
|
@ -74,6 +81,7 @@ private[sbt] object Server {
|
|||
case _: SocketTimeoutException => // its ok
|
||||
}
|
||||
}
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -106,7 +114,7 @@ private[sbt] object Server {
|
|||
private[this] def writeTokenfile(): Unit = {
|
||||
import JsonProtocol._
|
||||
|
||||
val uri = s"tcp://$host:$port"
|
||||
val uri = connection.shortName
|
||||
val t = TokenFile(uri, token)
|
||||
val jsonToken = Converter.toJson(t).get
|
||||
|
||||
|
|
@ -141,7 +149,7 @@ private[sbt] object Server {
|
|||
private[this] def writePortfile(): Unit = {
|
||||
import JsonProtocol._
|
||||
|
||||
val uri = s"tcp://$host:$port"
|
||||
val uri = connection.shortName
|
||||
val p =
|
||||
auth match {
|
||||
case _ if auth(ServerAuthentication.Token) =>
|
||||
|
|
@ -153,5 +161,32 @@ private[sbt] object Server {
|
|||
val json = Converter.toJson(p).get
|
||||
IO.write(portfile, CompactPrinter(json))
|
||||
}
|
||||
|
||||
private[sbt] def prepareSocketfile(): Unit = {
|
||||
if (socketfile.exists) {
|
||||
IO.delete(socketfile)
|
||||
}
|
||||
IO.createDirectory(socketfile.getParentFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[sbt] case class ServerConnection(
|
||||
connectionType: ConnectionType,
|
||||
host: String,
|
||||
port: Int,
|
||||
auth: Set[ServerAuthentication],
|
||||
portfile: File,
|
||||
tokenfile: File,
|
||||
socketfile: File,
|
||||
pipeName: String
|
||||
) {
|
||||
def shortName: String = {
|
||||
connectionType match {
|
||||
case ConnectionType.Local if isWindows => s"local:$pipeName"
|
||||
case ConnectionType.Local => s"local://$socketfile"
|
||||
case ConnectionType.Tcp => s"tcp://$host:$port"
|
||||
// case ConnectionType.Ssh => s"ssh://$host:$port"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -272,7 +272,11 @@ object Defaults extends BuildCommon {
|
|||
serverPort := 5000 + (Hash
|
||||
.toHex(Hash(appConfiguration.value.baseDirectory.toString))
|
||||
.## % 1000),
|
||||
serverAuthentication := Set(ServerAuthentication.Token),
|
||||
serverConnectionType := ConnectionType.Local,
|
||||
serverAuthentication := {
|
||||
if (serverConnectionType.value == ConnectionType.Tcp) Set(ServerAuthentication.Token)
|
||||
else Set()
|
||||
},
|
||||
insideCI :== sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI"),
|
||||
))
|
||||
|
||||
|
|
|
|||
|
|
@ -133,6 +133,8 @@ object Keys {
|
|||
val serverPort = SettingKey(BasicKeys.serverPort)
|
||||
val serverHost = SettingKey(BasicKeys.serverHost)
|
||||
val serverAuthentication = SettingKey(BasicKeys.serverAuthentication)
|
||||
val serverConnectionType = SettingKey(BasicKeys.serverConnectionType)
|
||||
|
||||
val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
|
||||
val watch = SettingKey(BasicKeys.watch)
|
||||
val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import Keys.{
|
|||
serverHost,
|
||||
serverPort,
|
||||
serverAuthentication,
|
||||
serverConnectionType,
|
||||
watch
|
||||
}
|
||||
import Scope.{ Global, ThisScope }
|
||||
|
|
@ -461,6 +462,7 @@ object Project extends ProjectExtra {
|
|||
val host: Option[String] = get(serverHost)
|
||||
val port: Option[Int] = get(serverPort)
|
||||
val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication)
|
||||
val connectionType: Option[ConnectionType] = get(serverConnectionType)
|
||||
val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true))
|
||||
val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands,
|
||||
projectCommand)
|
||||
|
|
@ -471,6 +473,7 @@ object Project extends ProjectExtra {
|
|||
.setCond(serverPort.key, port)
|
||||
.setCond(serverHost.key, host)
|
||||
.setCond(serverAuthentication.key, authentication)
|
||||
.setCond(serverConnectionType.key, connectionType)
|
||||
.put(historyPath.key, history)
|
||||
.put(templateResolverInfos.key, trs)
|
||||
.setCond(shellPrompt.key, prompt)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
|
|||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.annotation.tailrec
|
||||
import BasicKeys.{ serverHost, serverPort, serverAuthentication }
|
||||
import BasicKeys.{ serverHost, serverPort, serverAuthentication, serverConnectionType }
|
||||
import java.net.Socket
|
||||
import sjsonnew.JsonFormat
|
||||
import sjsonnew.shaded.scalajson.ast.unsafe._
|
||||
|
|
@ -83,6 +83,7 @@ private[sbt] final class CommandExchange {
|
|||
}
|
||||
|
||||
private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}"
|
||||
private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}"
|
||||
|
||||
/**
|
||||
* Check if a server instance is running already, and start one if it isn't.
|
||||
|
|
@ -100,19 +101,23 @@ private[sbt] final class CommandExchange {
|
|||
case Some(xs) => xs
|
||||
case None => Set(ServerAuthentication.Token)
|
||||
}
|
||||
lazy val connectionType = (s get serverConnectionType) match {
|
||||
case Some(x) => x
|
||||
case None => ConnectionType.Tcp
|
||||
}
|
||||
val serverLogLevel: Level.Value = Level.Debug
|
||||
def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = {
|
||||
s.log.info(s"new client connected from: ${socket.getPort}")
|
||||
val name = newNetworkName
|
||||
s.log.info(s"new client connected: $name")
|
||||
val logger: Logger = {
|
||||
val loggerName = s"network-${socket.getPort}"
|
||||
val log = LogExchange.logger(loggerName, None, None)
|
||||
LogExchange.unbindLoggerAppenders(loggerName)
|
||||
val log = LogExchange.logger(name, None, None)
|
||||
LogExchange.unbindLoggerAppenders(name)
|
||||
val appender = MainAppender.defaultScreen(s.globalLogging.console)
|
||||
LogExchange.bindLoggerAppenders(loggerName, List(appender -> serverLogLevel))
|
||||
LogExchange.bindLoggerAppenders(name, List(appender -> serverLogLevel))
|
||||
log
|
||||
}
|
||||
val channel =
|
||||
new NetworkChannel(newChannelName, socket, Project structure s, auth, instance, logger)
|
||||
new NetworkChannel(name, socket, Project structure s, auth, instance, logger)
|
||||
subscribe(channel)
|
||||
}
|
||||
server match {
|
||||
|
|
@ -121,7 +126,18 @@ private[sbt] final class CommandExchange {
|
|||
val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json"
|
||||
val h = Hash.halfHashString(portfile.toURI.toString)
|
||||
val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json"
|
||||
val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log)
|
||||
val socketfile = BuildPaths.getGlobalBase(s) / "server" / h / "sock"
|
||||
val pipeName = "sbt-server-" + h
|
||||
val connection =
|
||||
ServerConnection(connectionType,
|
||||
host,
|
||||
port,
|
||||
auth,
|
||||
portfile,
|
||||
tokenfile,
|
||||
socketfile,
|
||||
pipeName)
|
||||
val x = Server.start(connection, onIncomingSocket, s.log)
|
||||
Await.ready(x.ready, Duration("10s"))
|
||||
x.ready.value match {
|
||||
case Some(Success(_)) =>
|
||||
|
|
|
|||
|
|
@ -106,6 +106,8 @@ object Dependencies {
|
|||
val specs2 = "org.specs2" %% "specs2" % "2.4.17"
|
||||
val junit = "junit" % "junit" % "4.11"
|
||||
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
|
||||
val jna = "net.java.dev.jna" % "jna" % "4.1.0"
|
||||
val jnaPlatform = "net.java.dev.jna" % "jna-platform" % "4.1.0"
|
||||
|
||||
private def scala211Module(name: String, moduleVersion: String) = Def setting (
|
||||
scalaBinaryVersion.value match {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ lazy val runClient = taskKey[Unit]("")
|
|||
|
||||
lazy val root = (project in file("."))
|
||||
.settings(
|
||||
serverConnectionType in Global := ConnectionType.Tcp,
|
||||
scalaVersion := "2.12.3",
|
||||
serverPort in Global := 5123,
|
||||
libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1",
|
||||
|
|
|
|||
|
|
@ -23,19 +23,25 @@ export function activate(context: ExtensionContext) {
|
|||
let clientOptions: LanguageClientOptions = {
|
||||
documentSelector: [{ language: 'scala', scheme: 'file' }, { language: 'java', scheme: 'file' }],
|
||||
initializationOptions: () => {
|
||||
return {
|
||||
token: discoverToken()
|
||||
};
|
||||
return discoverToken();
|
||||
}
|
||||
}
|
||||
|
||||
// the port file is hardcoded to a particular location relative to the build.
|
||||
function discoverToken(): String {
|
||||
function discoverToken(): any {
|
||||
let pf = path.join(workspace.rootPath, 'project', 'target', 'active.json');
|
||||
let portfile = JSON.parse(fs.readFileSync(pf));
|
||||
let tf = portfile.tokenfilePath;
|
||||
let tokenfile = JSON.parse(fs.readFileSync(tf));
|
||||
return tokenfile.token;
|
||||
|
||||
// if tokenfilepath exists, return the token.
|
||||
if (portfile.hasOwnProperty('tokenfilePath')) {
|
||||
let tf = portfile.tokenfilePath;
|
||||
let tokenfile = JSON.parse(fs.readFileSync(tf));
|
||||
return {
|
||||
token: tokenfile.token
|
||||
};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Create the language client and start the client.
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import * as path from 'path';
|
|||
import * as url from 'url';
|
||||
let net = require('net'),
|
||||
fs = require('fs'),
|
||||
os = require('os'),
|
||||
stdin = process.stdin,
|
||||
stdout = process.stdout;
|
||||
|
||||
|
|
@ -16,7 +17,17 @@ socket.on('data', (chunk: any) => {
|
|||
}).on('end', () => {
|
||||
stdin.pause();
|
||||
});
|
||||
socket.connect(u.port, '127.0.0.1');
|
||||
|
||||
if (u.protocol == 'tcp:') {
|
||||
socket.connect(u.port, '127.0.0.1');
|
||||
} else if (u.protocol == 'local:' && os.platform() == 'win32') {
|
||||
let pipePath = '\\\\.\\pipe\\' + u.hostname;
|
||||
socket.connect(pipePath);
|
||||
} else if (u.protocol == 'local:') {
|
||||
socket.connect(u.path);
|
||||
} else {
|
||||
throw 'Unknown protocol ' + u.protocol;
|
||||
}
|
||||
|
||||
stdin.resume();
|
||||
stdin.on('data', (chunk: any) => {
|
||||
|
|
|
|||
Loading…
Reference in New Issue