From 348a07779715ba9cbd9da80c8dffa3f0b8eb52cd Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Thu, 21 Sep 2017 23:05:48 -0400 Subject: [PATCH] implement tokenfile authentication --- build.sbt | 12 +++- .../scala/sbt/internal/util/Attributes.scala | 11 +++ .../ServerAuthenticationFormats.scala | 26 +++++++ .../sbt/ServerAuthentication.scala | 12 ++++ main-command/src/main/contraband/state.contra | 4 ++ .../src/main/scala/sbt/BasicKeys.scala | 9 +++ .../scala/sbt/internal/server/Server.scala | 69 +++++++++++++++++-- main/src/main/scala/sbt/Defaults.scala | 4 +- main/src/main/scala/sbt/Keys.scala | 2 + main/src/main/scala/sbt/Project.scala | 23 +++++-- .../scala/sbt/internal/CommandExchange.scala | 21 ++++-- .../sbt/internal/server/NetworkChannel.scala | 46 +++++++++++-- project/Dependencies.scala | 2 +- .../sbt/protocol/InitCommand.scala | 43 ++++++++++++ .../codec/CommandMessageFormats.scala | 4 +- .../protocol/codec/InitCommandFormats.scala | 29 ++++++++ .../sbt/protocol/codec/JsonProtocol.scala | 1 + protocol/src/main/contraband/server.contra | 5 ++ .../sbt-test/server/handshake/Client.scala | 32 +++++++++ 19 files changed, 322 insertions(+), 33 deletions(-) create mode 100644 main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala create mode 100644 main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala diff --git a/build.sbt b/build.sbt index a9e37dd37..6f2185958 100644 --- a/build.sbt +++ b/build.sbt @@ -132,6 +132,10 @@ val collectionProj = (project in file("internal") / "util-collection") name := "Collections", libraryDependencies ++= Seq(sjsonNewScalaJson.value), mimaSettings, + mimaBinaryIssueFilters ++= Seq( + // Added private[sbt] method to capture State attributes. + exclude[ReversedMissingMethodProblem]("sbt.internal.util.AttributeMap.setCond"), + ), ) .configure(addSbtUtilPosition) @@ -292,7 +296,9 @@ lazy val commandProj = (project in file("main-command")) mimaSettings, mimaBinaryIssueFilters ++= Vector( // Changed the signature of Server method. nacho cheese. - exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*") + exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*"), + // Added method to ServerInstance. This is also internal. + exclude[ReversedMissingMethodProblem]("sbt.internal.server.ServerInstance.*"), ) ) .configure( @@ -365,6 +371,10 @@ lazy val mainProj = (project in file("main")) baseDirectory.value / "src" / "main" / "contraband-scala", sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", mimaSettings, + mimaBinaryIssueFilters ++= Vector( + // Changed the signature of NetworkChannel ctor. internal. + exclude[DirectMissingMethodProblem]("sbt.internal.server.NetworkChannel.*"), + ) ) .configure( addSbtIO, diff --git a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala index d70b3df8a..7c24cfd29 100644 --- a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala +++ b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala @@ -168,6 +168,11 @@ trait AttributeMap { /** `true` if there are no mappings in this map, `false` if there are. */ def isEmpty: Boolean + /** + * Adds the mapping `k -> opt.get` if opt is Some. + * Otherwise, it returns this map without the mapping for `k`. + */ + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap } object AttributeMap { @@ -217,6 +222,12 @@ private class BasicAttributeMap(private val backing: Map[AttributeKey[_], Any]) def entries: Iterable[AttributeEntry[_]] = for ((k: AttributeKey[kt], v) <- backing) yield AttributeEntry(k, v.asInstanceOf[kt]) + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap = + opt match { + case Some(v) => put(k, v) + case None => remove(k) + } + override def toString = entries.mkString("(", ", ", ")") } diff --git a/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala new file mode 100644 index 000000000..38bc9f9e1 --- /dev/null +++ b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala @@ -0,0 +1,26 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait ServerAuthenticationFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val ServerAuthenticationFormat: JsonFormat[sbt.ServerAuthentication] = new JsonFormat[sbt.ServerAuthentication] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ServerAuthentication = { + jsOpt match { + case Some(js) => + unbuilder.readString(js) match { + case "Token" => sbt.ServerAuthentication.Token + } + case None => + deserializationError("Expected JsString but found None") + } + } + override def write[J](obj: sbt.ServerAuthentication, builder: Builder[J]): Unit = { + val str = obj match { + case sbt.ServerAuthentication.Token => "Token" + } + builder.writeString(str) + } +} +} diff --git a/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala new file mode 100644 index 000000000..b6f074b75 --- /dev/null +++ b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala @@ -0,0 +1,12 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt +sealed abstract class ServerAuthentication extends Serializable +object ServerAuthentication { + + + case object Token extends ServerAuthentication +} diff --git a/main-command/src/main/contraband/state.contra b/main-command/src/main/contraband/state.contra index 929523505..79d0bcaab 100644 --- a/main-command/src/main/contraband/state.contra +++ b/main-command/src/main/contraband/state.contra @@ -12,3 +12,7 @@ type Exec { type CommandSource { channelName: String! } + +enum ServerAuthentication { + Token +} diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index 10f1215b1..f59f3f605 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -17,6 +17,15 @@ object BasicKeys { val watch = AttributeKey[Watched]("watch", "Continuous execution configuration.", 1000) val serverPort = AttributeKey[Int]("server-port", "The port number used by server command.", 10000) + + val serverHost = + AttributeKey[String]("serverHost", "The host used by server command.", 10000) + + val serverAuthentication = + AttributeKey[Set[ServerAuthentication]]("serverAuthentication", + "Method of authenticating server command.", + 10000) + private[sbt] val interactive = AttributeKey[Boolean]( "interactive", "True if commands are currently being entered from an interactive environment.", diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 45ea5ad67..fce9d591a 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -7,19 +7,22 @@ package server import java.io.File import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicLong } +import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } import scala.concurrent.{ Future, Promise } -import scala.util.{ Try, Success, Failure } +import scala.util.{ Try, Success, Failure, Random } import sbt.internal.util.ErrorHandling -import sbt.internal.protocol.PortFile +import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.util.Logger import sbt.io.IO +import sbt.io.syntax._ import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } import sbt.internal.protocol.codec._ private[sbt] sealed trait ServerInstance { def shutdown(): Unit def ready: Future[Unit] + def authenticate(challenge: String): Boolean } private[sbt] object Server { @@ -31,14 +34,16 @@ private[sbt] object Server { def start(host: String, port: Int, - onIncomingSocket: Socket => Unit, + onIncomingSocket: (Socket, ServerInstance) => Unit, + auth: Set[ServerAuthentication], portfile: File, tokenfile: File, log: Logger): ServerInstance = - new ServerInstance { + new ServerInstance { self => val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future + val token = new AtomicLong(Random.nextLong) val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { @@ -57,7 +62,7 @@ private[sbt] object Server { while (running.get()) { try { val socket = serverSocket.accept() - onIncomingSocket(socket) + onIncomingSocket(socket, self) } catch { case _: SocketTimeoutException => // its ok } @@ -67,6 +72,15 @@ private[sbt] object Server { } serverThread.start() + override def authenticate(challenge: String): Boolean = { + try { + val l = challenge.toLong + token.compareAndSet(l, Random.nextLong) + } catch { + case _: NumberFormatException => false + } + } + override def shutdown(): Unit = { log.info("shutting down server") if (portfile.exists) { @@ -78,10 +92,51 @@ private[sbt] object Server { running.set(false) } + def writeTokenfile(): Unit = { + import JsonProtocol._ + + val uri = s"tcp://$host:$port" + val t = TokenFile(uri, token.get.toString) + val jsonToken = Converter.toJson(t).get + + if (tokenfile.exists) { + IO.delete(tokenfile) + } + IO.touch(tokenfile) + ownerOnly(tokenfile) + IO.write(tokenfile, CompactPrinter(jsonToken), IO.utf8, true) + } + + /** Set the persmission of the file such that the only the owner can read/write it. */ + def ownerOnly(file: File): Unit = { + def acl(owner: UserPrincipal) = { + val builder = AclEntry.newBuilder + builder.setPrincipal(owner) + builder.setPermissions(AclEntryPermission.values(): _*) + builder.setType(AclEntryType.ALLOW) + builder.build + } + file match { + case _ if IO.isPosix => + IO.chmod("rw-------", file) + case _ if IO.hasAclFileAttributeView => + val view = file.aclFileAttributeView + view.setAcl(java.util.Collections.singletonList(acl(view.getOwner))) + case _ => () + } + } + // This file exists through the lifetime of the server. def writePortfile(): Unit = { import JsonProtocol._ - val p = PortFile(s"tcp://$host:$port", None) + + val uri = s"tcp://$host:$port" + val tokenRef = + if (auth(ServerAuthentication.Token)) { + writeTokenfile() + Some(tokenfile.toURI.toString) + } else None + val p = PortFile(uri, tokenRef) val json = Converter.toJson(p).get IO.write(portfile, CompactPrinter(json)) } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 1936641bf..638ccbe21 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -264,9 +264,11 @@ object Defaults extends BuildCommon { .getOrElse(GCUtil.defaultForceGarbageCollection), minForcegcInterval :== GCUtil.defaultMinForcegcInterval, interactionService :== CommandLineUIService, + serverHost := "127.0.0.1", serverPort := 5000 + (Hash .toHex(Hash(appConfiguration.value.baseDirectory.toString)) - .## % 1000) + .## % 1000), + serverAuthentication := Set(ServerAuthentication.Token), )) def defaultTestTasks(key: Scoped): Seq[Setting[_]] = diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 07d6d4d92..66c06b82e 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -127,6 +127,8 @@ object Keys { val historyPath = SettingKey(BasicKeys.historyPath) val shellPrompt = SettingKey(BasicKeys.shellPrompt) val serverPort = SettingKey(BasicKeys.serverPort) + val serverHost = SettingKey(BasicKeys.serverHost) + val serverAuthentication = SettingKey(BasicKeys.serverAuthentication) val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting) val watch = SettingKey(BasicKeys.watch) val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting) diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index b1aa71d06..6ad6e9d45 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -16,7 +16,9 @@ import Keys.{ sessionSettings, shellPrompt, templateResolverInfos, + serverHost, serverPort, + serverAuthentication, watch } import Scope.{ Global, ThisScope } @@ -509,23 +511,30 @@ object Project extends ProjectExtra { val prompt = get(shellPrompt) val trs = (templateResolverInfos in Global get structure.data).toList.flatten val watched = get(watch) + val host: Option[String] = get(serverHost) val port: Option[Int] = get(serverPort) + val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication) val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true)) val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands, projectCommand) - val newAttrs0 = - setCond(Watched.Configuration, watched, s.attributes).put(historyPath.key, history) - val newAttrs = setCond(serverPort.key, port, newAttrs0) - .put(historyPath.key, history) - .put(templateResolverInfos.key, trs) + val newAttrs = + s.attributes + .setCond(Watched.Configuration, watched) + .put(historyPath.key, history) + .setCond(serverPort.key, port) + .setCond(serverHost.key, host) + .setCond(serverAuthentication.key, authentication) + .put(historyPath.key, history) + .put(templateResolverInfos.key, trs) + .setCond(shellPrompt.key, prompt) s.copy( - attributes = setCond(shellPrompt.key, prompt, newAttrs), + attributes = newAttrs, definedCommands = newDefinedCommands ) } def setCond[T](key: AttributeKey[T], vopt: Option[T], attributes: AttributeMap): AttributeMap = - vopt match { case Some(v) => attributes.put(key, v); case None => attributes.remove(key) } + attributes.setCond(key, vopt) private[sbt] def checkTargets(data: Settings[Scope]): Option[String] = { val dups = overlappingTargets(allTargets(data)) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 0902665bb..574dfd8ce 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -6,10 +6,10 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import sbt.internal.server._ import sbt.internal.util.StringEvent -import sbt.protocol.{ EventMessage, Serialization, ChannelAcceptedEvent } +import sbt.protocol.{ EventMessage, Serialization } import scala.collection.mutable.ListBuffer import scala.annotation.tailrec -import BasicKeys.serverPort +import BasicKeys.{ serverHost, serverPort, serverAuthentication } import java.net.Socket import sjsonnew.JsonFormat import scala.concurrent.Await @@ -76,15 +76,22 @@ private[sbt] final class CommandExchange { * Check if a server instance is running already, and start one if it isn't. */ private[sbt] def runServer(s: State): State = { - def port = (s get serverPort) match { + lazy val port = (s get serverPort) match { case Some(x) => x case None => 5001 } - def onIncomingSocket(socket: Socket): Unit = { + lazy val host = (s get serverHost) match { + case Some(x) => x + case None => "127.0.0.1" + } + lazy val auth: Set[ServerAuthentication] = (s get serverAuthentication) match { + case Some(xs) => xs + case None => Set(ServerAuthentication.Token) + } + def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { s.log.info(s"new client connected from: ${socket.getPort}") - val channel = new NetworkChannel(newChannelName, socket, Project structure s) + val channel = new NetworkChannel(newChannelName, socket, Project structure s, auth, instance) subscribe(channel) - channel.publishEventMessage(ChannelAcceptedEvent(channel.name)) } server match { case Some(x) => // do nothing @@ -92,7 +99,7 @@ private[sbt] final class CommandExchange { val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" val h = Hash.halfHashString(portfile.toURI.toString) val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" - val x = Server.start("127.0.0.1", port, onIncomingSocket, portfile, tokenfile, s.log) + val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 22d51c5ce..eeaa8ed7c 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -10,11 +10,16 @@ import java.util.concurrent.atomic.AtomicBoolean import sbt.protocol._ import sjsonnew._ -final class NetworkChannel(val name: String, connection: Socket, structure: BuildStructure) +final class NetworkChannel(val name: String, + connection: Socket, + structure: BuildStructure, + auth: Set[ServerAuthentication], + instance: ServerInstance) extends CommandChannel { private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte private val out = connection.getOutputStream + private var initialized = false val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { override def run(): Unit = { @@ -42,12 +47,10 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil ) delimPos = buffer.indexOf(delimiter) } - } catch { case _: SocketTimeoutException => // its ok } } - } finally { shutdown() } @@ -72,15 +75,44 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil } def onCommand(command: CommandMessage): Unit = command match { + case x: InitCommand => onInitCommand(x) case x: ExecCommand => onExecCommand(x) case x: SettingQuery => onSettingQuery(x) } - private def onExecCommand(cmd: ExecCommand) = - append(Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + private def onInitCommand(cmd: InitCommand): Unit = { + if (auth(ServerAuthentication.Token)) { + cmd.token match { + case Some(x) => + instance.authenticate(x) match { + case true => + initialized = true + publishEventMessage(ChannelAcceptedEvent(name)) + case _ => sys.error("invalid token") + } + case None => sys.error("init command but without token.") + } + } else { + initialized = true + } + } - private def onSettingQuery(req: SettingQuery) = - StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + private def onExecCommand(cmd: ExecCommand) = { + if (initialized) { + append( + Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + } else { + println(s"ignoring command $cmd before initialization") + } + } + + private def onSettingQuery(req: SettingQuery) = { + if (initialized) { + StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + } else { + println(s"ignoring query $req before initialization") + } + } def shutdown(): Unit = { println("Shutting down client connection") diff --git a/project/Dependencies.scala b/project/Dependencies.scala index b1b31a0b4..8745f1dfb 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -12,7 +12,7 @@ object Dependencies { val baseScalaVersion = scala212 // sbt modules - private val ioVersion = "1.0.1" + private val ioVersion = "1.1.0" private val utilVersion = "1.0.1" private val lmVersion = "1.0.2" private val zincVersion = "1.0.1" diff --git a/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala new file mode 100644 index 000000000..e45b25c84 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala @@ -0,0 +1,43 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class InitCommand private ( + val token: Option[String], + val execId: Option[String]) extends sbt.protocol.CommandMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: InitCommand => (this.token == x.token) && (this.execId == x.execId) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.protocol.InitCommand".##) + token.##) + execId.##) + } + override def toString: String = { + "InitCommand(" + token + ", " + execId + ")" + } + protected[this] def copy(token: Option[String] = token, execId: Option[String] = execId): InitCommand = { + new InitCommand(token, execId) + } + def withToken(token: Option[String]): InitCommand = { + copy(token = token) + } + def withToken(token: String): InitCommand = { + copy(token = Option(token)) + } + def withExecId(execId: Option[String]): InitCommand = { + copy(execId = execId) + } + def withExecId(execId: String): InitCommand = { + copy(execId = Option(execId)) + } +} +object InitCommand { + + def apply(token: Option[String], execId: Option[String]): InitCommand = new InitCommand(token, execId) + def apply(token: String, execId: String): InitCommand = new InitCommand(Option(token), Option(execId)) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala index 650bf14dd..c80c7aaed 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => -implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat2[sbt.protocol.CommandMessage, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") +trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => +implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat3[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala new file mode 100644 index 000000000..8d3f50759 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait InitCommandFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val InitCommandFormat: JsonFormat[sbt.protocol.InitCommand] = new JsonFormat[sbt.protocol.InitCommand] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.InitCommand = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val token = unbuilder.readField[Option[String]]("token") + val execId = unbuilder.readField[Option[String]]("execId") + unbuilder.endObject() + sbt.protocol.InitCommand(token, execId) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.InitCommand, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("token", obj.token) + builder.addField("execId", obj.execId) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala index 3b92fc46c..cc9d0fa90 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala @@ -5,6 +5,7 @@ // DO NOT EDIT MANUALLY package sbt.protocol.codec trait JsonProtocol extends sjsonnew.BasicJsonProtocol + with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.CommandMessageFormats diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index 31160a694..2976be229 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -7,6 +7,11 @@ package sbt.protocol interface CommandMessage { } +type InitCommand implements CommandMessage { + token: String + execId: String +} + ## Command to execute sbt command. type ExecCommand implements CommandMessage { commandLine: String! diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala index 4be0d8aad..6858ff238 100644 --- a/sbt/src/sbt-test/server/handshake/Client.scala +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -18,6 +18,10 @@ object Client extends App { val out = connection.getOutputStream val in = connection.getInputStream + out.write(s"""{ "type": "InitCommand", "token": "$getToken" }""".getBytes("utf-8")) + out.write(delimiter.toInt) + out.flush + out.write("""{ "type": "ExecCommand", "commandLine": "exit" }""".getBytes("utf-8")) out.write(delimiter.toInt) out.flush @@ -25,6 +29,34 @@ object Client extends App { val baseDirectory = new File(args(0)) IO.write(baseDirectory / "ok.txt", "ok") + def getToken: String = { + val tokenfile = new File(getTokenFile) + val json: JValue = Parser.parseFromFile(tokenfile).get + json match { + case JObject(fields) => + (fields find { _.field == "token" } map { _.value }) match { + case Some(JString(value)) => value + case _ => + sys.error("json doesn't token field that is JString") + } + case _ => sys.error("json doesn't have token field") + } + } + + def getTokenFile: URI = { + val portfile = baseDirectory / "project" / "target" / "active.json" + val json: JValue = Parser.parseFromFile(portfile).get + json match { + case JObject(fields) => + (fields find { _.field == "tokenfile" } map { _.value }) match { + case Some(JString(value)) => new URI(value) + case _ => + sys.error("json doesn't tokenfile field that is JString") + } + case _ => sys.error("json doesn't have tokenfile field") + } + } + def getPort: Int = { val portfile = baseDirectory / "project" / "target" / "active.json" val json: JValue = Parser.parseFromFile(portfile).get