Use TCP for the fork test

**Problem**
Forked test currently uses stdio, but that println to be lost.

**Solution**
Use TCP for communication similar to sbt 1.x.
This commit is contained in:
Eugene Yokota 2026-01-03 15:10:31 -05:00
parent 2b404d6a60
commit c20d83a8e2
5 changed files with 153 additions and 21 deletions

View File

@ -546,6 +546,8 @@ lazy val actionsProj = (project in file("main-actions"))
Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat,
mimaSettings,
mimaBinaryIssueFilters ++= Vector(
exclude[DirectMissingMethodProblem]("sbt.internal.WorkerExchange.*"),
exclude[DirectMissingMethodProblem]("sbt.internal.WorkerProxy.*"),
),
)
.dependsOn(lmCore)

View File

@ -28,6 +28,7 @@ import scala.util.Random
import scala.util.control.NonFatal
import scala.jdk.CollectionConverters.*
import scala.sys.process.Process
import sbt.internal.WorkerConnection
/**
* This implements forked testing, in cooperation with the worker CLI,
@ -152,7 +153,8 @@ private[sbt] object ForkTests:
)
testListeners.foreach(_.doInit())
val result =
val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles)
val ct = WorkerConnection.Tcp
val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles, ct)
val wl = React(randomId, log, opts.testListeners, resultsAcc, w.process)
try
WorkerExchange.registerListener(wl)

View File

@ -11,7 +11,10 @@ package internal
import org.scalasbt.shadedgson.com.google.gson.Gson
import java.io.*
import java.net.{ InetAddress, ServerSocket }
import java.util.Scanner
import sbt.io.IO
import sbt.internal.io.Retry
import sbt.internal.worker1.*
import sbt.testing.Framework
import scala.sys.process.{ BasicIO, Process, ProcessIO }
@ -22,24 +25,52 @@ import scala.concurrent.duration.*
object WorkerExchange:
val listeners: mutable.ListBuffer[WorkerResponseListener] = ListBuffer.empty
private val loopback = InetAddress.getByName(null)
/**
* Start a worker process.
*/
def startWorker(fo: ForkOptions, extraCp: Seq[File]): WorkerProxy =
def startWorker(
fo: ForkOptions,
extraCp: Seq[File],
connectionType: WorkerConnection,
): WorkerProxy =
val fullCp = Seq(
IO.classLocationPath(classOf[WorkerMain]).toFile,
IO.classLocationPath(classOf[Framework]).toFile,
IO.classLocationPath(classOf[Gson]).toFile,
) ++ extraCp
val inputRef = Promise[OutputStream]()
val socketOpt = connectionType match
case WorkerConnection.Tcp =>
val serverSocket = Retry(ServerSocket(0, 1, loopback))
val accepter = Thread(() => {
val socket = serverSocket.accept()
inputRef.success(socket.getOutputStream())
val scanner = Scanner(socket.getInputStream(), "UTF-8")
while scanner.hasNextLine() do notifyListeners(scanner.nextLine())
})
accepter.start()
Some(serverSocket)
case _ => None
val options = Seq(
"-classpath",
fullCp.mkString(File.pathSeparator),
classOf[WorkerMain].getCanonicalName,
)
val inputRef = Promise[OutputStream]()
) ++
(socketOpt match
case Some(s) => Seq("--tcp", s.getLocalPort().toString())
case _ => Nil
)
val onStdoutLine: String => Unit = connectionType match
case WorkerConnection.Stdio => notifyListeners
case _ => (line) => scala.Console.out.println(line)
val processIo = ProcessIO(
in = (input) => inputRef.success(input),
in = (input) =>
(connectionType match
case WorkerConnection.Stdio => inputRef.success(input)
case _ => ()
),
out = BasicIO.processFully(onStdoutLine),
err = BasicIO.processFully((line) => scala.Console.err.println(line)),
)
@ -47,7 +78,7 @@ object WorkerExchange:
val p = Fork.java.fork(forkWithIo, options)
val forkTimeout = 30.seconds
val input = Await.result(inputRef.future, forkTimeout)
WorkerProxy(input, p, options)
WorkerProxy(input, p, options, socketOpt)
def registerListener(listener: WorkerResponseListener): Unit =
synchronized:
@ -61,16 +92,22 @@ object WorkerExchange:
/**
* Unified worker output handler.
*/
def onStdoutLine(line: String): Unit =
def notifyListeners(line: String): Unit =
synchronized:
listeners.foreach: wl =>
wl(line)
end WorkerExchange
class WorkerProxy(input: OutputStream, val process: Process, val options: Seq[String])
extends AutoCloseable:
class WorkerProxy(
input: OutputStream,
val process: Process,
val options: Seq[String],
serverSocket: Option[ServerSocket],
) extends AutoCloseable:
lazy val inputStream = PrintStream(input)
def close(): Unit = input.close()
def close(): Unit =
input.close()
serverSocket.foreach(_.close())
def blockForExitCode(): Int =
if !process.isAlive then process.exitValue()
else Fork.blockForExitCode(process)
@ -89,3 +126,7 @@ end WorkerProxy
abstract class WorkerResponseListener extends Function1[String, Unit]:
def notifyExit(p: Process): Unit
enum WorkerConnection:
case Stdio
case Tcp

View File

