diff --git a/build.sbt b/build.sbt index c3ec08cc4..c90b04fe2 100644 --- a/build.sbt +++ b/build.sbt @@ -471,7 +471,7 @@ lazy val utilScripted = (project in file("internal") / "util-scripted") // Runner for uniform test interface lazy val testingProj = (project in file("testing")) .enablePlugins(ContrabandPlugin, JsonCodecPlugin) - .dependsOn(testAgentProj, utilLogging) + .dependsOn(workerProj, utilLogging) .settings( baseSettings, name := "Testing", @@ -518,26 +518,15 @@ lazy val testingProj = (project in file("testing")) ) .configure(addSbtIO, addSbtCompilerClasspath) -// Testing agent for running tests in a separate process. -lazy val testAgentProj = (project in file("testing") / "agent") - .settings( - minimalSettings, - crossPaths := false, - autoScalaLibrary := false, - Compile / doc / javacOptions := Nil, - name := "Test Agent", - libraryDependencies += testInterface, - mimaSettings, - ) - lazy val workerProj = (project in file("worker")) .dependsOn(exampleWorkProj % Test) .settings( name := "worker", testedBaseSettings, Compile / doc / javacOptions := Nil, + crossPaths := false, autoScalaLibrary := false, - libraryDependencies += gson, + libraryDependencies ++= Seq(gson, testInterface), libraryDependencies += "org.scala-lang" %% "scala3-library" % scalaVersion.value % Test, // run / fork := false, Test / fork := true, @@ -1218,7 +1207,6 @@ def allProjects = logicProj, completeProj, testingProj, - testAgentProj, taskProj, stdTaskProj, runProj, diff --git a/main-actions/src/main/scala/sbt/ForkTests.scala b/main-actions/src/main/scala/sbt/ForkTests.scala index 677be9e07..7fdf27043 100755 --- a/main-actions/src/main/scala/sbt/ForkTests.scala +++ b/main-actions/src/main/scala/sbt/ForkTests.scala @@ -8,21 +8,33 @@ package sbt -import scala.collection.mutable +import com.google.gson.{ JsonObject, JsonParser } import testing.{ Logger as _, Task as _, * } -import scala.util.control.NonFatal -import java.net.ServerSocket import java.io.* +import java.util.ArrayList import Tests.{ Output as TestOutput, * } -import sbt.io.IO import sbt.util.Logger import sbt.ConcurrentRestrictions.Tag import sbt.protocol.testing.* +import sbt.internal.{ WorkerExchange, WorkerResponseListener } import sbt.internal.util.Util.* -import sbt.internal.util.{ Terminal as UTerminal } +import sbt.internal.util.{ MessageOnlyException, Terminal as UTerminal } +import sbt.internal.worker1.* import xsbti.{ FileConverter, HashedVirtualFileRef } +import scala.collection.mutable +import scala.concurrent.{ Await, Promise } +import scala.concurrent.duration.Duration +import scala.util.Random +import scala.jdk.CollectionConverters.* +import scala.sys.process.Process + +/** + * This implements forked testing, in cooperation with the worker CLI, + * which was previously called test-agent.jar. + */ +private[sbt] object ForkTests: + val r = Random() -private[sbt] object ForkTests { def apply( runners: Map[TestFramework, Runner], opts: ProcessedOptions, @@ -87,147 +99,145 @@ private[sbt] object ForkTests { parallel: Boolean ): Task[TestOutput] = std.TaskExtra.task { - val server = new ServerSocket(0) - val testListeners = opts.testListeners flatMap { + val testListeners = opts.testListeners.flatMap: case tl: TestsListener => tl.some case _ => none[TestsListener] - } - - object Acceptor extends Runnable { - val resultsAcc = mutable.Map.empty[String, SuiteResult] - lazy val result = - TestOutput(overall(resultsAcc.values.map(_.result)), resultsAcc.toMap, Iterable.empty) - - def run(): Unit = { - val socket = - try { - server.accept() - } catch { - case e: java.net.SocketException => - log.error( - "Could not accept connection from test agent: " + e.getClass + ": " + e.getMessage - ) - log.trace(e) - server.close() - return - } - val os = new ObjectOutputStream(socket.getOutputStream) - // Must flush the header that the constructor writes, otherwise the ObjectInputStream on the other end may block indefinitely - os.flush() - val is = new ObjectInputStream(socket.getInputStream) - - try { - val config = new ForkConfiguration(UTerminal.isAnsiSupported, parallel) - os.writeObject(config) - - val taskdefs = opts.tests.map { t => - new TaskDef( - t.name, - forkFingerprint(t.fingerprint), - t.explicitlySpecified, - t.selectors - ) - } - os.writeObject(taskdefs.toArray) - - os.writeInt(runners.size) - for ((testFramework, mainRunner) <- runners) { - os.writeObject(testFramework.implClassNames.toArray) - os.writeObject(mainRunner.args) - os.writeObject(mainRunner.remoteArgs) - } - os.flush() - - new React(is, os, log, opts.testListeners, resultsAcc).react() - } catch { - case NonFatal(e) => - def throwableToString(t: Throwable) = { - import java.io.*; val sw = new StringWriter; t.printStackTrace(new PrintWriter(sw)); - sw.toString - } - resultsAcc("Forked test harness failed: " + throwableToString(e)) = SuiteResult.Error - } finally { - is.close(); os.close(); socket.close() - } - } - } - - try { - testListeners.foreach(_.doInit()) - val acceptorThread = new Thread(Acceptor) - acceptorThread.start() - val cpFiles = classpath.map(converter.toPath).map(_.toFile()) - val fullCp = cpFiles ++ Seq( - IO.classLocationPath(classOf[ForkMain]).toFile, - IO.classLocationPath(classOf[Framework]).toFile, + val resultsAcc = mutable.Map.empty[String, SuiteResult] + val randomId = r.nextLong() + def testOutputResult = + TestOutput( + overall(resultsAcc.values.map(_.result)), + resultsAcc.toMap, + Iterable.empty ) - val options = Seq( - "-classpath", - fullCp mkString File.pathSeparator, - classOf[ForkMain].getCanonicalName, - server.getLocalPort.toString + val taskdefs = opts.tests.map: t => + new TaskDef( + t.name, + forkFingerprint(t.fingerprint), + t.explicitlySpecified, + t.selectors ) - val ec = Fork.java(fork, options) - val result = - if (ec != 0) - TestOutput( - TestResult.Error, - Map( - "Running java with options " + options - .mkString(" ") + " failed with exit code " + ec -> SuiteResult.Error - ), - Iterable.empty - ) - else { - // Need to wait acceptor thread to finish its business - acceptorThread.join() - Acceptor.result - } + val testRunners = runners.toSeq.map: (testFramework, mainRunner) => + TestInfo.TestRunner( + ArrayList(testFramework.implClassNames.asJava), + ArrayList(mainRunner.args().toList.asJava), + ArrayList(mainRunner.remoteArgs().toList.asJava) + ) + val g = WorkerMain.mkGson() + // virtualize classloading by using ClassLoader + val useClassLoader = true + val cpList = ArrayList[FilePath]( + (classpath + .map: vf => + FilePath(converter.toPath(vf).toUri(), vf.contentHashStr())) + .asJava + ) + val cpFiles = + classpath.map: vf => + converter.toPath(vf).toFile() + val param = TestInfo( + true, /* jvm */ + RunInfo.JvmRunInfo( + ArrayList(), + if useClassLoader then cpList else ArrayList(), + "", + false /*connectInput*/, + ), + null, + UTerminal.isAnsiSupported, + parallel, + ArrayList(taskdefs.asJava), + ArrayList(testRunners.asJava), + ) + testListeners.foreach(_.doInit()) + val result = + val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles) + val wl = React(randomId, log, opts.testListeners, resultsAcc, w.process) + try + WorkerExchange.registerListener(wl) + val paramJson = g.toJson(param, param.getClass) + val json = jsonRpcRequest(randomId, "test", paramJson) + w.println(json) + if wl.blockForResponse() != 0 then + throw MessageOnlyException("Forked test harness failed") + testOutputResult + finally WorkerExchange.unregisterListener(wl) + testListeners.foreach(_.doComplete(result.overall)) + result + } // end task - testListeners.foreach(_.doComplete(result.overall)) - result - } finally { - server.close() - } - } + private def jsonRpcRequest(id: Long, method: String, params: String): String = + s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params, "id": $id }""" private def forkFingerprint(f: Fingerprint): Fingerprint & Serializable = - f match { - case s: SubclassFingerprint => new ForkMain.SubclassFingerscan(s) - case a: AnnotatedFingerprint => new ForkMain.AnnotatedFingerscan(a) + f match + case s: SubclassFingerprint => ForkTestMain.SubclassFingerscan(s) + case a: AnnotatedFingerprint => ForkTestMain.AnnotatedFingerscan(a) case _ => sys.error("Unknown fingerprint type: " + f.getClass) - } -} -private final class React( - is: ObjectInputStream, - os: ObjectOutputStream, +end ForkTests + +private class React( + id: Long, log: Logger, listeners: Seq[TestReportListener], - results: mutable.Map[String, SuiteResult] -) { - import ForkTags.* - @annotation.tailrec - def react(): Unit = is.readObject match { - case `Done` => - os.writeObject(Done); os.flush() - case Array(`Error`, s: String) => - log.error(s); react() - case Array(`Warn`, s: String) => - log.warn(s); react() - case Array(`Info`, s: String) => - log.info(s); react() - case Array(`Debug`, s: String) => - log.debug(s); react() - case t: Throwable => - log.trace(t); react() - case Array(group: String, tEvents: Array[Event]) => - val events = tEvents.toSeq - listeners.foreach(_.startGroup(group)) - val event = TestEvent(events) - listeners.foreach(_.testEvent(event)) - val suiteResult = SuiteResult(events) - results += group -> suiteResult - listeners.foreach(_.endGroup(group, suiteResult.result)) - react() - } -} + results: mutable.Map[String, SuiteResult], + process: Process +) extends WorkerResponseListener: + val g = WorkerMain.mkGson() + val promise: Promise[Int] = Promise() + override def apply(line: String): Unit = + // scala.Console.err.println(line) + val o = JsonParser.parseString(line).getAsJsonObject() + if o.has("id") then + val resId = o.getAsJsonPrimitive("id").getAsLong() + if resId == id then + if promise.isCompleted then () + else if o.has("error") then promise.failure(new RuntimeException(line)) + else promise.success(0) + else () + else if o.has("method") then processNotification(o) + else () + + override def notifyExit(p: Process): Unit = + if !process.isAlive then promise.success(process.exitValue()) + + def processNotification(o: JsonObject): Unit = + val method = o.getAsJsonPrimitive("method").getAsString() + method match + case "testLog" => + val params = o.getAsJsonObject("params") + val info = g.fromJson[TestLogInfo](params, classOf[TestLogInfo]) + if info.id == id then + info.tag match + case ForkTags.Error => log.error(info.message) + case ForkTags.Warn => log.warn(info.message) + case ForkTags.Info => log.info(info.message) + case ForkTags.Debug => log.debug(info.message) + case _ => () + else () + case "testEvents" => + val params = o.getAsJsonObject("params") + val info = + g.fromJson[ForkTestMain.ForkEventsInfo](params, classOf[ForkTestMain.ForkEventsInfo]) + if info.id == id then + val events = info.events.asScala.toSeq + listeners.foreach(_.startGroup(info.group)) + val event = TestEvent(events) + listeners.foreach(_.testEvent(event)) + val suiteResult = SuiteResult(events) + results += info.group -> suiteResult + listeners.foreach(_.endGroup(info.group, suiteResult.result)) + else () + case "forkError" => + val params = o.getAsJsonObject("params") + val info = + g.fromJson[ForkTestMain.ForkErrorInfo](params, classOf[ForkTestMain.ForkErrorInfo]) + if info.id == id then + log.trace(info.error) + promise.failure(info.error) + else () + case _ => () + + def blockForResponse(): Int = + Await.result(promise.future, Duration.Inf) +end React diff --git a/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala new file mode 100644 index 000000000..e57e57783 --- /dev/null +++ b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala @@ -0,0 +1,88 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt +package internal + +import com.google.gson.Gson +import java.io.* +import java.util.concurrent.atomic.AtomicReference +import sbt.io.IO +import sbt.internal.worker1.* +import sbt.testing.Framework +import scala.sys.process.{ BasicIO, Process, ProcessIO } +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +object WorkerExchange: + val listeners: mutable.ListBuffer[WorkerResponseListener] = ListBuffer.empty + + /** + * Start a worker process. + */ + def startWorker(fo: ForkOptions, extraCp: Seq[File]): WorkerProxy = + val fullCp = Seq( + IO.classLocationPath(classOf[WorkerMain]).toFile, + IO.classLocationPath(classOf[Framework]).toFile, + IO.classLocationPath(classOf[Gson]).toFile, + ) ++ extraCp + val options = Seq( + "-classpath", + fullCp.mkString(File.pathSeparator), + classOf[WorkerMain].getCanonicalName, + ) + val inputRef = AtomicReference[OutputStream]() + val processIo = ProcessIO( + in = (input) => inputRef.set(input), + out = BasicIO.processFully(onStdoutLine), + err = BasicIO.processFully((line) => scala.Console.err.println(line)), + ) + val forkWithIo = fo.withOutputStrategy(OutputStrategy.CustomInputOutput(processIo)) + val p = Fork.java.fork(forkWithIo, options) + WorkerProxy(inputRef.get(), p, options) + + def registerListener(listener: WorkerResponseListener): Unit = + synchronized: + listeners.append(listener) + + def unregisterListener(listener: WorkerResponseListener): Unit = + synchronized: + if listeners.contains(listener) then listeners.remove(listeners.indexOf(listener)) + else () + + /** + * Unified worker output handler. + */ + def onStdoutLine(line: String): Unit = + synchronized: + listeners.foreach: wl => + wl(line) +end WorkerExchange + +class WorkerProxy(input: OutputStream, val process: Process, val options: Seq[String]) + extends AutoCloseable: + lazy val inputStream = PrintStream(input) + def close(): Unit = input.close() + def blockForExitCode(): Int = + if !process.isAlive then process.exitValue() + else Fork.blockForExitCode(process) + + /** print a line into stdin of the worker process. */ + def println(str: String): Unit = + inputStream.println(str) + inputStream.flush() + + val watch = Thread(() => { + while process.isAlive() do Thread.sleep(100) + WorkerExchange.listeners.foreach(_.notifyExit(process)) + }) + watch.start() +end WorkerProxy + +abstract class WorkerResponseListener extends Function1[String, Unit]: + def notifyExit(p: Process): Unit diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 5545af24e..a5368b335 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -1493,9 +1493,8 @@ object Defaults extends BuildCommon { } } val summaries = - runners map { (tf, r) => + runners.map: (tf, r) => Tests.Summary(frameworks(tf).name, r.done()) - } out.copy(summaries = summaries) } // Def.value[Task[Tests.Output]] { diff --git a/run/src/main/scala/sbt/Fork.scala b/run/src/main/scala/sbt/Fork.scala index dc53237c3..8b75e4fef 100644 --- a/run/src/main/scala/sbt/Fork.scala +++ b/run/src/main/scala/sbt/Fork.scala @@ -185,13 +185,13 @@ object Fork { () } val process = Process(jpb) - outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match { + outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match case StdoutOutput => process.run(connectInput = false) case out: BufferedOutput => out.logger.buffer { process.run(out.logger, connectInput = false) } - case out: LoggedOutput => process.run(out.logger, connectInput = false) - case out: CustomOutput => (process #> out.output).run(connectInput = false) - } + case out: LoggedOutput => process.run(out.logger, connectInput = false) + case out: CustomOutput => (process #> out.output).run(connectInput = false) + case out: CustomInputOutput => process.run(out.processIO) } private[sbt] def blockForExitCode(p: Process): Int = { diff --git a/run/src/main/scala/sbt/OutputStrategy.scala b/run/src/main/scala/sbt/OutputStrategy.scala index 26b7e710b..3066b5a13 100644 --- a/run/src/main/scala/sbt/OutputStrategy.scala +++ b/run/src/main/scala/sbt/OutputStrategy.scala @@ -8,6 +8,7 @@ package sbt +import scala.sys.process.ProcessIO import sbt.util.Logger import java.io.OutputStream @@ -102,4 +103,28 @@ object OutputStrategy { object CustomOutput { def apply(output: OutputStream): CustomOutput = new CustomOutput(output) } + + /** + * Configures the forked IO. + */ + final class CustomInputOutput private (val processIO: ProcessIO) + extends OutputStrategy + with Serializable: + override def equals(o: Any): Boolean = o match + case x: CustomInputOutput => (this.processIO == x.processIO) + case _ => false + override def hashCode: Int = + 37 * (17 + processIO.##) + "CustomInputOutput".## + override def toString: String = + "CustomInputOutput(...)" + private def copy(processIO: ProcessIO = processIO): CustomInputOutput = + new CustomInputOutput(processIO) + + def withProcessIO(processIO: ProcessIO): CustomInputOutput = + copy(processIO = processIO) + end CustomInputOutput + + object CustomInputOutput: + def apply(processIO: ProcessIO): CustomInputOutput = new CustomInputOutput(processIO) + end CustomInputOutput } diff --git a/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt b/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt index a43dedc9c..2bbba6379 100644 --- a/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt +++ b/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt @@ -33,6 +33,7 @@ lazy val root = (project in file(".")) "com.eed3si9n:sjson-new-scalajson_3", "com.github.ben-manes.caffeine:caffeine", "com.github.mwiede:jsch", + "com.google.code.gson:gson", "com.google.errorprone:error_prone_annotations", "com.lmax:disruptor", "com.swoval:file-tree-views", @@ -82,7 +83,7 @@ lazy val root = (project in file(".")) ) def assertCollectionsEqual(message: String, expected: Seq[String], actual: Seq[String]): Unit = // using the new line for a more readable comparison failure output - assert(expected.mkString("\n") == actual.mkString("\n"), message) + assert(expected.mkString("\n") == actual.mkString("\n"), message + ": " + actual) assertCollectionsEqual( "Unexpected module ids in updateSbtClassifiers", diff --git a/sbt-app/src/sbt-test/tests/fork/build.sbt b/sbt-app/src/sbt-test/tests/fork/build.sbt index 04553d69d..635492a86 100644 --- a/sbt-app/src/sbt-test/tests/fork/build.sbt +++ b/sbt-app/src/sbt-test/tests/fork/build.sbt @@ -11,6 +11,7 @@ val scalaxml = "org.scala-lang.modules" %% "scala-xml" % "1.1.1" def groupId(idx: Int) = "group_" + (idx + 1) def groupPrefix(idx: Int) = groupId(idx) + "_file_" +Global / localCacheDirectory := baseDirectory.value / "diskcache" ThisBuild / scalaVersion := "2.12.20" ThisBuild / organization := "org.example" @@ -19,7 +20,7 @@ lazy val root = (project in file(".")) Test / testGrouping := Def.uncached { val tests = (Test / definedTests).value assert(tests.size == 3) - for (idx <- 0 until groups) yield + for idx <- 0 until groups yield new Group( groupId(idx), tests, @@ -28,11 +29,11 @@ lazy val root = (project in file(".")) }, check := Def.uncached { val files = - for(i <- 0 until groups; j <- 1 to groupSize) yield + for i <- 0 until groups; j <- 1 to groupSize yield file(groupPrefix(i) + j) val (exist, absent) = files.partition(_.exists) exist.foreach(_.delete()) - if (absent.nonEmpty) + if absent.nonEmpty then sys.error("Files were not created:\n\t" + absent.mkString("\n\t")) }, concurrentRestrictions := Tags.limit(Tags.ForkedTestGroup, 2) :: Nil, diff --git a/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala b/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala index f2d45c34f..f45303e11 100644 --- a/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala +++ b/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala @@ -1,8 +1,8 @@ import org.scalatest.FlatSpec class Test extends FlatSpec { - val v = sys.env.getOrElse("tests.max.value", Int.MaxValue) - "A simple equation" should "hold" in { - assert(Int.MaxValue == v) - } + val v = sys.env.getOrElse("tests.max.value", Int.MaxValue) + "A simple equation" should "hold" in { + assert(Int.MaxValue == v) + } } diff --git a/testing/agent/src/main/java/sbt/ForkConfiguration.java b/testing/agent/src/main/java/sbt/ForkConfiguration.java deleted file mode 100644 index 993fe6908..000000000 --- a/testing/agent/src/main/java/sbt/ForkConfiguration.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * sbt - * Copyright 2023, Scala center - * Copyright 2011 - 2022, Lightbend, Inc. - * Copyright 2008 - 2010, Mark Harrah - * Licensed under Apache License 2.0 (see LICENSE) - */ - -package sbt; - -import java.io.Serializable; - -public final class ForkConfiguration implements Serializable { - private final boolean ansiCodesSupported; - private final boolean parallel; - - public ForkConfiguration(final boolean ansiCodesSupported, final boolean parallel) { - this.ansiCodesSupported = ansiCodesSupported; - this.parallel = parallel; - } - - public boolean isAnsiCodesSupported() { - return ansiCodesSupported; - } - - public boolean isParallel() { - return parallel; - } -} diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index 6ef27739c..3521230c0 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -67,8 +67,9 @@ final class TestFramework(val implClassNames: String*) extends Serializable { case head :: tail => try { Some(Class.forName(head, true, loader).getDeclaredConstructor().newInstance() match { - case newFramework: Framework => newFramework - case oldFramework: OldFramework => new FrameworkWrapper(oldFramework) + case newFramework: Framework => newFramework + case oldFramework: OldFramework => + new sbt.internal.worker1.FrameworkWrapper(oldFramework) }) } catch { case e: NoClassDefFoundError => diff --git a/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java b/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java new file mode 100644 index 000000000..00e9b3dd4 --- /dev/null +++ b/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java @@ -0,0 +1,344 @@ +/* + * Copyright (C) 2011 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.gson.typeadapters; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.google.gson.JsonPrimitive; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Adapts values whose runtime type may differ from their declaration type. This is necessary when a + * field's type is not the same type that GSON should create when deserializing that field. For + * example, consider these types: + * + *
+ * {
+ * 	@code
+ * 	abstract class Shape {
+ * 		int x;
+ * 		int y;
+ * 	}
+ * 	class Circle extends Shape {
+ * 		int radius;
+ * 	}
+ * 	class Rectangle extends Shape {
+ * 		int width;
+ * 		int height;
+ * 	}
+ * 	class Diamond extends Shape {
+ * 		int width;
+ * 		int height;
+ * 	}
+ * 	class Drawing {
+ * 		Shape bottomShape;
+ * 		Shape topShape;
+ * 	}
+ * }
+ * 
+ * + *

