From 6b8e716428fbebdef4d3cfbd57ef1673a7e71f08 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 11 Sep 2017 00:34:14 -0400 Subject: [PATCH 1/7] implement server handshake test --- .travis.yml | 2 +- .../sbt-test/server/handshake/Client.scala | 35 +++++++++++++++++++ sbt/src/sbt-test/server/handshake/build.sbt | 13 +++++++ sbt/src/sbt-test/server/handshake/test | 6 ++++ 4 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 sbt/src/sbt-test/server/handshake/Client.scala create mode 100644 sbt/src/sbt-test/server/handshake/build.sbt create mode 100644 sbt/src/sbt-test/server/handshake/test diff --git a/.travis.yml b/.travis.yml index a33abc64e..32110046e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ env: - SBT_CMD="scripted dependency-management/*4of4" - SBT_CMD="scripted java/* package/* reporter/* run/* project-load/*" - SBT_CMD="scripted project/*1of2" - - SBT_CMD="scripted project/*2of2" + - SBT_CMD="scripted project/*2of2 server/*" - SBT_CMD="scripted source-dependencies/*1of3" - SBT_CMD="scripted source-dependencies/*2of3" - SBT_CMD="scripted source-dependencies/*3of3" diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala new file mode 100644 index 000000000..ee0255abf --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -0,0 +1,35 @@ +package example + +import java.net.{ URI, Socket, InetAddress, SocketException } +import sbt.io._ +import sbt.io.syntax._ +import java.io.File + +object Client extends App { + val host = "127.0.0.1" + val port = 5123 + val delimiter: Byte = '\n'.toByte + + println("hello") + Thread.sleep(1000) + + val connection = getConnection + val out = connection.getOutputStream + val in = connection.getInputStream + + out.write("""{ "type": "ExecCommand", "commandLine": "exit" }""".getBytes("utf-8")) + out.write(delimiter.toInt) + out.flush + + val baseDirectory = new File(args(0)) + IO.write(baseDirectory / "ok.txt", "ok") + + def getConnection: Socket = + try { + new Socket(InetAddress.getByName(host), port) + } catch { + case _ => + Thread.sleep(1000) + getConnection + } +} diff --git a/sbt/src/sbt-test/server/handshake/build.sbt b/sbt/src/sbt-test/server/handshake/build.sbt new file mode 100644 index 000000000..9692ab267 --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/build.sbt @@ -0,0 +1,13 @@ +lazy val runClient = taskKey[Unit]("") + +lazy val root = (project in file(".")) + .settings( + scalaVersion := "2.12.3", + serverPort in Global := 5123, + libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1", + runClient := (Def.taskDyn { + val b = baseDirectory.value + (bgRun in Compile).toTask(s""" $b""") + }).value + ) + \ No newline at end of file diff --git a/sbt/src/sbt-test/server/handshake/test b/sbt/src/sbt-test/server/handshake/test new file mode 100644 index 000000000..9c2ba1cc1 --- /dev/null +++ b/sbt/src/sbt-test/server/handshake/test @@ -0,0 +1,6 @@ +> show serverPort +> runClient +-> shell + +$ sleep 1000 +$ exists ok.txt From 9d40404915f650579f787a88f1d2abc2c9ff17f2 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Sun, 17 Sep 2017 19:08:45 -0400 Subject: [PATCH 2/7] JSON port file This implements JSON-based port file. Thoughout the lifetime of the sbt server there will be `cwd / "project" / "target" / "active.json"`, which contains `url` field. Using this `url` the potential client, such as IDEs can find out which port number to hit. Ref #3508 --- build.sbt | 4 ++ .../scala/sbt/internal/server/Server.scala | 38 +++++++++++++--- main/src/main/scala/sbt/Main.scala | 4 +- .../scala/sbt/internal/CommandExchange.scala | 7 ++- .../sbt/internal/protocol/PortFile.scala | 45 +++++++++++++++++++ .../sbt/internal/protocol/TokenFile.scala | 36 +++++++++++++++ .../protocol/codec/PortFileFormats.scala | 29 ++++++++++++ .../protocol/codec/TokenFileFormats.scala | 29 ++++++++++++ protocol/src/main/contraband/portfile.contra | 16 +++++++ .../sbt-test/server/handshake/Client.scala | 21 ++++++++- sbt/src/sbt-test/server/handshake/build.sbt | 1 + 11 files changed, 219 insertions(+), 11 deletions(-) create mode 100644 protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala create mode 100644 protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala create mode 100644 protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala create mode 100644 protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala create mode 100644 protocol/src/main/contraband/portfile.contra diff --git a/build.sbt b/build.sbt index 7e4e6d751..a9e37dd37 100644 --- a/build.sbt +++ b/build.sbt @@ -290,6 +290,10 @@ lazy val commandProj = (project in file("main-command")) sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", contrabandFormatsForType in generateContrabands in Compile := ContrabandConfig.getFormats, mimaSettings, + mimaBinaryIssueFilters ++= Vector( + // Changed the signature of Server method. nacho cheese. + exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*") + ) ) .configure( addSbtIO, diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 7b764225f..45ea5ad67 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -5,12 +5,17 @@ package sbt package internal package server +import java.io.File import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } import java.util.concurrent.atomic.AtomicBoolean -import sbt.util.Logger -import sbt.internal.util.ErrorHandling import scala.concurrent.{ Future, Promise } import scala.util.{ Try, Success, Failure } +import sbt.internal.util.ErrorHandling +import sbt.internal.protocol.PortFile +import sbt.util.Logger +import sbt.io.IO +import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } +import sbt.internal.protocol.codec._ private[sbt] sealed trait ServerInstance { def shutdown(): Unit @@ -18,14 +23,19 @@ private[sbt] sealed trait ServerInstance { } private[sbt] object Server { + sealed trait JsonProtocol + extends sjsonnew.BasicJsonProtocol + with PortFileFormats + with TokenFileFormats + object JsonProtocol extends JsonProtocol + def start(host: String, port: Int, onIncomingSocket: Socket => Unit, - /*onIncomingCommand: CommandMessage => Unit,*/ log: Logger): ServerInstance = + portfile: File, + tokenfile: File, + log: Logger): ServerInstance = new ServerInstance { - - // val lock = new AnyRef {} - // val clients: mutable.ListBuffer[ClientConnection] = mutable.ListBuffer.empty val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future @@ -41,6 +51,7 @@ private[sbt] object Server { case Success(serverSocket) => serverSocket.setSoTimeout(5000) log.info(s"sbt server started at $host:$port") + writePortfile() running.set(true) p.success(()) while (running.get()) { @@ -58,8 +69,21 @@ private[sbt] object Server { override def shutdown(): Unit = { log.info("shutting down server") + if (portfile.exists) { + IO.delete(portfile) + } + if (tokenfile.exists) { + IO.delete(tokenfile) + } running.set(false) } - } + // This file exists through the lifetime of the server. + def writePortfile(): Unit = { + import JsonProtocol._ + val p = PortFile(s"tcp://$host:$port", None) + val json = Converter.toJson(p).get + IO.write(portfile, CompactPrinter(json)) + } + } } diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index f4043f989..47a67a0fd 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -101,7 +101,9 @@ object StandardMain { val previous = TrapExit.installManager() try { try { - MainLoop.runLogged(s) + try { + MainLoop.runLogged(s) + } finally exchange.shutdown } finally DefaultBackgroundJobService.backgroundJobService.shutdown() } finally TrapExit.uninstallManager(previous) } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index d56edb245..9cf8f6c62 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -15,6 +15,8 @@ import sjsonnew.JsonFormat import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.util.{ Success, Failure } +import sbt.io.syntax._ +import sbt.io.Hash /** * The command exchange merges multiple command channels (e.g. network and console), @@ -87,7 +89,10 @@ private[sbt] final class CommandExchange { server match { case Some(x) => // do nothing case _ => - val x = Server.start("127.0.0.1", port, onIncomingSocket, s.log) + val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" + val h = Hash.halfHashString(portfile.toURL.toString) + val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" + val x = Server.start("127.0.0.1", port, onIncomingSocket, portfile, tokenfile, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala new file mode 100644 index 000000000..542ab368f --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala @@ -0,0 +1,45 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol +/** + * This file should exist throughout the lifetime of the server. + * It can be used to find out the transport protocol (port number etc). + */ +final class PortFile private ( + /** URL of the sbt server. */ + val url: String, + val tokenfile: Option[String]) extends Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: PortFile => (this.url == x.url) && (this.tokenfile == x.tokenfile) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + url.##) + tokenfile.##) + } + override def toString: String = { + "PortFile(" + url + ", " + tokenfile + ")" + } + protected[this] def copy(url: String = url, tokenfile: Option[String] = tokenfile): PortFile = { + new PortFile(url, tokenfile) + } + def withUrl(url: String): PortFile = { + copy(url = url) + } + def withTokenfile(tokenfile: Option[String]): PortFile = { + copy(tokenfile = tokenfile) + } + def withTokenfile(tokenfile: String): PortFile = { + copy(tokenfile = Option(tokenfile)) + } +} +object PortFile { + + def apply(url: String, tokenfile: Option[String]): PortFile = new PortFile(url, tokenfile) + def apply(url: String, tokenfile: String): PortFile = new PortFile(url, Option(tokenfile)) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala new file mode 100644 index 000000000..12f1ac21c --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala @@ -0,0 +1,36 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol +final class TokenFile private ( + val url: String, + val token: String) extends Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: TokenFile => (this.url == x.url) && (this.token == x.token) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.internal.protocol.TokenFile".##) + url.##) + token.##) + } + override def toString: String = { + "TokenFile(" + url + ", " + token + ")" + } + protected[this] def copy(url: String = url, token: String = token): TokenFile = { + new TokenFile(url, token) + } + def withUrl(url: String): TokenFile = { + copy(url = url) + } + def withToken(token: String): TokenFile = { + copy(token = token) + } +} +object TokenFile { + + def apply(url: String, token: String): TokenFile = new TokenFile(url, token) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala new file mode 100644 index 000000000..daac10706 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait PortFileFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val PortFileFormat: JsonFormat[sbt.internal.protocol.PortFile] = new JsonFormat[sbt.internal.protocol.PortFile] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.protocol.PortFile = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val url = unbuilder.readField[String]("url") + val tokenfile = unbuilder.readField[Option[String]]("tokenfile") + unbuilder.endObject() + sbt.internal.protocol.PortFile(url, tokenfile) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.protocol.PortFile, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("url", obj.url) + builder.addField("tokenfile", obj.tokenfile) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala new file mode 100644 index 000000000..d812593b2 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TokenFileFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val TokenFileFormat: JsonFormat[sbt.internal.protocol.TokenFile] = new JsonFormat[sbt.internal.protocol.TokenFile] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.protocol.TokenFile = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val url = unbuilder.readField[String]("url") + val token = unbuilder.readField[String]("token") + unbuilder.endObject() + sbt.internal.protocol.TokenFile(url, token) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.protocol.TokenFile, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("url", obj.url) + builder.addField("token", obj.token) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband/portfile.contra b/protocol/src/main/contraband/portfile.contra new file mode 100644 index 000000000..f10d9c170 --- /dev/null +++ b/protocol/src/main/contraband/portfile.contra @@ -0,0 +1,16 @@ +package sbt.internal.protocol +@target(Scala) +@codecPackage("sbt.internal.protocol.codec") + +## This file should exist throughout the lifetime of the server. +## It can be used to find out the transport protocol (port number etc). +type PortFile { + ## URL of the sbt server. + url: String! + tokenfile: String +} + +type TokenFile { + url: String! + token: String! +} diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala index ee0255abf..54b7097f3 100644 --- a/sbt/src/sbt-test/server/handshake/Client.scala +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -4,10 +4,11 @@ import java.net.{ URI, Socket, InetAddress, SocketException } import sbt.io._ import sbt.io.syntax._ import java.io.File +import sjsonnew.support.scalajson.unsafe.{ Parser, Converter, CompactPrinter } +import sjsonnew.shaded.scalajson.ast.unsafe.{ JValue, JObject, JString } object Client extends App { val host = "127.0.0.1" - val port = 5123 val delimiter: Byte = '\n'.toByte println("hello") @@ -24,9 +25,25 @@ object Client extends App { val baseDirectory = new File(args(0)) IO.write(baseDirectory / "ok.txt", "ok") + def getPort: Int = { + val portfile = baseDirectory / "project" / "target" / "active.json" + val json: JValue = Parser.parseFromFile(portfile).get + json match { + case JObject(fields) => + (fields find { _.field == "url" } map { _.value }) match { + case Some(JString(value)) => + val u = new URI(value) + u.getPort + case _ => + sys.error("json doesn't url field that is JString") + } + case _ => sys.error("json doesn't have url field") + } + } + def getConnection: Socket = try { - new Socket(InetAddress.getByName(host), port) + new Socket(InetAddress.getByName(host), getPort) } catch { case _ => Thread.sleep(1000) diff --git a/sbt/src/sbt-test/server/handshake/build.sbt b/sbt/src/sbt-test/server/handshake/build.sbt index 9692ab267..fd924f0e4 100644 --- a/sbt/src/sbt-test/server/handshake/build.sbt +++ b/sbt/src/sbt-test/server/handshake/build.sbt @@ -5,6 +5,7 @@ lazy val root = (project in file(".")) scalaVersion := "2.12.3", serverPort in Global := 5123, libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1", + libraryDependencies += "com.eed3si9n" %% "sjson-new-scalajson" % "0.8.0", runClient := (Def.taskDyn { val b = baseDirectory.value (bgRun in Compile).toTask(s""" $b""") From c5bfc6775085594de58d5ff48fc33e14fd81e4bf Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Sun, 17 Sep 2017 22:31:57 -0400 Subject: [PATCH 3/7] Fixes test --- main/src/test/scala/sbt/internal/server/SettingQueryTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala b/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala index f254cad6e..3ef94c83b 100644 --- a/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala +++ b/main/src/test/scala/sbt/internal/server/SettingQueryTest.scala @@ -172,7 +172,7 @@ object SettingQueryTest extends org.specs2.mutable.Specification { def query(setting: String): String = { import sbt.protocol._ - val req: SettingQuery = protocol.SettingQuery(setting) + val req: SettingQuery = sbt.protocol.SettingQuery(setting) val rsp: SettingQueryResponse = server.SettingQuery.handleSettingQuery(req, structure) val bytes: Array[Byte] = Serialization serializeEventMessage rsp val payload: String = new String(bytes, java.nio.charset.StandardCharsets.UTF_8) From 8a8215cf1b25970d701a1c331533d73b8a63b977 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 18 Sep 2017 23:07:29 -0400 Subject: [PATCH 4/7] Use uri instead of url --- .../scala/sbt/internal/CommandExchange.scala | 2 +- .../sbt/internal/protocol/PortFile.scala | 22 +++++++++---------- .../sbt/internal/protocol/TokenFile.scala | 18 +++++++-------- .../protocol/codec/PortFileFormats.scala | 6 ++--- .../protocol/codec/TokenFileFormats.scala | 6 ++--- protocol/src/main/contraband/portfile.contra | 6 ++--- .../sbt-test/server/handshake/Client.scala | 6 ++--- 7 files changed, 33 insertions(+), 33 deletions(-) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 9cf8f6c62..0902665bb 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -90,7 +90,7 @@ private[sbt] final class CommandExchange { case Some(x) => // do nothing case _ => val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" - val h = Hash.halfHashString(portfile.toURL.toString) + val h = Hash.halfHashString(portfile.toURI.toString) val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" val x = Server.start("127.0.0.1", port, onIncomingSocket, portfile, tokenfile, s.log) Await.ready(x.ready, Duration("10s")) diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala index 542ab368f..648ab25b5 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala @@ -9,27 +9,27 @@ package sbt.internal.protocol * It can be used to find out the transport protocol (port number etc). */ final class PortFile private ( - /** URL of the sbt server. */ - val url: String, + /** URI of the sbt server. */ + val uri: String, val tokenfile: Option[String]) extends Serializable { override def equals(o: Any): Boolean = o match { - case x: PortFile => (this.url == x.url) && (this.tokenfile == x.tokenfile) + case x: PortFile => (this.uri == x.uri) && (this.tokenfile == x.tokenfile) case _ => false } override def hashCode: Int = { - 37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + url.##) + tokenfile.##) + 37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + uri.##) + tokenfile.##) } override def toString: String = { - "PortFile(" + url + ", " + tokenfile + ")" + "PortFile(" + uri + ", " + tokenfile + ")" } - protected[this] def copy(url: String = url, tokenfile: Option[String] = tokenfile): PortFile = { - new PortFile(url, tokenfile) + protected[this] def copy(uri: String = uri, tokenfile: Option[String] = tokenfile): PortFile = { + new PortFile(uri, tokenfile) } - def withUrl(url: String): PortFile = { - copy(url = url) + def withUri(uri: String): PortFile = { + copy(uri = uri) } def withTokenfile(tokenfile: Option[String]): PortFile = { copy(tokenfile = tokenfile) @@ -40,6 +40,6 @@ final class PortFile private ( } object PortFile { - def apply(url: String, tokenfile: Option[String]): PortFile = new PortFile(url, tokenfile) - def apply(url: String, tokenfile: String): PortFile = new PortFile(url, Option(tokenfile)) + def apply(uri: String, tokenfile: Option[String]): PortFile = new PortFile(uri, tokenfile) + def apply(uri: String, tokenfile: String): PortFile = new PortFile(uri, Option(tokenfile)) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala index 12f1ac21c..e2019147e 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/TokenFile.scala @@ -5,26 +5,26 @@ // DO NOT EDIT MANUALLY package sbt.internal.protocol final class TokenFile private ( - val url: String, + val uri: String, val token: String) extends Serializable { override def equals(o: Any): Boolean = o match { - case x: TokenFile => (this.url == x.url) && (this.token == x.token) + case x: TokenFile => (this.uri == x.uri) && (this.token == x.token) case _ => false } override def hashCode: Int = { - 37 * (37 * (37 * (17 + "sbt.internal.protocol.TokenFile".##) + url.##) + token.##) + 37 * (37 * (37 * (17 + "sbt.internal.protocol.TokenFile".##) + uri.##) + token.##) } override def toString: String = { - "TokenFile(" + url + ", " + token + ")" + "TokenFile(" + uri + ", " + token + ")" } - protected[this] def copy(url: String = url, token: String = token): TokenFile = { - new TokenFile(url, token) + protected[this] def copy(uri: String = uri, token: String = token): TokenFile = { + new TokenFile(uri, token) } - def withUrl(url: String): TokenFile = { - copy(url = url) + def withUri(uri: String): TokenFile = { + copy(uri = uri) } def withToken(token: String): TokenFile = { copy(token = token) @@ -32,5 +32,5 @@ final class TokenFile private ( } object TokenFile { - def apply(url: String, token: String): TokenFile = new TokenFile(url, token) + def apply(uri: String, token: String): TokenFile = new TokenFile(uri, token) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala index daac10706..8bed2e6f9 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala @@ -11,17 +11,17 @@ implicit lazy val PortFileFormat: JsonFormat[sbt.internal.protocol.PortFile] = n jsOpt match { case Some(js) => unbuilder.beginObject(js) - val url = unbuilder.readField[String]("url") + val uri = unbuilder.readField[String]("uri") val tokenfile = unbuilder.readField[Option[String]]("tokenfile") unbuilder.endObject() - sbt.internal.protocol.PortFile(url, tokenfile) + sbt.internal.protocol.PortFile(uri, tokenfile) case None => deserializationError("Expected JsObject but found None") } } override def write[J](obj: sbt.internal.protocol.PortFile, builder: Builder[J]): Unit = { builder.beginObject() - builder.addField("url", obj.url) + builder.addField("uri", obj.uri) builder.addField("tokenfile", obj.tokenfile) builder.endObject() } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala index d812593b2..74a15423f 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/TokenFileFormats.scala @@ -11,17 +11,17 @@ implicit lazy val TokenFileFormat: JsonFormat[sbt.internal.protocol.TokenFile] = jsOpt match { case Some(js) => unbuilder.beginObject(js) - val url = unbuilder.readField[String]("url") + val uri = unbuilder.readField[String]("uri") val token = unbuilder.readField[String]("token") unbuilder.endObject() - sbt.internal.protocol.TokenFile(url, token) + sbt.internal.protocol.TokenFile(uri, token) case None => deserializationError("Expected JsObject but found None") } } override def write[J](obj: sbt.internal.protocol.TokenFile, builder: Builder[J]): Unit = { builder.beginObject() - builder.addField("url", obj.url) + builder.addField("uri", obj.uri) builder.addField("token", obj.token) builder.endObject() } diff --git a/protocol/src/main/contraband/portfile.contra b/protocol/src/main/contraband/portfile.contra index f10d9c170..ccd3ec157 100644 --- a/protocol/src/main/contraband/portfile.contra +++ b/protocol/src/main/contraband/portfile.contra @@ -5,12 +5,12 @@ package sbt.internal.protocol ## This file should exist throughout the lifetime of the server. ## It can be used to find out the transport protocol (port number etc). type PortFile { - ## URL of the sbt server. - url: String! + ## URI of the sbt server. + uri: String! tokenfile: String } type TokenFile { - url: String! + uri: String! token: String! } diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala index 54b7097f3..4be0d8aad 100644 --- a/sbt/src/sbt-test/server/handshake/Client.scala +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -30,14 +30,14 @@ object Client extends App { val json: JValue = Parser.parseFromFile(portfile).get json match { case JObject(fields) => - (fields find { _.field == "url" } map { _.value }) match { + (fields find { _.field == "uri" } map { _.value }) match { case Some(JString(value)) => val u = new URI(value) u.getPort case _ => - sys.error("json doesn't url field that is JString") + sys.error("json doesn't uri field that is JString") } - case _ => sys.error("json doesn't have url field") + case _ => sys.error("json doesn't have uri field") } } From 348a07779715ba9cbd9da80c8dffa3f0b8eb52cd Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Thu, 21 Sep 2017 23:05:48 -0400 Subject: [PATCH 5/7] implement tokenfile authentication --- build.sbt | 12 +++- .../scala/sbt/internal/util/Attributes.scala | 11 +++ .../ServerAuthenticationFormats.scala | 26 +++++++ .../sbt/ServerAuthentication.scala | 12 ++++ main-command/src/main/contraband/state.contra | 4 ++ .../src/main/scala/sbt/BasicKeys.scala | 9 +++ .../scala/sbt/internal/server/Server.scala | 69 +++++++++++++++++-- main/src/main/scala/sbt/Defaults.scala | 4 +- main/src/main/scala/sbt/Keys.scala | 2 + main/src/main/scala/sbt/Project.scala | 23 +++++-- .../scala/sbt/internal/CommandExchange.scala | 21 ++++-- .../sbt/internal/server/NetworkChannel.scala | 46 +++++++++++-- project/Dependencies.scala | 2 +- .../sbt/protocol/InitCommand.scala | 43 ++++++++++++ .../codec/CommandMessageFormats.scala | 4 +- .../protocol/codec/InitCommandFormats.scala | 29 ++++++++ .../sbt/protocol/codec/JsonProtocol.scala | 1 + protocol/src/main/contraband/server.contra | 5 ++ .../sbt-test/server/handshake/Client.scala | 32 +++++++++ 19 files changed, 322 insertions(+), 33 deletions(-) create mode 100644 main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala create mode 100644 main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala create mode 100644 protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala diff --git a/build.sbt b/build.sbt index a9e37dd37..6f2185958 100644 --- a/build.sbt +++ b/build.sbt @@ -132,6 +132,10 @@ val collectionProj = (project in file("internal") / "util-collection") name := "Collections", libraryDependencies ++= Seq(sjsonNewScalaJson.value), mimaSettings, + mimaBinaryIssueFilters ++= Seq( + // Added private[sbt] method to capture State attributes. + exclude[ReversedMissingMethodProblem]("sbt.internal.util.AttributeMap.setCond"), + ), ) .configure(addSbtUtilPosition) @@ -292,7 +296,9 @@ lazy val commandProj = (project in file("main-command")) mimaSettings, mimaBinaryIssueFilters ++= Vector( // Changed the signature of Server method. nacho cheese. - exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*") + exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*"), + // Added method to ServerInstance. This is also internal. + exclude[ReversedMissingMethodProblem]("sbt.internal.server.ServerInstance.*"), ) ) .configure( @@ -365,6 +371,10 @@ lazy val mainProj = (project in file("main")) baseDirectory.value / "src" / "main" / "contraband-scala", sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", mimaSettings, + mimaBinaryIssueFilters ++= Vector( + // Changed the signature of NetworkChannel ctor. internal. + exclude[DirectMissingMethodProblem]("sbt.internal.server.NetworkChannel.*"), + ) ) .configure( addSbtIO, diff --git a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala index d70b3df8a..7c24cfd29 100644 --- a/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala +++ b/internal/util-collection/src/main/scala/sbt/internal/util/Attributes.scala @@ -168,6 +168,11 @@ trait AttributeMap { /** `true` if there are no mappings in this map, `false` if there are. */ def isEmpty: Boolean + /** + * Adds the mapping `k -> opt.get` if opt is Some. + * Otherwise, it returns this map without the mapping for `k`. + */ + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap } object AttributeMap { @@ -217,6 +222,12 @@ private class BasicAttributeMap(private val backing: Map[AttributeKey[_], Any]) def entries: Iterable[AttributeEntry[_]] = for ((k: AttributeKey[kt], v) <- backing) yield AttributeEntry(k, v.asInstanceOf[kt]) + private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap = + opt match { + case Some(v) => put(k, v) + case None => remove(k) + } + override def toString = entries.mkString("(", ", ", ")") } diff --git a/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala new file mode 100644 index 000000000..38bc9f9e1 --- /dev/null +++ b/main-command/src/main/contraband-scala/ServerAuthenticationFormats.scala @@ -0,0 +1,26 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait ServerAuthenticationFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val ServerAuthenticationFormat: JsonFormat[sbt.ServerAuthentication] = new JsonFormat[sbt.ServerAuthentication] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ServerAuthentication = { + jsOpt match { + case Some(js) => + unbuilder.readString(js) match { + case "Token" => sbt.ServerAuthentication.Token + } + case None => + deserializationError("Expected JsString but found None") + } + } + override def write[J](obj: sbt.ServerAuthentication, builder: Builder[J]): Unit = { + val str = obj match { + case sbt.ServerAuthentication.Token => "Token" + } + builder.writeString(str) + } +} +} diff --git a/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala new file mode 100644 index 000000000..b6f074b75 --- /dev/null +++ b/main-command/src/main/contraband-scala/sbt/ServerAuthentication.scala @@ -0,0 +1,12 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt +sealed abstract class ServerAuthentication extends Serializable +object ServerAuthentication { + + + case object Token extends ServerAuthentication +} diff --git a/main-command/src/main/contraband/state.contra b/main-command/src/main/contraband/state.contra index 929523505..79d0bcaab 100644 --- a/main-command/src/main/contraband/state.contra +++ b/main-command/src/main/contraband/state.contra @@ -12,3 +12,7 @@ type Exec { type CommandSource { channelName: String! } + +enum ServerAuthentication { + Token +} diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index 10f1215b1..f59f3f605 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -17,6 +17,15 @@ object BasicKeys { val watch = AttributeKey[Watched]("watch", "Continuous execution configuration.", 1000) val serverPort = AttributeKey[Int]("server-port", "The port number used by server command.", 10000) + + val serverHost = + AttributeKey[String]("serverHost", "The host used by server command.", 10000) + + val serverAuthentication = + AttributeKey[Set[ServerAuthentication]]("serverAuthentication", + "Method of authenticating server command.", + 10000) + private[sbt] val interactive = AttributeKey[Boolean]( "interactive", "True if commands are currently being entered from an interactive environment.", diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 45ea5ad67..fce9d591a 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -7,19 +7,22 @@ package server import java.io.File import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicLong } +import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } import scala.concurrent.{ Future, Promise } -import scala.util.{ Try, Success, Failure } +import scala.util.{ Try, Success, Failure, Random } import sbt.internal.util.ErrorHandling -import sbt.internal.protocol.PortFile +import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.util.Logger import sbt.io.IO +import sbt.io.syntax._ import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } import sbt.internal.protocol.codec._ private[sbt] sealed trait ServerInstance { def shutdown(): Unit def ready: Future[Unit] + def authenticate(challenge: String): Boolean } private[sbt] object Server { @@ -31,14 +34,16 @@ private[sbt] object Server { def start(host: String, port: Int, - onIncomingSocket: Socket => Unit, + onIncomingSocket: (Socket, ServerInstance) => Unit, + auth: Set[ServerAuthentication], portfile: File, tokenfile: File, log: Logger): ServerInstance = - new ServerInstance { + new ServerInstance { self => val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future + val token = new AtomicLong(Random.nextLong) val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { @@ -57,7 +62,7 @@ private[sbt] object Server { while (running.get()) { try { val socket = serverSocket.accept() - onIncomingSocket(socket) + onIncomingSocket(socket, self) } catch { case _: SocketTimeoutException => // its ok } @@ -67,6 +72,15 @@ private[sbt] object Server { } serverThread.start() + override def authenticate(challenge: String): Boolean = { + try { + val l = challenge.toLong + token.compareAndSet(l, Random.nextLong) + } catch { + case _: NumberFormatException => false + } + } + override def shutdown(): Unit = { log.info("shutting down server") if (portfile.exists) { @@ -78,10 +92,51 @@ private[sbt] object Server { running.set(false) } + def writeTokenfile(): Unit = { + import JsonProtocol._ + + val uri = s"tcp://$host:$port" + val t = TokenFile(uri, token.get.toString) + val jsonToken = Converter.toJson(t).get + + if (tokenfile.exists) { + IO.delete(tokenfile) + } + IO.touch(tokenfile) + ownerOnly(tokenfile) + IO.write(tokenfile, CompactPrinter(jsonToken), IO.utf8, true) + } + + /** Set the persmission of the file such that the only the owner can read/write it. */ + def ownerOnly(file: File): Unit = { + def acl(owner: UserPrincipal) = { + val builder = AclEntry.newBuilder + builder.setPrincipal(owner) + builder.setPermissions(AclEntryPermission.values(): _*) + builder.setType(AclEntryType.ALLOW) + builder.build + } + file match { + case _ if IO.isPosix => + IO.chmod("rw-------", file) + case _ if IO.hasAclFileAttributeView => + val view = file.aclFileAttributeView + view.setAcl(java.util.Collections.singletonList(acl(view.getOwner))) + case _ => () + } + } + // This file exists through the lifetime of the server. def writePortfile(): Unit = { import JsonProtocol._ - val p = PortFile(s"tcp://$host:$port", None) + + val uri = s"tcp://$host:$port" + val tokenRef = + if (auth(ServerAuthentication.Token)) { + writeTokenfile() + Some(tokenfile.toURI.toString) + } else None + val p = PortFile(uri, tokenRef) val json = Converter.toJson(p).get IO.write(portfile, CompactPrinter(json)) } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 1936641bf..638ccbe21 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -264,9 +264,11 @@ object Defaults extends BuildCommon { .getOrElse(GCUtil.defaultForceGarbageCollection), minForcegcInterval :== GCUtil.defaultMinForcegcInterval, interactionService :== CommandLineUIService, + serverHost := "127.0.0.1", serverPort := 5000 + (Hash .toHex(Hash(appConfiguration.value.baseDirectory.toString)) - .## % 1000) + .## % 1000), + serverAuthentication := Set(ServerAuthentication.Token), )) def defaultTestTasks(key: Scoped): Seq[Setting[_]] = diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 07d6d4d92..66c06b82e 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -127,6 +127,8 @@ object Keys { val historyPath = SettingKey(BasicKeys.historyPath) val shellPrompt = SettingKey(BasicKeys.shellPrompt) val serverPort = SettingKey(BasicKeys.serverPort) + val serverHost = SettingKey(BasicKeys.serverHost) + val serverAuthentication = SettingKey(BasicKeys.serverAuthentication) val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting) val watch = SettingKey(BasicKeys.watch) val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting) diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index b1aa71d06..6ad6e9d45 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -16,7 +16,9 @@ import Keys.{ sessionSettings, shellPrompt, templateResolverInfos, + serverHost, serverPort, + serverAuthentication, watch } import Scope.{ Global, ThisScope } @@ -509,23 +511,30 @@ object Project extends ProjectExtra { val prompt = get(shellPrompt) val trs = (templateResolverInfos in Global get structure.data).toList.flatten val watched = get(watch) + val host: Option[String] = get(serverHost) val port: Option[Int] = get(serverPort) + val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication) val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true)) val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands, projectCommand) - val newAttrs0 = - setCond(Watched.Configuration, watched, s.attributes).put(historyPath.key, history) - val newAttrs = setCond(serverPort.key, port, newAttrs0) - .put(historyPath.key, history) - .put(templateResolverInfos.key, trs) + val newAttrs = + s.attributes + .setCond(Watched.Configuration, watched) + .put(historyPath.key, history) + .setCond(serverPort.key, port) + .setCond(serverHost.key, host) + .setCond(serverAuthentication.key, authentication) + .put(historyPath.key, history) + .put(templateResolverInfos.key, trs) + .setCond(shellPrompt.key, prompt) s.copy( - attributes = setCond(shellPrompt.key, prompt, newAttrs), + attributes = newAttrs, definedCommands = newDefinedCommands ) } def setCond[T](key: AttributeKey[T], vopt: Option[T], attributes: AttributeMap): AttributeMap = - vopt match { case Some(v) => attributes.put(key, v); case None => attributes.remove(key) } + attributes.setCond(key, vopt) private[sbt] def checkTargets(data: Settings[Scope]): Option[String] = { val dups = overlappingTargets(allTargets(data)) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 0902665bb..574dfd8ce 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -6,10 +6,10 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import sbt.internal.server._ import sbt.internal.util.StringEvent -import sbt.protocol.{ EventMessage, Serialization, ChannelAcceptedEvent } +import sbt.protocol.{ EventMessage, Serialization } import scala.collection.mutable.ListBuffer import scala.annotation.tailrec -import BasicKeys.serverPort +import BasicKeys.{ serverHost, serverPort, serverAuthentication } import java.net.Socket import sjsonnew.JsonFormat import scala.concurrent.Await @@ -76,15 +76,22 @@ private[sbt] final class CommandExchange { * Check if a server instance is running already, and start one if it isn't. */ private[sbt] def runServer(s: State): State = { - def port = (s get serverPort) match { + lazy val port = (s get serverPort) match { case Some(x) => x case None => 5001 } - def onIncomingSocket(socket: Socket): Unit = { + lazy val host = (s get serverHost) match { + case Some(x) => x + case None => "127.0.0.1" + } + lazy val auth: Set[ServerAuthentication] = (s get serverAuthentication) match { + case Some(xs) => xs + case None => Set(ServerAuthentication.Token) + } + def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { s.log.info(s"new client connected from: ${socket.getPort}") - val channel = new NetworkChannel(newChannelName, socket, Project structure s) + val channel = new NetworkChannel(newChannelName, socket, Project structure s, auth, instance) subscribe(channel) - channel.publishEventMessage(ChannelAcceptedEvent(channel.name)) } server match { case Some(x) => // do nothing @@ -92,7 +99,7 @@ private[sbt] final class CommandExchange { val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" val h = Hash.halfHashString(portfile.toURI.toString) val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" - val x = Server.start("127.0.0.1", port, onIncomingSocket, portfile, tokenfile, s.log) + val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 22d51c5ce..eeaa8ed7c 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -10,11 +10,16 @@ import java.util.concurrent.atomic.AtomicBoolean import sbt.protocol._ import sjsonnew._ -final class NetworkChannel(val name: String, connection: Socket, structure: BuildStructure) +final class NetworkChannel(val name: String, + connection: Socket, + structure: BuildStructure, + auth: Set[ServerAuthentication], + instance: ServerInstance) extends CommandChannel { private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte private val out = connection.getOutputStream + private var initialized = false val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { override def run(): Unit = { @@ -42,12 +47,10 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil ) delimPos = buffer.indexOf(delimiter) } - } catch { case _: SocketTimeoutException => // its ok } } - } finally { shutdown() } @@ -72,15 +75,44 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil } def onCommand(command: CommandMessage): Unit = command match { + case x: InitCommand => onInitCommand(x) case x: ExecCommand => onExecCommand(x) case x: SettingQuery => onSettingQuery(x) } - private def onExecCommand(cmd: ExecCommand) = - append(Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + private def onInitCommand(cmd: InitCommand): Unit = { + if (auth(ServerAuthentication.Token)) { + cmd.token match { + case Some(x) => + instance.authenticate(x) match { + case true => + initialized = true + publishEventMessage(ChannelAcceptedEvent(name)) + case _ => sys.error("invalid token") + } + case None => sys.error("init command but without token.") + } + } else { + initialized = true + } + } - private def onSettingQuery(req: SettingQuery) = - StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + private def onExecCommand(cmd: ExecCommand) = { + if (initialized) { + append( + Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name)))) + } else { + println(s"ignoring command $cmd before initialization") + } + } + + private def onSettingQuery(req: SettingQuery) = { + if (initialized) { + StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure) + } else { + println(s"ignoring query $req before initialization") + } + } def shutdown(): Unit = { println("Shutting down client connection") diff --git a/project/Dependencies.scala b/project/Dependencies.scala index b1b31a0b4..8745f1dfb 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -12,7 +12,7 @@ object Dependencies { val baseScalaVersion = scala212 // sbt modules - private val ioVersion = "1.0.1" + private val ioVersion = "1.1.0" private val utilVersion = "1.0.1" private val lmVersion = "1.0.2" private val zincVersion = "1.0.1" diff --git a/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala new file mode 100644 index 000000000..e45b25c84 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/InitCommand.scala @@ -0,0 +1,43 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol +final class InitCommand private ( + val token: Option[String], + val execId: Option[String]) extends sbt.protocol.CommandMessage() with Serializable { + + + + override def equals(o: Any): Boolean = o match { + case x: InitCommand => (this.token == x.token) && (this.execId == x.execId) + case _ => false + } + override def hashCode: Int = { + 37 * (37 * (37 * (17 + "sbt.protocol.InitCommand".##) + token.##) + execId.##) + } + override def toString: String = { + "InitCommand(" + token + ", " + execId + ")" + } + protected[this] def copy(token: Option[String] = token, execId: Option[String] = execId): InitCommand = { + new InitCommand(token, execId) + } + def withToken(token: Option[String]): InitCommand = { + copy(token = token) + } + def withToken(token: String): InitCommand = { + copy(token = Option(token)) + } + def withExecId(execId: Option[String]): InitCommand = { + copy(execId = execId) + } + def withExecId(execId: String): InitCommand = { + copy(execId = Option(execId)) + } +} +object InitCommand { + + def apply(token: Option[String], execId: Option[String]): InitCommand = new InitCommand(token, execId) + def apply(token: String, execId: String): InitCommand = new InitCommand(Option(token), Option(execId)) +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala index 650bf14dd..c80c7aaed 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/CommandMessageFormats.scala @@ -6,6 +6,6 @@ package sbt.protocol.codec import _root_.sjsonnew.JsonFormat -trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => -implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat2[sbt.protocol.CommandMessage, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") +trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats => +implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat3[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type") } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala new file mode 100644 index 000000000..8d3f50759 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/InitCommandFormats.scala @@ -0,0 +1,29 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.protocol.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait InitCommandFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val InitCommandFormat: JsonFormat[sbt.protocol.InitCommand] = new JsonFormat[sbt.protocol.InitCommand] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.InitCommand = { + jsOpt match { + case Some(js) => + unbuilder.beginObject(js) + val token = unbuilder.readField[Option[String]]("token") + val execId = unbuilder.readField[Option[String]]("execId") + unbuilder.endObject() + sbt.protocol.InitCommand(token, execId) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.protocol.InitCommand, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("token", obj.token) + builder.addField("execId", obj.execId) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala index 3b92fc46c..cc9d0fa90 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/JsonProtocol.scala @@ -5,6 +5,7 @@ // DO NOT EDIT MANUALLY package sbt.protocol.codec trait JsonProtocol extends sjsonnew.BasicJsonProtocol + with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats with sbt.protocol.codec.CommandMessageFormats diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index 31160a694..2976be229 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -7,6 +7,11 @@ package sbt.protocol interface CommandMessage { } +type InitCommand implements CommandMessage { + token: String + execId: String +} + ## Command to execute sbt command. type ExecCommand implements CommandMessage { commandLine: String! diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala index 4be0d8aad..6858ff238 100644 --- a/sbt/src/sbt-test/server/handshake/Client.scala +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -18,6 +18,10 @@ object Client extends App { val out = connection.getOutputStream val in = connection.getInputStream + out.write(s"""{ "type": "InitCommand", "token": "$getToken" }""".getBytes("utf-8")) + out.write(delimiter.toInt) + out.flush + out.write("""{ "type": "ExecCommand", "commandLine": "exit" }""".getBytes("utf-8")) out.write(delimiter.toInt) out.flush @@ -25,6 +29,34 @@ object Client extends App { val baseDirectory = new File(args(0)) IO.write(baseDirectory / "ok.txt", "ok") + def getToken: String = { + val tokenfile = new File(getTokenFile) + val json: JValue = Parser.parseFromFile(tokenfile).get + json match { + case JObject(fields) => + (fields find { _.field == "token" } map { _.value }) match { + case Some(JString(value)) => value + case _ => + sys.error("json doesn't token field that is JString") + } + case _ => sys.error("json doesn't have token field") + } + } + + def getTokenFile: URI = { + val portfile = baseDirectory / "project" / "target" / "active.json" + val json: JValue = Parser.parseFromFile(portfile).get + json match { + case JObject(fields) => + (fields find { _.field == "tokenfile" } map { _.value }) match { + case Some(JString(value)) => new URI(value) + case _ => + sys.error("json doesn't tokenfile field that is JString") + } + case _ => sys.error("json doesn't have tokenfile field") + } + } + def getPort: Int = { val portfile = baseDirectory / "project" / "target" / "active.json" val json: JValue = Parser.parseFromFile(portfile).get From 252e803de8c89a437f67b7e1120f5f24c511ee61 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Fri, 22 Sep 2017 01:30:27 -0400 Subject: [PATCH 6/7] expand the token out to 128-bits --- .../scala/sbt/internal/server/Server.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index fce9d591a..3ae06ae27 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -7,10 +7,12 @@ package server import java.io.File import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } -import java.util.concurrent.atomic.{ AtomicBoolean, AtomicLong } +import java.util.concurrent.atomic.AtomicBoolean import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } +import java.security.SecureRandom +import java.math.BigInteger import scala.concurrent.{ Future, Promise } -import scala.util.{ Try, Success, Failure, Random } +import scala.util.{ Try, Success, Failure } import sbt.internal.util.ErrorHandling import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.util.Logger @@ -43,7 +45,8 @@ private[sbt] object Server { val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future - val token = new AtomicLong(Random.nextLong) + private[this] val rand = new SecureRandom + private[this] var token: String = nextToken val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { @@ -72,13 +75,17 @@ private[sbt] object Server { } serverThread.start() - override def authenticate(challenge: String): Boolean = { - try { - val l = challenge.toLong - token.compareAndSet(l, Random.nextLong) - } catch { - case _: NumberFormatException => false - } + override def authenticate(challenge: String): Boolean = synchronized { + if (token == challenge) { + token = nextToken + writeTokenfile() + true + } else false + } + + /** Generates 128-bit non-negative integer, and represent it as decimal string. */ + private[this] def nextToken: String = { + new BigInteger(128, rand).toString } override def shutdown(): Unit = { @@ -92,11 +99,11 @@ private[sbt] object Server { running.set(false) } - def writeTokenfile(): Unit = { + private[this] def writeTokenfile(): Unit = { import JsonProtocol._ val uri = s"tcp://$host:$port" - val t = TokenFile(uri, token.get.toString) + val t = TokenFile(uri, token) val jsonToken = Converter.toJson(t).get if (tokenfile.exists) { @@ -108,7 +115,7 @@ private[sbt] object Server { } /** Set the persmission of the file such that the only the owner can read/write it. */ - def ownerOnly(file: File): Unit = { + private[this] def ownerOnly(file: File): Unit = { def acl(owner: UserPrincipal) = { val builder = AclEntry.newBuilder builder.setPrincipal(owner) @@ -127,7 +134,7 @@ private[sbt] object Server { } // This file exists through the lifetime of the server. - def writePortfile(): Unit = { + private[this] def writePortfile(): Unit = { import JsonProtocol._ val uri = s"tcp://$host:$port" From d5e24979bff318a1f9089a533ef44c7e6f3c0806 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 25 Sep 2017 01:35:34 -0400 Subject: [PATCH 7/7] Reference token file using URI and full file path Node didn't seem to like read URI out of the box, and I am not sure if File -> URI -> File conversion is universally accepted. Ref sbt/sbt#3088 --- .../scala/sbt/internal/server/Server.scala | 14 +++++---- .../sbt/internal/protocol/PortFile.scala | 31 ++++++++++++------- .../protocol/codec/PortFileFormats.scala | 8 +++-- protocol/src/main/contraband/portfile.contra | 3 +- .../sbt-test/server/handshake/Client.scala | 6 ++-- 5 files changed, 37 insertions(+), 25 deletions(-) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 3ae06ae27..9c8d3ee95 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -138,12 +138,14 @@ private[sbt] object Server { import JsonProtocol._ val uri = s"tcp://$host:$port" - val tokenRef = - if (auth(ServerAuthentication.Token)) { - writeTokenfile() - Some(tokenfile.toURI.toString) - } else None - val p = PortFile(uri, tokenRef) + val p = + auth match { + case _ if auth(ServerAuthentication.Token) => + writeTokenfile() + PortFile(uri, Option(tokenfile.toString), Option(tokenfile.toURI.toString)) + case _ => + PortFile(uri, None, None) + } val json = Converter.toJson(p).get IO.write(portfile, CompactPrinter(json)) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala index 648ab25b5..218aefdfb 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/PortFile.scala @@ -11,35 +11,42 @@ package sbt.internal.protocol final class PortFile private ( /** URI of the sbt server. */ val uri: String, - val tokenfile: Option[String]) extends Serializable { + val tokenfilePath: Option[String], + val tokenfileUri: Option[String]) extends Serializable { override def equals(o: Any): Boolean = o match { - case x: PortFile => (this.uri == x.uri) && (this.tokenfile == x.tokenfile) + case x: PortFile => (this.uri == x.uri) && (this.tokenfilePath == x.tokenfilePath) && (this.tokenfileUri == x.tokenfileUri) case _ => false } override def hashCode: Int = { - 37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + uri.##) + tokenfile.##) + 37 * (37 * (37 * (37 * (17 + "sbt.internal.protocol.PortFile".##) + uri.##) + tokenfilePath.##) + tokenfileUri.##) } override def toString: String = { - "PortFile(" + uri + ", " + tokenfile + ")" + "PortFile(" + uri + ", " + tokenfilePath + ", " + tokenfileUri + ")" } - protected[this] def copy(uri: String = uri, tokenfile: Option[String] = tokenfile): PortFile = { - new PortFile(uri, tokenfile) + protected[this] def copy(uri: String = uri, tokenfilePath: Option[String] = tokenfilePath, tokenfileUri: Option[String] = tokenfileUri): PortFile = { + new PortFile(uri, tokenfilePath, tokenfileUri) } def withUri(uri: String): PortFile = { copy(uri = uri) } - def withTokenfile(tokenfile: Option[String]): PortFile = { - copy(tokenfile = tokenfile) + def withTokenfilePath(tokenfilePath: Option[String]): PortFile = { + copy(tokenfilePath = tokenfilePath) } - def withTokenfile(tokenfile: String): PortFile = { - copy(tokenfile = Option(tokenfile)) + def withTokenfilePath(tokenfilePath: String): PortFile = { + copy(tokenfilePath = Option(tokenfilePath)) + } + def withTokenfileUri(tokenfileUri: Option[String]): PortFile = { + copy(tokenfileUri = tokenfileUri) + } + def withTokenfileUri(tokenfileUri: String): PortFile = { + copy(tokenfileUri = Option(tokenfileUri)) } } object PortFile { - def apply(uri: String, tokenfile: Option[String]): PortFile = new PortFile(uri, tokenfile) - def apply(uri: String, tokenfile: String): PortFile = new PortFile(uri, Option(tokenfile)) + def apply(uri: String, tokenfilePath: Option[String], tokenfileUri: Option[String]): PortFile = new PortFile(uri, tokenfilePath, tokenfileUri) + def apply(uri: String, tokenfilePath: String, tokenfileUri: String): PortFile = new PortFile(uri, Option(tokenfilePath), Option(tokenfileUri)) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala index 8bed2e6f9..0a60963d2 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/PortFileFormats.scala @@ -12,9 +12,10 @@ implicit lazy val PortFileFormat: JsonFormat[sbt.internal.protocol.PortFile] = n case Some(js) => unbuilder.beginObject(js) val uri = unbuilder.readField[String]("uri") - val tokenfile = unbuilder.readField[Option[String]]("tokenfile") + val tokenfilePath = unbuilder.readField[Option[String]]("tokenfilePath") + val tokenfileUri = unbuilder.readField[Option[String]]("tokenfileUri") unbuilder.endObject() - sbt.internal.protocol.PortFile(uri, tokenfile) + sbt.internal.protocol.PortFile(uri, tokenfilePath, tokenfileUri) case None => deserializationError("Expected JsObject but found None") } @@ -22,7 +23,8 @@ implicit lazy val PortFileFormat: JsonFormat[sbt.internal.protocol.PortFile] = n override def write[J](obj: sbt.internal.protocol.PortFile, builder: Builder[J]): Unit = { builder.beginObject() builder.addField("uri", obj.uri) - builder.addField("tokenfile", obj.tokenfile) + builder.addField("tokenfilePath", obj.tokenfilePath) + builder.addField("tokenfileUri", obj.tokenfileUri) builder.endObject() } } diff --git a/protocol/src/main/contraband/portfile.contra b/protocol/src/main/contraband/portfile.contra index ccd3ec157..82f6567ef 100644 --- a/protocol/src/main/contraband/portfile.contra +++ b/protocol/src/main/contraband/portfile.contra @@ -7,7 +7,8 @@ package sbt.internal.protocol type PortFile { ## URI of the sbt server. uri: String! - tokenfile: String + tokenfilePath: String + tokenfileUri: String } type TokenFile { diff --git a/sbt/src/sbt-test/server/handshake/Client.scala b/sbt/src/sbt-test/server/handshake/Client.scala index 6858ff238..2acbdb9bb 100644 --- a/sbt/src/sbt-test/server/handshake/Client.scala +++ b/sbt/src/sbt-test/server/handshake/Client.scala @@ -30,7 +30,7 @@ object Client extends App { IO.write(baseDirectory / "ok.txt", "ok") def getToken: String = { - val tokenfile = new File(getTokenFile) + val tokenfile = new File(getTokenFileUri) val json: JValue = Parser.parseFromFile(tokenfile).get json match { case JObject(fields) => @@ -43,12 +43,12 @@ object Client extends App { } } - def getTokenFile: URI = { + def getTokenFileUri: URI = { val portfile = baseDirectory / "project" / "target" / "active.json" val json: JValue = Parser.parseFromFile(portfile).get json match { case JObject(fields) => - (fields find { _.field == "tokenfile" } map { _.value }) match { + (fields find { _.field == "tokenfileUri" } map { _.value }) match { case Some(JString(value)) => new URI(value) case _ => sys.error("json doesn't tokenfile field that is JString")