@ -0,0 +1,60 @@
package sbt
package internal
import hedgehog.*
import hedgehog.runner.*
import hedgehog.core.{ ShrinkLimit, SuccessCount }
import hedgehog.core.Result
import scala.sys.process.Process
object WorkerExchangeTest extends Properties:
given Gen[WorkerConnection] =
Gen.choice1(Gen.constant(WorkerConnection.Stdio), Gen.constant(WorkerConnection.Tcp))
def gen[A1: Gen]: Gen[A1] = summon[Gen[A1]]
override lazy val tests: List[Test] = List(
propertyN("non-jsonrpc should return exit code 1", propBadInput, 10),
propertyN("bye should return response json with a result", propBye, 10),
)
def propertyN(name: String, result: => Property, n: Int): Test =
Test(name, result)
.config(_.copy(testLimit = SuccessCount(n), shrinkLimit = ShrinkLimit(n * 10)))
def propBadInput: Property =
for
ct <- gen[WorkerConnection].forAll
w = WorkerExchange.startWorker(ForkOptions(), Nil, ct)
yield
w.println("{}")
val exitCode = w.blockForExitCode()
Result.assert(exitCode == 1)
val intGen = Gen.int(Range.linear(1, 100))
def propBye: Property =
for
ct <- gen[WorkerConnection].forAll
i <- intGen.forAll
w = WorkerExchange.startWorker(ForkOptions(), Nil, ct)
yield withListener: l =>
w.println(s"""{"jsonrpc": "2.0", "method": "bye", "params": {}, "id": $i}""")
val exitCode = w.blockForExitCode()
Result
.assert(exitCode == 0)
.and(Result.assert(l.sb.toString() == s"""{ "jsonrpc": "2.0", "result": 0, "id": $i }"""))
.log(s"\"${l.sb.toString()}\"")
def withListener[A1](f: ConcreteListener => A1) =
val l = ConcreteListener()
try
WorkerExchange.registerListener(l)
f(l)
finally WorkerExchange.unregisterListener(l)
class ConcreteListener extends WorkerResponseListener:
val sb = StringBuilder()
def notifyExit(p: Process): Unit = ()
def apply(line: String): Unit = sb.append(line)
end WorkerExchangeTest

View File

@ -19,6 +19,8 @@ import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.net.InetAddress;
import java.net.Socket;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
@ -28,12 +30,16 @@ import java.util.Scanner;
import sbt.testing.*;
/**
* WorkerMain that communicates via the stdin and stdout using JSON-RPC
* WorkerMain that communicates via the stdio or socket using JSON-RPC
* (https://www.jsonrpc.org/specification).
*/
public final class WorkerMain {
private PrintStream originalOut;
private InputStream originalIn;
// When using stdout, this is the original stdout
// When using tcp, this is going to be the socket out
private PrintStream jsonOut;
private Scanner inScanner;
public static Gson mkGson() {
@ -61,6 +67,11 @@ public final class WorkerMain {
WorkerMain app = new WorkerMain();
app.consoleWork();
System.exit(0);
} else if (args.length == 2 && args[0].equals("--tcp")) {
WorkerMain app = new WorkerMain();
int serverPort = Integer.parseInt(args[1]);
app.socketWork(serverPort);
System.exit(0);
} else {
System.err.println("missing args");
System.exit(1);
@ -73,13 +84,26 @@ public final class WorkerMain {
WorkerMain() {
this.originalOut = System.out;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
System.setOut(new PrintStream(baos));
this.originalIn = System.in;
this.inScanner = new Scanner(this.originalIn);
this.jsonOut = this.originalOut;
}
void consoleWork() throws Exception {
this.jsonOut = this.originalOut;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
System.setOut(new PrintStream(baos));
this.inScanner = new Scanner(this.originalIn, "UTF-8");
if (this.inScanner.hasNextLine()) {
String line = this.inScanner.nextLine();
process(line);
}
}
void socketWork(int serverPort) throws Exception {
InetAddress loopback = InetAddress.getByName(null);
Socket client = new Socket(loopback, serverPort);
this.jsonOut = new PrintStream(client.getOutputStream(), true, "UTF-8");
this.inScanner = new Scanner(client.getInputStream(), "UTF-8");
if (this.inScanner.hasNextLine()) {
String line = this.inScanner.nextLine();
process(line);
@ -87,11 +111,11 @@ public final class WorkerMain {
}
/** This processes single request of supposed JSON line. */
void process(String json) {
void process(String json) throws Exception {
JsonElement elem = JsonParser.parseString(json);
JsonObject o = elem.getAsJsonObject();
if (!o.has("jsonrpc")) {
return;
throw new IllegalArgumentException("jsonrpc expected but got: " + json);
}
Gson g = WorkerMain.mkGson();
long id = o.getAsJsonPrimitive("id").getAsLong();
@ -107,17 +131,20 @@ public final class WorkerMain {
TestInfo testInfo = g.fromJson(params, TestInfo.class);
test(id, testInfo);
break;
case "bye":
break;
}
String response = String.format("{ \"jsonrpc\": \"2.0\", \"result\": 0, \"id\": %d }", id);
this.originalOut.println(response);
this.originalOut.flush();
this.jsonOut.println(response);
this.jsonOut.flush();
} catch (Throwable e) {
WorkerError err = new WorkerError(1, e.getMessage());
String errMessage = g.toJson(err, err.getClass());
String errJson =
String.format("{ \"jsonrpc\": \"2.0\", \"error\": %s, \"id\": %d }", errMessage, id);
this.originalOut.println(errJson);
this.originalOut.flush();
this.jsonOut.println(errJson);
this.jsonOut.flush();
e.printStackTrace();
}
}
@ -143,7 +170,7 @@ public final class WorkerMain {
RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo;
ClassLoader parent = new ForkTestMain().getClass().getClassLoader();
try (URLClassLoader cl = createClassLoader(jvmRunInfo, parent)) {
ForkTestMain.main(id, info, this.originalOut, cl);
ForkTestMain.main(id, info, this.jsonOut, cl);
}
} else {
throw new RuntimeException("only jvm is supported");