Without additional type information, the serialized JSON is ambiguous. Is the bottom shape in + * this drawing a rectangle or a diamond? + * + *

{@code
+ * {
+ *   "bottomShape": {
+ *     "width": 10,
+ *     "height": 5,
+ *     "x": 0,
+ *     "y": 0
+ *   },
+ *   "topShape": {
+ *     "radius": 2,
+ *     "x": 4,
+ *     "y": 1
+ *   }
+ * }
+ * }
+ * + * This class addresses this problem by adding type information to the serialized JSON and honoring + * that type information when the JSON is deserialized: + * + *
{@code
+ * {
+ *   "bottomShape": {
+ *     "type": "Diamond",
+ *     "width": 10,
+ *     "height": 5,
+ *     "x": 0,
+ *     "y": 0
+ *   },
+ *   "topShape": {
+ *     "type": "Circle",
+ *     "radius": 2,
+ *     "x": 4,
+ *     "y": 1
+ *   }
+ * }
+ * }
+ * + * Both the type field name ({@code "type"}) and the type labels ({@code "Rectangle"}) are + * configurable. + * + *

Registering Types

+ * + * Create a {@code RuntimeTypeAdapterFactory} by passing the base type and type field name to the + * {@link #of} factory method. If you don't supply an explicit type field name, {@code "type"} will + * be used. + * + *
+ * {
+ * 	@code
+ * 	RuntimeTypeAdapterFactory<Shape> shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class, "type");
+ * }
+ * 
+ * + * Next register all of your subtypes. Every subtype must be explicitly registered. This protects + * your application from injection attacks. If you don't supply an explicit type label, the type's + * simple name will be used. + * + *
{@code
+ * shapeAdapterFactory.registerSubtype(Rectangle.class, "Rectangle");
+ * shapeAdapterFactory.registerSubtype(Circle.class, "Circle");
+ * shapeAdapterFactory.registerSubtype(Diamond.class, "Diamond");
+ * }
+ * + * Finally, register the type adapter factory in your application's GSON builder: + * + *
+ * {
+ * 	@code
+ * 	Gson gson = new GsonBuilder().registerTypeAdapterFactory(shapeAdapterFactory).create();
+ * }
+ * 
+ * + * Like {@code GsonBuilder}, this API supports chaining: + * + *
+ * {
+ * 	@code
+ * 	RuntimeTypeAdapterFactory<Shape> shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class)
+ * 			.registerSubtype(Rectangle.class).registerSubtype(Circle.class).registerSubtype(Diamond.class);
+ * }
+ * 
+ * + *

