diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index d5a11ccf6..855f5c467 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -716,6 +716,7 @@ object Keys { private[sbt] val taskCancelStrategy = settingKey[State => TaskCancellationStrategy]("Experimental task cancellation handler.").withRank(DTask) private[sbt] val cacheStoreFactoryFactory = AttributeKey[CacheStoreFactoryFactory]("cache-store-factory-factory") private[sbt] val bootServerSocket = AttributeKey[BootServerSocket]("boot-server-socket") + private[sbt] val channelProjectCursors = AttributeKey[Map[String, ProjectRef]]("channel-project-cursors", "Per-channel project cursors for multi-client server mode.") val fileCacheSize = settingKey[String]("The approximate maximum size in bytes of the cache used to store previous task results. For example, it could be set to \"256M\" to make the maximum size 256 megabytes.") // Experimental in sbt 0.13.2 to enable grabbing semantic compile failures. diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index a41aedb1f..4ca0b76b6 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -254,6 +254,7 @@ private[sbt] object MainLoop: * Dropping (FastTrackCommands.evaluate ... getOrElse) should be functionally identical * but slower. */ + val prevSessionCurrent = termState.get(Keys.sessionSettings).map(_.current) val newState = try { var errorMsg: Option[String] = None @@ -267,7 +268,7 @@ private[sbt] object MainLoop: ) case None => currentCmdProgress.foreach(_.afterCommand(exec.commandLine, Right(res))) } - res + syncChannelCursor(res, prevSessionCurrent, channelName) } catch { case _: RejectedExecutionException => val cancelled = new Cancelled(exec.commandLine) @@ -376,6 +377,21 @@ private[sbt] object MainLoop: then ExitCode(ErrorCodes.UnknownError) else ExitCode.Success + private def syncChannelCursor( + state: State, + prevSessionCurrent: Option[ProjectRef], + channelName: Option[String] + ): State = + channelName match + case Some(cn) => + val newSessionCurrent = state.get(Keys.sessionSettings).map(_.current) + (prevSessionCurrent, newSessionCurrent) match + case (prev, Some(curr)) if prev != Some(curr) => + val cursors = state.get(Keys.channelProjectCursors).getOrElse(Map.empty) + state.put(Keys.channelProjectCursors, cursors.updated(cn, curr)) + case _ => state + case None => state + end MainLoop // No stack trace since this is just to notify the user which command they cancelled diff --git a/main/src/main/scala/sbt/ProjectExtra.scala b/main/src/main/scala/sbt/ProjectExtra.scala index 63aa3eeec..d241a880b 100755 --- a/main/src/main/scala/sbt/ProjectExtra.scala +++ b/main/src/main/scala/sbt/ProjectExtra.scala @@ -237,8 +237,12 @@ trait ProjectExtra extends Scoped.Syntax: def isProjectLoaded(state: State): Boolean = (state has Keys.sessionSettings) && (state has Keys.stateBuildStructure) - def extract(state: State): Extracted = - Project.extract(Project.session(state), Project.structure(state)) + def extract(state: State): Extracted = { + val se = Project.session(state) + val st = Project.structure(state) + val currentRef = internal.ProjectNavigation.effectiveCurrentRef(state) + Extracted(st, se, currentRef)(using Project.showContextKey2(se)) + } private[sbt] def extract(se: SessionSettings, st: BuildStructure): Extracted = Extracted(st, se, se.current)(using Project.showContextKey2(se)) diff --git a/main/src/main/scala/sbt/internal/ProjectNavigation.scala b/main/src/main/scala/sbt/internal/ProjectNavigation.scala index 776a53fb9..e9c8cdefe 100644 --- a/main/src/main/scala/sbt/internal/ProjectNavigation.scala +++ b/main/src/main/scala/sbt/internal/ProjectNavigation.scala @@ -11,21 +11,49 @@ package internal import java.net.URI import sbt.internal.util.complete, complete.{ DefaultParsers, Parser }, DefaultParsers.* -import Keys.sessionSettings +import Keys.{ channelProjectCursors, sessionSettings } import sbt.ProjectExtra.{ extract, updateCurrent } object ProjectNavigation { def command(s: State): Parser[() => State] = if s.get(sessionSettings).isEmpty then failure("No project loaded") else (new ProjectNavigation(s)).command + + private[sbt] def getChannelName(s: State): Option[String] = + s.currentCommand + .flatMap(_.source) + .map(_.channelName) + .orElse(StandardMain.exchange.currentExec.flatMap(_.source).map(_.channelName)) + + def getChannelCursor(s: State): Option[ProjectRef] = + for { + channelName <- getChannelName(s) + cursors <- s.get(channelProjectCursors) + ref <- cursors.get(channelName) + } yield ref + + def setChannelCursor(s: State, ref: ProjectRef): State = + getChannelName(s) match { + case Some(channelName) => + val cursors = s.get(channelProjectCursors).getOrElse(Map.empty) + s.put(channelProjectCursors, cursors.updated(channelName, ref)) + case None => s + } + + def effectiveCurrentRef(s: State): ProjectRef = + getChannelCursor(s).getOrElse( + s.get(sessionSettings).map(_.current).getOrElse(sys.error("Session not initialized")) + ) } final class ProjectNavigation(s: State) { val extracted: Extracted = Project.extract(s) - import extracted.{ currentRef, structure, session } + import extracted.{ structure, session } + + def effectiveCurrentRef: ProjectRef = ProjectNavigation.effectiveCurrentRef(s) def setProject(nuri: URI, nid: String): State = { - val neval = if (currentRef.build == nuri) session.currentEval else mkEval(nuri) + val neval = if (effectiveCurrentRef.build == nuri) session.currentEval else mkEval(nuri) Project.updateCurrent(s.put(sessionSettings, session.setCurrent(nuri, nid, neval))) } @@ -38,11 +66,10 @@ final class ProjectNavigation(s: State) { show(); s case Some(BuildRef(uri)) => changeBuild(uri) case Some(ProjectRef(uri, id)) => selectProject(uri, id) - /* else if(to.forall(_ == '.')) - if(to.length > 1) gotoParent(to.length - 1, nav, s) else s */ // semantics currently undefined } - def show(): Unit = s.log.info(s"${currentRef.project} (in build ${currentRef.build})") + def show(): Unit = + s.log.info(s"${effectiveCurrentRef.project} (in build ${effectiveCurrentRef.build})") def selectProject(uri: URI, to: String): State = if (structure.units(uri).defined.contains(to)) @@ -63,12 +90,13 @@ final class ProjectNavigation(s: State) { import Parser.*, complete.Parsers.* val parser: Parser[Option[ResolvedReference]] = { - val reference = Act.resolvedReference(structure.index.keyIndex, currentRef.build, success(())) + val reference = + Act.resolvedReference(structure.index.keyIndex, effectiveCurrentRef.build, success(())) val root = token('/' ^^^ rootRef) success(None) | some(token(Space) ~> (root | reference)) } - def rootRef = ProjectRef(currentRef.build, getRoot(currentRef.build)) + def rootRef = ProjectRef(effectiveCurrentRef.build, getRoot(effectiveCurrentRef.build)) val command: Parser[() => State] = Command.applyEffect(parser)(apply) } diff --git a/server-test/src/server-test/channel-cursor/build.sbt b/server-test/src/server-test/channel-cursor/build.sbt new file mode 100644 index 000000000..b32bc101e --- /dev/null +++ b/server-test/src/server-test/channel-cursor/build.sbt @@ -0,0 +1,28 @@ +scalaVersion := "3.8.1" + +val printCurrentProject = taskKey[Unit]("Prints current project name") + +lazy val projectA = (project in file("projectA")) + .settings( + name := "project-a", + printCurrentProject := { + streams.value.log.info(s"CURRENT_PROJECT_IS:${name.value}") + } + ) + +lazy val projectB = (project in file("projectB")) + .settings( + name := "project-b", + printCurrentProject := { + streams.value.log.info(s"CURRENT_PROJECT_IS:${name.value}") + } + ) + +lazy val root = (project in file(".")) + .aggregate(projectA, projectB) + .settings( + name := "root", + printCurrentProject := { + streams.value.log.info(s"CURRENT_PROJECT_IS:${name.value}") + } + ) diff --git a/server-test/src/server-test/channel-cursor/projectA/.gitkeep b/server-test/src/server-test/channel-cursor/projectA/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/server-test/src/server-test/channel-cursor/projectB/.gitkeep b/server-test/src/server-test/channel-cursor/projectB/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/server-test/src/test/scala/testpkg/ChannelCursorTest.scala b/server-test/src/test/scala/testpkg/ChannelCursorTest.scala new file mode 100644 index 000000000..6656d80f8 --- /dev/null +++ b/server-test/src/test/scala/testpkg/ChannelCursorTest.scala @@ -0,0 +1,151 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package testpkg + +import java.io.IOException +import java.net.Socket +import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } +import java.util.concurrent.atomic.AtomicBoolean +import sbt.protocol.ClientSocket + +import scala.annotation.tailrec +import scala.concurrent.duration.* + +class ChannelCursorTest extends AbstractServerTest { + override val testDirectory: String = "channel-cursor" + + private def createSecondConnection() + : (Socket, java.io.OutputStream, LinkedBlockingQueue[String], AtomicBoolean) = { + val portfile = testPath.resolve("project/target/active.json").toFile + @tailrec + def connect(attempt: Int): Socket = { + val res = + try Some(ClientSocket.socket(portfile)._1) + catch { case _: IOException if attempt < 10 => None } + res match { + case Some(s) => s + case _ => + Thread.sleep(100) + connect(attempt + 1) + } + } + val sk = connect(0) + val out = sk.getOutputStream + val in = sk.getInputStream + val lines = new LinkedBlockingQueue[String] + val running = new AtomicBoolean(true) + new Thread( + () => { + while (running.get) { + try lines.put(sbt.ReadJson(in, running)) + catch { case _: Exception => running.set(false) } + } + }, + "sbt-server-test-read-thread-2" + ) { + setDaemon(true) + start() + } + (sk, out, lines, running) + } + + private def sendJsonRpc(out: java.io.OutputStream, message: String): Unit = { + def writeLine(s: String): Unit = { + val retByte: Byte = '\r'.toByte + val delimiter: Byte = '\n'.toByte + if (s != "") { + out.write(s.getBytes("UTF-8")) + } + out.write(retByte.toInt) + out.write(delimiter.toInt) + out.flush + } + writeLine(s"""Content-Length: ${message.size + 2}""") + writeLine("") + writeLine(message) + } + + private def waitForString(lines: LinkedBlockingQueue[String], duration: FiniteDuration)( + f: String => Boolean + ): Boolean = { + val deadline = duration.fromNow + @tailrec def impl(): Boolean = + lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match { + case null => false + case s => if (!f(s) && !deadline.isOverdue) impl() else !deadline.isOverdue() + } + impl() + } + + test("channel cursor - independent project cursors") { + val (sk2, out2, lines2, running2) = createSecondConnection() + try { + sendJsonRpc( + out2, + """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { "skipAnalysis": true } } }""" + ) + waitForString(lines2, 10.seconds)(_.contains(""""capabilities":{""")) + + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 10, "method": "sbt/exec", "params": { "commandLine": "project projectA" } }""" + ) + assert( + svr.waitForString(10.seconds) { s => + println(s"[channel1] $s") + s.contains("projectA") || s.contains("\"execId\":10") + }, + "Channel 1 should switch to projectA" + ) + + sendJsonRpc( + out2, + """{ "jsonrpc": "2.0", "id": 20, "method": "sbt/exec", "params": { "commandLine": "project projectB" } }""" + ) + assert( + waitForString(lines2, 10.seconds) { s => + println(s"[channel2] $s") + s.contains("projectB") || s.contains("\"execId\":20") + }, + "Channel 2 should switch to projectB" + ) + + svr.sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "printCurrentProject" } }""" + ) + var foundProjectA = false + assert( + svr.waitForString(30.seconds) { s => + println(s"[channel1 name] $s") + if (s.contains("CURRENT_PROJECT_IS:project-a")) foundProjectA = true + s.contains("\"execId\":11") && s.contains("\"status\":\"Done\"") + }, + "First channel printCurrentProject command should complete" + ) + assert(foundProjectA, "First channel should still be on projectA") + + sendJsonRpc( + out2, + """{ "jsonrpc": "2.0", "id": 21, "method": "sbt/exec", "params": { "commandLine": "printCurrentProject" } }""" + ) + var foundProjectB = false + assert( + waitForString(lines2, 30.seconds) { s => + println(s"[channel2 name] $s") + if (s.contains("CURRENT_PROJECT_IS:project-b")) foundProjectB = true + s.contains("\"execId\":21") && s.contains("\"status\":\"Done\"") + }, + "Second channel printCurrentProject command should complete" + ) + assert(foundProjectB, "Second channel should still be on projectB") + } finally { + running2.set(false) + sk2.close() + } + } +}