Merge pull request #3975 from eed3si9n/wip/serverext

make sbt server extensible
This commit is contained in:
Dale Wijnand 2018-03-15 01:58:07 +00:00 committed by GitHub
commit 087f21741d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 256 additions and 86 deletions

View File

@ -498,6 +498,7 @@ lazy val sbtProj = (project in file("sbt"))
connectInput in run in Test := true,
outputStrategy in run in Test := Some(StdoutOutput),
fork in Test := true,
parallelExecution in Test := false,
)
.configure(addSbtCompilerBridge)

View File

@ -10,6 +10,7 @@ package sbt
import java.io.File
import sbt.internal.util.AttributeKey
import sbt.internal.inc.classpath.ClassLoaderCache
import sbt.internal.server.ServerHandler
import sbt.librarymanagement.ModuleID
import sbt.util.Level
@ -39,6 +40,11 @@ object BasicKeys {
"The wire protocol for the server command.",
10000)
val fullServerHandlers =
AttributeKey[Seq[ServerHandler]]("fullServerHandlers",
"Combines default server handlers and user-defined handlers.",
10000)
val autoStartServer =
AttributeKey[Boolean](
"autoStartServer",

View File

@ -0,0 +1,69 @@
/*
* sbt
* Copyright 2011 - 2017, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under BSD-3-Clause license (see LICENSE)
*/
package sbt
package internal
package server
import sjsonnew.JsonFormat
import sbt.internal.protocol._
import sbt.util.Logger
import sbt.protocol.{ SettingQuery => Q }
/**
* ServerHandler allows plugins to extend sbt server.
* It's a wrapper around curried function ServerCallback => JsonRpcRequestMessage => Unit.
*/
final class ServerHandler(val handler: ServerCallback => ServerIntent) {
override def toString: String = s"Serverhandler(...)"
}
object ServerHandler {
def apply(handler: ServerCallback => ServerIntent): ServerHandler =
new ServerHandler(handler)
lazy val fallback: ServerHandler = ServerHandler({ handler =>
ServerIntent(
{ case x => handler.log.debug(s"Unhandled notification received: ${x.method}: $x") },
{ case x => handler.log.debug(s"Unhandled request received: ${x.method}: $x") }
)
})
}
final class ServerIntent(val onRequest: PartialFunction[JsonRpcRequestMessage, Unit],
val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit]) {
override def toString: String = s"ServerIntent(...)"
}
object ServerIntent {
def apply(onRequest: PartialFunction[JsonRpcRequestMessage, Unit],
onNotification: PartialFunction[JsonRpcNotificationMessage, Unit]): ServerIntent =
new ServerIntent(onRequest, onNotification)
def request(onRequest: PartialFunction[JsonRpcRequestMessage, Unit]): ServerIntent =
new ServerIntent(onRequest, PartialFunction.empty)
def notify(onNotification: PartialFunction[JsonRpcNotificationMessage, Unit]): ServerIntent =
new ServerIntent(PartialFunction.empty, onNotification)
}
/**
* Interface to invoke JSON-RPC response.
*/
trait ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit
def appendExec(exec: Exec): Boolean
def log: Logger
def name: String
private[sbt] def authOptions: Set[ServerAuthentication]
private[sbt] def authenticate(token: String): Boolean
private[sbt] def setInitialized(value: Boolean): Unit
private[sbt] def onSettingQuery(execId: Option[String], req: Q): Unit
}

View File

@ -26,7 +26,12 @@ import sbt.internal.librarymanagement.mavenint.{
PomExtraDependencyAttributes,
SbtPomExtraProperties
}
import sbt.internal.server.{ LanguageServerReporter, Definition }
import sbt.internal.server.{
LanguageServerReporter,
Definition,
LanguageServerProtocol,
ServerHandler
}
import sbt.internal.testing.TestLogger
import sbt.internal.util._
import sbt.internal.util.Attributed.data
@ -278,6 +283,12 @@ object Defaults extends BuildCommon {
if (serverConnectionType.value == ConnectionType.Tcp) Set(ServerAuthentication.Token)
else Set()
},
serverHandlers :== Nil,
fullServerHandlers := {
(Vector(LanguageServerProtocol.handler)
++ serverHandlers.value
++ Vector(ServerHandler.fallback))
},
insideCI :== sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI"),
))

View File

