From 25a79d0ac63b9373cc2b232d859642ba6f95c80c Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Sun, 20 Oct 2019 22:33:02 -0400 Subject: [PATCH] Add State extension Fixes https://github.com/sbt/sbt/issues/3112 This unpacks Extracted as State's extension methods. In addition this provides a way of responding via LSP. --- build.sbt | 1 + main/src/main/scala/sbt/BuildSyntax.scala | 3 + main/src/main/scala/sbt/MainLoop.scala | 7 +- main/src/main/scala/sbt/UpperStateOps.scala | 142 ++++++++++++++++++ .../scala/sbt/internal/CommandExchange.scala | 111 ++++++++++---- .../server/LanguageServerProtocol.scala | 9 ++ .../sbt/internal/server/NetworkChannel.scala | 32 +++- .../protocol/JsonRpcResponseError.scala | 4 +- .../sbt/protocol/ExecStatusEvent.scala | 24 ++- .../codec/ExecStatusEventFormats.scala | 4 +- protocol/src/main/contraband/jsonrpc.contra | 5 + protocol/src/main/contraband/server.contra | 1 + .../src/server-test/response/build.sbt | 59 ++++++++ .../src/test/scala/testpkg/ResponseTest.scala | 67 +++++++++ 14 files changed, 422 insertions(+), 47 deletions(-) create mode 100644 main/src/main/scala/sbt/UpperStateOps.scala create mode 100644 server-test/src/server-test/response/build.sbt create mode 100644 server-test/src/test/scala/testpkg/ResponseTest.scala diff --git a/build.sbt b/build.sbt index cd96f8a51..241d0c385 100644 --- a/build.sbt +++ b/build.sbt @@ -675,6 +675,7 @@ lazy val protocolProj = (project in file("protocol")) exclude[DirectMissingMethodProblem]("sbt.protocol.SettingQuerySuccess.copy$default$*"), // ignore missing methods in sbt.internal exclude[DirectMissingMethodProblem]("sbt.internal.*"), + exclude[MissingTypesProblem]("sbt.internal.protocol.JsonRpcResponseError"), ) ) diff --git a/main/src/main/scala/sbt/BuildSyntax.scala b/main/src/main/scala/sbt/BuildSyntax.scala index ef755be7d..b2038c4a3 100644 --- a/main/src/main/scala/sbt/BuildSyntax.scala +++ b/main/src/main/scala/sbt/BuildSyntax.scala @@ -22,5 +22,8 @@ private[sbt] trait BuildSyntax { def dependsOn(deps: ClasspathDep[ProjectReference]*): DslEntry = DslEntry.DslDependsOn(deps) // avoid conflict with `sbt.Keys.aggregate` def aggregateProjects(refs: ProjectReference*): DslEntry = DslEntry.DslAggregate(refs) + + implicit def sbtStateToUpperStateOps(s: State): UpperStateOps = + new UpperStateOps.UpperStateOpsImpl(s) } private[sbt] object BuildSyntax extends BuildSyntax diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index c4351a210..7f0f8d890 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -13,6 +13,7 @@ import java.util.Properties import jline.TerminalFactory import sbt.internal.{ Aggregation, ShutdownHooks } import sbt.internal.langserver.ErrorCodes +import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.util.complete.Parser import sbt.internal.util.{ ErrorHandling, GlobalLogBacking } import sbt.io.{ IO, Using } @@ -230,6 +231,9 @@ object MainLoop { } } } catch { + case err: JsonRpcResponseError => + StandardMain.exchange.respondError(err, exec.execId, channelName.map(CommandSource(_))) + throw err case err: Throwable => val errorEvent = ExecStatusEvent( "Error", @@ -237,9 +241,10 @@ object MainLoop { exec.execId, Vector(), ExitCode(ErrorCodes.UnknownError), + Option(err.getMessage), ) import sbt.protocol.codec.JsonProtocol._ - StandardMain.exchange publishEvent errorEvent + StandardMain.exchange.publishEvent(errorEvent) throw err } } diff --git a/main/src/main/scala/sbt/UpperStateOps.scala b/main/src/main/scala/sbt/UpperStateOps.scala new file mode 100644 index 000000000..897083e38 --- /dev/null +++ b/main/src/main/scala/sbt/UpperStateOps.scala @@ -0,0 +1,142 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt + +import sjsonnew.JsonFormat +import Def.Setting +import sbt.internal.{ BuildStructure, LoadedBuildUnit, SessionSettings } + +/** + * Extends State with setting-level knowledge. + */ +trait UpperStateOps extends Any { + + /** + * ProjectRef to the current project of the state session that can be change using + * `project` commmand. + */ + def currentRef: ProjectRef + + /** + * Current project of the state session that can be change using `project` commmand. + */ + def currentProject: ResolvedProject + + private[sbt] def structure: BuildStructure + + private[sbt] def session: SessionSettings + + private[sbt] def currentUnit: LoadedBuildUnit + + /** + * Gets the value assigned to `key` in the computed settings map. + * If the project axis is not explicitly specified, it is resolved to be the current project according to the extracted `session`. + * Other axes are resolved to be `Zero` if they are not specified. + */ + def setting[A](key: SettingKey[A]): A + + /** + * Gets the value assigned to `key` in the computed settings map. + * If the project axis is not explicitly specified, it is resolved to be the current project according to the extracted `session`. + * Other axes are resolved to be `Zero` if they are not specified. + */ + def taskValue[A](key: TaskKey[A]): Task[A] + + /** + * Runs the task specified by `key` and returns the transformed State and the resulting value of the task. + * If the project axis is not defined for the key, it is resolved to be the current project. + * Other axes are resolved to `Zero` if unspecified. + * + * This method requests execution of only the given task and does not aggregate execution. + * See `runAggregated` for that. + * + * To avoid race conditions, this should NOT be called from a task. + */ + def unsafeRunTask[A](key: TaskKey[A]): (State, A) + + /** + * Runs the input task specified by `key`, using the `input` as the input to it, and returns the transformed State + * and the resulting value of the input task. + * + * If the project axis is not defined for the key, it is resolved to be the current project. + * Other axes are resolved to `Zero` if unspecified. + * + * This method requests execution of only the given task and does not aggregate execution. + * To avoid race conditions, this should NOT be called from a task. + */ + def unsafeRunInputTask[A](key: InputKey[A], input: String): (State, A) + + /** + * Runs the tasks selected by aggregating `key` and returns the transformed State. + * If the project axis is not defined for the key, it is resolved to be the current project. + * The project axis is what determines where aggregation starts, so ensure this is set to what you want. + * Other axes are resolved to `Zero` if unspecified. + * + * To avoid race conditions, this should NOT be called from a task. + */ + def unsafeRunAggregated[A](key: TaskKey[A]): State + + /** Appends the given settings to all the build state settings, including session settings. */ + def appendWithSession(settings: Seq[Setting[_]]): State + + /** + * Appends the given settings to the original build state settings, discarding any settings + * appended to the session in the process. + */ + def appendWithoutSession(settings: Seq[Setting[_]], state: State): State + + def respondEvent[A: JsonFormat](event: A): Unit + def respondError(code: Long, message: String): Unit + def notifyEvent[A: JsonFormat](method: String, params: A): Unit +} + +object UpperStateOps { + lazy val exchange = StandardMain.exchange + + implicit class UpperStateOpsImpl(val s: State) extends AnyVal with UpperStateOps { + def extract: Extracted = Project.extract(s) + + def currentRef = extract.currentRef + + def currentProject: ResolvedProject = extract.currentProject + + private[sbt] def structure: BuildStructure = Project.structure(s) + + private[sbt] def currentUnit: LoadedBuildUnit = extract.currentUnit + + private[sbt] def session: SessionSettings = Project.session(s) + + def setting[A](key: SettingKey[A]): A = extract.get(key) + + def taskValue[A](key: TaskKey[A]): Task[A] = extract.get(key) + + def unsafeRunTask[A](key: TaskKey[A]): (State, A) = extract.runTask(key, s) + + def unsafeRunInputTask[A](key: InputKey[A], input: String): (State, A) = + extract.runInputTask(key, input, s) + + def unsafeRunAggregated[A](key: TaskKey[A]): State = + extract.runAggregated(key, s) + + def appendWithSession(settings: Seq[Setting[_]]): State = + extract.appendWithSession(settings, s) + + def appendWithoutSession(settings: Seq[Setting[_]], state: State): State = + extract.appendWithoutSession(settings, s) + + def respondEvent[A: JsonFormat](event: A): Unit = { + exchange.respondEvent(event, s.currentCommand.flatMap(_.execId), s.source) + } + def respondError(code: Long, message: String): Unit = { + exchange.respondError(code, message, s.currentCommand.flatMap(_.execId), s.source) + } + def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { + exchange.notifyEvent(method, params) + } + } +} diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 25a11c153..461b538e5 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -15,6 +15,7 @@ import java.util.concurrent.atomic._ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger +import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.langserver.{ LogMessageParams, MessageType } import sbt.internal.server._ import sbt.internal.util.codec.JValueFormats @@ -52,6 +53,17 @@ private[sbt] final class CommandExchange { private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} def channels: List[CommandChannel] = channelBuffer.toList + private[this] def removeChannels(toDel: List[CommandChannel]): Unit = { + toDel match { + case Nil => // do nothing + case xs => + channelBufferLock.synchronized { + channelBuffer --= xs + () + } + } + } + def subscribe(c: CommandChannel): Unit = channelBufferLock.synchronized { channelBuffer.append(c) c.register(commandChannelQueue) @@ -181,6 +193,69 @@ private[sbt] final class CommandExchange { server = None } + // This is an interface to directly respond events. + private[sbt] def respondError( + code: Long, + message: String, + execId: Option[String], + source: Option[CommandSource] + ): Unit = { + val toDel: ListBuffer[CommandChannel] = ListBuffer.empty + channels.foreach { + case _: ConsoleChannel => + case c: NetworkChannel => + try { + // broadcast to all network channels + c.respondError(code, message, execId, source) + } catch { + case _: IOException => + toDel += c + } + } + removeChannels(toDel.toList) + } + + private[sbt] def respondError( + err: JsonRpcResponseError, + execId: Option[String], + source: Option[CommandSource] + ): Unit = { + val toDel: ListBuffer[CommandChannel] = ListBuffer.empty + channels.foreach { + case _: ConsoleChannel => + case c: NetworkChannel => + try { + // broadcast to all network channels + c.respondError(err, execId, source) + } catch { + case _: IOException => + toDel += c + } + } + removeChannels(toDel.toList) + } + + // This is an interface to directly respond events. + private[sbt] def respondEvent[A: JsonFormat]( + event: A, + execId: Option[String], + source: Option[CommandSource] + ): Unit = { + val toDel: ListBuffer[CommandChannel] = ListBuffer.empty + channels.foreach { + case _: ConsoleChannel => + case c: NetworkChannel => + try { + // broadcast to all network channels + c.respondEvent(event, execId, source) + } catch { + case _: IOException => + toDel += c + } + } + removeChannels(toDel.toList) + } + // This is an interface to directly notify events. private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { val toDel: ListBuffer[CommandChannel] = ListBuffer.empty @@ -195,14 +270,7 @@ private[sbt] final class CommandExchange { toDel += c } } - toDel.toList match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } - } + removeChannels(toDel.toList) } private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit = @@ -248,14 +316,7 @@ private[sbt] final class CommandExchange { tryTo(c.publishEvent(event), c, toDel) } } - toDel.toList match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } - } + removeChannels(toDel.toList) } private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = { @@ -290,14 +351,7 @@ private[sbt] final class CommandExchange { toDel += c } } - toDel.toList match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } - } + removeChannels(toDel.toList) } // fanout publishEvent @@ -328,13 +382,6 @@ private[sbt] final class CommandExchange { } } - toDel.toList match { - case Nil => // do nothing - case xs => - channelBufferLock.synchronized { - channelBuffer --= xs - () - } - } + removeChannels(toDel.toList) } } diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index 9a369acb8..98298f1db 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -181,6 +181,15 @@ private[sbt] trait LanguageServerProtocol extends CommandChannel { self => ): Unit = jsonRpcRespondErrorImpl(execId, code, message, Option(Converter.toJson[A](data).get)) + private[sbt] def jsonRpcRespondError( + execId: Option[String], + err: JsonRpcResponseError + ): Unit = { + val m = JsonRpcResponseMessage("2.0", execId, None, Option(err)) + val bytes = Serialization.serializeResponseMessage(m) + publishBytes(bytes) + } + private[this] def jsonRpcRespondErrorImpl( execId: Option[String], code: Long, diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index f7613541e..66a5e01ac 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -19,7 +19,11 @@ import sbt.internal.langserver.{ ErrorCodes, CancelRequestParams } import sbt.internal.util.{ ObjectEvent, StringEvent } import sbt.internal.util.complete.Parser import sbt.internal.util.codec.JValueFormats -import sbt.internal.protocol.{ JsonRpcRequestMessage, JsonRpcNotificationMessage } +import sbt.internal.protocol.{ + JsonRpcResponseError, + JsonRpcRequestMessage, + JsonRpcNotificationMessage +} import sbt.util.Logger import scala.util.Try import scala.util.control.NonFatal @@ -241,6 +245,25 @@ final class NetworkChannel( } } + private[sbt] def respondError( + err: JsonRpcResponseError, + execId: Option[String], + source: Option[CommandSource] + ): Unit = jsonRpcRespondError(execId, err) + + private[sbt] def respondError( + code: Long, + message: String, + execId: Option[String], + source: Option[CommandSource] + ): Unit = jsonRpcRespondError(execId, code, message) + + private[sbt] def respondEvent[A: JsonFormat]( + event: A, + execId: Option[String], + source: Option[CommandSource] + ): Unit = jsonRpcRespond(event, execId) + private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { if (isLanguageServerProtocol) { jsonRpcNotify(method, params) @@ -255,9 +278,10 @@ final class NetworkChannel( case entry: StringEvent => logMessage(entry.level, entry.message) case entry: ExecStatusEvent => entry.exitCode match { - case None => jsonRpcRespond(event, entry.execId) - case Some(0) => jsonRpcRespond(event, entry.execId) - case Some(exitCode) => jsonRpcRespondError(entry.execId, exitCode, "") + case None => jsonRpcRespond(event, entry.execId) + case Some(0) => jsonRpcRespond(event, entry.execId) + case Some(exitCode) => + jsonRpcRespondError(entry.execId, exitCode, entry.message.getOrElse("")) } case _ => jsonRpcRespond(event, execId) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseError.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseError.scala index 41f939f61..61dd823b5 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseError.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/JsonRpcResponseError.scala @@ -13,7 +13,7 @@ package sbt.internal.protocol final class JsonRpcResponseError private ( val code: Long, val message: String, - val data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]) extends Serializable { + val data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]) extends RuntimeException(message) with Serializable { @@ -44,7 +44,7 @@ final class JsonRpcResponseError private ( } } object JsonRpcResponseError { - + def apply(code: Long, message: String): JsonRpcResponseError = new JsonRpcResponseError(code, message, None) def apply(code: Long, message: String, data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): JsonRpcResponseError = new JsonRpcResponseError(code, message, data) def apply(code: Long, message: String, data: sjsonnew.shaded.scalajson.ast.unsafe.JValue): JsonRpcResponseError = new JsonRpcResponseError(code, message, Option(data)) } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/ExecStatusEvent.scala b/protocol/src/main/contraband-scala/sbt/protocol/ExecStatusEvent.scala index 3f42e0316..3a02692d3 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/ExecStatusEvent.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/ExecStatusEvent.scala @@ -10,22 +10,24 @@ final class ExecStatusEvent private ( val channelName: Option[String], val execId: Option[String], val commandQueue: Vector[String], - val exitCode: Option[Long]) extends sbt.protocol.EventMessage() with Serializable { + val exitCode: Option[Long], + val message: Option[String]) extends sbt.protocol.EventMessage() with Serializable { - private def this(status: String, channelName: Option[String], execId: Option[String], commandQueue: Vector[String]) = this(status, channelName, execId, commandQueue, None) + private def this(status: String, channelName: Option[String], execId: Option[String], commandQueue: Vector[String]) = this(status, channelName, execId, commandQueue, None, None) + private def this(status: String, channelName: Option[String], execId: Option[String], commandQueue: Vector[String], exitCode: Option[Long]) = this(status, channelName, execId, commandQueue, exitCode, None) override def equals(o: Any): Boolean = o match { - case x: ExecStatusEvent => (this.status == x.status) && (this.channelName == x.channelName) && (this.execId == x.execId) && (this.commandQueue == x.commandQueue) && (this.exitCode == x.exitCode) + case x: ExecStatusEvent => (this.status == x.status) && (this.channelName == x.channelName) && (this.execId == x.execId) && (this.commandQueue == x.commandQueue) && (this.exitCode == x.exitCode) && (this.message == x.message) case _ => false } override def hashCode: Int = { - 37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.protocol.ExecStatusEvent".##) + status.##) + channelName.##) + execId.##) + commandQueue.##) + exitCode.##) + 37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.protocol.ExecStatusEvent".##) + status.##) + channelName.##) + execId.##) + commandQueue.##) + exitCode.##) + message.##) } override def toString: String = { - "ExecStatusEvent(" + status + ", " + channelName + ", " + execId + ", " + commandQueue + ", " + exitCode + ")" + "ExecStatusEvent(" + status + ", " + channelName + ", " + execId + ", " + commandQueue + ", " + exitCode + ", " + message + ")" } - private[this] def copy(status: String = status, channelName: Option[String] = channelName, execId: Option[String] = execId, commandQueue: Vector[String] = commandQueue, exitCode: Option[Long] = exitCode): ExecStatusEvent = { - new ExecStatusEvent(status, channelName, execId, commandQueue, exitCode) + private[this] def copy(status: String = status, channelName: Option[String] = channelName, execId: Option[String] = execId, commandQueue: Vector[String] = commandQueue, exitCode: Option[Long] = exitCode, message: Option[String] = message): ExecStatusEvent = { + new ExecStatusEvent(status, channelName, execId, commandQueue, exitCode, message) } def withStatus(status: String): ExecStatusEvent = { copy(status = status) @@ -51,6 +53,12 @@ final class ExecStatusEvent private ( def withExitCode(exitCode: Long): ExecStatusEvent = { copy(exitCode = Option(exitCode)) } + def withMessage(message: Option[String]): ExecStatusEvent = { + copy(message = message) + } + def withMessage(message: String): ExecStatusEvent = { + copy(message = Option(message)) + } } object ExecStatusEvent { @@ -58,4 +66,6 @@ object ExecStatusEvent { def apply(status: String, channelName: String, execId: String, commandQueue: Vector[String]): ExecStatusEvent = new ExecStatusEvent(status, Option(channelName), Option(execId), commandQueue) def apply(status: String, channelName: Option[String], execId: Option[String], commandQueue: Vector[String], exitCode: Option[Long]): ExecStatusEvent = new ExecStatusEvent(status, channelName, execId, commandQueue, exitCode) def apply(status: String, channelName: String, execId: String, commandQueue: Vector[String], exitCode: Long): ExecStatusEvent = new ExecStatusEvent(status, Option(channelName), Option(execId), commandQueue, Option(exitCode)) + def apply(status: String, channelName: Option[String], execId: Option[String], commandQueue: Vector[String], exitCode: Option[Long], message: Option[String]): ExecStatusEvent = new ExecStatusEvent(status, channelName, execId, commandQueue, exitCode, message) + def apply(status: String, channelName: String, execId: String, commandQueue: Vector[String], exitCode: Long, message: String): ExecStatusEvent = new ExecStatusEvent(status, Option(channelName), Option(execId), commandQueue, Option(exitCode), Option(message)) } diff --git a/protocol/src/main/contraband-scala/sbt/protocol/codec/ExecStatusEventFormats.scala b/protocol/src/main/contraband-scala/sbt/protocol/codec/ExecStatusEventFormats.scala index a83741506..d8f0849c6 100644 --- a/protocol/src/main/contraband-scala/sbt/protocol/codec/ExecStatusEventFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/protocol/codec/ExecStatusEventFormats.scala @@ -16,8 +16,9 @@ implicit lazy val ExecStatusEventFormat: JsonFormat[sbt.protocol.ExecStatusEvent val execId = unbuilder.readField[Option[String]]("execId") val commandQueue = unbuilder.readField[Vector[String]]("commandQueue") val exitCode = unbuilder.readField[Option[Long]]("exitCode") + val message = unbuilder.readField[Option[String]]("message") unbuilder.endObject() - sbt.protocol.ExecStatusEvent(status, channelName, execId, commandQueue, exitCode) + sbt.protocol.ExecStatusEvent(status, channelName, execId, commandQueue, exitCode, message) case None => deserializationError("Expected JsObject but found None") } @@ -29,6 +30,7 @@ implicit lazy val ExecStatusEventFormat: JsonFormat[sbt.protocol.ExecStatusEvent builder.addField("execId", obj.execId) builder.addField("commandQueue", obj.commandQueue) builder.addField("exitCode", obj.exitCode) + builder.addField("message", obj.message) builder.endObject() } } diff --git a/protocol/src/main/contraband/jsonrpc.contra b/protocol/src/main/contraband/jsonrpc.contra index 9495b4a7a..54a134b40 100644 --- a/protocol/src/main/contraband/jsonrpc.contra +++ b/protocol/src/main/contraband/jsonrpc.contra @@ -56,7 +56,12 @@ type JsonRpcResponseError ## information about the error. Can be omitted. data: sjsonnew.shaded.scalajson.ast.unsafe.JValue + #xinterface RuntimeException(message) + #xtostring s"""JsonRpcResponseError($code, $message, ${sbt.protocol.Serialization.compactPrintJsonOpt(data)})""" + + #xcompanion def apply(code: Long, message: String): JsonRpcResponseError = new JsonRpcResponseError(code, message, None) + } type JsonRpcNotificationMessage implements JsonRpcMessage diff --git a/protocol/src/main/contraband/server.contra b/protocol/src/main/contraband/server.contra index fc39411b0..3f45eaa5e 100644 --- a/protocol/src/main/contraband/server.contra +++ b/protocol/src/main/contraband/server.contra @@ -47,6 +47,7 @@ type ExecStatusEvent implements EventMessage { execId: String commandQueue: [String] exitCode: Long @since("1.1.2") + message: String @since("1.4.0") } interface SettingQueryResponse implements EventMessage {} diff --git a/server-test/src/server-test/response/build.sbt b/server-test/src/server-test/response/build.sbt new file mode 100644 index 000000000..6b804d8e9 --- /dev/null +++ b/server-test/src/server-test/response/build.sbt @@ -0,0 +1,59 @@ +import sbt.internal.server.{ ServerHandler, ServerIntent } + +ThisBuild / scalaVersion := "2.12.10" + +Global / serverLog / logLevel := Level.Debug +// custom handler +Global / serverHandlers += ServerHandler({ callback => + import callback._ + import sjsonnew.BasicJsonProtocol._ + import sbt.internal.protocol.JsonRpcRequestMessage + ServerIntent( + { + case r: JsonRpcRequestMessage if r.method == "foo/export" => + appendExec(Exec("fooExport", Some(r.id), Some(CommandSource(callback.name)))) + () + case r: JsonRpcRequestMessage if r.method == "foo/fail" => + appendExec(Exec("fooFail", Some(r.id), Some(CommandSource(callback.name)))) + () + case r: JsonRpcRequestMessage if r.method == "foo/customfail" => + appendExec(Exec("fooCustomFail", Some(r.id), Some(CommandSource(callback.name)))) + () + case r: JsonRpcRequestMessage if r.method == "foo/notification" => + appendExec(Exec("fooNotification", Some(r.id), Some(CommandSource(callback.name)))) + () + case r: JsonRpcRequestMessage if r.method == "foo/rootClasspath" => + appendExec(Exec("fooClasspath", Some(r.id), Some(CommandSource(callback.name)))) + () + }, + PartialFunction.empty + ) +}) + +lazy val fooClasspath = taskKey[Unit]("") +lazy val root = (project in file(".")) + .settings( + name := "response", + commands += Command.command("fooExport") { s0: State => + val (s1, cp) = s0.unsafeRunTask(Compile / fullClasspath) + s0.respondEvent(cp.map(_.data)) + s1 + }, + commands += Command.command("fooFail") { s0: State => + sys.error("fail message") + }, + commands += Command.command("fooCustomFail") { s0: State => + import sbt.internal.protocol.JsonRpcResponseError + throw JsonRpcResponseError(500, "some error") + }, + commands += Command.command("fooNotification") { s0: State => + import CacheImplicits._ + s0.notifyEvent("foo/something", "something") + s0 + }, + fooClasspath := { + val s = state.value + val cp = (Compile / fullClasspath).value + s.respondEvent(cp.map(_.data)) + }, + ) diff --git a/server-test/src/test/scala/testpkg/ResponseTest.scala b/server-test/src/test/scala/testpkg/ResponseTest.scala new file mode 100644 index 000000000..8f8ee3459 --- /dev/null +++ b/server-test/src/test/scala/testpkg/ResponseTest.scala @@ -0,0 +1,67 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package testpkg + +import scala.concurrent.duration._ + +// starts svr using server-test/response and perform custom server tests +object ResponseTest extends AbstractServerTest { + override val testDirectory: String = "response" + + test("response from a command") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "10", "method": "foo/export", "params": {} }""" + ) + assert(svr.waitForString(10.seconds) { s => + println(s) + (s contains """"id":"10"""") && + (s contains "scala-library.jar") + }) + } + + test("response from a task") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "11", "method": "foo/rootClasspath", "params": {} }""" + ) + assert(svr.waitForString(10.seconds) { s => + println(s) + (s contains """"id":"11"""") && + (s contains "scala-library.jar") + }) + } + + test("a command failure") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "12", "method": "foo/fail", "params": {} }""" + ) + assert(svr.waitForString(10.seconds) { s => + println(s) + (s contains """"error":{"code":-33000,"message":"fail message"""") + }) + } + + test("a command failure with custom code") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "13", "method": "foo/customfail", "params": {} }""" + ) + assert(svr.waitForString(10.seconds) { s => + println(s) + (s contains """"error":{"code":500,"message":"some error"""") + }) + } + + test("a command with a notification") { _ => + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": "14", "method": "foo/notification", "params": {} }""" + ) + assert(svr.waitForString(10.seconds) { s => + println(s) + (s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""") + }) + } +}