implement tokenfile authentication

This commit is contained in:
Eugene Yokota 2017-09-21 23:05:48 -04:00
parent 8a8215cf1b
commit 348a077797
19 changed files with 322 additions and 33 deletions

View File

@ -132,6 +132,10 @@ val collectionProj = (project in file("internal") / "util-collection")
name := "Collections",
libraryDependencies ++= Seq(sjsonNewScalaJson.value),
mimaSettings,
mimaBinaryIssueFilters ++= Seq(
// Added private[sbt] method to capture State attributes.
exclude[ReversedMissingMethodProblem]("sbt.internal.util.AttributeMap.setCond"),
),
)
.configure(addSbtUtilPosition)
@ -292,7 +296,9 @@ lazy val commandProj = (project in file("main-command"))
mimaSettings,
mimaBinaryIssueFilters ++= Vector(
// Changed the signature of Server method. nacho cheese.
exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*")
exclude[DirectMissingMethodProblem]("sbt.internal.server.Server.*"),
// Added method to ServerInstance. This is also internal.
exclude[ReversedMissingMethodProblem]("sbt.internal.server.ServerInstance.*"),
)
)
.configure(
@ -365,6 +371,10 @@ lazy val mainProj = (project in file("main"))
baseDirectory.value / "src" / "main" / "contraband-scala",
sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala",
mimaSettings,
mimaBinaryIssueFilters ++= Vector(
// Changed the signature of NetworkChannel ctor. internal.
exclude[DirectMissingMethodProblem]("sbt.internal.server.NetworkChannel.*"),
)
)
.configure(
addSbtIO,

View File

@ -168,6 +168,11 @@ trait AttributeMap {
/** `true` if there are no mappings in this map, `false` if there are. */
def isEmpty: Boolean
/**
* Adds the mapping `k -> opt.get` if opt is Some.
* Otherwise, it returns this map without the mapping for `k`.
*/
private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap
}
object AttributeMap {
@ -217,6 +222,12 @@ private class BasicAttributeMap(private val backing: Map[AttributeKey[_], Any])
def entries: Iterable[AttributeEntry[_]] =
for ((k: AttributeKey[kt], v) <- backing) yield AttributeEntry(k, v.asInstanceOf[kt])
private[sbt] def setCond[T](k: AttributeKey[T], opt: Option[T]): AttributeMap =
opt match {
case Some(v) => put(k, v)
case None => remove(k)
}
override def toString = entries.mkString("(", ", ", ")")
}

View File

@ -0,0 +1,26 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait ServerAuthenticationFormats { self: sjsonnew.BasicJsonProtocol =>
implicit lazy val ServerAuthenticationFormat: JsonFormat[sbt.ServerAuthentication] = new JsonFormat[sbt.ServerAuthentication] {
override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ServerAuthentication = {
jsOpt match {
case Some(js) =>
unbuilder.readString(js) match {
case "Token" => sbt.ServerAuthentication.Token
}
case None =>
deserializationError("Expected JsString but found None")
}
}
override def write[J](obj: sbt.ServerAuthentication, builder: Builder[J]): Unit = {
val str = obj match {
case sbt.ServerAuthentication.Token => "Token"
}
builder.writeString(str)
}
}
}

View File

@ -0,0 +1,12 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt
sealed abstract class ServerAuthentication extends Serializable
object ServerAuthentication {
case object Token extends ServerAuthentication
}

View File

@ -12,3 +12,7 @@ type Exec {
type CommandSource {
channelName: String!
}
enum ServerAuthentication {
Token
}

View File

@ -17,6 +17,15 @@ object BasicKeys {
val watch = AttributeKey[Watched]("watch", "Continuous execution configuration.", 1000)
val serverPort =
AttributeKey[Int]("server-port", "The port number used by server command.", 10000)
val serverHost =
AttributeKey[String]("serverHost", "The host used by server command.", 10000)
val serverAuthentication =
AttributeKey[Set[ServerAuthentication]]("serverAuthentication",
"Method of authenticating server command.",
10000)
private[sbt] val interactive = AttributeKey[Boolean](
"interactive",
"True if commands are currently being entered from an interactive environment.",

View File

@ -7,19 +7,22 @@ package server
import java.io.File
import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket }
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicLong }
import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType }
import scala.concurrent.{ Future, Promise }
import scala.util.{ Try, Success, Failure }
import scala.util.{ Try, Success, Failure, Random }
import sbt.internal.util.ErrorHandling
import sbt.internal.protocol.PortFile
import sbt.internal.protocol.{ PortFile, TokenFile }
import sbt.util.Logger
import sbt.io.IO
import sbt.io.syntax._
import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter }
import sbt.internal.protocol.codec._
private[sbt] sealed trait ServerInstance {
def shutdown(): Unit
def ready: Future[Unit]
def authenticate(challenge: String): Boolean
}
private[sbt] object Server {
@ -31,14 +34,16 @@ private[sbt] object Server {
def start(host: String,
port: Int,
onIncomingSocket: Socket => Unit,
onIncomingSocket: (Socket, ServerInstance) => Unit,
auth: Set[ServerAuthentication],
portfile: File,
tokenfile: File,
log: Logger): ServerInstance =
new ServerInstance {
new ServerInstance { self =>
val running = new AtomicBoolean(false)
val p: Promise[Unit] = Promise[Unit]()
val ready: Future[Unit] = p.future
val token = new AtomicLong(Random.nextLong)
val serverThread = new Thread("sbt-socket-server") {
override def run(): Unit = {
@ -57,7 +62,7 @@ private[sbt] object Server {
while (running.get()) {
try {
val socket = serverSocket.accept()
onIncomingSocket(socket)
onIncomingSocket(socket, self)
} catch {
case _: SocketTimeoutException => // its ok
}
@ -67,6 +72,15 @@ private[sbt] object Server {
}
serverThread.start()
override def authenticate(challenge: String): Boolean = {
try {
val l = challenge.toLong
token.compareAndSet(l, Random.nextLong)
} catch {
case _: NumberFormatException => false
}
}
override def shutdown(): Unit = {
log.info("shutting down server")
if (portfile.exists) {
@ -78,10 +92,51 @@ private[sbt] object Server {
running.set(false)
}
def writeTokenfile(): Unit = {
import JsonProtocol._
val uri = s"tcp://$host:$port"
val t = TokenFile(uri, token.get.toString)
val jsonToken = Converter.toJson(t).get
if (tokenfile.exists) {
IO.delete(tokenfile)
}
IO.touch(tokenfile)
ownerOnly(tokenfile)
IO.write(tokenfile, CompactPrinter(jsonToken), IO.utf8, true)
}
/** Set the persmission of the file such that the only the owner can read/write it. */
def ownerOnly(file: File): Unit = {
def acl(owner: UserPrincipal) = {
val builder = AclEntry.newBuilder
builder.setPrincipal(owner)
builder.setPermissions(AclEntryPermission.values(): _*)
builder.setType(AclEntryType.ALLOW)
builder.build
}
file match {
case _ if IO.isPosix =>
IO.chmod("rw-------", file)
case _ if IO.hasAclFileAttributeView =>
val view = file.aclFileAttributeView
view.setAcl(java.util.Collections.singletonList(acl(view.getOwner)))
case _ => ()
}
}
// This file exists through the lifetime of the server.
def writePortfile(): Unit = {
import JsonProtocol._
val p = PortFile(s"tcp://$host:$port", None)
val uri = s"tcp://$host:$port"
val tokenRef =
if (auth(ServerAuthentication.Token)) {
writeTokenfile()
Some(tokenfile.toURI.toString)
} else None
val p = PortFile(uri, tokenRef)
val json = Converter.toJson(p).get
IO.write(portfile, CompactPrinter(json))
}

View File

@ -264,9 +264,11 @@ object Defaults extends BuildCommon {
.getOrElse(GCUtil.defaultForceGarbageCollection),
minForcegcInterval :== GCUtil.defaultMinForcegcInterval,
interactionService :== CommandLineUIService,
serverHost := "127.0.0.1",
serverPort := 5000 + (Hash
.toHex(Hash(appConfiguration.value.baseDirectory.toString))
.## % 1000)
.## % 1000),
serverAuthentication := Set(ServerAuthentication.Token),
))
def defaultTestTasks(key: Scoped): Seq[Setting[_]] =

View File

@ -127,6 +127,8 @@ object Keys {
val historyPath = SettingKey(BasicKeys.historyPath)
val shellPrompt = SettingKey(BasicKeys.shellPrompt)
val serverPort = SettingKey(BasicKeys.serverPort)
val serverHost = SettingKey(BasicKeys.serverHost)
val serverAuthentication = SettingKey(BasicKeys.serverAuthentication)
val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
val watch = SettingKey(BasicKeys.watch)
val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting)

View File

@ -16,7 +16,9 @@ import Keys.{
sessionSettings,
shellPrompt,
templateResolverInfos,
serverHost,
serverPort,
serverAuthentication,
watch
}
import Scope.{ Global, ThisScope }
@ -509,23 +511,30 @@ object Project extends ProjectExtra {
val prompt = get(shellPrompt)
val trs = (templateResolverInfos in Global get structure.data).toList.flatten
val watched = get(watch)
val host: Option[String] = get(serverHost)
val port: Option[Int] = get(serverPort)
val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication)
val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true))
val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands,
projectCommand)
val newAttrs0 =
setCond(Watched.Configuration, watched, s.attributes).put(historyPath.key, history)
val newAttrs = setCond(serverPort.key, port, newAttrs0)
.put(historyPath.key, history)
.put(templateResolverInfos.key, trs)
val newAttrs =
s.attributes
.setCond(Watched.Configuration, watched)
.put(historyPath.key, history)
.setCond(serverPort.key, port)
.setCond(serverHost.key, host)
.setCond(serverAuthentication.key, authentication)
.put(historyPath.key, history)
.put(templateResolverInfos.key, trs)
.setCond(shellPrompt.key, prompt)
s.copy(
attributes = setCond(shellPrompt.key, prompt, newAttrs),
attributes = newAttrs,
definedCommands = newDefinedCommands
)
}
def setCond[T](key: AttributeKey[T], vopt: Option[T], attributes: AttributeMap): AttributeMap =
vopt match { case Some(v) => attributes.put(key, v); case None => attributes.remove(key) }
attributes.setCond(key, vopt)
private[sbt] def checkTargets(data: Settings[Scope]): Option[String] = {
val dups = overlappingTargets(allTargets(data))

View File

@ -6,10 +6,10 @@ import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import sbt.internal.server._
import sbt.internal.util.StringEvent
import sbt.protocol.{ EventMessage, Serialization, ChannelAcceptedEvent }
import sbt.protocol.{ EventMessage, Serialization }
import scala.collection.mutable.ListBuffer
import scala.annotation.tailrec
import BasicKeys.serverPort
import BasicKeys.{ serverHost, serverPort, serverAuthentication }
import java.net.Socket
import sjsonnew.JsonFormat
import scala.concurrent.Await
@ -76,15 +76,22 @@ private[sbt] final class CommandExchange {
* Check if a server instance is running already, and start one if it isn't.
*/
private[sbt] def runServer(s: State): State = {
def port = (s get serverPort) match {
lazy val port = (s get serverPort) match {
case Some(x) => x
case None => 5001
}
def onIncomingSocket(socket: Socket): Unit = {
lazy val host = (s get serverHost) match {
case Some(x) => x
case None => "127.0.0.1"
}
lazy val auth: Set[ServerAuthentication] = (s get serverAuthentication) match {
case Some(xs) => xs
case None => Set(ServerAuthentication.Token)
}
def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = {
s.log.info(s"new client connected from: ${socket.getPort}")
val channel = new NetworkChannel(newChannelName, socket, Project structure s)
val channel = new NetworkChannel(newChannelName, socket, Project structure s, auth, instance)
subscribe(channel)
channel.publishEventMessage(ChannelAcceptedEvent(channel.name))
}
server match {
case Some(x) => // do nothing
@ -92,7 +99,7 @@ private[sbt] final class CommandExchange {
val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json"
val h = Hash.halfHashString(portfile.toURI.toString)
val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json"
val x = Server.start("127.0.0.1", port, onIncomingSocket, portfile, tokenfile, s.log)
val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log)
Await.ready(x.ready, Duration("10s"))
x.ready.value match {
case Some(Success(_)) =>

View File

@ -10,11 +10,16 @@ import java.util.concurrent.atomic.AtomicBoolean
import sbt.protocol._
import sjsonnew._
final class NetworkChannel(val name: String, connection: Socket, structure: BuildStructure)
final class NetworkChannel(val name: String,
connection: Socket,
structure: BuildStructure,
auth: Set[ServerAuthentication],
instance: ServerInstance)
extends CommandChannel {
private val running = new AtomicBoolean(true)
private val delimiter: Byte = '\n'.toByte
private val out = connection.getOutputStream
private var initialized = false
val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") {
override def run(): Unit = {
@ -42,12 +47,10 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil
)
delimPos = buffer.indexOf(delimiter)
}
} catch {
case _: SocketTimeoutException => // its ok
}
}
} finally {
shutdown()
}
@ -72,15 +75,44 @@ final class NetworkChannel(val name: String, connection: Socket, structure: Buil
}
def onCommand(command: CommandMessage): Unit = command match {
case x: InitCommand => onInitCommand(x)
case x: ExecCommand => onExecCommand(x)
case x: SettingQuery => onSettingQuery(x)
}
private def onExecCommand(cmd: ExecCommand) =
append(Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name))))
private def onInitCommand(cmd: InitCommand): Unit = {
if (auth(ServerAuthentication.Token)) {
cmd.token match {
case Some(x) =>
instance.authenticate(x) match {
case true =>
initialized = true
publishEventMessage(ChannelAcceptedEvent(name))
case _ => sys.error("invalid token")
}
case None => sys.error("init command but without token.")
}
} else {
initialized = true
}
}
private def onSettingQuery(req: SettingQuery) =
StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure)
private def onExecCommand(cmd: ExecCommand) = {
if (initialized) {
append(
Exec(cmd.commandLine, cmd.execId orElse Some(Exec.newExecId), Some(CommandSource(name))))
} else {
println(s"ignoring command $cmd before initialization")
}
}
private def onSettingQuery(req: SettingQuery) = {
if (initialized) {
StandardMain.exchange publishEventMessage SettingQuery.handleSettingQuery(req, structure)
} else {
println(s"ignoring query $req before initialization")
}
}
def shutdown(): Unit = {
println("Shutting down client connection")

View File

@ -12,7 +12,7 @@ object Dependencies {
val baseScalaVersion = scala212
// sbt modules
private val ioVersion = "1.0.1"
private val ioVersion = "1.1.0"
private val utilVersion = "1.0.1"
private val lmVersion = "1.0.2"
private val zincVersion = "1.0.1"

View File

@ -0,0 +1,43 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.protocol
final class InitCommand private (
val token: Option[String],
val execId: Option[String]) extends sbt.protocol.CommandMessage() with Serializable {
override def equals(o: Any): Boolean = o match {
case x: InitCommand => (this.token == x.token) && (this.execId == x.execId)
case _ => false
}
override def hashCode: Int = {
37 * (37 * (37 * (17 + "sbt.protocol.InitCommand".##) + token.##) + execId.##)
}
override def toString: String = {
"InitCommand(" + token + ", " + execId + ")"
}
protected[this] def copy(token: Option[String] = token, execId: Option[String] = execId): InitCommand = {
new InitCommand(token, execId)
}
def withToken(token: Option[String]): InitCommand = {
copy(token = token)
}
def withToken(token: String): InitCommand = {
copy(token = Option(token))
}
def withExecId(execId: Option[String]): InitCommand = {
copy(execId = execId)
}
def withExecId(execId: String): InitCommand = {
copy(execId = Option(execId))
}
}
object InitCommand {
def apply(token: Option[String], execId: Option[String]): InitCommand = new InitCommand(token, execId)
def apply(token: String, execId: String): InitCommand = new InitCommand(Option(token), Option(execId))
}

View File

@ -6,6 +6,6 @@
package sbt.protocol.codec
import _root_.sjsonnew.JsonFormat
trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats =>
implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat2[sbt.protocol.CommandMessage, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type")
trait CommandMessageFormats { self: sjsonnew.BasicJsonProtocol with sbt.protocol.codec.InitCommandFormats with sbt.protocol.codec.ExecCommandFormats with sbt.protocol.codec.SettingQueryFormats =>
implicit lazy val CommandMessageFormat: JsonFormat[sbt.protocol.CommandMessage] = flatUnionFormat3[sbt.protocol.CommandMessage, sbt.protocol.InitCommand, sbt.protocol.ExecCommand, sbt.protocol.SettingQuery]("type")
}

View File

@ -0,0 +1,29 @@
/**
* This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.protocol.codec
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait InitCommandFormats { self: sjsonnew.BasicJsonProtocol =>
implicit lazy val InitCommandFormat: JsonFormat[sbt.protocol.InitCommand] = new JsonFormat[sbt.protocol.InitCommand] {
override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.protocol.InitCommand = {
jsOpt match {
case Some(js) =>
unbuilder.beginObject(js)
val token = unbuilder.readField[Option[String]]("token")
val execId = unbuilder.readField[Option[String]]("execId")
unbuilder.endObject()
sbt.protocol.InitCommand(token, execId)
case None =>
deserializationError("Expected JsObject but found None")
}
}
override def write[J](obj: sbt.protocol.InitCommand, builder: Builder[J]): Unit = {
builder.beginObject()
builder.addField("token", obj.token)
builder.addField("execId", obj.execId)
builder.endObject()
}
}
}

View File

@ -5,6 +5,7 @@
// DO NOT EDIT MANUALLY
package sbt.protocol.codec
trait JsonProtocol extends sjsonnew.BasicJsonProtocol
with sbt.protocol.codec.InitCommandFormats
with sbt.protocol.codec.ExecCommandFormats
with sbt.protocol.codec.SettingQueryFormats
with sbt.protocol.codec.CommandMessageFormats

View File

@ -7,6 +7,11 @@ package sbt.protocol
interface CommandMessage {
}
type InitCommand implements CommandMessage {
token: String
execId: String
}
## Command to execute sbt command.
type ExecCommand implements CommandMessage {
commandLine: String!

View File

@ -18,6 +18,10 @@ object Client extends App {
val out = connection.getOutputStream
val in = connection.getInputStream
out.write(s"""{ "type": "InitCommand", "token": "$getToken" }""".getBytes("utf-8"))
out.write(delimiter.toInt)
out.flush
out.write("""{ "type": "ExecCommand", "commandLine": "exit" }""".getBytes("utf-8"))
out.write(delimiter.toInt)
out.flush
@ -25,6 +29,34 @@ object Client extends App {
val baseDirectory = new File(args(0))
IO.write(baseDirectory / "ok.txt", "ok")
def getToken: String = {
val tokenfile = new File(getTokenFile)
val json: JValue = Parser.parseFromFile(tokenfile).get
json match {
case JObject(fields) =>
(fields find { _.field == "token" } map { _.value }) match {
case Some(JString(value)) => value
case _ =>
sys.error("json doesn't token field that is JString")
}
case _ => sys.error("json doesn't have token field")
}
}
def getTokenFile: URI = {
val portfile = baseDirectory / "project" / "target" / "active.json"
val json: JValue = Parser.parseFromFile(portfile).get
json match {
case JObject(fields) =>
(fields find { _.field == "tokenfile" } map { _.value }) match {
case Some(JString(value)) => new URI(value)
case _ =>
sys.error("json doesn't tokenfile field that is JString")
}
case _ => sys.error("json doesn't have tokenfile field")
}
}
def getPort: Int = {
val portfile = baseDirectory / "project" / "target" / "active.json"
val json: JValue = Parser.parseFromFile(portfile).get