From c20d83a8e2586e1d613274ad32651eb8205a43e4 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Sat, 3 Jan 2026 15:10:31 -0500 Subject: [PATCH] Use TCP for the fork test **Problem** Forked test currently uses stdio, but that println to be lost. **Solution** Use TCP for communication similar to sbt 1.x. --- build.sbt | 2 + .../src/main/scala/sbt/ForkTests.scala | 4 +- .../scala/sbt/internal/WorkerExchange.scala | 59 +++++++++++++++--- .../sbt/internal/WorkerExchangeTest.scala | 60 +++++++++++++++++++ .../java/sbt/internal/worker1/WorkerMain.java | 49 +++++++++++---- 5 files changed, 153 insertions(+), 21 deletions(-) create mode 100644 main-actions/src/test/scala/sbt/internal/WorkerExchangeTest.scala diff --git a/build.sbt b/build.sbt index fde047e09..3908e4928 100644 --- a/build.sbt +++ b/build.sbt @@ -546,6 +546,8 @@ lazy val actionsProj = (project in file("main-actions")) Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat, mimaSettings, mimaBinaryIssueFilters ++= Vector( + exclude[DirectMissingMethodProblem]("sbt.internal.WorkerExchange.*"), + exclude[DirectMissingMethodProblem]("sbt.internal.WorkerProxy.*"), ), ) .dependsOn(lmCore) diff --git a/main-actions/src/main/scala/sbt/ForkTests.scala b/main-actions/src/main/scala/sbt/ForkTests.scala index aea61c893..81d5ac606 100755 --- a/main-actions/src/main/scala/sbt/ForkTests.scala +++ b/main-actions/src/main/scala/sbt/ForkTests.scala @@ -28,6 +28,7 @@ import scala.util.Random import scala.util.control.NonFatal import scala.jdk.CollectionConverters.* import scala.sys.process.Process +import sbt.internal.WorkerConnection /** * This implements forked testing, in cooperation with the worker CLI, @@ -152,7 +153,8 @@ private[sbt] object ForkTests: ) testListeners.foreach(_.doInit()) val result = - val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles) + val ct = WorkerConnection.Tcp + val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles, ct) val wl = React(randomId, log, opts.testListeners, resultsAcc, w.process) try WorkerExchange.registerListener(wl) diff --git a/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala index 0d9133beb..f4ce25e1c 100644 --- a/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala +++ b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala @@ -11,7 +11,10 @@ package internal import org.scalasbt.shadedgson.com.google.gson.Gson import java.io.* +import java.net.{ InetAddress, ServerSocket } +import java.util.Scanner import sbt.io.IO +import sbt.internal.io.Retry import sbt.internal.worker1.* import sbt.testing.Framework import scala.sys.process.{ BasicIO, Process, ProcessIO } @@ -22,24 +25,52 @@ import scala.concurrent.duration.* object WorkerExchange: val listeners: mutable.ListBuffer[WorkerResponseListener] = ListBuffer.empty + private val loopback = InetAddress.getByName(null) /** * Start a worker process. */ - def startWorker(fo: ForkOptions, extraCp: Seq[File]): WorkerProxy = + def startWorker( + fo: ForkOptions, + extraCp: Seq[File], + connectionType: WorkerConnection, + ): WorkerProxy = val fullCp = Seq( IO.classLocationPath(classOf[WorkerMain]).toFile, IO.classLocationPath(classOf[Framework]).toFile, IO.classLocationPath(classOf[Gson]).toFile, ) ++ extraCp + val inputRef = Promise[OutputStream]() + val socketOpt = connectionType match + case WorkerConnection.Tcp => + val serverSocket = Retry(ServerSocket(0, 1, loopback)) + val accepter = Thread(() => { + val socket = serverSocket.accept() + inputRef.success(socket.getOutputStream()) + val scanner = Scanner(socket.getInputStream(), "UTF-8") + while scanner.hasNextLine() do notifyListeners(scanner.nextLine()) + }) + accepter.start() + Some(serverSocket) + case _ => None val options = Seq( "-classpath", fullCp.mkString(File.pathSeparator), classOf[WorkerMain].getCanonicalName, - ) - val inputRef = Promise[OutputStream]() + ) ++ + (socketOpt match + case Some(s) => Seq("--tcp", s.getLocalPort().toString()) + case _ => Nil + ) + val onStdoutLine: String => Unit = connectionType match + case WorkerConnection.Stdio => notifyListeners + case _ => (line) => scala.Console.out.println(line) val processIo = ProcessIO( - in = (input) => inputRef.success(input), + in = (input) => + (connectionType match + case WorkerConnection.Stdio => inputRef.success(input) + case _ => () + ), out = BasicIO.processFully(onStdoutLine), err = BasicIO.processFully((line) => scala.Console.err.println(line)), ) @@ -47,7 +78,7 @@ object WorkerExchange: val p = Fork.java.fork(forkWithIo, options) val forkTimeout = 30.seconds val input = Await.result(inputRef.future, forkTimeout) - WorkerProxy(input, p, options) + WorkerProxy(input, p, options, socketOpt) def registerListener(listener: WorkerResponseListener): Unit = synchronized: @@ -61,16 +92,22 @@ object WorkerExchange: /** * Unified worker output handler. */ - def onStdoutLine(line: String): Unit = + def notifyListeners(line: String): Unit = synchronized: listeners.foreach: wl => wl(line) end WorkerExchange -class WorkerProxy(input: OutputStream, val process: Process, val options: Seq[String]) - extends AutoCloseable: +class WorkerProxy( + input: OutputStream, + val process: Process, + val options: Seq[String], + serverSocket: Option[ServerSocket], +) extends AutoCloseable: lazy val inputStream = PrintStream(input) - def close(): Unit = input.close() + def close(): Unit = + input.close() + serverSocket.foreach(_.close()) def blockForExitCode(): Int = if !process.isAlive then process.exitValue() else Fork.blockForExitCode(process) @@ -89,3 +126,7 @@ end WorkerProxy abstract class WorkerResponseListener extends Function1[String, Unit]: def notifyExit(p: Process): Unit + +enum WorkerConnection: + case Stdio + case Tcp diff --git a/main-actions/src/test/scala/sbt/internal/WorkerExchangeTest.scala b/main-actions/src/test/scala/sbt/internal/WorkerExchangeTest.scala new file mode 100644 index 000000000..24eeb87c2 --- /dev/null +++ b/main-actions/src/test/scala/sbt/internal/WorkerExchangeTest.scala @@ -0,0 +1,60 @@ +package sbt +package internal + +import hedgehog.* +import hedgehog.runner.* +import hedgehog.core.{ ShrinkLimit, SuccessCount } +import hedgehog.core.Result +import scala.sys.process.Process + +object WorkerExchangeTest extends Properties: + given Gen[WorkerConnection] = + Gen.choice1(Gen.constant(WorkerConnection.Stdio), Gen.constant(WorkerConnection.Tcp)) + + def gen[A1: Gen]: Gen[A1] = summon[Gen[A1]] + + override lazy val tests: List[Test] = List( + propertyN("non-jsonrpc should return exit code 1", propBadInput, 10), + propertyN("bye should return response json with a result", propBye, 10), + ) + + def propertyN(name: String, result: => Property, n: Int): Test = + Test(name, result) + .config(_.copy(testLimit = SuccessCount(n), shrinkLimit = ShrinkLimit(n * 10))) + + def propBadInput: Property = + for + ct <- gen[WorkerConnection].forAll + w = WorkerExchange.startWorker(ForkOptions(), Nil, ct) + yield + w.println("{}") + val exitCode = w.blockForExitCode() + Result.assert(exitCode == 1) + + val intGen = Gen.int(Range.linear(1, 100)) + + def propBye: Property = + for + ct <- gen[WorkerConnection].forAll + i <- intGen.forAll + w = WorkerExchange.startWorker(ForkOptions(), Nil, ct) + yield withListener: l => + w.println(s"""{"jsonrpc": "2.0", "method": "bye", "params": {}, "id": $i}""") + val exitCode = w.blockForExitCode() + Result + .assert(exitCode == 0) + .and(Result.assert(l.sb.toString() == s"""{ "jsonrpc": "2.0", "result": 0, "id": $i }""")) + .log(s"\"${l.sb.toString()}\"") + + def withListener[A1](f: ConcreteListener => A1) = + val l = ConcreteListener() + try + WorkerExchange.registerListener(l) + f(l) + finally WorkerExchange.unregisterListener(l) + + class ConcreteListener extends WorkerResponseListener: + val sb = StringBuilder() + def notifyExit(p: Process): Unit = () + def apply(line: String): Unit = sb.append(line) +end WorkerExchangeTest diff --git a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java index 995ada951..262371790 100644 --- a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java +++ b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java @@ -19,6 +19,8 @@ import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.IOException; import java.io.PrintStream; +import java.net.InetAddress; +import java.net.Socket; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; @@ -28,12 +30,16 @@ import java.util.Scanner; import sbt.testing.*; /** - * WorkerMain that communicates via the stdin and stdout using JSON-RPC + * WorkerMain that communicates via the stdio or socket using JSON-RPC * (https://www.jsonrpc.org/specification). */ public final class WorkerMain { private PrintStream originalOut; private InputStream originalIn; + + // When using stdout, this is the original stdout + // When using tcp, this is going to be the socket out + private PrintStream jsonOut; private Scanner inScanner; public static Gson mkGson() { @@ -61,6 +67,11 @@ public final class WorkerMain { WorkerMain app = new WorkerMain(); app.consoleWork(); System.exit(0); + } else if (args.length == 2 && args[0].equals("--tcp")) { + WorkerMain app = new WorkerMain(); + int serverPort = Integer.parseInt(args[1]); + app.socketWork(serverPort); + System.exit(0); } else { System.err.println("missing args"); System.exit(1); @@ -73,13 +84,26 @@ public final class WorkerMain { WorkerMain() { this.originalOut = System.out; - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - System.setOut(new PrintStream(baos)); this.originalIn = System.in; - this.inScanner = new Scanner(this.originalIn); + this.jsonOut = this.originalOut; } void consoleWork() throws Exception { + this.jsonOut = this.originalOut; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + System.setOut(new PrintStream(baos)); + this.inScanner = new Scanner(this.originalIn, "UTF-8"); + if (this.inScanner.hasNextLine()) { + String line = this.inScanner.nextLine(); + process(line); + } + } + + void socketWork(int serverPort) throws Exception { + InetAddress loopback = InetAddress.getByName(null); + Socket client = new Socket(loopback, serverPort); + this.jsonOut = new PrintStream(client.getOutputStream(), true, "UTF-8"); + this.inScanner = new Scanner(client.getInputStream(), "UTF-8"); if (this.inScanner.hasNextLine()) { String line = this.inScanner.nextLine(); process(line); @@ -87,11 +111,11 @@ public final class WorkerMain { } /** This processes single request of supposed JSON line. */ - void process(String json) { + void process(String json) throws Exception { JsonElement elem = JsonParser.parseString(json); JsonObject o = elem.getAsJsonObject(); if (!o.has("jsonrpc")) { - return; + throw new IllegalArgumentException("jsonrpc expected but got: " + json); } Gson g = WorkerMain.mkGson(); long id = o.getAsJsonPrimitive("id").getAsLong(); @@ -107,17 +131,20 @@ public final class WorkerMain { TestInfo testInfo = g.fromJson(params, TestInfo.class); test(id, testInfo); break; + case "bye": + break; } String response = String.format("{ \"jsonrpc\": \"2.0\", \"result\": 0, \"id\": %d }", id); - this.originalOut.println(response); - this.originalOut.flush(); + this.jsonOut.println(response); + this.jsonOut.flush(); } catch (Throwable e) { WorkerError err = new WorkerError(1, e.getMessage()); String errMessage = g.toJson(err, err.getClass()); String errJson = String.format("{ \"jsonrpc\": \"2.0\", \"error\": %s, \"id\": %d }", errMessage, id); - this.originalOut.println(errJson); - this.originalOut.flush(); + this.jsonOut.println(errJson); + this.jsonOut.flush(); + e.printStackTrace(); } } @@ -143,7 +170,7 @@ public final class WorkerMain { RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo; ClassLoader parent = new ForkTestMain().getClass().getClassLoader(); try (URLClassLoader cl = createClassLoader(jvmRunInfo, parent)) { - ForkTestMain.main(id, info, this.originalOut, cl); + ForkTestMain.main(id, info, this.jsonOut, cl); } } else { throw new RuntimeException("only jvm is supported");