Serialization and deserialization

+ * + * In order to serialize and deserialize a polymorphic object, you must specify the base type + * explicitly. + * + *
+ * {
+ * 	@code
+ * 	Diamond diamond = new Diamond();
+ * 	String json = gson.toJson(diamond, Shape.class);
+ * }
+ * 
+ * + * And then: + * + *
+ * {
+ * 	@code
+ * 	Shape shape = gson.fromJson(json, Shape.class);
+ * }
+ * 
+ */ +public final class RuntimeTypeAdapterFactory implements TypeAdapterFactory { + private final Class baseType; + private final String typeFieldName; + private final Map> labelToSubtype = new LinkedHashMap<>(); + private final Map, String> subtypeToLabel = new LinkedHashMap<>(); + private final boolean maintainType; + private boolean recognizeSubtypes; + + private RuntimeTypeAdapterFactory(Class baseType, String typeFieldName, boolean maintainType) { + if (typeFieldName == null || baseType == null) { + throw new NullPointerException(); + } + this.baseType = baseType; + this.typeFieldName = typeFieldName; + this.maintainType = maintainType; + } + + /** + * Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type + * field name. Type field names are case sensitive. + * + * @param maintainType true if the type field should be included in deserialized objects + */ + public static RuntimeTypeAdapterFactory of( + Class baseType, String typeFieldName, boolean maintainType) { + return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, maintainType); + } + + /** + * Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type + * field name. Type field names are case sensitive. + */ + public static RuntimeTypeAdapterFactory of(Class baseType, String typeFieldName) { + return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, false); + } + + /** + * Creates a new runtime type adapter for {@code baseType} using {@code "type"} as the type field + * name. + */ + public static RuntimeTypeAdapterFactory of(Class baseType) { + return new RuntimeTypeAdapterFactory<>(baseType, "type", false); + } + + /** + * Ensures that this factory will handle not just the given {@code baseType}, but any subtype of + * that type. + */ + @CanIgnoreReturnValue + public RuntimeTypeAdapterFactory recognizeSubtypes() { + this.recognizeSubtypes = true; + return this; + } + + /** + * Registers {@code type} identified by {@code label}. Labels are case sensitive. + * + * @throws IllegalArgumentException if either {@code type} or {@code label} have already been + * registered on this type adapter. + */ + @CanIgnoreReturnValue + public RuntimeTypeAdapterFactory registerSubtype(Class type, String label) { + if (type == null || label == null) { + throw new NullPointerException(); + } + if (subtypeToLabel.containsKey(type) || labelToSubtype.containsKey(label)) { + throw new IllegalArgumentException("types and labels must be unique"); + } + labelToSubtype.put(label, type); + subtypeToLabel.put(type, label); + return this; + } + + /** + * Registers {@code type} identified by its {@link Class#getSimpleName simple name}. Labels are + * case sensitive. + * + * @throws IllegalArgumentException if either {@code type} or its simple name have already been + * registered on this type adapter. + */ + @CanIgnoreReturnValue + public RuntimeTypeAdapterFactory registerSubtype(Class type) { + return registerSubtype(type, type.getSimpleName()); + } + + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + if (type == null) { + return null; + } + Class rawType = type.getRawType(); + boolean handle = + recognizeSubtypes ? baseType.isAssignableFrom(rawType) : baseType.equals(rawType); + if (!handle) { + return null; + } + + TypeAdapter jsonElementAdapter = gson.getAdapter(JsonElement.class); + Map> labelToDelegate = new LinkedHashMap<>(); + Map, TypeAdapter> subtypeToDelegate = new LinkedHashMap<>(); + for (Map.Entry> entry : labelToSubtype.entrySet()) { + TypeAdapter delegate = gson.getDelegateAdapter(this, TypeToken.get(entry.getValue())); + labelToDelegate.put(entry.getKey(), delegate); + subtypeToDelegate.put(entry.getValue(), delegate); + } + + return new TypeAdapter() { + @Override + public R read(JsonReader in) throws IOException { + JsonElement jsonElement = jsonElementAdapter.read(in); + JsonElement labelJsonElement; + if (maintainType) { + labelJsonElement = jsonElement.getAsJsonObject().get(typeFieldName); + } else { + labelJsonElement = jsonElement.getAsJsonObject().remove(typeFieldName); + } + + if (labelJsonElement == null) { + throw new JsonParseException( + "cannot deserialize " + + baseType + + " because it does not define a field named " + + typeFieldName); + } + String label = labelJsonElement.getAsString(); + @SuppressWarnings("unchecked") // registration requires that subtype extends T + TypeAdapter delegate = (TypeAdapter) labelToDelegate.get(label); + if (delegate == null) { + throw new JsonParseException( + "cannot deserialize " + + baseType + + " subtype named " + + label + + "; did you forget to register a subtype?"); + } + return delegate.fromJsonTree(jsonElement); + } + + @Override + public void write(JsonWriter out, R value) throws IOException { + Class srcType = value.getClass(); + String label = subtypeToLabel.get(srcType); + @SuppressWarnings("unchecked") // registration requires that subtype extends T + TypeAdapter delegate = (TypeAdapter) subtypeToDelegate.get(srcType); + if (delegate == null) { + throw new JsonParseException( + "cannot serialize " + srcType.getName() + "; did you forget to register a subtype?"); + } + JsonObject jsonObject = delegate.toJsonTree(value).getAsJsonObject(); + + if (maintainType) { + jsonElementAdapter.write(out, jsonObject); + return; + } + + JsonObject clone = new JsonObject(); + + if (jsonObject.has(typeFieldName)) { + throw new JsonParseException( + "cannot serialize " + + srcType.getName() + + " because it already defines a field named " + + typeFieldName); + } + clone.add(typeFieldName, new JsonPrimitive(label)); + + for (Map.Entry e : jsonObject.entrySet()) { + clone.add(e.getKey(), e.getValue()); + } + jsonElementAdapter.write(out, clone); + } + }.nullSafe(); + } +} diff --git a/testing/agent/src/main/java/sbt/ForkTags.java b/worker/src/main/java/sbt/internal/worker1/ForkTags.java similarity index 89% rename from testing/agent/src/main/java/sbt/ForkTags.java rename to worker/src/main/java/sbt/internal/worker1/ForkTags.java index 1b0660119..d5daabec2 100644 --- a/testing/agent/src/main/java/sbt/ForkTags.java +++ b/worker/src/main/java/sbt/internal/worker1/ForkTags.java @@ -6,7 +6,7 @@ * Licensed under Apache License 2.0 (see LICENSE) */ -package sbt; +package sbt.internal.worker1; public enum ForkTags { Error, diff --git a/testing/agent/src/main/java/sbt/ForkMain.java b/worker/src/main/java/sbt/internal/worker1/ForkTestMain.java similarity index 66% rename from testing/agent/src/main/java/sbt/ForkMain.java rename to worker/src/main/java/sbt/internal/worker1/ForkTestMain.java index fdc79a7d8..607cf3e0d 100644 --- a/testing/agent/src/main/java/sbt/ForkMain.java +++ b/worker/src/main/java/sbt/internal/worker1/ForkTestMain.java @@ -6,16 +6,15 @@ * Licensed under Apache License 2.0 (see LICENSE) */ -package sbt; +package sbt.internal.worker1; + +import com.google.gson.Gson; import sbt.testing.*; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.io.PrintStream; import java.io.Serializable; -import java.net.Socket; -import java.net.InetAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -23,7 +22,7 @@ import java.util.List; import java.util.LinkedHashSet; import java.util.concurrent.*; -public final class ForkMain { +public class ForkTestMain { // serializables // ----------------------------------------------------------------------------- @@ -70,7 +69,7 @@ public final class ForkMain { } } - static final class ForkEvent implements Event, Serializable { + public static final class ForkEvent implements Event, Serializable { private final String fullyQualifiedName; private final Fingerprint fingerprint; private final Selector selector; @@ -79,23 +78,23 @@ public final class ForkMain { private final long duration; ForkEvent(final Event e) { - fullyQualifiedName = e.fullyQualifiedName(); + this.fullyQualifiedName = e.fullyQualifiedName(); final Fingerprint rawFingerprint = e.fingerprint(); if (rawFingerprint instanceof SubclassFingerprint) - fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint); - else fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint); + this.fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint); + else this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint); - selector = e.selector(); + this.selector = e.selector(); checkSerializableSelector(selector); - status = e.status(); + this.status = e.status(); final OptionalThrowable originalThrowable = e.throwable(); if (originalThrowable.isDefined()) - throwable = new OptionalThrowable(new ForkError(originalThrowable.get())); - else throwable = originalThrowable; + this.throwable = new OptionalThrowable(new ForkError(originalThrowable.get())); + else this.throwable = originalThrowable; - duration = e.duration(); + this.duration = e.duration(); } public String fullyQualifiedName() { @@ -132,18 +131,30 @@ public final class ForkMain { } } + public static class ForkEventsInfo implements Serializable { + public long id; + public String group; + public ArrayList events; + + public ForkEventsInfo(long id, String group, ArrayList events) { + this.id = id; + this.group = group; + this.events = events; + } + } + // ----------------------------------------------------------------------------- - static final class ForkError extends Exception { + public static final class ForkError extends Exception { private final String originalMessage; private final String originalName; - private ForkError cause; + private ForkError cause1; ForkError(final Throwable t) { originalMessage = t.getMessage(); originalName = t.getClass().getName(); setStackTrace(t.getStackTrace()); - if (t.getCause() != null) cause = new ForkError(t.getCause()); + if (t.getCause() != null) cause1 = new ForkError(t.getCause()); } public String getMessage() { @@ -151,51 +162,50 @@ public final class ForkMain { } public Exception getCause() { - return cause; + return cause1; + } + } + + public static class ForkErrorInfo implements Serializable { + public final long id; + public final ForkError error; + + public ForkErrorInfo(long id, ForkError error) { + this.id = id; + this.error = error; } } // main // ---------------------------------------------------------------------------------------------------------------- - public static void main(final String[] args) throws Exception { - ClassLoader classLoader = new Run().getClass().getClassLoader(); - try { - main(args, classLoader); - } finally { - System.exit(0); - } - } - - public static void main(final String[] args, ClassLoader classLoader) throws Exception { - final Socket socket = new Socket(InetAddress.getByName(null), Integer.valueOf(args[0])); - final ObjectInputStream is = new ObjectInputStream(socket.getInputStream()); - final ObjectOutputStream os = new ObjectOutputStream(socket.getOutputStream()); - // Must flush the header that the constructor writes, otherwise the ObjectInputStream on the - // other end may block indefinitely - os.flush(); - try { - new Run().run(is, os, classLoader); - } finally { - is.close(); - os.close(); - } + public static void main(long id, TestInfo info, PrintStream originalOut, ClassLoader classLoader) + throws Exception { + new Run(originalOut, id).run(info, classLoader); } // ---------------------------------------------------------------------------------------------------------------- - private static final class Run { + public static final class Run { + final PrintStream originalOut; + final long id; + final Gson gson; - private void run( - final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader) { + Run(PrintStream originalOut, long id) { + this.originalOut = originalOut; + this.id = id; + this.gson = WorkerMain.mkGson(); + } + + private void run(TestInfo info, ClassLoader classLoader) { try { - runTests(is, os, classLoader); + runTests(info, classLoader); } catch (final RunAborted e) { internalError(e); } catch (final Throwable t) { try { - logError(os, "Uncaught exception when running tests: " + t.toString()); - write(os, new ForkError(t)); + logError("Uncaught exception when running tests: " + t.toString()); + writeError(new ForkError(t)); } catch (final Throwable t2) { internalError(t2); } @@ -223,106 +233,118 @@ public final class ForkMain { } } - private synchronized void write(final ObjectOutputStream os, final Object obj) { - try { - os.writeObject(obj); - os.flush(); - } catch (final IOException e) { - throw new RunAborted(e); - } + private void writeError(ForkError error) { + ForkErrorInfo info = new ForkErrorInfo(this.id, error); + String params = this.gson.toJson(info, ForkErrorInfo.class); + String notification = + String.format( + "{ \"jsonrpc\": \"2.0\", \"method\": \"forkError\", \"params\": %s }", params); + this.originalOut.println(notification); + this.originalOut.flush(); } - private void log(final ObjectOutputStream os, final String message, final ForkTags level) { - write(os, new Object[] {level, message}); + private void log(final String message, final ForkTags level) { + TestLogInfo info = new TestLogInfo(this.id, level, message); + String params = this.gson.toJson(info, TestLogInfo.class); + String notification = + String.format( + "{ \"jsonrpc\": \"2.0\", \"method\": \"testLog\", \"params\": %s }", params); + this.originalOut.println(notification); + this.originalOut.flush(); } - private void logDebug(final ObjectOutputStream os, final String message) { - log(os, message, ForkTags.Debug); + private void logDebug(final String message) { + log(message, ForkTags.Debug); } - private void logInfo(final ObjectOutputStream os, final String message) { - log(os, message, ForkTags.Info); + private void logInfo(final String message) { + log(message, ForkTags.Info); } - private void logWarn(final ObjectOutputStream os, final String message) { - log(os, message, ForkTags.Warn); + private void logWarn(final String message) { + log(message, ForkTags.Warn); } - private void logError(final ObjectOutputStream os, final String message) { - log(os, message, ForkTags.Error); + private void logError(final String message) { + log(message, ForkTags.Error); } - private Logger remoteLogger(final boolean ansiCodesSupported, final ObjectOutputStream os) { + private Logger remoteLogger(final boolean ansiCodesSupported) { return new Logger() { public boolean ansiCodesSupported() { return ansiCodesSupported; } public void error(final String s) { - logError(os, s); + logError(s); } public void warn(final String s) { - logWarn(os, s); + logWarn(s); } public void info(final String s) { - logInfo(os, s); + logInfo(s); } public void debug(final String s) { - logDebug(os, s); + logDebug(s); } public void trace(final Throwable t) { - write(os, new ForkError(t)); + writeError(new ForkError(t)); } }; } - private void writeEvents( - final ObjectOutputStream os, final TaskDef taskDef, final ForkEvent[] events) { - write(os, new Object[] {taskDef.fullyQualifiedName(), events}); + private void writeEvents(final TaskDef taskDef, final ForkEvent[] events) { + ForkEventsInfo info = + new ForkEventsInfo( + this.id, + taskDef.fullyQualifiedName(), + new ArrayList(Arrays.asList(events))); + String params = this.gson.toJson(info, ForkEventsInfo.class); + String notification = + String.format( + "{ \"jsonrpc\": \"2.0\", \"method\": \"testEvents\", \"params\": %s }", params); + this.originalOut.println(notification); + this.originalOut.flush(); } - private ExecutorService executorService( - final ForkConfiguration config, final ObjectOutputStream os) { - if (config.isParallel()) { + private ExecutorService executorService(final boolean parallel) { + if (parallel) { final int nbThreads = Runtime.getRuntime().availableProcessors(); - logDebug(os, "Create a test executor with a thread pool of " + nbThreads + " threads."); + logDebug("Create a test executor with a thread pool of " + nbThreads + " threads."); // more options later... // TODO we might want to configure the blocking queue with size #proc return Executors.newFixedThreadPool(nbThreads); } else { - logDebug(os, "Create a single-thread test executor"); + logDebug("Create a single-thread test executor"); return Executors.newSingleThreadExecutor(); } } - private void runTests( - final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader) - throws Exception { - final ForkConfiguration config = (ForkConfiguration) is.readObject(); - final ExecutorService executor = executorService(config, os); - final TaskDef[] tests = (TaskDef[]) is.readObject(); - final int nFrameworks = is.readInt(); - final Logger[] loggers = {remoteLogger(config.isAnsiCodesSupported(), os)}; + private void runTests(TestInfo info, ClassLoader classLoader) throws Exception { + final ExecutorService executor = executorService(info.parallel); + final TaskDef[] tests = info.taskDefs.toArray(new TaskDef[] {}); + final int nFrameworks = info.testRunners.size(); + final Logger[] loggers = {remoteLogger(info.ansiCodesSupported)}; - for (int i = 0; i < nFrameworks; i++) { - final String[] implClassNames = (String[]) is.readObject(); - final String[] frameworkArgs = (String[]) is.readObject(); - final String[] remoteFrameworkArgs = (String[]) is.readObject(); + for (TestInfo.TestRunner testRunner : info.testRunners) { + final String[] frameworkArgs = testRunner.mainRunnerArgs.toArray(new String[] {}); + final String[] remoteFrameworkArgs = + testRunner.mainRunnerRemoteArgs.toArray(new String[] {}); Framework framework = null; - for (final String implClassName : implClassNames) { + for (final String implClassName : testRunner.implClassNames) { try { final Object rawFramework = - Class.forName(implClassName).getDeclaredConstructor().newInstance(); + classLoader.loadClass(implClassName).getDeclaredConstructor().newInstance(); if (rawFramework instanceof Framework) framework = (Framework) rawFramework; else framework = new FrameworkWrapper((org.scalatools.testing.Framework) rawFramework); break; } catch (final ClassNotFoundException e) { - logDebug(os, "Framework implementation '" + implClassName + "' not present."); + logError("Framework implementation '" + implClassName + "' not present."); } } @@ -344,7 +366,6 @@ public final class ForkMain { final Runner runner = framework.runner(frameworkArgs, remoteFrameworkArgs, classLoader); final Task[] tasks = runner.tasks(filteredTests.toArray(new TaskDef[filteredTests.size()])); logDebug( - os, "Runner for " + framework.getClass().getName() + " produced " @@ -356,25 +377,20 @@ public final class ForkMain { Thread callDoneOnShutdown = new Thread(() -> runner.done()); Runtime.getRuntime().addShutdownHook(callDoneOnShutdown); - runTestTasks(executor, tasks, loggers, os); + runTestTasks(executor, tasks, loggers); runner.done(); Runtime.getRuntime().removeShutdownHook(callDoneOnShutdown); } - write(os, ForkTags.Done); - is.readObject(); } private void runTestTasks( - final ExecutorService executor, - final Task[] tasks, - final Logger[] loggers, - final ObjectOutputStream os) { + final ExecutorService executor, final Task[] tasks, final Logger[] loggers) { if (tasks.length > 0) { final List> futureNestedTasks = new ArrayList<>(); for (final Task task : tasks) { - futureNestedTasks.add(runTest(executor, task, loggers, os)); + futureNestedTasks.add(runTest(executor, task, loggers)); } // Note: this could be optimized further, we could have a callback once a test finishes that @@ -385,18 +401,15 @@ public final class ForkMain { try { nestedTasks.addAll(Arrays.asList(futureNestedTask.get())); } catch (final Exception e) { - logError(os, "Failed to execute task " + futureNestedTask); + logError("Failed to execute task " + futureNestedTask); } } - runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers, os); + runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers); } } private Future runTest( - final ExecutorService executor, - final Task task, - final Logger[] loggers, - final ObjectOutputStream os) { + final ExecutorService executor, final Task task, final Logger[] loggers) { return executor.submit( () -> { ForkEvent[] events; @@ -410,11 +423,10 @@ public final class ForkMain { eventList.add(new ForkEvent(e)); } }; - logDebug(os, " Running " + taskDef); + logDebug(" Running " + taskDef); nestedTasks = task.execute(handler, loggers); if (nestedTasks.length > 0 || eventList.size() > 0) logDebug( - os, " Produced " + nestedTasks.length + " nested tasks and " @@ -426,7 +438,6 @@ public final class ForkMain { events = new ForkEvent[] { testError( - os, taskDef, "Uncaught exception when running " + taskDef.fullyQualifiedName() @@ -435,7 +446,7 @@ public final class ForkMain { t) }; } - writeEvents(os, taskDef, events); + writeEvents(taskDef, events); return nestedTasks; }); } @@ -482,14 +493,10 @@ public final class ForkMain { }); } - private ForkEvent testError( - final ObjectOutputStream os, - final TaskDef taskDef, - final String message, - final Throwable t) { - logError(os, message); + private ForkEvent testError(final TaskDef taskDef, final String message, final Throwable t) { + logError(message); final ForkError fe = new ForkError(t); - write(os, fe); + writeError(fe); return testEvent( taskDef.fullyQualifiedName(), taskDef.fingerprint(), diff --git a/testing/agent/src/main/java/sbt/FrameworkWrapper.java b/worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java similarity index 98% rename from testing/agent/src/main/java/sbt/FrameworkWrapper.java rename to worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java index 4d67edf8f..10cdcb499 100644 --- a/testing/agent/src/main/java/sbt/FrameworkWrapper.java +++ b/worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java @@ -6,7 +6,7 @@ * Licensed under Apache License 2.0 (see LICENSE) */ -package sbt; +package sbt.internal.worker1; import sbt.testing.*; @@ -14,11 +14,11 @@ import sbt.testing.*; * Adapts the old {@link org.scalatools.testing.Framework} interface into the new {@link * sbt.testing.Framework} */ -final class FrameworkWrapper implements Framework { +public final class FrameworkWrapper implements Framework { private final org.scalatools.testing.Framework oldFramework; - FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) { + public FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) { this.oldFramework = oldFramework; } diff --git a/worker/src/main/java/sbt/internal/worker1/RunInfo.java b/worker/src/main/java/sbt/internal/worker1/RunInfo.java index 39744a7b7..079e3e1e2 100644 --- a/worker/src/main/java/sbt/internal/worker1/RunInfo.java +++ b/worker/src/main/java/sbt/internal/worker1/RunInfo.java @@ -8,10 +8,11 @@ package sbt.internal.worker1; +import java.io.Serializable; import java.util.ArrayList; -public class RunInfo { - public class JvmRunInfo { +public class RunInfo implements Serializable { + public static class JvmRunInfo implements Serializable { public ArrayList args; public ArrayList classpath; public String mainClass; @@ -29,7 +30,7 @@ public class RunInfo { } } - public class NativeRunInfo {} + public static class NativeRunInfo implements Serializable {} public boolean jvm; public JvmRunInfo jvmRunInfo; diff --git a/worker/src/main/java/sbt/internal/worker1/TestInfo.java b/worker/src/main/java/sbt/internal/worker1/TestInfo.java new file mode 100644 index 000000000..a1990e0d7 --- /dev/null +++ b/worker/src/main/java/sbt/internal/worker1/TestInfo.java @@ -0,0 +1,55 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.worker1; + +import java.io.Serializable; +import java.util.ArrayList; +import sbt.testing.TaskDef; + +public class TestInfo implements Serializable { + public static class TestRunner implements Serializable { + public final ArrayList implClassNames; + public final ArrayList mainRunnerArgs; + public final ArrayList mainRunnerRemoteArgs; + + public TestRunner( + ArrayList implClassNames, + ArrayList mainRunnerArgs, + ArrayList mainRunnerRemoteArgs) { + this.implClassNames = implClassNames; + this.mainRunnerArgs = mainRunnerArgs; + this.mainRunnerRemoteArgs = mainRunnerRemoteArgs; + } + } + + public final boolean jvm; + public final RunInfo.JvmRunInfo jvmRunInfo; + public final RunInfo.NativeRunInfo nativeRunInfo; + public final boolean ansiCodesSupported; + public final boolean parallel; + public final ArrayList taskDefs; + public final ArrayList testRunners; + + public TestInfo( + boolean jvm, + RunInfo.JvmRunInfo jvmRunInfo, + RunInfo.NativeRunInfo nativeRunInfo, + boolean ansiCodesSupported, + boolean parallel, + ArrayList taskDefs, + ArrayList testRunners) { + this.jvm = jvm; + this.jvmRunInfo = jvmRunInfo; + this.nativeRunInfo = nativeRunInfo; + this.ansiCodesSupported = ansiCodesSupported; + this.parallel = parallel; + this.taskDefs = taskDefs; + this.testRunners = testRunners; + } +} diff --git a/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java b/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java new file mode 100644 index 000000000..58cbf602a --- /dev/null +++ b/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java @@ -0,0 +1,25 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.worker1; + +import java.io.Serializable; +import java.util.ArrayList; +import sbt.testing.TaskDef; + +public class TestLogInfo implements Serializable { + public final long id; + public final ForkTags tag; + public final String message; + + public TestLogInfo(long id, ForkTags tag, String message) { + this.id = id; + this.tag = tag; + this.message = message; + } +} diff --git a/worker/src/main/java/sbt/internal/worker1/WorkerError.java b/worker/src/main/java/sbt/internal/worker1/WorkerError.java new file mode 100644 index 000000000..13ddec30f --- /dev/null +++ b/worker/src/main/java/sbt/internal/worker1/WorkerError.java @@ -0,0 +1,13 @@ +package sbt.internal.worker1; + +import java.io.Serializable; + +public final class WorkerError implements Serializable { + public final int code; + public final String message; + + public WorkerError(int code, String message) { + this.code = code; + this.message = message; + } +} diff --git a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java index 1f6ccc058..0533de423 100644 --- a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java +++ b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java @@ -9,11 +9,14 @@ package sbt.internal.worker1; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; +import com.google.gson.typeadapters.RuntimeTypeAdapterFactory; import java.io.ByteArrayOutputStream; +import java.io.InputStream; import java.io.IOException; import java.io.PrintStream; import java.lang.reflect.Method; @@ -22,15 +25,41 @@ import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Scanner; +import sbt.testing.*; +/** + * WorkerMain that communicates via the stdin and stdout using JSON-RPC + * (https://www.jsonrpc.org/specification). + */ public final class WorkerMain { private PrintStream originalOut; + private InputStream originalIn; + private Scanner inScanner; + + public static Gson mkGson() { + RuntimeTypeAdapterFactory fingerprintFac = + RuntimeTypeAdapterFactory.of(Fingerprint.class, "type"); + fingerprintFac.registerSubtype(ForkTestMain.SubclassFingerscan.class, "SubclassFingerscan"); + fingerprintFac.registerSubtype(ForkTestMain.AnnotatedFingerscan.class, "AnnotatedFingerscan"); + RuntimeTypeAdapterFactory selectorFac = + RuntimeTypeAdapterFactory.of(Selector.class, "type"); + selectorFac.registerSubtype(SuiteSelector.class, "SuiteSelector"); + selectorFac.registerSubtype(TestSelector.class, "TestSelector"); + selectorFac.registerSubtype(NestedSuiteSelector.class, "NestedSuiteSelector"); + selectorFac.registerSubtype(NestedTestSelector.class, "NestedTestSelector"); + selectorFac.registerSubtype(TestWildcardSelector.class, "TestWildcardSelector"); + return new GsonBuilder() + .registerTypeAdapterFactory(fingerprintFac) + .registerTypeAdapterFactory(selectorFac) + .create(); + } public static void main(final String[] args) throws Exception { try { if (args.length == 0) { WorkerMain app = new WorkerMain(); app.consoleWork(); + System.exit(0); } else { System.err.println("missing args"); System.exit(1); @@ -45,31 +74,49 @@ public final class WorkerMain { this.originalOut = System.out; ByteArrayOutputStream baos = new ByteArrayOutputStream(); System.setOut(new PrintStream(baos)); + this.originalIn = System.in; + this.inScanner = new Scanner(this.originalIn); } void consoleWork() throws Exception { - Scanner input = new Scanner(System.in); - if (input.hasNextLine()) { - String line = input.nextLine(); + if (this.inScanner.hasNextLine()) { + String line = this.inScanner.nextLine(); process(line); } } - void process(String json) throws Exception { + /** This processes single request of supposed JSON line. */ + void process(String json) { JsonElement elem = JsonParser.parseString(json); JsonObject o = elem.getAsJsonObject(); if (!o.has("jsonrpc")) { - throw new RuntimeException("missing jsonprc element"); + return; } + Gson g = WorkerMain.mkGson(); long id = o.getAsJsonPrimitive("id").getAsLong(); - String method = o.getAsJsonPrimitive("method").getAsString(); - JsonObject params = o.getAsJsonObject("params"); - switch (method) { - case "run": - Gson g = new Gson(); - RunInfo info = g.fromJson(params, RunInfo.class); - run(info); - break; + try { + String method = o.getAsJsonPrimitive("method").getAsString(); + JsonObject params = o.getAsJsonObject("params"); + switch (method) { + case "run": + RunInfo info = g.fromJson(params, RunInfo.class); + run(info); + break; + case "test": + TestInfo testInfo = g.fromJson(params, TestInfo.class); + test(id, testInfo); + break; + } + String response = String.format("{ \"jsonrpc\": \"2.0\", \"result\": 0, \"id\": %d }", id); + this.originalOut.println(response); + this.originalOut.flush(); + } catch (Throwable e) { + WorkerError err = new WorkerError(1, e.getMessage()); + String errMessage = g.toJson(err, err.getClass()); + String errJson = + String.format("{ \"jsonrpc\": \"2.0\", \"error\": %s, \"id\": %d }", errMessage, id); + this.originalOut.println(errJson); + this.originalOut.flush(); } } @@ -79,20 +126,7 @@ public final class WorkerMain { throw new RuntimeException("missing jvmRunInfo element"); } RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo; - URL[] urls = - jvmRunInfo - .classpath - .stream() - .map( - filePath -> { - try { - return filePath.path.toURL(); - } catch (MalformedURLException e) { - throw new RuntimeException(e); - } - }) - .toArray(URL[]::new); - URLClassLoader cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader()); + URLClassLoader cl = createClassLoader(jvmRunInfo, ClassLoader.getSystemClassLoader()); try { Class mainClass = cl.loadClass(jvmRunInfo.mainClass); Method mainMethod = mainClass.getMethod("main", String[].class); @@ -105,4 +139,37 @@ public final class WorkerMain { throw new RuntimeException("only jvm is supported"); } } + + void test(long id, TestInfo info) throws Exception { + if (info.jvm) { + RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo; + ClassLoader parent = new ForkTestMain().getClass().getClassLoader(); + ClassLoader cl = createClassLoader(jvmRunInfo, parent); + try { + ForkTestMain.main(id, info, this.originalOut, cl); + } finally { + if (cl instanceof URLClassLoader) { + ((URLClassLoader) cl).close(); + } + } + } else { + throw new RuntimeException("only jvm is supported"); + } + } + + private URLClassLoader createClassLoader(RunInfo.JvmRunInfo info, ClassLoader parent) { + URL[] urls = + info.classpath + .stream() + .map( + filePath -> { + try { + return filePath.path.toURL(); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + }) + .toArray(URL[]::new); + return new URLClassLoader(urls, parent); + } } diff --git a/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala b/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala index f516ab45c..8285f3997 100644 --- a/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala +++ b/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala @@ -8,7 +8,7 @@ object WorkerTest extends verify.BasicTestSuite: test("process") { val u0 = IO.classLocationPath(classOf[example.Hello]).toUri() val u1 = IO.classLocationPath(classOf[scala.quoted.Quotes]).toUri() - val u2 = IO.classLocationPath(classOf[scala.AnyVal]).toUri() + val u2 = IO.classLocationPath(classOf[scala.collection.immutable.List[?]]).toUri() val cp = s"""[{ "path": "${u0}", "digest": "" }, { "path": "${u1}", "digest": "" }, { "path": "${u2}", "digest": "" }]""" val runInfo =