@ -42,6 +42,7 @@ import sbt.internal.{
}
import sbt.io.{ FileFilter, WatchService }
import sbt.internal.io.WatchState
import sbt.internal.server.ServerHandler
import sbt.internal.util.{ AttributeKey, SourcePosition }
import sbt.librarymanagement.Configurations.CompilerPlugin
@ -136,6 +137,8 @@ object Keys {
val serverHost = SettingKey(BasicKeys.serverHost)
val serverAuthentication = SettingKey(BasicKeys.serverAuthentication)
val serverConnectionType = SettingKey(BasicKeys.serverConnectionType)
val fullServerHandlers = SettingKey(BasicKeys.fullServerHandlers)
val serverHandlers = settingKey[Seq[ServerHandler]]("User-defined server handlers.")
val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting)
val watch = SettingKey(BasicKeys.watch)

View File

@ -27,6 +27,7 @@ import Keys.{
serverPort,
serverAuthentication,
serverConnectionType,
fullServerHandlers,
logLevel,
watch
}
@ -44,6 +45,7 @@ import sbt.internal.{
import sbt.internal.util.{ AttributeKey, AttributeMap, Dag, Relation, Settings, ~> }
import sbt.internal.util.Types.{ const, idFun }
import sbt.internal.util.complete.DefaultParsers
import sbt.internal.server.ServerHandler
import sbt.librarymanagement.Configuration
import sbt.util.{ Show, Level }
import sjsonnew.JsonFormat
@ -475,6 +477,7 @@ object Project extends ProjectExtra {
val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication)
val connectionType: Option[ConnectionType] = get(serverConnectionType)
val srvLogLevel: Option[Level.Value] = (logLevel in (ref, serverLog)).get(structure.data)
val hs: Option[Seq[ServerHandler]] = get(fullServerHandlers)
val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true))
val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands,
projectCommand)
@ -491,6 +494,7 @@ object Project extends ProjectExtra {
.put(templateResolverInfos.key, trs)
.setCond(shellPrompt.key, prompt)
.setCond(serverLogLevel, srvLogLevel)
.setCond(fullServerHandlers.key, hs)
s.copy(
attributes = newAttrs,
definedCommands = newDefinedCommands

View File

@ -20,6 +20,7 @@ import BasicKeys.{
serverAuthentication,
serverConnectionType,
serverLogLevel,
fullServerHandlers,
logLevel
}
import java.net.Socket
@ -102,6 +103,7 @@ private[sbt] final class CommandExchange {
s.get(serverAuthentication).getOrElse(Set(ServerAuthentication.Token))
lazy val connectionType = s.get(serverConnectionType).getOrElse(ConnectionType.Tcp)
lazy val level = s.get(serverLogLevel).orElse(s.get(logLevel)).getOrElse(Level.Warn)
lazy val handlers = s.get(fullServerHandlers).getOrElse(Nil)
def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = {
val name = newNetworkName
@ -114,7 +116,7 @@ private[sbt] final class CommandExchange {
log
}
val channel =
new NetworkChannel(name, socket, Project structure s, auth, instance, logger)
new NetworkChannel(name, socket, Project structure s, auth, instance, handlers, logger)
subscribe(channel)
}
if (server.isEmpty && firstInstance.get) {
@ -210,7 +212,7 @@ private[sbt] final class CommandExchange {
// in case we have a better client that can utilize the knowledge.
import sbt.internal.langserver.codec.JsonProtocol._
if (broadcastStringMessage || (entry.channelName contains c.name))
c.langNotify("window/logMessage", params)
c.jsonRpcNotify("window/logMessage", params)
} catch { case _: IOException => toDel += c }
}
case _ =>

View File

@ -21,10 +21,70 @@ import sbt.util.Logger
private[sbt] case class LangServerError(code: Long, message: String) extends Throwable(message)
private[sbt] object LanguageServerProtocol {
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
lazy val serverCapabilities: ServerCapabilities = {
ServerCapabilities(textDocumentSync =
TextDocumentSyncOptions(true, 0, false, false, SaveOptions(false)),
hoverProvider = false,
definitionProvider = true)
}
lazy val handler: ServerHandler = ServerHandler({
case callback: ServerCallback =>
import callback._
ServerIntent(
{
import sbt.internal.langserver.codec.JsonProtocol._
import internalJsonProtocol._
def json(r: JsonRpcRequestMessage) =
r.params.getOrElse(
throw LangServerError(ErrorCodes.InvalidParams,
s"param is expected on '${r.method}' method."))
{
case r: JsonRpcRequestMessage if r.method == "initialize" =>
if (authOptions(ServerAuthentication.Token)) {
val param = Converter.fromJson[InitializeParams](json(r)).get
val optionJson = param.initializationOptions.getOrElse(
throw LangServerError(ErrorCodes.InvalidParams,
"initializationOptions is expected on 'initialize' param."))
val opt = Converter.fromJson[InitializeOption](optionJson).get
val token = opt.token.getOrElse(sys.error("'token' is missing."))
if (authenticate(token)) ()
else throw LangServerError(ErrorCodes.InvalidRequest, "invalid token")
} else ()
setInitialized(true)
appendExec(Exec(s"collectAnalyses", Some(r.id), Some(CommandSource(name))))
jsonRpcRespond(InitializeResult(serverCapabilities), Option(r.id))
case r: JsonRpcRequestMessage if r.method == "textDocument/definition" =>
import scala.concurrent.ExecutionContext.Implicits.global
Definition.lspDefinition(json(r), r.id, CommandSource(name), log)
()
case r: JsonRpcRequestMessage if r.method == "sbt/exec" =>
val param = Converter.fromJson[SbtExecParams](json(r)).get
appendExec(Exec(param.commandLine, Some(r.id), Some(CommandSource(name))))
()
case r: JsonRpcRequestMessage if r.method == "sbt/setting" =>
import sbt.protocol.codec.JsonProtocol._
val param = Converter.fromJson[Q](json(r)).get
onSettingQuery(Option(r.id), param)
}
}, {
case n: JsonRpcNotificationMessage if n.method == "textDocument/didSave" =>
appendExec(Exec(";compile; collectAnalyses", None, Some(CommandSource(name))))
()
}
)
})
}
/**
* Implements Language Server Protocol <https://github.com/Microsoft/language-server-protocol>.
*/
private[sbt] trait LanguageServerProtocol extends CommandChannel {
private[sbt] trait LanguageServerProtocol extends CommandChannel { self =>
lazy val internalJsonProtocol = new InitializeOptionFormats with sjsonnew.BasicJsonProtocol {}
@ -34,54 +94,24 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
protected def log: Logger
protected def onSettingQuery(execId: Option[String], req: Q): Unit
protected def onNotification(notification: JsonRpcNotificationMessage): Unit = {
log.debug(s"onNotification: $notification")
notification.method match {
case "textDocument/didSave" =>
append(Exec(";compile; collectAnalyses", None, Some(CommandSource(name))))
()
case u => log.debug(s"Unhandled notification received: $u")
}
}
protected lazy val callbackImpl: ServerCallback = new ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
self.jsonRpcRespond(event, execId)
protected def onRequestMessage(request: JsonRpcRequestMessage): Unit = {
import sbt.internal.langserver.codec.JsonProtocol._
import internalJsonProtocol._
def json =
request.params.getOrElse(
throw LangServerError(ErrorCodes.InvalidParams,
s"param is expected on '${request.method}' method."))
log.debug(s"onRequestMessage: $request")
request.method match {
case "initialize" =>
if (authOptions(ServerAuthentication.Token)) {
val param = Converter.fromJson[InitializeParams](json).get
val optionJson = param.initializationOptions.getOrElse(
throw LangServerError(ErrorCodes.InvalidParams,
"initializationOptions is expected on 'initialize' param."))
val opt = Converter.fromJson[InitializeOption](optionJson).get
val token = opt.token.getOrElse(sys.error("'token' is missing."))
if (authenticate(token)) ()
else throw LangServerError(ErrorCodes.InvalidRequest, "invalid token")
} else ()
setInitialized(true)
append(Exec(s"collectAnalyses", Some(request.id), Some(CommandSource(name))))
langRespond(InitializeResult(serverCapabilities), Option(request.id))
case "textDocument/definition" =>
import scala.concurrent.ExecutionContext.Implicits.global
Definition.lspDefinition(json, request.id, CommandSource(name), log)
()
case "sbt/exec" =>
val param = Converter.fromJson[SbtExecParams](json).get
append(Exec(param.commandLine, Some(request.id), Some(CommandSource(name))))
()
case "sbt/setting" => {
import sbt.protocol.codec.JsonProtocol._
val param = Converter.fromJson[Q](json).get
onSettingQuery(Option(request.id), param)
}
case unhandledRequest => log.debug(s"Unhandled request received: $unhandledRequest")
}
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
self.jsonRpcRespondError(execId, code, message)
def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit =
self.jsonRpcNotify(method, params)
def appendExec(exec: Exec): Boolean = self.append(exec)
def log: Logger = self.log
def name: String = self.name
private[sbt] def authOptions: Set[ServerAuthentication] = self.authOptions
private[sbt] def authenticate(token: String): Boolean = self.authenticate(token)
private[sbt] def setInitialized(value: Boolean): Unit = self.setInitialized(value)
private[sbt] def onSettingQuery(execId: Option[String], req: Q): Unit =
self.onSettingQuery(execId, req)
}
/**
@ -97,7 +127,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
// LanguageServerReporter sends PublishDiagnosticsParams
case "sbt.internal.langserver.PublishDiagnosticsParams" =>
// val p = event.message.asInstanceOf[PublishDiagnosticsParams]
// langNotify("textDocument/publishDiagnostics", p)
// jsonRpcNotify("textDocument/publishDiagnostics", p)
case "xsbti.Problem" =>
() // ignore
case _ =>
@ -109,7 +139,7 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
/**
* Respond back to Language Server's client.
*/
private[sbt] def langRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = {
val m =
JsonRpcResponseMessage("2.0", execId, Option(Converter.toJson[A](event).get), None)
val bytes = Serialization.serializeResponseMessage(m)
@ -119,7 +149,9 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
/**
* Respond back to Language Server's client.
*/
private[sbt] def langError(execId: Option[String], code: Long, message: String): Unit = {
private[sbt] def jsonRpcRespondError(execId: Option[String],
code: Long,
message: String): Unit = {
val e = JsonRpcResponseError(code, message, None)
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
val bytes = Serialization.serializeResponseMessage(m)
@ -129,10 +161,10 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
/**
* Respond back to Language Server's client.
*/
private[sbt] def langError[A: JsonFormat](execId: Option[String],
code: Long,
message: String,
data: A): Unit = {
private[sbt] def jsonRpcRespondError[A: JsonFormat](execId: Option[String],
code: Long,
message: String,
data: A): Unit = {
val e = JsonRpcResponseError(code, message, Option(Converter.toJson[A](data).get))
val m = JsonRpcResponseMessage("2.0", execId, None, Option(e))
val bytes = Serialization.serializeResponseMessage(m)
@ -142,26 +174,19 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel {
/**
* Notify to Language Server's client.
*/
private[sbt] def langNotify[A: JsonFormat](method: String, params: A): Unit = {
private[sbt] def jsonRpcNotify[A: JsonFormat](method: String, params: A): Unit = {
val m =
JsonRpcNotificationMessage("2.0", method, Option(Converter.toJson[A](params).get))
log.debug(s"langNotify: $m")
log.debug(s"jsonRpcNotify: $m")
val bytes = Serialization.serializeNotificationMessage(m)
publishBytes(bytes)
}
def logMessage(level: String, message: String): Unit = {
import sbt.internal.langserver.codec.JsonProtocol._
langNotify(
jsonRpcNotify(
"window/logMessage",
LogMessageParams(MessageType.fromLevelString(level), message)
)
}
private[sbt] lazy val serverCapabilities: ServerCapabilities = {
ServerCapabilities(textDocumentSync =
TextDocumentSyncOptions(true, 0, false, false, SaveOptions(false)),
hoverProvider = false,
definitionProvider = true)
}
}

View File

@ -26,6 +26,7 @@ final class NetworkChannel(val name: String,
structure: BuildStructure,
auth: Set[ServerAuthentication],
instance: ServerInstance,
handlers: Seq[ServerHandler],
val log: Logger)
extends CommandChannel
with LanguageServerProtocol {
@ -76,7 +77,6 @@ final class NetworkChannel(val name: String,
// contentType = ""
state = SingleLine
}
def tillEndOfLine: Option[Vector[Byte]] = {
val delimPos = buffer.indexOf(delimiter)
if (delimPos > 0) {
@ -165,6 +165,21 @@ final class NetworkChannel(val name: String,
}
}
private lazy val intents = {
val cb = callbackImpl
handlers.toVector map { h =>
h.handler(cb)
}
}
lazy val onRequestMessage: PartialFunction[JsonRpcRequestMessage, Unit] =
intents.foldLeft(PartialFunction.empty[JsonRpcRequestMessage, Unit]) {
case (f, i) => f orElse i.onRequest
}
lazy val onNotification: PartialFunction[JsonRpcNotificationMessage, Unit] =
intents.foldLeft(PartialFunction.empty[JsonRpcNotificationMessage, Unit]) {
case (f, i) => f orElse i.onNotification
}
def handleBody(chunk: Vector[Byte]): Unit = {
if (isLanguageServerProtocol) {
Serialization.deserializeJsonMessage(chunk) match {
@ -174,7 +189,7 @@ final class NetworkChannel(val name: String,
} catch {
case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message")
langError(Option(req.id), code, message)
jsonRpcRespondError(Option(req.id), code, message)
}
case Right(ntf: JsonRpcNotificationMessage) =>
try {
@ -182,13 +197,13 @@ final class NetworkChannel(val name: String,
} catch {
case LangServerError(code, message) =>
log.debug(s"sending error: $code: $message")
langError(None, code, message) // new id?
jsonRpcRespondError(None, code, message) // new id?
}
case Right(msg) =>
log.debug(s"Unhandled message: $msg")
case Left(errorDesc) =>
val msg = s"Got invalid chunk from client (${new String(chunk.toArray, "UTF-8")}): " + errorDesc
langError(None, ErrorCodes.ParseError, msg)
jsonRpcRespondError(None, ErrorCodes.ParseError, msg)
}
} else {
contentType match {
@ -230,7 +245,7 @@ final class NetworkChannel(val name: String,
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
if (isLanguageServerProtocol) {
langNotify(method, params)
jsonRpcNotify(method, params)
} else {
()
}
@ -240,7 +255,7 @@ final class NetworkChannel(val name: String,
if (isLanguageServerProtocol) {
event match {
case entry: StringEvent => logMessage(entry.level, entry.message)
case _ => langRespond(event, execId)
case _ => jsonRpcRespond(event, execId)
}
} else {
contentType match {
@ -341,8 +356,8 @@ final class NetworkChannel(val name: String,
if (initialized) {
import sbt.protocol.codec.JsonProtocol._
SettingQuery.handleSettingQueryEither(req, structure) match {
case Right(x) => langRespond(x, execId)
case Left(s) => langError(execId, ErrorCodes.InvalidParams, s)
case Right(x) => jsonRpcRespond(x, execId)
case Left(s) => jsonRpcRespondError(execId, ErrorCodes.InvalidParams, s)
}
} else {
log.warn(s"ignoring query $req before initialization")

View File

@ -1,6 +1,23 @@
import sbt.internal.ServerHandler
lazy val root = (project in file("."))
.settings(
Global / serverLog / logLevel := Level.Debug,
// custom handler
Global / serverHandlers += ServerHandler({ callback =>
import callback._
import sjsonnew.BasicJsonProtocol._
import sbt.internal.protocol.JsonRpcRequestMessage
ServerIntent(
{
case r: JsonRpcRequestMessage if r.method == "lunar/helo" =>
jsonRpcNotify("lunar/oleh", "")
()
},
PartialFunction.empty
)
}),
name := "handshake",
scalaVersion := "2.12.3",
)

View File

@ -20,15 +20,19 @@ class ServerSpec extends AsyncFlatSpec with Matchers {
"server" should "start" in {
withBuildSocket("handshake") { (out, in, tkn) =>
writeLine(
"""{ "jsonrpc": "2.0", "id": 3, "method": "sbt/setting", "params": { "setting": "root/name" } }""",
"""{ "jsonrpc": "2.0", "id": 3, "method": "sbt/setting", "params": { "setting": "handshake/name" } }""",
out)
Thread.sleep(100)
val l2 = contentLength(in)
println(l2)
readLine(in)
readLine(in)
val x2 = readContentLength(in, l2)
println(x2)
println(readFrame(in))
println(readFrame(in))
println(readFrame(in))
println(readFrame(in))
// println(readFrame(in))
writeLine("""{ "jsonrpc": "2.0", "id": 10, "method": "lunar/helo", "params": {} }""", out)
Thread.sleep(100)
assert(1 == 1)
}
}
@ -90,6 +94,14 @@ object ServerSpec {
writeLine(message, out)
}
def readFrame(in: InputStream): Option[String] = {
val l = contentLength(in)
println(l)
readLine(in)
readLine(in)
readContentLength(in, l)
}
def contentLength(in: InputStream): Int = {
readLine(in) map { line =>
line.drop(16).toInt
@ -98,7 +110,11 @@ object ServerSpec {
def readLine(in: InputStream): Option[String] = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
val bytesRead = try {
in.read(readBuffer)
} catch {
case _: java.io.IOException => 0
}
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
@ -153,11 +169,12 @@ object ServerSpec {
else {
if (n <= 0) sys.error(s"Timeout. $portfile is not found.")
else {
println(s" waiting for $portfile...")
Thread.sleep(1000)
waitForPortfile(n - 1)
}
}
waitForPortfile(10)
waitForPortfile(20)
val (sk, tkn) = ClientSocket.socket(portfile)
val out = sk.getOutputStream
val in = sk.getInputStream
@ -172,7 +189,7 @@ object ServerSpec {
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "exit" } }""",
out)
shutdown()
// shutdown()
}
}
}