diff --git a/build.sbt b/build.sbt
index c3ec08cc4..c90b04fe2 100644
--- a/build.sbt
+++ b/build.sbt
@@ -471,7 +471,7 @@ lazy val utilScripted = (project in file("internal") / "util-scripted")
// Runner for uniform test interface
lazy val testingProj = (project in file("testing"))
.enablePlugins(ContrabandPlugin, JsonCodecPlugin)
- .dependsOn(testAgentProj, utilLogging)
+ .dependsOn(workerProj, utilLogging)
.settings(
baseSettings,
name := "Testing",
@@ -518,26 +518,15 @@ lazy val testingProj = (project in file("testing"))
)
.configure(addSbtIO, addSbtCompilerClasspath)
-// Testing agent for running tests in a separate process.
-lazy val testAgentProj = (project in file("testing") / "agent")
- .settings(
- minimalSettings,
- crossPaths := false,
- autoScalaLibrary := false,
- Compile / doc / javacOptions := Nil,
- name := "Test Agent",
- libraryDependencies += testInterface,
- mimaSettings,
- )
-
lazy val workerProj = (project in file("worker"))
.dependsOn(exampleWorkProj % Test)
.settings(
name := "worker",
testedBaseSettings,
Compile / doc / javacOptions := Nil,
+ crossPaths := false,
autoScalaLibrary := false,
- libraryDependencies += gson,
+ libraryDependencies ++= Seq(gson, testInterface),
libraryDependencies += "org.scala-lang" %% "scala3-library" % scalaVersion.value % Test,
// run / fork := false,
Test / fork := true,
@@ -1218,7 +1207,6 @@ def allProjects =
logicProj,
completeProj,
testingProj,
- testAgentProj,
taskProj,
stdTaskProj,
runProj,
diff --git a/main-actions/src/main/scala/sbt/ForkTests.scala b/main-actions/src/main/scala/sbt/ForkTests.scala
index 677be9e07..7fdf27043 100755
--- a/main-actions/src/main/scala/sbt/ForkTests.scala
+++ b/main-actions/src/main/scala/sbt/ForkTests.scala
@@ -8,21 +8,33 @@
package sbt
-import scala.collection.mutable
+import com.google.gson.{ JsonObject, JsonParser }
import testing.{ Logger as _, Task as _, * }
-import scala.util.control.NonFatal
-import java.net.ServerSocket
import java.io.*
+import java.util.ArrayList
import Tests.{ Output as TestOutput, * }
-import sbt.io.IO
import sbt.util.Logger
import sbt.ConcurrentRestrictions.Tag
import sbt.protocol.testing.*
+import sbt.internal.{ WorkerExchange, WorkerResponseListener }
import sbt.internal.util.Util.*
-import sbt.internal.util.{ Terminal as UTerminal }
+import sbt.internal.util.{ MessageOnlyException, Terminal as UTerminal }
+import sbt.internal.worker1.*
import xsbti.{ FileConverter, HashedVirtualFileRef }
+import scala.collection.mutable
+import scala.concurrent.{ Await, Promise }
+import scala.concurrent.duration.Duration
+import scala.util.Random
+import scala.jdk.CollectionConverters.*
+import scala.sys.process.Process
+
+/**
+ * This implements forked testing, in cooperation with the worker CLI,
+ * which was previously called test-agent.jar.
+ */
+private[sbt] object ForkTests:
+ val r = Random()
-private[sbt] object ForkTests {
def apply(
runners: Map[TestFramework, Runner],
opts: ProcessedOptions,
@@ -87,147 +99,145 @@ private[sbt] object ForkTests {
parallel: Boolean
): Task[TestOutput] =
std.TaskExtra.task {
- val server = new ServerSocket(0)
- val testListeners = opts.testListeners flatMap {
+ val testListeners = opts.testListeners.flatMap:
case tl: TestsListener => tl.some
case _ => none[TestsListener]
- }
-
- object Acceptor extends Runnable {
- val resultsAcc = mutable.Map.empty[String, SuiteResult]
- lazy val result =
- TestOutput(overall(resultsAcc.values.map(_.result)), resultsAcc.toMap, Iterable.empty)
-
- def run(): Unit = {
- val socket =
- try {
- server.accept()
- } catch {
- case e: java.net.SocketException =>
- log.error(
- "Could not accept connection from test agent: " + e.getClass + ": " + e.getMessage
- )
- log.trace(e)
- server.close()
- return
- }
- val os = new ObjectOutputStream(socket.getOutputStream)
- // Must flush the header that the constructor writes, otherwise the ObjectInputStream on the other end may block indefinitely
- os.flush()
- val is = new ObjectInputStream(socket.getInputStream)
-
- try {
- val config = new ForkConfiguration(UTerminal.isAnsiSupported, parallel)
- os.writeObject(config)
-
- val taskdefs = opts.tests.map { t =>
- new TaskDef(
- t.name,
- forkFingerprint(t.fingerprint),
- t.explicitlySpecified,
- t.selectors
- )
- }
- os.writeObject(taskdefs.toArray)
-
- os.writeInt(runners.size)
- for ((testFramework, mainRunner) <- runners) {
- os.writeObject(testFramework.implClassNames.toArray)
- os.writeObject(mainRunner.args)
- os.writeObject(mainRunner.remoteArgs)
- }
- os.flush()
-
- new React(is, os, log, opts.testListeners, resultsAcc).react()
- } catch {
- case NonFatal(e) =>
- def throwableToString(t: Throwable) = {
- import java.io.*; val sw = new StringWriter; t.printStackTrace(new PrintWriter(sw));
- sw.toString
- }
- resultsAcc("Forked test harness failed: " + throwableToString(e)) = SuiteResult.Error
- } finally {
- is.close(); os.close(); socket.close()
- }
- }
- }
-
- try {
- testListeners.foreach(_.doInit())
- val acceptorThread = new Thread(Acceptor)
- acceptorThread.start()
- val cpFiles = classpath.map(converter.toPath).map(_.toFile())
- val fullCp = cpFiles ++ Seq(
- IO.classLocationPath(classOf[ForkMain]).toFile,
- IO.classLocationPath(classOf[Framework]).toFile,
+ val resultsAcc = mutable.Map.empty[String, SuiteResult]
+ val randomId = r.nextLong()
+ def testOutputResult =
+ TestOutput(
+ overall(resultsAcc.values.map(_.result)),
+ resultsAcc.toMap,
+ Iterable.empty
)
- val options = Seq(
- "-classpath",
- fullCp mkString File.pathSeparator,
- classOf[ForkMain].getCanonicalName,
- server.getLocalPort.toString
+ val taskdefs = opts.tests.map: t =>
+ new TaskDef(
+ t.name,
+ forkFingerprint(t.fingerprint),
+ t.explicitlySpecified,
+ t.selectors
)
- val ec = Fork.java(fork, options)
- val result =
- if (ec != 0)
- TestOutput(
- TestResult.Error,
- Map(
- "Running java with options " + options
- .mkString(" ") + " failed with exit code " + ec -> SuiteResult.Error
- ),
- Iterable.empty
- )
- else {
- // Need to wait acceptor thread to finish its business
- acceptorThread.join()
- Acceptor.result
- }
+ val testRunners = runners.toSeq.map: (testFramework, mainRunner) =>
+ TestInfo.TestRunner(
+ ArrayList(testFramework.implClassNames.asJava),
+ ArrayList(mainRunner.args().toList.asJava),
+ ArrayList(mainRunner.remoteArgs().toList.asJava)
+ )
+ val g = WorkerMain.mkGson()
+ // virtualize classloading by using ClassLoader
+ val useClassLoader = true
+ val cpList = ArrayList[FilePath](
+ (classpath
+ .map: vf =>
+ FilePath(converter.toPath(vf).toUri(), vf.contentHashStr()))
+ .asJava
+ )
+ val cpFiles =
+ classpath.map: vf =>
+ converter.toPath(vf).toFile()
+ val param = TestInfo(
+ true, /* jvm */
+ RunInfo.JvmRunInfo(
+ ArrayList(),
+ if useClassLoader then cpList else ArrayList(),
+ "",
+ false /*connectInput*/,
+ ),
+ null,
+ UTerminal.isAnsiSupported,
+ parallel,
+ ArrayList(taskdefs.asJava),
+ ArrayList(testRunners.asJava),
+ )
+ testListeners.foreach(_.doInit())
+ val result =
+ val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles)
+ val wl = React(randomId, log, opts.testListeners, resultsAcc, w.process)
+ try
+ WorkerExchange.registerListener(wl)
+ val paramJson = g.toJson(param, param.getClass)
+ val json = jsonRpcRequest(randomId, "test", paramJson)
+ w.println(json)
+ if wl.blockForResponse() != 0 then
+ throw MessageOnlyException("Forked test harness failed")
+ testOutputResult
+ finally WorkerExchange.unregisterListener(wl)
+ testListeners.foreach(_.doComplete(result.overall))
+ result
+ } // end task
- testListeners.foreach(_.doComplete(result.overall))
- result
- } finally {
- server.close()
- }
- }
+ private def jsonRpcRequest(id: Long, method: String, params: String): String =
+ s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params, "id": $id }"""
private def forkFingerprint(f: Fingerprint): Fingerprint & Serializable =
- f match {
- case s: SubclassFingerprint => new ForkMain.SubclassFingerscan(s)
- case a: AnnotatedFingerprint => new ForkMain.AnnotatedFingerscan(a)
+ f match
+ case s: SubclassFingerprint => ForkTestMain.SubclassFingerscan(s)
+ case a: AnnotatedFingerprint => ForkTestMain.AnnotatedFingerscan(a)
case _ => sys.error("Unknown fingerprint type: " + f.getClass)
- }
-}
-private final class React(
- is: ObjectInputStream,
- os: ObjectOutputStream,
+end ForkTests
+
+private class React(
+ id: Long,
log: Logger,
listeners: Seq[TestReportListener],
- results: mutable.Map[String, SuiteResult]
-) {
- import ForkTags.*
- @annotation.tailrec
- def react(): Unit = is.readObject match {
- case `Done` =>
- os.writeObject(Done); os.flush()
- case Array(`Error`, s: String) =>
- log.error(s); react()
- case Array(`Warn`, s: String) =>
- log.warn(s); react()
- case Array(`Info`, s: String) =>
- log.info(s); react()
- case Array(`Debug`, s: String) =>
- log.debug(s); react()
- case t: Throwable =>
- log.trace(t); react()
- case Array(group: String, tEvents: Array[Event]) =>
- val events = tEvents.toSeq
- listeners.foreach(_.startGroup(group))
- val event = TestEvent(events)
- listeners.foreach(_.testEvent(event))
- val suiteResult = SuiteResult(events)
- results += group -> suiteResult
- listeners.foreach(_.endGroup(group, suiteResult.result))
- react()
- }
-}
+ results: mutable.Map[String, SuiteResult],
+ process: Process
+) extends WorkerResponseListener:
+ val g = WorkerMain.mkGson()
+ val promise: Promise[Int] = Promise()
+ override def apply(line: String): Unit =
+ // scala.Console.err.println(line)
+ val o = JsonParser.parseString(line).getAsJsonObject()
+ if o.has("id") then
+ val resId = o.getAsJsonPrimitive("id").getAsLong()
+ if resId == id then
+ if promise.isCompleted then ()
+ else if o.has("error") then promise.failure(new RuntimeException(line))
+ else promise.success(0)
+ else ()
+ else if o.has("method") then processNotification(o)
+ else ()
+
+ override def notifyExit(p: Process): Unit =
+ if !process.isAlive then promise.success(process.exitValue())
+
+ def processNotification(o: JsonObject): Unit =
+ val method = o.getAsJsonPrimitive("method").getAsString()
+ method match
+ case "testLog" =>
+ val params = o.getAsJsonObject("params")
+ val info = g.fromJson[TestLogInfo](params, classOf[TestLogInfo])
+ if info.id == id then
+ info.tag match
+ case ForkTags.Error => log.error(info.message)
+ case ForkTags.Warn => log.warn(info.message)
+ case ForkTags.Info => log.info(info.message)
+ case ForkTags.Debug => log.debug(info.message)
+ case _ => ()
+ else ()
+ case "testEvents" =>
+ val params = o.getAsJsonObject("params")
+ val info =
+ g.fromJson[ForkTestMain.ForkEventsInfo](params, classOf[ForkTestMain.ForkEventsInfo])
+ if info.id == id then
+ val events = info.events.asScala.toSeq
+ listeners.foreach(_.startGroup(info.group))
+ val event = TestEvent(events)
+ listeners.foreach(_.testEvent(event))
+ val suiteResult = SuiteResult(events)
+ results += info.group -> suiteResult
+ listeners.foreach(_.endGroup(info.group, suiteResult.result))
+ else ()
+ case "forkError" =>
+ val params = o.getAsJsonObject("params")
+ val info =
+ g.fromJson[ForkTestMain.ForkErrorInfo](params, classOf[ForkTestMain.ForkErrorInfo])
+ if info.id == id then
+ log.trace(info.error)
+ promise.failure(info.error)
+ else ()
+ case _ => ()
+
+ def blockForResponse(): Int =
+ Await.result(promise.future, Duration.Inf)
+end React
diff --git a/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala
new file mode 100644
index 000000000..e57e57783
--- /dev/null
+++ b/main-actions/src/main/scala/sbt/internal/WorkerExchange.scala
@@ -0,0 +1,88 @@
+/*
+ * sbt
+ * Copyright 2023, Scala center
+ * Copyright 2011 - 2022, Lightbend, Inc.
+ * Copyright 2008 - 2010, Mark Harrah
+ * Licensed under Apache License 2.0 (see LICENSE)
+ */
+
+package sbt
+package internal
+
+import com.google.gson.Gson
+import java.io.*
+import java.util.concurrent.atomic.AtomicReference
+import sbt.io.IO
+import sbt.internal.worker1.*
+import sbt.testing.Framework
+import scala.sys.process.{ BasicIO, Process, ProcessIO }
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+object WorkerExchange:
+ val listeners: mutable.ListBuffer[WorkerResponseListener] = ListBuffer.empty
+
+ /**
+ * Start a worker process.
+ */
+ def startWorker(fo: ForkOptions, extraCp: Seq[File]): WorkerProxy =
+ val fullCp = Seq(
+ IO.classLocationPath(classOf[WorkerMain]).toFile,
+ IO.classLocationPath(classOf[Framework]).toFile,
+ IO.classLocationPath(classOf[Gson]).toFile,
+ ) ++ extraCp
+ val options = Seq(
+ "-classpath",
+ fullCp.mkString(File.pathSeparator),
+ classOf[WorkerMain].getCanonicalName,
+ )
+ val inputRef = AtomicReference[OutputStream]()
+ val processIo = ProcessIO(
+ in = (input) => inputRef.set(input),
+ out = BasicIO.processFully(onStdoutLine),
+ err = BasicIO.processFully((line) => scala.Console.err.println(line)),
+ )
+ val forkWithIo = fo.withOutputStrategy(OutputStrategy.CustomInputOutput(processIo))
+ val p = Fork.java.fork(forkWithIo, options)
+ WorkerProxy(inputRef.get(), p, options)
+
+ def registerListener(listener: WorkerResponseListener): Unit =
+ synchronized:
+ listeners.append(listener)
+
+ def unregisterListener(listener: WorkerResponseListener): Unit =
+ synchronized:
+ if listeners.contains(listener) then listeners.remove(listeners.indexOf(listener))
+ else ()
+
+ /**
+ * Unified worker output handler.
+ */
+ def onStdoutLine(line: String): Unit =
+ synchronized:
+ listeners.foreach: wl =>
+ wl(line)
+end WorkerExchange
+
+class WorkerProxy(input: OutputStream, val process: Process, val options: Seq[String])
+ extends AutoCloseable:
+ lazy val inputStream = PrintStream(input)
+ def close(): Unit = input.close()
+ def blockForExitCode(): Int =
+ if !process.isAlive then process.exitValue()
+ else Fork.blockForExitCode(process)
+
+ /** print a line into stdin of the worker process. */
+ def println(str: String): Unit =
+ inputStream.println(str)
+ inputStream.flush()
+
+ val watch = Thread(() => {
+ while process.isAlive() do Thread.sleep(100)
+ WorkerExchange.listeners.foreach(_.notifyExit(process))
+ })
+ watch.start()
+end WorkerProxy
+
+abstract class WorkerResponseListener extends Function1[String, Unit]:
+ def notifyExit(p: Process): Unit
diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala
index 5545af24e..a5368b335 100644
--- a/main/src/main/scala/sbt/Defaults.scala
+++ b/main/src/main/scala/sbt/Defaults.scala
@@ -1493,9 +1493,8 @@ object Defaults extends BuildCommon {
}
}
val summaries =
- runners map { (tf, r) =>
+ runners.map: (tf, r) =>
Tests.Summary(frameworks(tf).name, r.done())
- }
out.copy(summaries = summaries)
}
// Def.value[Task[Tests.Output]] {
diff --git a/run/src/main/scala/sbt/Fork.scala b/run/src/main/scala/sbt/Fork.scala
index dc53237c3..8b75e4fef 100644
--- a/run/src/main/scala/sbt/Fork.scala
+++ b/run/src/main/scala/sbt/Fork.scala
@@ -185,13 +185,13 @@ object Fork {
()
}
val process = Process(jpb)
- outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match {
+ outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match
case StdoutOutput => process.run(connectInput = false)
case out: BufferedOutput =>
out.logger.buffer { process.run(out.logger, connectInput = false) }
- case out: LoggedOutput => process.run(out.logger, connectInput = false)
- case out: CustomOutput => (process #> out.output).run(connectInput = false)
- }
+ case out: LoggedOutput => process.run(out.logger, connectInput = false)
+ case out: CustomOutput => (process #> out.output).run(connectInput = false)
+ case out: CustomInputOutput => process.run(out.processIO)
}
private[sbt] def blockForExitCode(p: Process): Int = {
diff --git a/run/src/main/scala/sbt/OutputStrategy.scala b/run/src/main/scala/sbt/OutputStrategy.scala
index 26b7e710b..3066b5a13 100644
--- a/run/src/main/scala/sbt/OutputStrategy.scala
+++ b/run/src/main/scala/sbt/OutputStrategy.scala
@@ -8,6 +8,7 @@
package sbt
+import scala.sys.process.ProcessIO
import sbt.util.Logger
import java.io.OutputStream
@@ -102,4 +103,28 @@ object OutputStrategy {
object CustomOutput {
def apply(output: OutputStream): CustomOutput = new CustomOutput(output)
}
+
+ /**
+ * Configures the forked IO.
+ */
+ final class CustomInputOutput private (val processIO: ProcessIO)
+ extends OutputStrategy
+ with Serializable:
+ override def equals(o: Any): Boolean = o match
+ case x: CustomInputOutput => (this.processIO == x.processIO)
+ case _ => false
+ override def hashCode: Int =
+ 37 * (17 + processIO.##) + "CustomInputOutput".##
+ override def toString: String =
+ "CustomInputOutput(...)"
+ private def copy(processIO: ProcessIO = processIO): CustomInputOutput =
+ new CustomInputOutput(processIO)
+
+ def withProcessIO(processIO: ProcessIO): CustomInputOutput =
+ copy(processIO = processIO)
+ end CustomInputOutput
+
+ object CustomInputOutput:
+ def apply(processIO: ProcessIO): CustomInputOutput = new CustomInputOutput(processIO)
+ end CustomInputOutput
}
diff --git a/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt b/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt
index a43dedc9c..2bbba6379 100644
--- a/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt
+++ b/sbt-app/src/sbt-test/dependency-management/update-sbt-classifiers/build.sbt
@@ -33,6 +33,7 @@ lazy val root = (project in file("."))
"com.eed3si9n:sjson-new-scalajson_3",
"com.github.ben-manes.caffeine:caffeine",
"com.github.mwiede:jsch",
+ "com.google.code.gson:gson",
"com.google.errorprone:error_prone_annotations",
"com.lmax:disruptor",
"com.swoval:file-tree-views",
@@ -82,7 +83,7 @@ lazy val root = (project in file("."))
)
def assertCollectionsEqual(message: String, expected: Seq[String], actual: Seq[String]): Unit =
// using the new line for a more readable comparison failure output
- assert(expected.mkString("\n") == actual.mkString("\n"), message)
+ assert(expected.mkString("\n") == actual.mkString("\n"), message + ": " + actual)
assertCollectionsEqual(
"Unexpected module ids in updateSbtClassifiers",
diff --git a/sbt-app/src/sbt-test/tests/fork/build.sbt b/sbt-app/src/sbt-test/tests/fork/build.sbt
index 04553d69d..635492a86 100644
--- a/sbt-app/src/sbt-test/tests/fork/build.sbt
+++ b/sbt-app/src/sbt-test/tests/fork/build.sbt
@@ -11,6 +11,7 @@ val scalaxml = "org.scala-lang.modules" %% "scala-xml" % "1.1.1"
def groupId(idx: Int) = "group_" + (idx + 1)
def groupPrefix(idx: Int) = groupId(idx) + "_file_"
+Global / localCacheDirectory := baseDirectory.value / "diskcache"
ThisBuild / scalaVersion := "2.12.20"
ThisBuild / organization := "org.example"
@@ -19,7 +20,7 @@ lazy val root = (project in file("."))
Test / testGrouping := Def.uncached {
val tests = (Test / definedTests).value
assert(tests.size == 3)
- for (idx <- 0 until groups) yield
+ for idx <- 0 until groups yield
new Group(
groupId(idx),
tests,
@@ -28,11 +29,11 @@ lazy val root = (project in file("."))
},
check := Def.uncached {
val files =
- for(i <- 0 until groups; j <- 1 to groupSize) yield
+ for i <- 0 until groups; j <- 1 to groupSize yield
file(groupPrefix(i) + j)
val (exist, absent) = files.partition(_.exists)
exist.foreach(_.delete())
- if (absent.nonEmpty)
+ if absent.nonEmpty then
sys.error("Files were not created:\n\t" + absent.mkString("\n\t"))
},
concurrentRestrictions := Tags.limit(Tags.ForkedTestGroup, 2) :: Nil,
diff --git a/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala b/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala
index f2d45c34f..f45303e11 100644
--- a/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala
+++ b/sbt-app/src/sbt-test/tests/fork2/changes/Test.scala
@@ -1,8 +1,8 @@
import org.scalatest.FlatSpec
class Test extends FlatSpec {
- val v = sys.env.getOrElse("tests.max.value", Int.MaxValue)
- "A simple equation" should "hold" in {
- assert(Int.MaxValue == v)
- }
+ val v = sys.env.getOrElse("tests.max.value", Int.MaxValue)
+ "A simple equation" should "hold" in {
+ assert(Int.MaxValue == v)
+ }
}
diff --git a/testing/agent/src/main/java/sbt/ForkConfiguration.java b/testing/agent/src/main/java/sbt/ForkConfiguration.java
deleted file mode 100644
index 993fe6908..000000000
--- a/testing/agent/src/main/java/sbt/ForkConfiguration.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * sbt
- * Copyright 2023, Scala center
- * Copyright 2011 - 2022, Lightbend, Inc.
- * Copyright 2008 - 2010, Mark Harrah
- * Licensed under Apache License 2.0 (see LICENSE)
- */
-
-package sbt;
-
-import java.io.Serializable;
-
-public final class ForkConfiguration implements Serializable {
- private final boolean ansiCodesSupported;
- private final boolean parallel;
-
- public ForkConfiguration(final boolean ansiCodesSupported, final boolean parallel) {
- this.ansiCodesSupported = ansiCodesSupported;
- this.parallel = parallel;
- }
-
- public boolean isAnsiCodesSupported() {
- return ansiCodesSupported;
- }
-
- public boolean isParallel() {
- return parallel;
- }
-}
diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala
index 6ef27739c..3521230c0 100644
--- a/testing/src/main/scala/sbt/TestFramework.scala
+++ b/testing/src/main/scala/sbt/TestFramework.scala
@@ -67,8 +67,9 @@ final class TestFramework(val implClassNames: String*) extends Serializable {
case head :: tail =>
try {
Some(Class.forName(head, true, loader).getDeclaredConstructor().newInstance() match {
- case newFramework: Framework => newFramework
- case oldFramework: OldFramework => new FrameworkWrapper(oldFramework)
+ case newFramework: Framework => newFramework
+ case oldFramework: OldFramework =>
+ new sbt.internal.worker1.FrameworkWrapper(oldFramework)
})
} catch {
case e: NoClassDefFoundError =>
diff --git a/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java b/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java
new file mode 100644
index 000000000..00e9b3dd4
--- /dev/null
+++ b/worker/src/main/java/com/google/gson/typeadapters/RuntimeTypeAdapterFactory.java
@@ -0,0 +1,344 @@
+/*
+ * Copyright (C) 2011 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.gson.typeadapters;
+
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
+import com.google.gson.Gson;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.TypeAdapter;
+import com.google.gson.TypeAdapterFactory;
+import com.google.gson.reflect.TypeToken;
+import com.google.gson.stream.JsonReader;
+import com.google.gson.stream.JsonWriter;
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * Adapts values whose runtime type may differ from their declaration type. This is necessary when a
+ * field's type is not the same type that GSON should create when deserializing that field. For
+ * example, consider these types:
+ *
+ *
+ * {
+ * @code
+ * abstract class Shape {
+ * int x;
+ * int y;
+ * }
+ * class Circle extends Shape {
+ * int radius;
+ * }
+ * class Rectangle extends Shape {
+ * int width;
+ * int height;
+ * }
+ * class Diamond extends Shape {
+ * int width;
+ * int height;
+ * }
+ * class Drawing {
+ * Shape bottomShape;
+ * Shape topShape;
+ * }
+ * }
+ *
+ *
+ * Without additional type information, the serialized JSON is ambiguous. Is the bottom shape in
+ * this drawing a rectangle or a diamond?
+ *
+ *
{@code
+ * {
+ * "bottomShape": {
+ * "width": 10,
+ * "height": 5,
+ * "x": 0,
+ * "y": 0
+ * },
+ * "topShape": {
+ * "radius": 2,
+ * "x": 4,
+ * "y": 1
+ * }
+ * }
+ * }
+ *
+ * This class addresses this problem by adding type information to the serialized JSON and honoring
+ * that type information when the JSON is deserialized:
+ *
+ * {@code
+ * {
+ * "bottomShape": {
+ * "type": "Diamond",
+ * "width": 10,
+ * "height": 5,
+ * "x": 0,
+ * "y": 0
+ * },
+ * "topShape": {
+ * "type": "Circle",
+ * "radius": 2,
+ * "x": 4,
+ * "y": 1
+ * }
+ * }
+ * }
+ *
+ * Both the type field name ({@code "type"}) and the type labels ({@code "Rectangle"}) are
+ * configurable.
+ *
+ * Registering Types
+ *
+ * Create a {@code RuntimeTypeAdapterFactory} by passing the base type and type field name to the
+ * {@link #of} factory method. If you don't supply an explicit type field name, {@code "type"} will
+ * be used.
+ *
+ *
+ * {
+ * @code
+ * RuntimeTypeAdapterFactory<Shape> shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class, "type");
+ * }
+ *
+ *
+ * Next register all of your subtypes. Every subtype must be explicitly registered. This protects
+ * your application from injection attacks. If you don't supply an explicit type label, the type's
+ * simple name will be used.
+ *
+ * {@code
+ * shapeAdapterFactory.registerSubtype(Rectangle.class, "Rectangle");
+ * shapeAdapterFactory.registerSubtype(Circle.class, "Circle");
+ * shapeAdapterFactory.registerSubtype(Diamond.class, "Diamond");
+ * }
+ *
+ * Finally, register the type adapter factory in your application's GSON builder:
+ *
+ *
+ * {
+ * @code
+ * Gson gson = new GsonBuilder().registerTypeAdapterFactory(shapeAdapterFactory).create();
+ * }
+ *
+ *
+ * Like {@code GsonBuilder}, this API supports chaining:
+ *
+ *
+ * {
+ * @code
+ * RuntimeTypeAdapterFactory<Shape> shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class)
+ * .registerSubtype(Rectangle.class).registerSubtype(Circle.class).registerSubtype(Diamond.class);
+ * }
+ *
+ *
+ * Serialization and deserialization
+ *
+ * In order to serialize and deserialize a polymorphic object, you must specify the base type
+ * explicitly.
+ *
+ *
+ * {
+ * @code
+ * Diamond diamond = new Diamond();
+ * String json = gson.toJson(diamond, Shape.class);
+ * }
+ *
+ *
+ * And then:
+ *
+ *
+ * {
+ * @code
+ * Shape shape = gson.fromJson(json, Shape.class);
+ * }
+ *
+ */
+public final class RuntimeTypeAdapterFactory implements TypeAdapterFactory {
+ private final Class> baseType;
+ private final String typeFieldName;
+ private final Map> labelToSubtype = new LinkedHashMap<>();
+ private final Map, String> subtypeToLabel = new LinkedHashMap<>();
+ private final boolean maintainType;
+ private boolean recognizeSubtypes;
+
+ private RuntimeTypeAdapterFactory(Class> baseType, String typeFieldName, boolean maintainType) {
+ if (typeFieldName == null || baseType == null) {
+ throw new NullPointerException();
+ }
+ this.baseType = baseType;
+ this.typeFieldName = typeFieldName;
+ this.maintainType = maintainType;
+ }
+
+ /**
+ * Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type
+ * field name. Type field names are case sensitive.
+ *
+ * @param maintainType true if the type field should be included in deserialized objects
+ */
+ public static RuntimeTypeAdapterFactory of(
+ Class baseType, String typeFieldName, boolean maintainType) {
+ return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, maintainType);
+ }
+
+ /**
+ * Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type
+ * field name. Type field names are case sensitive.
+ */
+ public static RuntimeTypeAdapterFactory of(Class baseType, String typeFieldName) {
+ return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, false);
+ }
+
+ /**
+ * Creates a new runtime type adapter for {@code baseType} using {@code "type"} as the type field
+ * name.
+ */
+ public static RuntimeTypeAdapterFactory of(Class baseType) {
+ return new RuntimeTypeAdapterFactory<>(baseType, "type", false);
+ }
+
+ /**
+ * Ensures that this factory will handle not just the given {@code baseType}, but any subtype of
+ * that type.
+ */
+ @CanIgnoreReturnValue
+ public RuntimeTypeAdapterFactory recognizeSubtypes() {
+ this.recognizeSubtypes = true;
+ return this;
+ }
+
+ /**
+ * Registers {@code type} identified by {@code label}. Labels are case sensitive.
+ *
+ * @throws IllegalArgumentException if either {@code type} or {@code label} have already been
+ * registered on this type adapter.
+ */
+ @CanIgnoreReturnValue
+ public RuntimeTypeAdapterFactory registerSubtype(Class extends T> type, String label) {
+ if (type == null || label == null) {
+ throw new NullPointerException();
+ }
+ if (subtypeToLabel.containsKey(type) || labelToSubtype.containsKey(label)) {
+ throw new IllegalArgumentException("types and labels must be unique");
+ }
+ labelToSubtype.put(label, type);
+ subtypeToLabel.put(type, label);
+ return this;
+ }
+
+ /**
+ * Registers {@code type} identified by its {@link Class#getSimpleName simple name}. Labels are
+ * case sensitive.
+ *
+ * @throws IllegalArgumentException if either {@code type} or its simple name have already been
+ * registered on this type adapter.
+ */
+ @CanIgnoreReturnValue
+ public RuntimeTypeAdapterFactory registerSubtype(Class extends T> type) {
+ return registerSubtype(type, type.getSimpleName());
+ }
+
+ @Override
+ public TypeAdapter create(Gson gson, TypeToken type) {
+ if (type == null) {
+ return null;
+ }
+ Class> rawType = type.getRawType();
+ boolean handle =
+ recognizeSubtypes ? baseType.isAssignableFrom(rawType) : baseType.equals(rawType);
+ if (!handle) {
+ return null;
+ }
+
+ TypeAdapter jsonElementAdapter = gson.getAdapter(JsonElement.class);
+ Map> labelToDelegate = new LinkedHashMap<>();
+ Map, TypeAdapter>> subtypeToDelegate = new LinkedHashMap<>();
+ for (Map.Entry> entry : labelToSubtype.entrySet()) {
+ TypeAdapter> delegate = gson.getDelegateAdapter(this, TypeToken.get(entry.getValue()));
+ labelToDelegate.put(entry.getKey(), delegate);
+ subtypeToDelegate.put(entry.getValue(), delegate);
+ }
+
+ return new TypeAdapter() {
+ @Override
+ public R read(JsonReader in) throws IOException {
+ JsonElement jsonElement = jsonElementAdapter.read(in);
+ JsonElement labelJsonElement;
+ if (maintainType) {
+ labelJsonElement = jsonElement.getAsJsonObject().get(typeFieldName);
+ } else {
+ labelJsonElement = jsonElement.getAsJsonObject().remove(typeFieldName);
+ }
+
+ if (labelJsonElement == null) {
+ throw new JsonParseException(
+ "cannot deserialize "
+ + baseType
+ + " because it does not define a field named "
+ + typeFieldName);
+ }
+ String label = labelJsonElement.getAsString();
+ @SuppressWarnings("unchecked") // registration requires that subtype extends T
+ TypeAdapter delegate = (TypeAdapter) labelToDelegate.get(label);
+ if (delegate == null) {
+ throw new JsonParseException(
+ "cannot deserialize "
+ + baseType
+ + " subtype named "
+ + label
+ + "; did you forget to register a subtype?");
+ }
+ return delegate.fromJsonTree(jsonElement);
+ }
+
+ @Override
+ public void write(JsonWriter out, R value) throws IOException {
+ Class> srcType = value.getClass();
+ String label = subtypeToLabel.get(srcType);
+ @SuppressWarnings("unchecked") // registration requires that subtype extends T
+ TypeAdapter delegate = (TypeAdapter) subtypeToDelegate.get(srcType);
+ if (delegate == null) {
+ throw new JsonParseException(
+ "cannot serialize " + srcType.getName() + "; did you forget to register a subtype?");
+ }
+ JsonObject jsonObject = delegate.toJsonTree(value).getAsJsonObject();
+
+ if (maintainType) {
+ jsonElementAdapter.write(out, jsonObject);
+ return;
+ }
+
+ JsonObject clone = new JsonObject();
+
+ if (jsonObject.has(typeFieldName)) {
+ throw new JsonParseException(
+ "cannot serialize "
+ + srcType.getName()
+ + " because it already defines a field named "
+ + typeFieldName);
+ }
+ clone.add(typeFieldName, new JsonPrimitive(label));
+
+ for (Map.Entry e : jsonObject.entrySet()) {
+ clone.add(e.getKey(), e.getValue());
+ }
+ jsonElementAdapter.write(out, clone);
+ }
+ }.nullSafe();
+ }
+}
diff --git a/testing/agent/src/main/java/sbt/ForkTags.java b/worker/src/main/java/sbt/internal/worker1/ForkTags.java
similarity index 89%
rename from testing/agent/src/main/java/sbt/ForkTags.java
rename to worker/src/main/java/sbt/internal/worker1/ForkTags.java
index 1b0660119..d5daabec2 100644
--- a/testing/agent/src/main/java/sbt/ForkTags.java
+++ b/worker/src/main/java/sbt/internal/worker1/ForkTags.java
@@ -6,7 +6,7 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
-package sbt;
+package sbt.internal.worker1;
public enum ForkTags {
Error,
diff --git a/testing/agent/src/main/java/sbt/ForkMain.java b/worker/src/main/java/sbt/internal/worker1/ForkTestMain.java
similarity index 66%
rename from testing/agent/src/main/java/sbt/ForkMain.java
rename to worker/src/main/java/sbt/internal/worker1/ForkTestMain.java
index fdc79a7d8..607cf3e0d 100644
--- a/testing/agent/src/main/java/sbt/ForkMain.java
+++ b/worker/src/main/java/sbt/internal/worker1/ForkTestMain.java
@@ -6,16 +6,15 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
-package sbt;
+package sbt.internal.worker1;
+
+import com.google.gson.Gson;
import sbt.testing.*;
import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
+import java.io.PrintStream;
import java.io.Serializable;
-import java.net.Socket;
-import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -23,7 +22,7 @@ import java.util.List;
import java.util.LinkedHashSet;
import java.util.concurrent.*;
-public final class ForkMain {
+public class ForkTestMain {
// serializables
// -----------------------------------------------------------------------------
@@ -70,7 +69,7 @@ public final class ForkMain {
}
}
- static final class ForkEvent implements Event, Serializable {
+ public static final class ForkEvent implements Event, Serializable {
private final String fullyQualifiedName;
private final Fingerprint fingerprint;
private final Selector selector;
@@ -79,23 +78,23 @@ public final class ForkMain {
private final long duration;
ForkEvent(final Event e) {
- fullyQualifiedName = e.fullyQualifiedName();
+ this.fullyQualifiedName = e.fullyQualifiedName();
final Fingerprint rawFingerprint = e.fingerprint();
if (rawFingerprint instanceof SubclassFingerprint)
- fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint);
- else fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint);
+ this.fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint);
+ else this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint);
- selector = e.selector();
+ this.selector = e.selector();
checkSerializableSelector(selector);
- status = e.status();
+ this.status = e.status();
final OptionalThrowable originalThrowable = e.throwable();
if (originalThrowable.isDefined())
- throwable = new OptionalThrowable(new ForkError(originalThrowable.get()));
- else throwable = originalThrowable;
+ this.throwable = new OptionalThrowable(new ForkError(originalThrowable.get()));
+ else this.throwable = originalThrowable;
- duration = e.duration();
+ this.duration = e.duration();
}
public String fullyQualifiedName() {
@@ -132,18 +131,30 @@ public final class ForkMain {
}
}
+ public static class ForkEventsInfo implements Serializable {
+ public long id;
+ public String group;
+ public ArrayList events;
+
+ public ForkEventsInfo(long id, String group, ArrayList events) {
+ this.id = id;
+ this.group = group;
+ this.events = events;
+ }
+ }
+
// -----------------------------------------------------------------------------
- static final class ForkError extends Exception {
+ public static final class ForkError extends Exception {
private final String originalMessage;
private final String originalName;
- private ForkError cause;
+ private ForkError cause1;
ForkError(final Throwable t) {
originalMessage = t.getMessage();
originalName = t.getClass().getName();
setStackTrace(t.getStackTrace());
- if (t.getCause() != null) cause = new ForkError(t.getCause());
+ if (t.getCause() != null) cause1 = new ForkError(t.getCause());
}
public String getMessage() {
@@ -151,51 +162,50 @@ public final class ForkMain {
}
public Exception getCause() {
- return cause;
+ return cause1;
+ }
+ }
+
+ public static class ForkErrorInfo implements Serializable {
+ public final long id;
+ public final ForkError error;
+
+ public ForkErrorInfo(long id, ForkError error) {
+ this.id = id;
+ this.error = error;
}
}
// main
// ----------------------------------------------------------------------------------------------------------------
- public static void main(final String[] args) throws Exception {
- ClassLoader classLoader = new Run().getClass().getClassLoader();
- try {
- main(args, classLoader);
- } finally {
- System.exit(0);
- }
- }
-
- public static void main(final String[] args, ClassLoader classLoader) throws Exception {
- final Socket socket = new Socket(InetAddress.getByName(null), Integer.valueOf(args[0]));
- final ObjectInputStream is = new ObjectInputStream(socket.getInputStream());
- final ObjectOutputStream os = new ObjectOutputStream(socket.getOutputStream());
- // Must flush the header that the constructor writes, otherwise the ObjectInputStream on the
- // other end may block indefinitely
- os.flush();
- try {
- new Run().run(is, os, classLoader);
- } finally {
- is.close();
- os.close();
- }
+ public static void main(long id, TestInfo info, PrintStream originalOut, ClassLoader classLoader)
+ throws Exception {
+ new Run(originalOut, id).run(info, classLoader);
}
// ----------------------------------------------------------------------------------------------------------------
- private static final class Run {
+ public static final class Run {
+ final PrintStream originalOut;
+ final long id;
+ final Gson gson;
- private void run(
- final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader) {
+ Run(PrintStream originalOut, long id) {
+ this.originalOut = originalOut;
+ this.id = id;
+ this.gson = WorkerMain.mkGson();
+ }
+
+ private void run(TestInfo info, ClassLoader classLoader) {
try {
- runTests(is, os, classLoader);
+ runTests(info, classLoader);
} catch (final RunAborted e) {
internalError(e);
} catch (final Throwable t) {
try {
- logError(os, "Uncaught exception when running tests: " + t.toString());
- write(os, new ForkError(t));
+ logError("Uncaught exception when running tests: " + t.toString());
+ writeError(new ForkError(t));
} catch (final Throwable t2) {
internalError(t2);
}
@@ -223,106 +233,118 @@ public final class ForkMain {
}
}
- private synchronized void write(final ObjectOutputStream os, final Object obj) {
- try {
- os.writeObject(obj);
- os.flush();
- } catch (final IOException e) {
- throw new RunAborted(e);
- }
+ private void writeError(ForkError error) {
+ ForkErrorInfo info = new ForkErrorInfo(this.id, error);
+ String params = this.gson.toJson(info, ForkErrorInfo.class);
+ String notification =
+ String.format(
+ "{ \"jsonrpc\": \"2.0\", \"method\": \"forkError\", \"params\": %s }", params);
+ this.originalOut.println(notification);
+ this.originalOut.flush();
}
- private void log(final ObjectOutputStream os, final String message, final ForkTags level) {
- write(os, new Object[] {level, message});
+ private void log(final String message, final ForkTags level) {
+ TestLogInfo info = new TestLogInfo(this.id, level, message);
+ String params = this.gson.toJson(info, TestLogInfo.class);
+ String notification =
+ String.format(
+ "{ \"jsonrpc\": \"2.0\", \"method\": \"testLog\", \"params\": %s }", params);
+ this.originalOut.println(notification);
+ this.originalOut.flush();
}
- private void logDebug(final ObjectOutputStream os, final String message) {
- log(os, message, ForkTags.Debug);
+ private void logDebug(final String message) {
+ log(message, ForkTags.Debug);
}
- private void logInfo(final ObjectOutputStream os, final String message) {
- log(os, message, ForkTags.Info);
+ private void logInfo(final String message) {
+ log(message, ForkTags.Info);
}
- private void logWarn(final ObjectOutputStream os, final String message) {
- log(os, message, ForkTags.Warn);
+ private void logWarn(final String message) {
+ log(message, ForkTags.Warn);
}
- private void logError(final ObjectOutputStream os, final String message) {
- log(os, message, ForkTags.Error);
+ private void logError(final String message) {
+ log(message, ForkTags.Error);
}
- private Logger remoteLogger(final boolean ansiCodesSupported, final ObjectOutputStream os) {
+ private Logger remoteLogger(final boolean ansiCodesSupported) {
return new Logger() {
public boolean ansiCodesSupported() {
return ansiCodesSupported;
}
public void error(final String s) {
- logError(os, s);
+ logError(s);
}
public void warn(final String s) {
- logWarn(os, s);
+ logWarn(s);
}
public void info(final String s) {
- logInfo(os, s);
+ logInfo(s);
}
public void debug(final String s) {
- logDebug(os, s);
+ logDebug(s);
}
public void trace(final Throwable t) {
- write(os, new ForkError(t));
+ writeError(new ForkError(t));
}
};
}
- private void writeEvents(
- final ObjectOutputStream os, final TaskDef taskDef, final ForkEvent[] events) {
- write(os, new Object[] {taskDef.fullyQualifiedName(), events});
+ private void writeEvents(final TaskDef taskDef, final ForkEvent[] events) {
+ ForkEventsInfo info =
+ new ForkEventsInfo(
+ this.id,
+ taskDef.fullyQualifiedName(),
+ new ArrayList(Arrays.asList(events)));
+ String params = this.gson.toJson(info, ForkEventsInfo.class);
+ String notification =
+ String.format(
+ "{ \"jsonrpc\": \"2.0\", \"method\": \"testEvents\", \"params\": %s }", params);
+ this.originalOut.println(notification);
+ this.originalOut.flush();
}
- private ExecutorService executorService(
- final ForkConfiguration config, final ObjectOutputStream os) {
- if (config.isParallel()) {
+ private ExecutorService executorService(final boolean parallel) {
+ if (parallel) {
final int nbThreads = Runtime.getRuntime().availableProcessors();
- logDebug(os, "Create a test executor with a thread pool of " + nbThreads + " threads.");
+ logDebug("Create a test executor with a thread pool of " + nbThreads + " threads.");
// more options later...
// TODO we might want to configure the blocking queue with size #proc
return Executors.newFixedThreadPool(nbThreads);
} else {
- logDebug(os, "Create a single-thread test executor");
+ logDebug("Create a single-thread test executor");
return Executors.newSingleThreadExecutor();
}
}
- private void runTests(
- final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader)
- throws Exception {
- final ForkConfiguration config = (ForkConfiguration) is.readObject();
- final ExecutorService executor = executorService(config, os);
- final TaskDef[] tests = (TaskDef[]) is.readObject();
- final int nFrameworks = is.readInt();
- final Logger[] loggers = {remoteLogger(config.isAnsiCodesSupported(), os)};
+ private void runTests(TestInfo info, ClassLoader classLoader) throws Exception {
+ final ExecutorService executor = executorService(info.parallel);
+ final TaskDef[] tests = info.taskDefs.toArray(new TaskDef[] {});
+ final int nFrameworks = info.testRunners.size();
+ final Logger[] loggers = {remoteLogger(info.ansiCodesSupported)};
- for (int i = 0; i < nFrameworks; i++) {
- final String[] implClassNames = (String[]) is.readObject();
- final String[] frameworkArgs = (String[]) is.readObject();
- final String[] remoteFrameworkArgs = (String[]) is.readObject();
+ for (TestInfo.TestRunner testRunner : info.testRunners) {
+ final String[] frameworkArgs = testRunner.mainRunnerArgs.toArray(new String[] {});
+ final String[] remoteFrameworkArgs =
+ testRunner.mainRunnerRemoteArgs.toArray(new String[] {});
Framework framework = null;
- for (final String implClassName : implClassNames) {
+ for (final String implClassName : testRunner.implClassNames) {
try {
final Object rawFramework =
- Class.forName(implClassName).getDeclaredConstructor().newInstance();
+ classLoader.loadClass(implClassName).getDeclaredConstructor().newInstance();
if (rawFramework instanceof Framework) framework = (Framework) rawFramework;
else framework = new FrameworkWrapper((org.scalatools.testing.Framework) rawFramework);
break;
} catch (final ClassNotFoundException e) {
- logDebug(os, "Framework implementation '" + implClassName + "' not present.");
+ logError("Framework implementation '" + implClassName + "' not present.");
}
}
@@ -344,7 +366,6 @@ public final class ForkMain {
final Runner runner = framework.runner(frameworkArgs, remoteFrameworkArgs, classLoader);
final Task[] tasks = runner.tasks(filteredTests.toArray(new TaskDef[filteredTests.size()]));
logDebug(
- os,
"Runner for "
+ framework.getClass().getName()
+ " produced "
@@ -356,25 +377,20 @@ public final class ForkMain {
Thread callDoneOnShutdown = new Thread(() -> runner.done());
Runtime.getRuntime().addShutdownHook(callDoneOnShutdown);
- runTestTasks(executor, tasks, loggers, os);
+ runTestTasks(executor, tasks, loggers);
runner.done();
Runtime.getRuntime().removeShutdownHook(callDoneOnShutdown);
}
- write(os, ForkTags.Done);
- is.readObject();
}
private void runTestTasks(
- final ExecutorService executor,
- final Task[] tasks,
- final Logger[] loggers,
- final ObjectOutputStream os) {
+ final ExecutorService executor, final Task[] tasks, final Logger[] loggers) {
if (tasks.length > 0) {
final List> futureNestedTasks = new ArrayList<>();
for (final Task task : tasks) {
- futureNestedTasks.add(runTest(executor, task, loggers, os));
+ futureNestedTasks.add(runTest(executor, task, loggers));
}
// Note: this could be optimized further, we could have a callback once a test finishes that
@@ -385,18 +401,15 @@ public final class ForkMain {
try {
nestedTasks.addAll(Arrays.asList(futureNestedTask.get()));
} catch (final Exception e) {
- logError(os, "Failed to execute task " + futureNestedTask);
+ logError("Failed to execute task " + futureNestedTask);
}
}
- runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers, os);
+ runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers);
}
}
private Future runTest(
- final ExecutorService executor,
- final Task task,
- final Logger[] loggers,
- final ObjectOutputStream os) {
+ final ExecutorService executor, final Task task, final Logger[] loggers) {
return executor.submit(
() -> {
ForkEvent[] events;
@@ -410,11 +423,10 @@ public final class ForkMain {
eventList.add(new ForkEvent(e));
}
};
- logDebug(os, " Running " + taskDef);
+ logDebug(" Running " + taskDef);
nestedTasks = task.execute(handler, loggers);
if (nestedTasks.length > 0 || eventList.size() > 0)
logDebug(
- os,
" Produced "
+ nestedTasks.length
+ " nested tasks and "
@@ -426,7 +438,6 @@ public final class ForkMain {
events =
new ForkEvent[] {
testError(
- os,
taskDef,
"Uncaught exception when running "
+ taskDef.fullyQualifiedName()
@@ -435,7 +446,7 @@ public final class ForkMain {
t)
};
}
- writeEvents(os, taskDef, events);
+ writeEvents(taskDef, events);
return nestedTasks;
});
}
@@ -482,14 +493,10 @@ public final class ForkMain {
});
}
- private ForkEvent testError(
- final ObjectOutputStream os,
- final TaskDef taskDef,
- final String message,
- final Throwable t) {
- logError(os, message);
+ private ForkEvent testError(final TaskDef taskDef, final String message, final Throwable t) {
+ logError(message);
final ForkError fe = new ForkError(t);
- write(os, fe);
+ writeError(fe);
return testEvent(
taskDef.fullyQualifiedName(),
taskDef.fingerprint(),
diff --git a/testing/agent/src/main/java/sbt/FrameworkWrapper.java b/worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java
similarity index 98%
rename from testing/agent/src/main/java/sbt/FrameworkWrapper.java
rename to worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java
index 4d67edf8f..10cdcb499 100644
--- a/testing/agent/src/main/java/sbt/FrameworkWrapper.java
+++ b/worker/src/main/java/sbt/internal/worker1/FrameworkWrapper.java
@@ -6,7 +6,7 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
-package sbt;
+package sbt.internal.worker1;
import sbt.testing.*;
@@ -14,11 +14,11 @@ import sbt.testing.*;
* Adapts the old {@link org.scalatools.testing.Framework} interface into the new {@link
* sbt.testing.Framework}
*/
-final class FrameworkWrapper implements Framework {
+public final class FrameworkWrapper implements Framework {
private final org.scalatools.testing.Framework oldFramework;
- FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) {
+ public FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) {
this.oldFramework = oldFramework;
}
diff --git a/worker/src/main/java/sbt/internal/worker1/RunInfo.java b/worker/src/main/java/sbt/internal/worker1/RunInfo.java
index 39744a7b7..079e3e1e2 100644
--- a/worker/src/main/java/sbt/internal/worker1/RunInfo.java
+++ b/worker/src/main/java/sbt/internal/worker1/RunInfo.java
@@ -8,10 +8,11 @@
package sbt.internal.worker1;
+import java.io.Serializable;
import java.util.ArrayList;
-public class RunInfo {
- public class JvmRunInfo {
+public class RunInfo implements Serializable {
+ public static class JvmRunInfo implements Serializable {
public ArrayList args;
public ArrayList classpath;
public String mainClass;
@@ -29,7 +30,7 @@ public class RunInfo {
}
}
- public class NativeRunInfo {}
+ public static class NativeRunInfo implements Serializable {}
public boolean jvm;
public JvmRunInfo jvmRunInfo;
diff --git a/worker/src/main/java/sbt/internal/worker1/TestInfo.java b/worker/src/main/java/sbt/internal/worker1/TestInfo.java
new file mode 100644
index 000000000..a1990e0d7
--- /dev/null
+++ b/worker/src/main/java/sbt/internal/worker1/TestInfo.java
@@ -0,0 +1,55 @@
+/*
+ * sbt
+ * Copyright 2023, Scala center
+ * Copyright 2011 - 2022, Lightbend, Inc.
+ * Copyright 2008 - 2010, Mark Harrah
+ * Licensed under Apache License 2.0 (see LICENSE)
+ */
+
+package sbt.internal.worker1;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import sbt.testing.TaskDef;
+
+public class TestInfo implements Serializable {
+ public static class TestRunner implements Serializable {
+ public final ArrayList implClassNames;
+ public final ArrayList mainRunnerArgs;
+ public final ArrayList mainRunnerRemoteArgs;
+
+ public TestRunner(
+ ArrayList implClassNames,
+ ArrayList mainRunnerArgs,
+ ArrayList mainRunnerRemoteArgs) {
+ this.implClassNames = implClassNames;
+ this.mainRunnerArgs = mainRunnerArgs;
+ this.mainRunnerRemoteArgs = mainRunnerRemoteArgs;
+ }
+ }
+
+ public final boolean jvm;
+ public final RunInfo.JvmRunInfo jvmRunInfo;
+ public final RunInfo.NativeRunInfo nativeRunInfo;
+ public final boolean ansiCodesSupported;
+ public final boolean parallel;
+ public final ArrayList taskDefs;
+ public final ArrayList testRunners;
+
+ public TestInfo(
+ boolean jvm,
+ RunInfo.JvmRunInfo jvmRunInfo,
+ RunInfo.NativeRunInfo nativeRunInfo,
+ boolean ansiCodesSupported,
+ boolean parallel,
+ ArrayList taskDefs,
+ ArrayList testRunners) {
+ this.jvm = jvm;
+ this.jvmRunInfo = jvmRunInfo;
+ this.nativeRunInfo = nativeRunInfo;
+ this.ansiCodesSupported = ansiCodesSupported;
+ this.parallel = parallel;
+ this.taskDefs = taskDefs;
+ this.testRunners = testRunners;
+ }
+}
diff --git a/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java b/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java
new file mode 100644
index 000000000..58cbf602a
--- /dev/null
+++ b/worker/src/main/java/sbt/internal/worker1/TestLogInfo.java
@@ -0,0 +1,25 @@
+/*
+ * sbt
+ * Copyright 2023, Scala center
+ * Copyright 2011 - 2022, Lightbend, Inc.
+ * Copyright 2008 - 2010, Mark Harrah
+ * Licensed under Apache License 2.0 (see LICENSE)
+ */
+
+package sbt.internal.worker1;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import sbt.testing.TaskDef;
+
+public class TestLogInfo implements Serializable {
+ public final long id;
+ public final ForkTags tag;
+ public final String message;
+
+ public TestLogInfo(long id, ForkTags tag, String message) {
+ this.id = id;
+ this.tag = tag;
+ this.message = message;
+ }
+}
diff --git a/worker/src/main/java/sbt/internal/worker1/WorkerError.java b/worker/src/main/java/sbt/internal/worker1/WorkerError.java
new file mode 100644
index 000000000..13ddec30f
--- /dev/null
+++ b/worker/src/main/java/sbt/internal/worker1/WorkerError.java
@@ -0,0 +1,13 @@
+package sbt.internal.worker1;
+
+import java.io.Serializable;
+
+public final class WorkerError implements Serializable {
+ public final int code;
+ public final String message;
+
+ public WorkerError(int code, String message) {
+ this.code = code;
+ this.message = message;
+ }
+}
diff --git a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java
index 1f6ccc058..0533de423 100644
--- a/worker/src/main/java/sbt/internal/worker1/WorkerMain.java
+++ b/worker/src/main/java/sbt/internal/worker1/WorkerMain.java
@@ -9,11 +9,14 @@
package sbt.internal.worker1;
import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
+import com.google.gson.typeadapters.RuntimeTypeAdapterFactory;
import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Method;
@@ -22,15 +25,41 @@ import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Scanner;
+import sbt.testing.*;
+/**
+ * WorkerMain that communicates via the stdin and stdout using JSON-RPC
+ * (https://www.jsonrpc.org/specification).
+ */
public final class WorkerMain {
private PrintStream originalOut;
+ private InputStream originalIn;
+ private Scanner inScanner;
+
+ public static Gson mkGson() {
+ RuntimeTypeAdapterFactory fingerprintFac =
+ RuntimeTypeAdapterFactory.of(Fingerprint.class, "type");
+ fingerprintFac.registerSubtype(ForkTestMain.SubclassFingerscan.class, "SubclassFingerscan");
+ fingerprintFac.registerSubtype(ForkTestMain.AnnotatedFingerscan.class, "AnnotatedFingerscan");
+ RuntimeTypeAdapterFactory selectorFac =
+ RuntimeTypeAdapterFactory.of(Selector.class, "type");
+ selectorFac.registerSubtype(SuiteSelector.class, "SuiteSelector");
+ selectorFac.registerSubtype(TestSelector.class, "TestSelector");
+ selectorFac.registerSubtype(NestedSuiteSelector.class, "NestedSuiteSelector");
+ selectorFac.registerSubtype(NestedTestSelector.class, "NestedTestSelector");
+ selectorFac.registerSubtype(TestWildcardSelector.class, "TestWildcardSelector");
+ return new GsonBuilder()
+ .registerTypeAdapterFactory(fingerprintFac)
+ .registerTypeAdapterFactory(selectorFac)
+ .create();
+ }
public static void main(final String[] args) throws Exception {
try {
if (args.length == 0) {
WorkerMain app = new WorkerMain();
app.consoleWork();
+ System.exit(0);
} else {
System.err.println("missing args");
System.exit(1);
@@ -45,31 +74,49 @@ public final class WorkerMain {
this.originalOut = System.out;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
System.setOut(new PrintStream(baos));
+ this.originalIn = System.in;
+ this.inScanner = new Scanner(this.originalIn);
}
void consoleWork() throws Exception {
- Scanner input = new Scanner(System.in);
- if (input.hasNextLine()) {
- String line = input.nextLine();
+ if (this.inScanner.hasNextLine()) {
+ String line = this.inScanner.nextLine();
process(line);
}
}
- void process(String json) throws Exception {
+ /** This processes single request of supposed JSON line. */
+ void process(String json) {
JsonElement elem = JsonParser.parseString(json);
JsonObject o = elem.getAsJsonObject();
if (!o.has("jsonrpc")) {
- throw new RuntimeException("missing jsonprc element");
+ return;
}
+ Gson g = WorkerMain.mkGson();
long id = o.getAsJsonPrimitive("id").getAsLong();
- String method = o.getAsJsonPrimitive("method").getAsString();
- JsonObject params = o.getAsJsonObject("params");
- switch (method) {
- case "run":
- Gson g = new Gson();
- RunInfo info = g.fromJson(params, RunInfo.class);
- run(info);
- break;
+ try {
+ String method = o.getAsJsonPrimitive("method").getAsString();
+ JsonObject params = o.getAsJsonObject("params");
+ switch (method) {
+ case "run":
+ RunInfo info = g.fromJson(params, RunInfo.class);
+ run(info);
+ break;
+ case "test":
+ TestInfo testInfo = g.fromJson(params, TestInfo.class);
+ test(id, testInfo);
+ break;
+ }
+ String response = String.format("{ \"jsonrpc\": \"2.0\", \"result\": 0, \"id\": %d }", id);
+ this.originalOut.println(response);
+ this.originalOut.flush();
+ } catch (Throwable e) {
+ WorkerError err = new WorkerError(1, e.getMessage());
+ String errMessage = g.toJson(err, err.getClass());
+ String errJson =
+ String.format("{ \"jsonrpc\": \"2.0\", \"error\": %s, \"id\": %d }", errMessage, id);
+ this.originalOut.println(errJson);
+ this.originalOut.flush();
}
}
@@ -79,20 +126,7 @@ public final class WorkerMain {
throw new RuntimeException("missing jvmRunInfo element");
}
RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo;
- URL[] urls =
- jvmRunInfo
- .classpath
- .stream()
- .map(
- filePath -> {
- try {
- return filePath.path.toURL();
- } catch (MalformedURLException e) {
- throw new RuntimeException(e);
- }
- })
- .toArray(URL[]::new);
- URLClassLoader cl = new URLClassLoader(urls, ClassLoader.getSystemClassLoader());
+ URLClassLoader cl = createClassLoader(jvmRunInfo, ClassLoader.getSystemClassLoader());
try {
Class> mainClass = cl.loadClass(jvmRunInfo.mainClass);
Method mainMethod = mainClass.getMethod("main", String[].class);
@@ -105,4 +139,37 @@ public final class WorkerMain {
throw new RuntimeException("only jvm is supported");
}
}
+
+ void test(long id, TestInfo info) throws Exception {
+ if (info.jvm) {
+ RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo;
+ ClassLoader parent = new ForkTestMain().getClass().getClassLoader();
+ ClassLoader cl = createClassLoader(jvmRunInfo, parent);
+ try {
+ ForkTestMain.main(id, info, this.originalOut, cl);
+ } finally {
+ if (cl instanceof URLClassLoader) {
+ ((URLClassLoader) cl).close();
+ }
+ }
+ } else {
+ throw new RuntimeException("only jvm is supported");
+ }
+ }
+
+ private URLClassLoader createClassLoader(RunInfo.JvmRunInfo info, ClassLoader parent) {
+ URL[] urls =
+ info.classpath
+ .stream()
+ .map(
+ filePath -> {
+ try {
+ return filePath.path.toURL();
+ } catch (MalformedURLException e) {
+ throw new RuntimeException(e);
+ }
+ })
+ .toArray(URL[]::new);
+ return new URLClassLoader(urls, parent);
+ }
}
diff --git a/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala b/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala
index f516ab45c..8285f3997 100644
--- a/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala
+++ b/worker/src/test/scala/sbt/internal/worker1/WorkerTest.scala
@@ -8,7 +8,7 @@ object WorkerTest extends verify.BasicTestSuite:
test("process") {
val u0 = IO.classLocationPath(classOf[example.Hello]).toUri()
val u1 = IO.classLocationPath(classOf[scala.quoted.Quotes]).toUri()
- val u2 = IO.classLocationPath(classOf[scala.AnyVal]).toUri()
+ val u2 = IO.classLocationPath(classOf[scala.collection.immutable.List[?]]).toUri()
val cp =
s"""[{ "path": "${u0}", "digest": "" }, { "path": "${u1}", "digest": "" }, { "path": "${u2}", "digest": "" }]"""
val runInfo =