diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index b2672ab00..342d652e4 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -317,6 +317,7 @@ class NetworkClient( token = tkn, skipAnalysis = Some(skipAnalysis), canWork = Some(true), + subscribeToAll = Some(false), ) val initCommand = InitCommand( token = tkn, // duplicated with opts for compatibility diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index a83d0ed94..9d0fbeb5a 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -366,12 +366,11 @@ private[sbt] final class CommandExchange { } // This is an interface to directly notify events. - private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { - channels.foreach { - case c: NetworkChannel => tryTo(_.notifyEvent(method, params))(c) - case _ => - } - } + private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = + channels.foreach: + case c: NetworkChannel if c.subscribeToAll || isChannelOwner(c) => + tryTo(_.notifyEvent(method, params))(c) + case _ => private def tryTo(f: NetworkChannel => Unit)( channel: NetworkChannel @@ -418,12 +417,14 @@ private[sbt] final class CommandExchange { } def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event)) - def logMessage(event: LogEvent): Unit = { - channels.foreach { - case c: NetworkChannel => tryTo(_.notifyEvent(event))(c) - case _ => - } - } + def logMessage(event: LogEvent): Unit = + channels.foreach: + case c: NetworkChannel if c.subscribeToAll || isChannelOwner(c) => + tryTo(_.notifyEvent(event))(c) + case _ => + + private def isChannelOwner(c: NetworkChannel): Boolean = + currentExec.exists(_.source.exists(_.channelName == c.name)) def notifyStatus(event: ExecStatusEvent): Unit = { for { diff --git a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala index e5af5d4b6..4ea284708 100644 --- a/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/LanguageServerProtocol.scala @@ -58,13 +58,13 @@ private[sbt] object LanguageServerProtocol { ) ) val opt = Converter.fromJson[InitializeOption](optionJson).get + setInitializeOption(opt) if (authOptions(ServerAuthentication.Token)) { val token = opt.token.getOrElse(sys.error("'token' is missing.")) if (authenticate(token)) () else throw LangServerError(ErrorCodes.InvalidRequest, "invalid token") } else () setInitialized(true) - setInitializeOption(opt) if (!opt.skipAnalysis.getOrElse(false)) appendExec("collectAnalyses", None) jsonRpcRespond(InitializeResult(serverCapabilities), Some(r.id)) diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 7b1fe9ed4..c97a2c255 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -165,6 +165,10 @@ final class NetworkChannel( case _ => false } + /** True if this channel should receive broadcast events (logMessage, notifyEvent). Default true for backward compatibility. */ + private[sbt] def subscribeToAll: Boolean = + Option(initializeOption.get).flatMap(_.subscribeToAll).getOrElse(false) + protected def authenticate(token: String): Boolean = instance.authenticate(token) protected def setInitialized(value: Boolean): Unit = initialized = value diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/InitializeOption.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/InitializeOption.scala index 42e362496..2155d50bf 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/InitializeOption.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/InitializeOption.scala @@ -11,23 +11,25 @@ package sbt.internal.protocol final class InitializeOption private ( val token: Option[String], val skipAnalysis: Option[Boolean], - val canWork: Option[Boolean]) extends Serializable { + val canWork: Option[Boolean], + val subscribeToAll: Option[Boolean]) extends Serializable { - private def this(token: Option[String]) = this(token, None, None) - private def this(token: Option[String], skipAnalysis: Option[Boolean]) = this(token, skipAnalysis, None) + private def this(token: Option[String]) = this(token, None, None, None) + private def this(token: Option[String], skipAnalysis: Option[Boolean]) = this(token, skipAnalysis, None, None) + private def this(token: Option[String], skipAnalysis: Option[Boolean], canWork: Option[Boolean]) = this(token, skipAnalysis, canWork, None) override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { - case x: InitializeOption => (this.token == x.token) && (this.skipAnalysis == x.skipAnalysis) && (this.canWork == x.canWork) + case x: InitializeOption => (this.token == x.token) && (this.skipAnalysis == x.skipAnalysis) && (this.canWork == x.canWork) && (this.subscribeToAll == x.subscribeToAll) case _ => false }) override def hashCode: Int = { - 37 * (37 * (37 * (37 * (17 + "sbt.internal.protocol.InitializeOption".##) + token.##) + skipAnalysis.##) + canWork.##) + 37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.protocol.InitializeOption".##) + token.##) + skipAnalysis.##) + canWork.##) + subscribeToAll.##) } override def toString: String = { - "InitializeOption(" + token + ", " + skipAnalysis + ", " + canWork + ")" + "InitializeOption(" + token + ", " + skipAnalysis + ", " + canWork + ", " + subscribeToAll + ")" } - private def copy(token: Option[String] = token, skipAnalysis: Option[Boolean] = skipAnalysis, canWork: Option[Boolean] = canWork): InitializeOption = { - new InitializeOption(token, skipAnalysis, canWork) + private def copy(token: Option[String] = token, skipAnalysis: Option[Boolean] = skipAnalysis, canWork: Option[Boolean] = canWork, subscribeToAll: Option[Boolean] = subscribeToAll): InitializeOption = { + new InitializeOption(token, skipAnalysis, canWork, subscribeToAll) } def withToken(token: Option[String]): InitializeOption = { copy(token = token) @@ -47,6 +49,12 @@ final class InitializeOption private ( def withCanWork(canWork: Boolean): InitializeOption = { copy(canWork = Option(canWork)) } + def withSubscribeToAll(subscribeToAll: Option[Boolean]): InitializeOption = { + copy(subscribeToAll = subscribeToAll) + } + def withSubscribeToAll(subscribeToAll: Boolean): InitializeOption = { + copy(subscribeToAll = Option(subscribeToAll)) + } } object InitializeOption { @@ -56,4 +64,6 @@ object InitializeOption { def apply(token: String, skipAnalysis: Boolean): InitializeOption = new InitializeOption(Option(token), Option(skipAnalysis)) def apply(token: Option[String], skipAnalysis: Option[Boolean], canWork: Option[Boolean]): InitializeOption = new InitializeOption(token, skipAnalysis, canWork) def apply(token: String, skipAnalysis: Boolean, canWork: Boolean): InitializeOption = new InitializeOption(Option(token), Option(skipAnalysis), Option(canWork)) + def apply(token: Option[String], skipAnalysis: Option[Boolean], canWork: Option[Boolean], subscribeToAll: Option[Boolean]): InitializeOption = new InitializeOption(token, skipAnalysis, canWork, subscribeToAll) + def apply(token: String, skipAnalysis: Boolean, canWork: Boolean, subscribeToAll: Boolean): InitializeOption = new InitializeOption(Option(token), Option(skipAnalysis), Option(canWork), Option(subscribeToAll)) } diff --git a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/InitializeOptionFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/InitializeOptionFormats.scala index 061d00af9..5f9a490dc 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/InitializeOptionFormats.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/protocol/codec/InitializeOptionFormats.scala @@ -14,8 +14,9 @@ given InitializeOptionFormat: JsonFormat[sbt.internal.protocol.InitializeOption] val token = unbuilder.readField[Option[String]]("token") val skipAnalysis = unbuilder.readField[Option[Boolean]]("skipAnalysis") val canWork = unbuilder.readField[Option[Boolean]]("canWork") + val subscribeToAll = unbuilder.readField[Option[Boolean]]("subscribeToAll") unbuilder.endObject() - sbt.internal.protocol.InitializeOption(token, skipAnalysis, canWork) + sbt.internal.protocol.InitializeOption(token, skipAnalysis, canWork, subscribeToAll) case None => deserializationError("Expected JsObject but found None") } @@ -25,6 +26,7 @@ given InitializeOptionFormat: JsonFormat[sbt.internal.protocol.InitializeOption] builder.addField("token", obj.token) builder.addField("skipAnalysis", obj.skipAnalysis) builder.addField("canWork", obj.canWork) + builder.addField("subscribeToAll", obj.subscribeToAll) builder.endObject() } } diff --git a/protocol/src/main/contraband/portfile.contra b/protocol/src/main/contraband/portfile.contra index ffd5dc6c9..5f396bf66 100644 --- a/protocol/src/main/contraband/portfile.contra +++ b/protocol/src/main/contraband/portfile.contra @@ -22,4 +22,5 @@ type InitializeOption { token: String skipAnalysis: Boolean @since("1.4.0") canWork: Boolean @since("1.10.8") + subscribeToAll: Boolean @since("2.0.0") } diff --git a/sbt-app/src/sbt-test/server/client-subscription/build.sbt b/sbt-app/src/sbt-test/server/client-subscription/build.sbt new file mode 100644 index 000000000..dc2ea62e5 --- /dev/null +++ b/sbt-app/src/sbt-test/server/client-subscription/build.sbt @@ -0,0 +1,2 @@ +name := "client-subscription" +scalaVersion := "2.12.19" diff --git a/sbt-app/src/sbt-test/server/client-subscription/test b/sbt-app/src/sbt-test/server/client-subscription/test new file mode 100644 index 000000000..fff477039 --- /dev/null +++ b/sbt-app/src/sbt-test/server/client-subscription/test @@ -0,0 +1,2 @@ +# Exercise server client path (subscribe-to-all by default). Closes #4399. +> show name diff --git a/server-test/src/test/scala/testpkg/ClientSubscriptionTest.scala b/server-test/src/test/scala/testpkg/ClientSubscriptionTest.scala new file mode 100644 index 000000000..dddde2cd1 --- /dev/null +++ b/server-test/src/test/scala/testpkg/ClientSubscriptionTest.scala @@ -0,0 +1,48 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package testpkg + +import scala.concurrent.duration.* + +class ClientSubscriptionTest extends AbstractServerTest { + override val testDirectory: String = "handshake" + + test("subscribe-to-all (default) client receives broadcast build/logMessage when command runs") { + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 2, "method": "sbt/exec", "params": { "commandLine": "show name" } }""" + ) + def isLogMessageNotification(line: String): Boolean = + line.contains("\"method\":\"build/logMessage\"") || line.contains( + "\"method\": \"build/logMessage\"" + ) + assert( + svr.waitForString(10.seconds)(isLogMessageNotification), + "subscribe-to-all client must receive broadcast build/logMessage when a command produces log output" + ) + } +} + +class ClientNoSubscriptionTest extends AbstractServerTest { + override val testDirectory: String = "handshake" + override def subscribeToAllForTest: Boolean = false + + test("non-subscribed client receives build/logMessage for its own command") { + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 2, "method": "sbt/exec", "params": { "commandLine": "show name" } }""" + ) + def isLogMessageNotification(line: String): Boolean = + line.contains("\"method\":\"build/logMessage\"") || line.contains( + "\"method\": \"build/logMessage\"" + ) + assert( + svr.waitForString(10.seconds)(isLogMessageNotification), + "non-subscribed client must still receive build/logMessage for its own command" + ) + } +} diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index a3b462cb9..fec7c578b 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -31,6 +31,7 @@ trait AbstractServerTest extends AnyFunSuite with BeforeAndAfterAll { var svr: TestServer = scala.compiletime.uninitialized def testDirectory: String def testPath: Path = temp.toPath.resolve(testDirectory) + def subscribeToAllForTest: Boolean = true private val targetDir: File = { val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target" @@ -48,7 +49,14 @@ trait AbstractServerTest extends AnyFunSuite with BeforeAndAfterAll { val classpath = TestProperties.classpath.split(File.pathSeparator).map(new File(_)) val sbtVersion = TestProperties.version val scalaVersion = TestProperties.scalaVersion - svr = TestServer.get(testDirectory, scalaVersion, sbtVersion, classpath.toSeq, temp) + svr = TestServer.get( + testDirectory, + scalaVersion, + sbtVersion, + classpath.toSeq, + temp, + subscribeToAllForTest + ) } override protected def afterAll(): Unit = { svr.bye() @@ -71,13 +79,14 @@ object TestServer { scalaVersion: String, sbtVersion: String, classpath: Seq[File], - temp: File + temp: File, + subscribeToAll: Boolean = true ): TestServer = { println(s"Starting test server $testBuild") IO.copyDirectory(serverTestBase / testBuild, temp / testBuild) - // Each test server instance will be executed in a Thread pool separated from the tests - val testServer = TestServer(temp / testBuild, scalaVersion, sbtVersion, classpath) + val testServer = + TestServer(temp / testBuild, scalaVersion, sbtVersion, classpath, subscribeToAll) // checking last log message after initialization // if something goes wrong here the communication streams are corrupted, restarting val init = @@ -155,7 +164,8 @@ case class TestServer( baseDirectory: File, scalaVersion: String, sbtVersion: String, - classpath: Seq[File] + classpath: Seq[File], + subscribeToAll: Boolean = true ) { import TestServer.hostLog @@ -227,8 +237,11 @@ case class TestServer( } // initiate handshake + private val initOptions = + if subscribeToAll then """{ "skipAnalysis": true, "canWork": true }""" + else """{ "skipAnalysis": true, "canWork": true, "subscribeToAll": false }""" sendJsonRpc( - s"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { "skipAnalysis": true, "canWork": true } } }""" + s"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": $initOptions } }""" ) def test(f: TestServer => Future[Unit]): Future[Unit] = f(this)