diff --git a/.travis.yml b/.travis.yml index a33abc64e..32110046e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ env: - SBT_CMD="scripted dependency-management/*4of4" - SBT_CMD="scripted java/* package/* reporter/* run/* project-load/*" - SBT_CMD="scripted project/*1of2" - - SBT_CMD="scripted project/*2of2" + - SBT_CMD="scripted project/*2of2 server/*" - SBT_CMD="scripted source-dependencies/*1of3" - SBT_CMD="scripted source-dependencies/*2of3" - SBT_CMD="scripted source-dependencies/*3of3" diff --git a/build.sbt b/build.sbt index 7e4e6d751..6f2185958 100644 --- a/build.sbt +++ b/build.sbt @@ -132,6 +132,10 @@ val collectionProj = (project in file("internal") / "util-collection") name := "Collections", libraryDependencies ++= Seq(sjsonNewScalaJson.value), mimaSettings, + mimaBinaryIssueFilters ++= Seq( + // Added private[sbt] method to capture State attributes. + exclude[ReversedMissingMethodProblem]("sbt.internal.util.AttributeMap.setCond"), + ), ) .configure(addSbtUtilPosition) @@ -290,6 +294,12 @@ lazy val commandProj = (project in file("main-command")) sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", contrabandFormatsForType in generateContrabands in Compile := ContrabandConfig.getFormats, mimaSettings, + mimaBinaryIssueFilters ++= Vector( + // Changed the signature of Server method. nacho cheese. + exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*"), + // Added method to ServerInstance. This is also internal. + exclude[ReversedMissingMethodProblem]("sbt.internal.server.ServerInstance.*"), + ) ) .configure( addSbtIO, @@ -361,6 +371,10 @@ lazy val mainProj = (project in file("main")) baseDirectory.value / "src" / "main" / "contraband-scala", sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", mimaSettings, + mimaBinaryIssueFilters ++= Vector( + // Changed the signature of NetworkChannel ctor. internal. + exclude[DirectMissingMethodProblem]("sbt.internal.server.NetworkChannel.*"), + ) ) .configure( addSbtIO, diff --git a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala index d70b3df8a..7c24cfd29 100644 --- a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala +++ b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala @@ -168,6 +168,11 @@ trait AttributeMap { /** `true` if there are no mappings in this map, `false` if there are. */ def isEmpty: Boolean + /** + * Adds the mapping `k -> opt.get` if opt is Some. + * Otherwise, it returns this map without the mapping for `k`. + */ + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap } object AttributeMap { @@ -217,6 +222,12 @@ private class BasicAttributeMap(private val backing: Map[AttributeKey[_], Any]) def entries: Iterable[AttributeEntry[_]] = for ((k: AttributeKey[kt], v) <- backing) yield AttributeEntry(k, v.asInstanceOf[kt]) + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap = + opt match { + case Some(v) => put(k, v) + case None => remove(k) + } + override def toString = entries.mkString("(", ", ", ")") } diff --git a/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala new file mode 100644 index 000000000..38bc9f9e1 --- /dev/null +++ b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala @@ -0,0 +1,26 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait ServerAuthenticationFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val ServerAuthenticationFormat: JsonFormat[sbt.ServerAuthentication] = new JsonFormat[sbt.ServerAuthentication] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ServerAuthentication = { + jsOpt match { + case Some(js) => + unbuilder.readString(js) match { + case "Token" => sbt.ServerAuthentication.Token + } + case None => + deserializationError("Expected JsString but found None") + } + } + override def write[J](obj: sbt.ServerAuthentication, builder: Builder[J]): Unit = { + val str = obj match { + case sbt.ServerAuthentication.Token => "Token" + } + builder.writeString(str) + } +} +} diff --git a/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala new file mode 100644 index 000000000..b6f074b75 --- /dev/null +++ b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala @@ -0,0 +1,12 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt +sealed abstract class ServerAuthentication extends Serializable +object ServerAuthentication { + + + case object Token extends ServerAuthentication +} diff --git a/main-command/src/main/contraband/state.contra b/main-command/src/main/contraband/state.contra index 929523505..79d0bcaab 100644 --- a/main-command/src/main/contraband/state.contra +++ b/main-command/src/main/contraband/state.contra @@ -12,3 +12,7 @@ type Exec { type CommandSource { channelName: String! } + +enum ServerAuthentication { + Token +} diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index 10f1215b1..f59f3f605 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -17,6 +17,15 @@ object BasicKeys { val watch = AttributeKey[Watched]("watch", "Continuous execution configuration.", 1000) val serverPort = AttributeKey[Int]("server-port", "The port number used by server command.", 10000) + + val serverHost = + AttributeKey[String]("serverHost", "The host used by server command.", 10000) + + val serverAuthentication = + AttributeKey[Set[ServerAuthentication]]("serverAuthentication", + "Method of authenticating server command.", + 10000) + private[sbt] val interactive = AttributeKey[Boolean]( "interactive", "True if commands are currently being entered from an interactive environment.", diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 7b764225f..9c8d3ee95 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -5,30 +5,48 @@ package sbt package internal package server +import java.io.File import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } import java.util.concurrent.atomic.AtomicBoolean -import sbt.util.Logger -import sbt.internal.util.ErrorHandling +import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } +import java.security.SecureRandom +import java.math.BigInteger import scala.concurrent.{ Future, Promise } import scala.util.{ Try, Success, Failure } +import sbt.internal.util.ErrorHandling +import sbt.internal.protocol.{ PortFile, TokenFile } +import sbt.util.Logger +import sbt.io.IO +import sbt.io.syntax._ +import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } +import sbt.internal.protocol.codec._ private[sbt] sealed trait ServerInstance { def shutdown(): Unit def ready: Future[Unit] + def authenticate(challenge: String): Boolean } private[sbt] object Server { + sealed trait JsonProtocol + extends sjsonnew.BasicJsonProtocol + with PortFileFormats + with TokenFileFormats + object JsonProtocol extends JsonProtocol + def start(host: String, port: Int, - onIncomingSocket: Socket => Unit, - /*onIncomingCommand: CommandMessage => Unit,*/ log: Logger): ServerInstance = - new ServerInstance { - - // val lock = new AnyRef {} - // val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty + onIncomingSocket: (Socket, ServerInstance) => Unit, + auth: Set[ServerAuthentication], + portfile: File, + tokenfile: File, + log: Logger): ServerInstance = + new ServerInstance { self => val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future + private[this] val rand = new SecureRandom + private[this] var token: String = nextToken val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { @@ -41,12 +59,13 @@ private[sbt] object Server { case Success(serverSocket) => serverSocket.setSoTimeout(5000) log.info(s"sbt server started at $host:$port") + writePortfile() running.set(true) p.success(()) while (running.get()) { try { val socket = serverSocket.accept() - onIncomingSocket(socket) + onIncomingSocket(socket, self) } catch { case _: SocketTimeoutException => // its ok } @@ -56,10 +75,79 @@ private[sbt] object Server { } serverThread.start() + override def authenticate(challenge: String): Boolean = synchronized { + if (token == challenge) { + token = nextToken + writeTokenfile() + true + } else false + } + + /** Generates 128-bit non-negative integer, and represent it as decimal string. */ + private[this] def nextToken: String = { + new BigInteger(128, rand).toString + } + override def shutdown(): Unit = { log.info("shutting down server") + if (portfile.exists) { + IO.delete(portfile) + } + if (tokenfile.exists) { + IO.delete(tokenfile) + } running.set(false) } - } + private[this] def writeTokenfile(): Unit = { + import JsonProtocol._ + + val uri = s"tcp://$host:$port" + val t = TokenFile(uri, token) + val jsonToken = Converter.toJson(t).get + + if (tokenfile.exists) { + IO.delete(tokenfile) + } + IO.touch(tokenfile) + ownerOnly(tokenfile) + IO.write(tokenfile, CompactPrinter(jsonToken), IO.utf8, true) + } + + /** Set the persmission of the file such that the only the owner can read/write it. */ + private[this] def ownerOnly(file: File): Unit = { + def acl(owner: UserPrincipal) = { + val builder = AclEntry.newBuilder + builder.setPrincipal(owner) + builder.setPermissions(AclEntryPermission.values(): _*) + builder.setType(AclEntryType.ALLOW) + builder.build + } + file match { + case _ if IO.isPosix => + IO.chmod("rw-------", file) + case _ if IO.hasAclFileAttributeView => + val view = file.aclFileAttributeView + view.setAcl(java.util.Collections.singletonList(acl(view.getOwner))) + case _ => () + } + } + + // This file exists through the lifetime of the server. + private[this] def writePortfile(): Unit = { + import JsonProtocol._ + + val uri = s"tcp://$host:$port" + val p = + auth match { + case _ if auth(ServerAuthentication.Token) => + writeTokenfile() + PortFile(uri, Option(tokenfile.toString), Option(tokenfile.toURI.toString)) + case _ => + PortFile(uri, None, None) + } + val json = Converter.toJson(p).get + IO.write(portfile, CompactPrinter(json)) + } + } } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 1936641bf..638ccbe21 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -264,9 +264,11 @@ object Defaults extends BuildCommon { .getOrElse(GCUtil.defaultForceGarbageCollection), minForcegcInterval :== GCUtil.defaultMinForcegcInterval, interactionService :== CommandLineUIService, + serverHost := "127.0.0.1", serverPort := 5000 + (Hash .toHex(Hash(appConfiguration.value.baseDirectory.toString)) - .## % 1000) + .## % 1000), + serverAuthentication := Set(ServerAuthentication.Token), )) def defaultTestTasks(key: Scoped): Seq[Setting[_]] = diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 07d6d4d92..66c06b82e 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -127,6 +127,8 @@ object Keys { val historyPath = SettingKey(BasicKeys.historyPath) val shellPrompt = SettingKey(BasicKeys.shellPrompt) val serverPort = SettingKey(BasicKeys.serverPort) + val serverHost = SettingKey(BasicKeys.serverHost) + val serverAuthentication = SettingKey(BasicKeys.serverAuthentication) val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting) val watch = SettingKey(BasicKeys.watch) val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index f4043f989..47a67a0fd 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -101,7 +101,9 @@ object StandardMain { val previous = TrapExit.installManager() try { try { - MainLoop.runLogged(s) + try { + MainLoop.runLogged(s) + } finally exchange.shutdown } finally DefaultBackgroundJobService.backgroundJobService.shutdown() } finally TrapExit.uninstallManager(previous) } diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index b1aa71d06..6ad6e9d45 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -16,7 +16,9 @@ import Keys.{ sessionSettings, shellPrompt, templateResolverInfos, + serverHost, serverPort, + serverAuthentication, watch } import Scope.{ Global, ThisScope } @@ -509,23 +511,30 @@ object Project extends ProjectExtra { val prompt = get(shellPrompt) val trs = (templateResolverInfos in Global get structure.data).toList.flatten val watched = get(watch) + val host: Option[String] = get(serverHost) val port: Option[Int] = get(serverPort) + val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication) val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true)) val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands, projectCommand) - val newAttrs0 = - setCond(Watched.Configuration, watched, s.attributes).put(historyPath.key, history) - val newAttrs = setCond(serverPort.key, port, newAttrs0) - .put(historyPath.key, history) - .put(templateResolverInfos.key, trs) + val newAttrs = + s.attributes + .setCond(Watched.Configuration, watched) + .put(historyPath.key, history) + .setCond(serverPort.key, port) + .setCond(serverHost.key, host) + .setCond(serverAuthentication.key, authentication) + .put(historyPath.key, history) + .put(templateResolverInfos.key, trs) + .setCond(shellPrompt.key, prompt) s.copy( - attributes = setCond(shellPrompt.key, prompt, newAttrs), + attributes = newAttrs, definedCommands = newDefinedCommands ) } def setCond[T](key: AttributeKey[T], vopt: Option[T], attributes: AttributeMap): AttributeMap = - vopt match { case Some(v) => attributes.put(key, v); case None => attributes.remove(key) } + attributes.setCond(key, vopt) private[sbt] def checkTargets(data: Settings[Scope]): Option[String] = { val dups = overlappingTargets(allTargets(data)) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index d56edb245..574dfd8ce 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -6,15 +6,17 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import sbt.internal.server._ import sbt.internal.util.StringEvent -import sbt.protocol.{ EventMessage, Serialization, ChannelAcceptedEvent } +import sbt.protocol.{ EventMessage, Serialization } import scala.collection.mutable.ListBuffer import scala.annotation.tailrec -import BasicKeys.serverPort +import BasicKeys.{ serverHost, serverPort, serverAuthentication } import java.net.Socket import sjsonnew.JsonFormat import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.util.{ Success, Failure } +import sbt.io.syntax._ +import sbt.io.Hash /** * The command exchange merges multiple command channels (e.g. network and console), @@ -74,20 +76,30 @@ private[sbt] final class CommandExchange { * Check if a server instance is running already, and start one if it isn't. */ private[sbt] def runServer(s: State): State = { - def port = (s get serverPort) match { + lazy val port = (s get serverPort) match { case Some(x) => x case None => 5001 } - def onIncomingSocket(socket: Socket): Unit = { + lazy val host = (s get serverHost) match { + case Some(x) => x + case None => "127.0.0.1" + } + lazy val auth: Set[ServerAuthentication] = (s get serverAuthentication) match { + case Some(xs) => xs + case None => Set(ServerAuthentication.Token) + } + def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { s.log.info(s"new client connected from: ${socket.getPort}") - val channel = new NetworkChannel(newChannelName, socket, Project structure s) + val channel = new NetworkChannel(newChannelName, socket, Project structure s, auth, instance) subscribe(channel) - channel.publishEventMessage(ChannelAcceptedEvent(channel.name)) } server match { case Some(x) => // do nothing case _ => - val x = Server.start("127.0.0.1", port, onIncomingSocket, s.log) + val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" + val h = Hash.halfHashString(portfile.toURI.toString) + val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" + val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 22d51c5ce..eeaa8ed7c 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -10,11 +10,16 @@ import java.util.concurrent.atomic.AtomicBoolean import sbt.protocol._ import sjsonnew._ -final class NetworkChannel(val name: String, connection: Socket, structure: BuildStructure) +final class NetworkChannel(val name: String, + connection: Socket, + structure: BuildStructure, + auth: Set[ServerAuthentication], + instance: ServerInstance) extends CommandChannel { private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte private val out = connection.getOutputStream + private var initialized = false val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { override def run(): Unit = { @@ -42,12 +47,10 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil ) delimPos = buffer.indexOf(delimiter) } - } catch { case _: SocketTimeoutException => // its ok } } - } finally { shutdown() } @@ -72,15 +75,44 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil } def onCommand(command: CommandMessage): Unit = command match { + case x: InitCommand => onInitCommand(x) case x: ExecCommand => onExecCommand(x) case x: SettingQuery => onSettingQuery(x) } - private def onExecCommand(cmd: ExecCommand) = - append(Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + private def onInitCommand(cmd: InitCommand): Unit = { + if (auth(ServerAuthentication.Token)) { + cmd.token match { + case Some(x) => + instance.authenticate(x) match { + case true => + initialized = true + publishEventMessage(ChannelAcceptedEvent(name)) + case _ => sys.error("invalid token") + } + case None => sys.error("init command but without token.") + } + } else { + initialized = true + } + } - private def onSettingQuery(req: SettingQuery) = - StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + private def onExecCommand(cmd: ExecCommand) = { + if (initialized) { + append( + Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + } else { + println(s"ignoring command $cmd before initialization") + } + } + + private def onSettingQuery(req: SettingQuery) = { + if (initialized) { + StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + } else { + println(s"ignoring query $req before initialization") + } + } def shutdown(): Unit = { println("Shutting down client connection") diff --git a/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala b/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala index f254cad6e..3ef94c83b 100644 --- a/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala +++ b/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala @@ -172,7 +172,7 @@ object SettingQueryTest extends org.specs2.mutable.Specification { def query(setting: String): String = { import sbt.protocol._ - val req: SettingQuery = protocol.SettingQuery(setting) + val req: SettingQuery = sbt.protocol.SettingQuery(setting) val rsp: SettingQueryResponse = server.SettingQuery.handleSettingQuery(req, structure) val bytes: Array[Byte] = Serialization serializeEventMessage rsp val payload: String = new String(bytes, java.nio.charset.StandardCharsets.UTF_8) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index b1b31a0b4..8745f1dfb 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -12,7 +12,7 @@ object Dependencies { val baseScalaVersion = scala212 // sbt modules - private val ioVersion = "1.0.1" + private val ioVersion = "1.1.0" private val utilVersion = "1.0.1" private val lmVersion = "1.0.2" private val zincVersion = "1.0.1" diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala new file mode 100644 index 000000000..218aefdfb --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala @@ -0,0 +1,52 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol +/** + * This file should exist throughout the lifetime of the server. + * It can be used to find out the transport protocol (port number etc). + */ +final class PortFile private ( + /** URI of the sbt server. */ + val uri: String, + val tokenfilePath: Option[String], + val tokenfileUri: Option[String]) extends Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: PortFile => (this.uri == x.uri) && (this.tokenfilePath == x.tokenfilePath) && (this.tokenfileUri == x.tokenfileUri) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + uri.##) + tokenfilePath.##) + tokenfileUri.##) + } + override def toString: String = { + "PortFile(" + uri + ", " + tokenfilePath + ", " + tokenfileUri + ")" + } + protected[this] def copy(uri: String = uri, tokenfilePath: Option[String] = tokenfilePath, tokenfileUri: Option[String] = tokenfileUri): PortFile = { + new PortFile(uri, tokenfilePath, tokenfileUri) + } + def withUri(uri: String): PortFile = { + copy(uri = uri) + } + def withTokenfilePath(tokenfilePath: Option[String]): PortFile = { + copy(tokenfilePath = tokenfilePath) + } + def withTokenfilePath(tokenfilePath: String): PortFile = { + copy(tokenfilePath = Option(tokenfilePath)) + } + def withTokenfileUri(tokenfileUri: Option[String]): PortFile = { + copy(tokenfileUri = tokenfileUri) + } + def withTokenfileUri(tokenfileUri: String): PortFile = { + copy(tokenfileUri = Option(tokenfileUri)) + } +} +object PortFile { + + def apply(uri: String, tokenfilePath: Option[String], tokenfileUri: Option[String]): PortFile = new PortFile(uri, tokenfilePath, tokenfileUri) + def apply(uri: String, tokenfilePath: String, tokenfileUri: String): PortFile = new PortFile(uri, Option(tokenfilePath), Option(tokenfileUri)) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala new file mode 100644 index 000000000..e2019147e --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala @@ -0,0 +1,36 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol +final class TokenFile private ( + val uri: String, + val token: String) extends Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TokenFile => (this.uri == x.uri) && (this.token == x.token) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.internal.protocol.TokenFile".##) + uri.##) + token.##) + } + override def toString: String = { + "TokenFile(" + uri + ", " + token + ")" + } + protected[this] def copy(uri: String = uri, token: String = token): TokenFile = { + new TokenFile(uri, token) + } + def withUri(uri: String): TokenFile = { + copy(uri = uri) + } + def withToken(token: String): TokenFile = { + copy(token = token) + } +} +object TokenFile { + + def apply(uri: String, token: String): TokenFile = new TokenFile(uri, token) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala new file mode 100644 index 000000000..0a60963d2 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala @@ -0,0 +1,31 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait PortFileFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val PortFileFormat: JsonFormat[sbt.internal.protocol.PortFile] = new JsonFormat[sbt.internal.protocol.PortFile] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.protocol.PortFile = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val uri = unbuilder.readField[String]("uri") + val tokenfilePath = unbuilder.readField[Option[String]]("tokenfilePath") + val tokenfileUri = unbuilder.readField[Option[String]]("tokenfileUri") + unbuilder.endObject() + sbt.internal.protocol.PortFile(uri, tokenfilePath, tokenfileUri) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.protocol.PortFile, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("uri", obj.uri) + builder.addField("tokenfilePath", obj.tokenfilePath) + builder.addField("tokenfileUri", obj.tokenfileUri) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala new file mode 100644 index 000000000..74a15423f --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TokenFileFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TokenFileFormat: JsonFormat[sbt.internal.protocol.TokenFile] = new JsonFormat[sbt.internal.protocol.TokenFile] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.protocol.TokenFile = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val uri = unbuilder.readField[String]("uri") + val token = unbuilder.readField[String]("token") + unbuilder.endObject() + sbt.internal.protocol.TokenFile(uri, token) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.protocol.TokenFile, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("uri", obj.uri) + builder.addField("token", obj.token) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala new file mode 100644 index 000000000..e45b25c84 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala @@ -0,0 +1,43 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class InitCommand private ( + val token: Option[String], + val execId: Option[String]) extends sbt.protocol.CommandMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: InitCommand => (this.token == x.token) && (this.execId == x.execId) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.protocol.InitCommand".##) + token.##) + execId.##) + } + override def toString: String = { + "InitCommand(" + token + ", " + execId + ")" + } + protected[this] def copy(token: Option[String] = token, execId: Option[String] = execId): InitCommand = { + new InitCommand(token, execId) + } + def withToken(token: Option[String]): InitCommand = { + copy(token = token) + } + def withToken(token: String): InitCommand = { + copy(token = Option(token)) + } + def withExecId(execId: Option[String]): InitCommand = { + copy(execId = execId) + } + def withExecId(execId: String): InitCommand = { + copy(execId = Option(execId)) + } +} +object InitCommand { + + def apply(token: Option[String], execId: Option[String]): InitCommand = new InitCommand(token, execId) + def apply(token: String, execId: String): InitCommand = new InitCommand(Option(token), Option(execId)) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala index 650bf14dd..c80c7aaed 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => -implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat2[sbt.protocol.CommandMessage, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") +trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => +implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat3[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala new file mode 100644 index 000000000..8d3f50759 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait InitCommandFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val InitCommandFormat: JsonFormat[sbt.protocol.InitCommand] = new JsonFormat[sbt.protocol.InitCommand] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.InitCommand = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val token = unbuilder.readField[Option[String]]("token") + val execId = unbuilder.readField[Option[String]]("execId") + unbuilder.endObject() + sbt.protocol.InitCommand(token, execId) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.InitCommand, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("token", obj.token) + builder.addField("execId", obj.execId) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala index 3b92fc46c..cc9d0fa90 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala @@ -5,6 +5,7 @@ // DO NOT EDIT MANUALLY package sbt.protocol.codec trait JsonProtocol extends sjsonnew.BasicJsonProtocol + with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.CommandMessageFormats diff --git a/protocol/src/main/contraband/portfile.contra b/protocol/src/main/contraband/portfile.contra new file mode 100644 index 000000000..82f6567ef --- /dev/null +++ b/protocol/src/main/contraband/portfile.contra @@ -0,0 +1,17 @@ +package sbt.internal.protocol +@target(Scala) +@codecPackage("sbt.internal.protocol.codec") + +## This file should exist throughout the lifetime of the server. +## It can be used to find out the transport protocol (port number etc). +type PortFile { + ## URI of the sbt server. + uri: String! + tokenfilePath: String + tokenfileUri: String +} + +type TokenFile { + uri: String! + token: String! +} diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index 31160a694..2976be229 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -7,6 +7,11 @@ package sbt.protocol interface CommandMessage { } +type InitCommand implements CommandMessage { + token: String + execId: String +} + ## Command to execute sbt command. type ExecCommand implements CommandMessage { commandLine: String! diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala new file mode 100644 index 000000000..2acbdb9bb --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -0,0 +1,84 @@ +package example + +import java.net.{ URI, Socket, InetAddress, SocketException } +import sbt.io._ +import sbt.io.syntax._ +import java.io.File +import sjsonnew.support.scalajson.unsafe.{ Parser, Converter, CompactPrinter } +import sjsonnew.shaded.scalajson.ast.unsafe.{ JValue, JObject, JString } + +object Client extends App { + val host = "127.0.0.1" + val delimiter: Byte = '\n'.toByte + + println("hello") + Thread.sleep(1000) + + val connection = getConnection + val out = connection.getOutputStream + val in = connection.getInputStream + + out.write(s"""{ "type": "InitCommand", "token": "$getToken" }""".getBytes("utf-8")) + out.write(delimiter.toInt) + out.flush + + out.write("""{ "type": "ExecCommand", "commandLine": "exit" }""".getBytes("utf-8")) + out.write(delimiter.toInt) + out.flush + + val baseDirectory = new File(args(0)) + IO.write(baseDirectory / "ok.txt", "ok") + + def getToken: String = { + val tokenfile = new File(getTokenFileUri) + val json: JValue = Parser.parseFromFile(tokenfile).get + json match { + case JObject(fields) => + (fields find { _.field == "token" } map { _.value }) match { + case Some(JString(value)) => value + case _ => + sys.error("json doesn't token field that is JString") + } + case _ => sys.error("json doesn't have token field") + } + } + + def getTokenFileUri: URI = { + val portfile = baseDirectory / "project" / "target" / "active.json" + val json: JValue = Parser.parseFromFile(portfile).get + json match { + case JObject(fields) => + (fields find { _.field == "tokenfileUri" } map { _.value }) match { + case Some(JString(value)) => new URI(value) + case _ => + sys.error("json doesn't tokenfile field that is JString") + } + case _ => sys.error("json doesn't have tokenfile field") + } + } + + def getPort: Int = { + val portfile = baseDirectory / "project" / "target" / "active.json" + val json: JValue = Parser.parseFromFile(portfile).get + json match { + case JObject(fields) => + (fields find { _.field == "uri" } map { _.value }) match { + case Some(JString(value)) => + val u = new URI(value) + u.getPort + case _ => + sys.error("json doesn't uri field that is JString") + } + case _ => sys.error("json doesn't have uri field") + } + } + + def getConnection: Socket = + try { + new Socket(InetAddress.getByName(host), getPort) + } catch { + case _ => + Thread.sleep(1000) + getConnection + } +} diff --git a/sbt/src/sbt-test/server/handshake/build.sbt b/sbt/src/sbt-test/server/handshake/build.sbt new file mode 100644 index 000000000..fd924f0e4 --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/build.sbt @@ -0,0 +1,14 @@ +lazy val runClient = taskKey[Unit]("") + +lazy val root = (project in file(".")) + .settings( + scalaVersion := "2.12.3", + serverPort in Global := 5123, + libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1", + libraryDependencies += "com.eed3si9n" %% "sjson-new-scalajson" % "0.8.0", + runClient := (Def.taskDyn { + val b = baseDirectory.value + (bgRun in Compile).toTask(s""" $b""") + }).value + ) + \ No newline at end of file diff --git a/sbt/src/sbt-test/server/handshake/test b/sbt/src/sbt-test/server/handshake/test new file mode 100644 index 000000000..9c2ba1cc1 --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/test @@ -0,0 +1,6 @@ +> show serverPort +> runClient +-> shell + +$ sleep 1000 +$ exists ok.txt