Merge pull request #8170 from eed3si9n/wip/worker2

[2.x] Implement worker + forked test
This commit is contained in:
eugene yokota 2025-07-06 13:31:26 -05:00 committed by GitHub
commit f284e809d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1133 additions and 328 deletions

View File

@ -471,7 +471,7 @@ lazy val utilScripted = (project in file("internal") / "util-scripted")
// Runner for uniform test interface
lazy val testingProj = (project in file("testing"))
.enablePlugins(ContrabandPlugin, JsonCodecPlugin)
.dependsOn(testAgentProj, utilLogging)
.dependsOn(workerProj, utilLogging)
.settings(
baseSettings,
name := "Testing",
@ -518,16 +518,26 @@ lazy val testingProj = (project in file("testing"))
)
.configure(addSbtIO, addSbtCompilerClasspath)
// Testing agent for running tests in a separate process.
lazy val testAgentProj = (project in file("testing") / "agent")
lazy val workerProj = (project in file("worker"))
.dependsOn(exampleWorkProj % Test)
.settings(
minimalSettings,
name := "worker",
testedBaseSettings,
Compile / doc / javacOptions := Nil,
crossPaths := false,
autoScalaLibrary := false,
Compile / doc / javacOptions := Nil,
name := "Test Agent",
libraryDependencies += testInterface,
mimaSettings,
libraryDependencies ++= Seq(gson, testInterface),
libraryDependencies += "org.scala-lang" %% "scala3-library" % scalaVersion.value % Test,
// run / fork := false,
Test / fork := true,
)
.configure(addSbtIOForTest)
lazy val exampleWorkProj = (project in file("internal") / "example-work")
.settings(
minimalSettings,
name := "example work",
publish / skip := true,
)
// Basic task engine
@ -656,6 +666,8 @@ lazy val actionsProj = (project in file("main-actions"))
utilLogging,
utilRelation,
utilTracking,
workerProj,
protocolProj,
)
.settings(
testedBaseSettings,
@ -666,18 +678,9 @@ lazy val actionsProj = (project in file("main-actions"))
baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / sourceManaged := baseDirectory.value / "src" / "main" / "contraband-scala",
Compile / generateContrabands / contrabandFormatsForType := ContrabandConfig.getFormats,
// Test / fork := true,
Test / classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.Flat,
mimaSettings,
mimaBinaryIssueFilters ++= Seq(
// Removed unused private[sbt] nested class
exclude[MissingClassProblem]("sbt.Doc$Scaladoc"),
// Removed no longer used private[sbt] method
exclude[DirectMissingMethodProblem]("sbt.Doc.generate"),
exclude[DirectMissingMethodProblem]("sbt.compiler.Eval.filesModifiedBytes"),
exclude[DirectMissingMethodProblem]("sbt.compiler.Eval.fileModifiedBytes"),
exclude[DirectMissingMethodProblem]("sbt.Doc.$init$"),
// Added field in nested private[this] class
exclude[ReversedMissingMethodProblem]("sbt.compiler.Eval#EvalType.sourceName"),
),
)
.dependsOn(lmCore)
.configure(
@ -1204,7 +1207,6 @@ def allProjects =
logicProj,
completeProj,
testingProj,
testAgentProj,
taskProj,
stdTaskProj,
runProj,
@ -1231,6 +1233,7 @@ def allProjects =
lmCoursier,
lmCoursierShaded,
lmCoursierShadedPublishing,
workerProj,
) ++ lowerUtilProjects
// These need to be cross published to 2.12 and 2.13 for Zinc

View File

@ -0,0 +1,9 @@
package example
class Hello
object Hello:
def main(args: Array[String]): Unit =
if args.toList == List("boom") then sys.error("boom")
else println(s"${args.mkString}")
end Hello

View File

@ -8,21 +8,33 @@
package sbt
import scala.collection.mutable
import com.google.gson.{ JsonObject, JsonParser }
import testing.{ Logger as _, Task as _, * }
import scala.util.control.NonFatal
import java.net.ServerSocket
import java.io.*
import java.util.ArrayList
import Tests.{ Output as TestOutput, * }
import sbt.io.IO
import sbt.util.Logger
import sbt.ConcurrentRestrictions.Tag
import sbt.protocol.testing.*
import sbt.internal.{ WorkerExchange, WorkerResponseListener }
import sbt.internal.util.Util.*
import sbt.internal.util.{ Terminal as UTerminal }
import sbt.internal.util.{ MessageOnlyException, Terminal as UTerminal }
import sbt.internal.worker1.*
import xsbti.{ FileConverter, HashedVirtualFileRef }
import scala.collection.mutable
import scala.concurrent.{ Await, Promise }
import scala.concurrent.duration.Duration
import scala.util.Random
import scala.jdk.CollectionConverters.*
import scala.sys.process.Process
/**
* This implements forked testing, in cooperation with the worker CLI,
* which was previously called test-agent.jar.
*/
private[sbt] object ForkTests:
val r = Random()
private[sbt] object ForkTests {
def apply(
runners: Map[TestFramework, Runner],
opts: ProcessedOptions,
@ -87,147 +99,145 @@ private[sbt] object ForkTests {
parallel: Boolean
): Task[TestOutput] =
std.TaskExtra.task {
val server = new ServerSocket(0)
val testListeners = opts.testListeners flatMap {
val testListeners = opts.testListeners.flatMap:
case tl: TestsListener => tl.some
case _ => none[TestsListener]
}
object Acceptor extends Runnable {
val resultsAcc = mutable.Map.empty[String, SuiteResult]
lazy val result =
TestOutput(overall(resultsAcc.values.map(_.result)), resultsAcc.toMap, Iterable.empty)
def run(): Unit = {
val socket =
try {
server.accept()
} catch {
case e: java.net.SocketException =>
log.error(
"Could not accept connection from test agent: " + e.getClass + ": " + e.getMessage
)
log.trace(e)
server.close()
return
}
val os = new ObjectOutputStream(socket.getOutputStream)
// Must flush the header that the constructor writes, otherwise the ObjectInputStream on the other end may block indefinitely
os.flush()
val is = new ObjectInputStream(socket.getInputStream)
try {
val config = new ForkConfiguration(UTerminal.isAnsiSupported, parallel)
os.writeObject(config)
val taskdefs = opts.tests.map { t =>
new TaskDef(
t.name,
forkFingerprint(t.fingerprint),
t.explicitlySpecified,
t.selectors
)
}
os.writeObject(taskdefs.toArray)
os.writeInt(runners.size)
for ((testFramework, mainRunner) <- runners) {
os.writeObject(testFramework.implClassNames.toArray)
os.writeObject(mainRunner.args)
os.writeObject(mainRunner.remoteArgs)
}
os.flush()
new React(is, os, log, opts.testListeners, resultsAcc).react()
} catch {
case NonFatal(e) =>
def throwableToString(t: Throwable) = {
import java.io.*; val sw = new StringWriter; t.printStackTrace(new PrintWriter(sw));
sw.toString
}
resultsAcc("Forked test harness failed: " + throwableToString(e)) = SuiteResult.Error
} finally {
is.close(); os.close(); socket.close()
}
}
}
try {
testListeners.foreach(_.doInit())
val acceptorThread = new Thread(Acceptor)
acceptorThread.start()
val cpFiles = classpath.map(converter.toPath).map(_.toFile())
val fullCp = cpFiles ++ Seq(
IO.classLocationPath(classOf[ForkMain]).toFile,
IO.classLocationPath(classOf[Framework]).toFile,
val resultsAcc = mutable.Map.empty[String, SuiteResult]
val randomId = r.nextLong()
def testOutputResult =
TestOutput(
overall(resultsAcc.values.map(_.result)),
resultsAcc.toMap,
Iterable.empty
)
val options = Seq(
"-classpath",
fullCp mkString File.pathSeparator,
classOf[ForkMain].getCanonicalName,
server.getLocalPort.toString
val taskdefs = opts.tests.map: t =>
new TaskDef(
t.name,
forkFingerprint(t.fingerprint),
t.explicitlySpecified,
t.selectors
)
val ec = Fork.java(fork, options)
val result =
if (ec != 0)
TestOutput(
TestResult.Error,
Map(
"Running java with options " + options
.mkString(" ") + " failed with exit code " + ec -> SuiteResult.Error
),
Iterable.empty
)
else {
// Need to wait acceptor thread to finish its business
acceptorThread.join()
Acceptor.result
}
val testRunners = runners.toSeq.map: (testFramework, mainRunner) =>
TestInfo.TestRunner(
ArrayList(testFramework.implClassNames.asJava),
ArrayList(mainRunner.args().toList.asJava),
ArrayList(mainRunner.remoteArgs().toList.asJava)
)
val g = WorkerMain.mkGson()
// virtualize classloading by using ClassLoader
val useClassLoader = true
val cpList = ArrayList[FilePath](
(classpath
.map: vf =>
FilePath(converter.toPath(vf).toUri(), vf.contentHashStr()))
.asJava
)
val cpFiles =
classpath.map: vf =>
converter.toPath(vf).toFile()
val param = TestInfo(
true, /* jvm */
RunInfo.JvmRunInfo(
ArrayList(),
if useClassLoader then cpList else ArrayList(),
"",
false /*connectInput*/,
),
null,
UTerminal.isAnsiSupported,
parallel,
ArrayList(taskdefs.asJava),
ArrayList(testRunners.asJava),
)
testListeners.foreach(_.doInit())
val result =
val w = WorkerExchange.startWorker(fork, if useClassLoader then Nil else cpFiles)
val wl = React(randomId, log, opts.testListeners, resultsAcc, w.process)
try
WorkerExchange.registerListener(wl)
val paramJson = g.toJson(param, param.getClass)
val json = jsonRpcRequest(randomId, "test", paramJson)
w.println(json)
if wl.blockForResponse() != 0 then
throw MessageOnlyException("Forked test harness failed")
testOutputResult
finally WorkerExchange.unregisterListener(wl)
testListeners.foreach(_.doComplete(result.overall))
result
} // end task
testListeners.foreach(_.doComplete(result.overall))
result
} finally {
server.close()
}
}
private def jsonRpcRequest(id: Long, method: String, params: String): String =
s"""{ "jsonrpc": "2.0", "method": "$method", "params": $params, "id": $id }"""
private def forkFingerprint(f: Fingerprint): Fingerprint & Serializable =
f match {
case s: SubclassFingerprint => new ForkMain.SubclassFingerscan(s)
case a: AnnotatedFingerprint => new ForkMain.AnnotatedFingerscan(a)
f match
case s: SubclassFingerprint => ForkTestMain.SubclassFingerscan(s)
case a: AnnotatedFingerprint => ForkTestMain.AnnotatedFingerscan(a)
case _ => sys.error("Unknown fingerprint type: " + f.getClass)
}
}
private final class React(
is: ObjectInputStream,
os: ObjectOutputStream,
end ForkTests
private class React(
id: Long,
log: Logger,
listeners: Seq[TestReportListener],
results: mutable.Map[String, SuiteResult]
) {
import ForkTags.*
@annotation.tailrec
def react(): Unit = is.readObject match {
case `Done` =>
os.writeObject(Done); os.flush()
case Array(`Error`, s: String) =>
log.error(s); react()
case Array(`Warn`, s: String) =>
log.warn(s); react()
case Array(`Info`, s: String) =>
log.info(s); react()
case Array(`Debug`, s: String) =>
log.debug(s); react()
case t: Throwable =>
log.trace(t); react()
case Array(group: String, tEvents: Array[Event]) =>
val events = tEvents.toSeq
listeners.foreach(_.startGroup(group))
val event = TestEvent(events)
listeners.foreach(_.testEvent(event))
val suiteResult = SuiteResult(events)
results += group -> suiteResult
listeners.foreach(_.endGroup(group, suiteResult.result))
react()
}
}
results: mutable.Map[String, SuiteResult],
process: Process
) extends WorkerResponseListener:
val g = WorkerMain.mkGson()
val promise: Promise[Int] = Promise()
override def apply(line: String): Unit =
// scala.Console.err.println(line)
val o = JsonParser.parseString(line).getAsJsonObject()
if o.has("id") then
val resId = o.getAsJsonPrimitive("id").getAsLong()
if resId == id then
if promise.isCompleted then ()
else if o.has("error") then promise.failure(new RuntimeException(line))
else promise.success(0)
else ()
else if o.has("method") then processNotification(o)
else ()
override def notifyExit(p: Process): Unit =
if !process.isAlive then promise.success(process.exitValue())
def processNotification(o: JsonObject): Unit =
val method = o.getAsJsonPrimitive("method").getAsString()
method match
case "testLog" =>
val params = o.getAsJsonObject("params")
val info = g.fromJson[TestLogInfo](params, classOf[TestLogInfo])
if info.id == id then
info.tag match
case ForkTags.Error => log.error(info.message)
case ForkTags.Warn => log.warn(info.message)
case ForkTags.Info => log.info(info.message)
case ForkTags.Debug => log.debug(info.message)
case _ => ()
else ()
case "testEvents" =>
val params = o.getAsJsonObject("params")
val info =
g.fromJson[ForkTestMain.ForkEventsInfo](params, classOf[ForkTestMain.ForkEventsInfo])
if info.id == id then
val events = info.events.asScala.toSeq
listeners.foreach(_.startGroup(info.group))
val event = TestEvent(events)
listeners.foreach(_.testEvent(event))
val suiteResult = SuiteResult(events)
results += info.group -> suiteResult
listeners.foreach(_.endGroup(info.group, suiteResult.result))
else ()
case "forkError" =>
val params = o.getAsJsonObject("params")
val info =
g.fromJson[ForkTestMain.ForkErrorInfo](params, classOf[ForkTestMain.ForkErrorInfo])
if info.id == id then
log.trace(info.error)
promise.failure(info.error)
else ()
case _ => ()
def blockForResponse(): Int =
Await.result(promise.future, Duration.Inf)
end React

View File

@ -0,0 +1,88 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
package internal
import com.google.gson.Gson
import java.io.*
import java.util.concurrent.atomic.AtomicReference
import sbt.io.IO
import sbt.internal.worker1.*
import sbt.testing.Framework
import scala.sys.process.{ BasicIO, Process, ProcessIO }
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object WorkerExchange:
val listeners: mutable.ListBuffer[WorkerResponseListener] = ListBuffer.empty
/**
* Start a worker process.
*/
def startWorker(fo: ForkOptions, extraCp: Seq[File]): WorkerProxy =
val fullCp = Seq(
IO.classLocationPath(classOf[WorkerMain]).toFile,
IO.classLocationPath(classOf[Framework]).toFile,
IO.classLocationPath(classOf[Gson]).toFile,
) ++ extraCp
val options = Seq(
"-classpath",
fullCp.mkString(File.pathSeparator),
classOf[WorkerMain].getCanonicalName,
)
val inputRef = AtomicReference[OutputStream]()
val processIo = ProcessIO(
in = (input) => inputRef.set(input),
out = BasicIO.processFully(onStdoutLine),
err = BasicIO.processFully((line) => scala.Console.err.println(line)),
)
val forkWithIo = fo.withOutputStrategy(OutputStrategy.CustomInputOutput(processIo))
val p = Fork.java.fork(forkWithIo, options)
WorkerProxy(inputRef.get(), p, options)
def registerListener(listener: WorkerResponseListener): Unit =
synchronized:
listeners.append(listener)
def unregisterListener(listener: WorkerResponseListener): Unit =
synchronized:
if listeners.contains(listener) then listeners.remove(listeners.indexOf(listener))
else ()
/**
* Unified worker output handler.
*/
def onStdoutLine(line: String): Unit =
synchronized:
listeners.foreach: wl =>
wl(line)
end WorkerExchange
class WorkerProxy(input: OutputStream, val process: Process, val options: Seq[String])
extends AutoCloseable:
lazy val inputStream = PrintStream(input)
def close(): Unit = input.close()
def blockForExitCode(): Int =
if !process.isAlive then process.exitValue()
else Fork.blockForExitCode(process)
/** print a line into stdin of the worker process. */
def println(str: String): Unit =
inputStream.println(str)
inputStream.flush()
val watch = Thread(() => {
while process.isAlive() do Thread.sleep(100)
WorkerExchange.listeners.foreach(_.notifyExit(process))
})
watch.start()
end WorkerProxy
abstract class WorkerResponseListener extends Function1[String, Unit]:
def notifyExit(p: Process): Unit

View File

@ -1493,9 +1493,8 @@ object Defaults extends BuildCommon {
}
}
val summaries =
runners map { (tf, r) =>
runners.map: (tf, r) =>
Tests.Summary(frameworks(tf).name, r.done())
}
out.copy(summaries = summaries)
}
// Def.value[Task[Tests.Output]] {

View File

@ -58,6 +58,7 @@ object Dependencies {
}
def addSbtIO = addSbtModule(sbtIoPath, "io", sbtIO)
def addSbtIOForTest = addSbtModule(sbtIoPath, "io", sbtIO, Some(Test))
def addSbtCompilerInterface = addSbtModule(sbtZincPath, "compilerInterface", compilerInterface)
def addSbtCompilerClasspath = addSbtModule(sbtZincPath, "zincClasspath", compilerClasspath)
@ -91,6 +92,7 @@ object Dependencies {
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
val remoteapis =
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08"
val gson = "com.google.code.gson" % "gson" % "2.13.1"
val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3
val scala3Library = "org.scala-lang" %% "scala3-library" % scala3

View File

@ -185,13 +185,13 @@ object Fork {
()
}
val process = Process(jpb)
outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match {
outputStrategy.getOrElse(StdoutOutput: OutputStrategy) match
case StdoutOutput => process.run(connectInput = false)
case out: BufferedOutput =>
out.logger.buffer { process.run(out.logger, connectInput = false) }
case out: LoggedOutput => process.run(out.logger, connectInput = false)
case out: CustomOutput => (process #> out.output).run(connectInput = false)
}
case out: LoggedOutput => process.run(out.logger, connectInput = false)
case out: CustomOutput => (process #> out.output).run(connectInput = false)
case out: CustomInputOutput => process.run(out.processIO)
}
private[sbt] def blockForExitCode(p: Process): Int = {

View File

@ -8,6 +8,7 @@
package sbt
import scala.sys.process.ProcessIO
import sbt.util.Logger
import java.io.OutputStream
@ -102,4 +103,28 @@ object OutputStrategy {
object CustomOutput {
def apply(output: OutputStream): CustomOutput = new CustomOutput(output)
}
/**
* Configures the forked IO.
*/
final class CustomInputOutput private (val processIO: ProcessIO)
extends OutputStrategy
with Serializable:
override def equals(o: Any): Boolean = o match
case x: CustomInputOutput => (this.processIO == x.processIO)
case _ => false
override def hashCode: Int =
37 * (17 + processIO.##) + "CustomInputOutput".##
override def toString: String =
"CustomInputOutput(...)"
private def copy(processIO: ProcessIO = processIO): CustomInputOutput =
new CustomInputOutput(processIO)
def withProcessIO(processIO: ProcessIO): CustomInputOutput =
copy(processIO = processIO)
end CustomInputOutput
object CustomInputOutput:
def apply(processIO: ProcessIO): CustomInputOutput = new CustomInputOutput(processIO)
end CustomInputOutput
}

View File

@ -33,6 +33,7 @@ lazy val root = (project in file("."))
"com.eed3si9n:sjson-new-scalajson_3",
"com.github.ben-manes.caffeine:caffeine",
"com.github.mwiede:jsch",
"com.google.code.gson:gson",
"com.google.errorprone:error_prone_annotations",
"com.lmax:disruptor",
"com.swoval:file-tree-views",
@ -82,7 +83,7 @@ lazy val root = (project in file("."))
)
def assertCollectionsEqual(message: String, expected: Seq[String], actual: Seq[String]): Unit =
// using the new line for a more readable comparison failure output
assert(expected.mkString("\n") == actual.mkString("\n"), message)
assert(expected.mkString("\n") == actual.mkString("\n"), message + ": " + actual)
assertCollectionsEqual(
"Unexpected module ids in updateSbtClassifiers",

View File

@ -11,6 +11,7 @@ val scalaxml = "org.scala-lang.modules" %% "scala-xml" % "1.1.1"
def groupId(idx: Int) = "group_" + (idx + 1)
def groupPrefix(idx: Int) = groupId(idx) + "_file_"
Global / localCacheDirectory := baseDirectory.value / "diskcache"
ThisBuild / scalaVersion := "2.12.20"
ThisBuild / organization := "org.example"
@ -19,7 +20,7 @@ lazy val root = (project in file("."))
Test / testGrouping := Def.uncached {
val tests = (Test / definedTests).value
assert(tests.size == 3)
for (idx <- 0 until groups) yield
for idx <- 0 until groups yield
new Group(
groupId(idx),
tests,
@ -28,11 +29,11 @@ lazy val root = (project in file("."))
},
check := Def.uncached {
val files =
for(i <- 0 until groups; j <- 1 to groupSize) yield
for i <- 0 until groups; j <- 1 to groupSize yield
file(groupPrefix(i) + j)
val (exist, absent) = files.partition(_.exists)
exist.foreach(_.delete())
if (absent.nonEmpty)
if absent.nonEmpty then
sys.error("Files were not created:\n\t" + absent.mkString("\n\t"))
},
concurrentRestrictions := Tags.limit(Tags.ForkedTestGroup, 2) :: Nil,

View File

@ -1,8 +1,8 @@
import org.scalatest.FlatSpec
class Test extends FlatSpec {
val v = sys.env.getOrElse("tests.max.value", Int.MaxValue)
"A simple equation" should "hold" in {
assert(Int.MaxValue == v)
}
val v = sys.env.getOrElse("tests.max.value", Int.MaxValue)
"A simple equation" should "hold" in {
assert(Int.MaxValue == v)
}
}

View File

@ -1,29 +0,0 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt;
import java.io.Serializable;
public final class ForkConfiguration implements Serializable {
private final boolean ansiCodesSupported;
private final boolean parallel;
public ForkConfiguration(final boolean ansiCodesSupported, final boolean parallel) {
this.ansiCodesSupported = ansiCodesSupported;
this.parallel = parallel;
}
public boolean isAnsiCodesSupported() {
return ansiCodesSupported;
}
public boolean isParallel() {
return parallel;
}
}

View File

@ -67,8 +67,9 @@ final class TestFramework(val implClassNames: String*) extends Serializable {
case head :: tail =>
try {
Some(Class.forName(head, true, loader).getDeclaredConstructor().newInstance() match {
case newFramework: Framework => newFramework
case oldFramework: OldFramework => new FrameworkWrapper(oldFramework)
case newFramework: Framework => newFramework
case oldFramework: OldFramework =>
new sbt.internal.worker1.FrameworkWrapper(oldFramework)
})
} catch {
case e: NoClassDefFoundError =>

View File

@ -0,0 +1,344 @@
/*
* Copyright (C) 2011 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.gson.typeadapters;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonPrimitive;
import com.google.gson.TypeAdapter;
import com.google.gson.TypeAdapterFactory;
import com.google.gson.reflect.TypeToken;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Adapts values whose runtime type may differ from their declaration type. This is necessary when a
* field's type is not the same type that GSON should create when deserializing that field. For
* example, consider these types:
*
* <pre>
* {
* &#64;code
* abstract class Shape {
* int x;
* int y;
* }
* class Circle extends Shape {
* int radius;
* }
* class Rectangle extends Shape {
* int width;
* int height;
* }
* class Diamond extends Shape {
* int width;
* int height;
* }
* class Drawing {
* Shape bottomShape;
* Shape topShape;
* }
* }
* </pre>
*
* <p>Without additional type information, the serialized JSON is ambiguous. Is the bottom shape in
* this drawing a rectangle or a diamond?
*
* <pre>{@code
* {
* "bottomShape": {
* "width": 10,
* "height": 5,
* "x": 0,
* "y": 0
* },
* "topShape": {
* "radius": 2,
* "x": 4,
* "y": 1
* }
* }
* }</pre>
*
* This class addresses this problem by adding type information to the serialized JSON and honoring
* that type information when the JSON is deserialized:
*
* <pre>{@code
* {
* "bottomShape": {
* "type": "Diamond",
* "width": 10,
* "height": 5,
* "x": 0,
* "y": 0
* },
* "topShape": {
* "type": "Circle",
* "radius": 2,
* "x": 4,
* "y": 1
* }
* }
* }</pre>
*
* Both the type field name ({@code "type"}) and the type labels ({@code "Rectangle"}) are
* configurable.
*
* <h2>Registering Types</h2>
*
* Create a {@code RuntimeTypeAdapterFactory} by passing the base type and type field name to the
* {@link #of} factory method. If you don't supply an explicit type field name, {@code "type"} will
* be used.
*
* <pre>
* {
* &#64;code
* RuntimeTypeAdapterFactory&lt;Shape&gt; shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class, "type");
* }
* </pre>
*
* Next register all of your subtypes. Every subtype must be explicitly registered. This protects
* your application from injection attacks. If you don't supply an explicit type label, the type's
* simple name will be used.
*
* <pre>{@code
* shapeAdapterFactory.registerSubtype(Rectangle.class, "Rectangle");
* shapeAdapterFactory.registerSubtype(Circle.class, "Circle");
* shapeAdapterFactory.registerSubtype(Diamond.class, "Diamond");
* }</pre>
*
* Finally, register the type adapter factory in your application's GSON builder:
*
* <pre>
* {
* &#64;code
* Gson gson = new GsonBuilder().registerTypeAdapterFactory(shapeAdapterFactory).create();
* }
* </pre>
*
* Like {@code GsonBuilder}, this API supports chaining:
*
* <pre>
* {
* &#64;code
* RuntimeTypeAdapterFactory&lt;Shape&gt; shapeAdapterFactory = RuntimeTypeAdapterFactory.of(Shape.class)
* .registerSubtype(Rectangle.class).registerSubtype(Circle.class).registerSubtype(Diamond.class);
* }
* </pre>
*
* <h2>Serialization and deserialization</h2>
*
* In order to serialize and deserialize a polymorphic object, you must specify the base type
* explicitly.
*
* <pre>
* {
* &#64;code
* Diamond diamond = new Diamond();
* String json = gson.toJson(diamond, Shape.class);
* }
* </pre>
*
* And then:
*
* <pre>
* {
* &#64;code
* Shape shape = gson.fromJson(json, Shape.class);
* }
* </pre>
*/
public final class RuntimeTypeAdapterFactory<T> implements TypeAdapterFactory {
private final Class<?> baseType;
private final String typeFieldName;
private final Map<String, Class<?>> labelToSubtype = new LinkedHashMap<>();
private final Map<Class<?>, String> subtypeToLabel = new LinkedHashMap<>();
private final boolean maintainType;
private boolean recognizeSubtypes;
private RuntimeTypeAdapterFactory(Class<?> baseType, String typeFieldName, boolean maintainType) {
if (typeFieldName == null || baseType == null) {
throw new NullPointerException();
}
this.baseType = baseType;
this.typeFieldName = typeFieldName;
this.maintainType = maintainType;
}
/**
* Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type
* field name. Type field names are case sensitive.
*
* @param maintainType true if the type field should be included in deserialized objects
*/
public static <T> RuntimeTypeAdapterFactory<T> of(
Class<T> baseType, String typeFieldName, boolean maintainType) {
return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, maintainType);
}
/**
* Creates a new runtime type adapter for {@code baseType} using {@code typeFieldName} as the type
* field name. Type field names are case sensitive.
*/
public static <T> RuntimeTypeAdapterFactory<T> of(Class<T> baseType, String typeFieldName) {
return new RuntimeTypeAdapterFactory<>(baseType, typeFieldName, false);
}
/**
* Creates a new runtime type adapter for {@code baseType} using {@code "type"} as the type field
* name.
*/
public static <T> RuntimeTypeAdapterFactory<T> of(Class<T> baseType) {
return new RuntimeTypeAdapterFactory<>(baseType, "type", false);
}
/**
* Ensures that this factory will handle not just the given {@code baseType}, but any subtype of
* that type.
*/
@CanIgnoreReturnValue
public RuntimeTypeAdapterFactory<T> recognizeSubtypes() {
this.recognizeSubtypes = true;
return this;
}
/**
* Registers {@code type} identified by {@code label}. Labels are case sensitive.
*
* @throws IllegalArgumentException if either {@code type} or {@code label} have already been
* registered on this type adapter.
*/
@CanIgnoreReturnValue
public RuntimeTypeAdapterFactory<T> registerSubtype(Class<? extends T> type, String label) {
if (type == null || label == null) {
throw new NullPointerException();
}
if (subtypeToLabel.containsKey(type) || labelToSubtype.containsKey(label)) {
throw new IllegalArgumentException("types and labels must be unique");
}
labelToSubtype.put(label, type);
subtypeToLabel.put(type, label);
return this;
}
/**
* Registers {@code type} identified by its {@link Class#getSimpleName simple name}. Labels are
* case sensitive.
*
* @throws IllegalArgumentException if either {@code type} or its simple name have already been
* registered on this type adapter.
*/
@CanIgnoreReturnValue
public RuntimeTypeAdapterFactory<T> registerSubtype(Class<? extends T> type) {
return registerSubtype(type, type.getSimpleName());
}
@Override
public <R> TypeAdapter<R> create(Gson gson, TypeToken<R> type) {
if (type == null) {
return null;
}
Class<?> rawType = type.getRawType();
boolean handle =
recognizeSubtypes ? baseType.isAssignableFrom(rawType) : baseType.equals(rawType);
if (!handle) {
return null;
}
TypeAdapter<JsonElement> jsonElementAdapter = gson.getAdapter(JsonElement.class);
Map<String, TypeAdapter<?>> labelToDelegate = new LinkedHashMap<>();
Map<Class<?>, TypeAdapter<?>> subtypeToDelegate = new LinkedHashMap<>();
for (Map.Entry<String, Class<?>> entry : labelToSubtype.entrySet()) {
TypeAdapter<?> delegate = gson.getDelegateAdapter(this, TypeToken.get(entry.getValue()));
labelToDelegate.put(entry.getKey(), delegate);
subtypeToDelegate.put(entry.getValue(), delegate);
}
return new TypeAdapter<R>() {
@Override
public R read(JsonReader in) throws IOException {
JsonElement jsonElement = jsonElementAdapter.read(in);
JsonElement labelJsonElement;
if (maintainType) {
labelJsonElement = jsonElement.getAsJsonObject().get(typeFieldName);
} else {
labelJsonElement = jsonElement.getAsJsonObject().remove(typeFieldName);
}
if (labelJsonElement == null) {
throw new JsonParseException(
"cannot deserialize "
+ baseType
+ " because it does not define a field named "
+ typeFieldName);
}
String label = labelJsonElement.getAsString();
@SuppressWarnings("unchecked") // registration requires that subtype extends T
TypeAdapter<R> delegate = (TypeAdapter<R>) labelToDelegate.get(label);
if (delegate == null) {
throw new JsonParseException(
"cannot deserialize "
+ baseType
+ " subtype named "
+ label
+ "; did you forget to register a subtype?");
}
return delegate.fromJsonTree(jsonElement);
}
@Override
public void write(JsonWriter out, R value) throws IOException {
Class<?> srcType = value.getClass();
String label = subtypeToLabel.get(srcType);
@SuppressWarnings("unchecked") // registration requires that subtype extends T
TypeAdapter<R> delegate = (TypeAdapter<R>) subtypeToDelegate.get(srcType);
if (delegate == null) {
throw new JsonParseException(
"cannot serialize " + srcType.getName() + "; did you forget to register a subtype?");
}
JsonObject jsonObject = delegate.toJsonTree(value).getAsJsonObject();
if (maintainType) {
jsonElementAdapter.write(out, jsonObject);
return;
}
JsonObject clone = new JsonObject();
if (jsonObject.has(typeFieldName)) {
throw new JsonParseException(
"cannot serialize "
+ srcType.getName()
+ " because it already defines a field named "
+ typeFieldName);
}
clone.add(typeFieldName, new JsonPrimitive(label));
for (Map.Entry<String, JsonElement> e : jsonObject.entrySet()) {
clone.add(e.getKey(), e.getValue());
}
jsonElementAdapter.write(out, clone);
}
}.nullSafe();
}
}

View File

@ -0,0 +1,13 @@
package sbt.internal.worker1;
import java.net.URI;
public class FilePath {
public URI path;
public String digest;
public FilePath(URI path, String digest) {
this.path = path;
this.digest = digest;
}
}

View File

@ -6,7 +6,7 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt;
package sbt.internal.worker1;
public enum ForkTags {
Error,

View File

@ -6,16 +6,15 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt;
package sbt.internal.worker1;
import com.google.gson.Gson;
import sbt.testing.*;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.net.Socket;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -23,7 +22,7 @@ import java.util.List;
import java.util.LinkedHashSet;
import java.util.concurrent.*;
public final class ForkMain {
public class ForkTestMain {
// serializables
// -----------------------------------------------------------------------------
@ -70,7 +69,7 @@ public final class ForkMain {
}
}
static final class ForkEvent implements Event, Serializable {
public static final class ForkEvent implements Event, Serializable {
private final String fullyQualifiedName;
private final Fingerprint fingerprint;
private final Selector selector;
@ -79,23 +78,23 @@ public final class ForkMain {
private final long duration;
ForkEvent(final Event e) {
fullyQualifiedName = e.fullyQualifiedName();
this.fullyQualifiedName = e.fullyQualifiedName();
final Fingerprint rawFingerprint = e.fingerprint();
if (rawFingerprint instanceof SubclassFingerprint)
fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint);
else fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint);
this.fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint);
else this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint);
selector = e.selector();
this.selector = e.selector();
checkSerializableSelector(selector);
status = e.status();
this.status = e.status();
final OptionalThrowable originalThrowable = e.throwable();
if (originalThrowable.isDefined())
throwable = new OptionalThrowable(new ForkError(originalThrowable.get()));
else throwable = originalThrowable;
this.throwable = new OptionalThrowable(new ForkError(originalThrowable.get()));
else this.throwable = originalThrowable;
duration = e.duration();
this.duration = e.duration();
}
public String fullyQualifiedName() {
@ -132,18 +131,30 @@ public final class ForkMain {
}
}
public static class ForkEventsInfo implements Serializable {
public long id;
public String group;
public ArrayList<ForkEvent> events;
public ForkEventsInfo(long id, String group, ArrayList<ForkEvent> events) {
this.id = id;
this.group = group;
this.events = events;
}
}
// -----------------------------------------------------------------------------
static final class ForkError extends Exception {
public static final class ForkError extends Exception {
private final String originalMessage;
private final String originalName;
private ForkError cause;
private ForkError cause1;
ForkError(final Throwable t) {
originalMessage = t.getMessage();
originalName = t.getClass().getName();
setStackTrace(t.getStackTrace());
if (t.getCause() != null) cause = new ForkError(t.getCause());
if (t.getCause() != null) cause1 = new ForkError(t.getCause());
}
public String getMessage() {
@ -151,51 +162,50 @@ public final class ForkMain {
}
public Exception getCause() {
return cause;
return cause1;
}
}
public static class ForkErrorInfo implements Serializable {
public final long id;
public final ForkError error;
public ForkErrorInfo(long id, ForkError error) {
this.id = id;
this.error = error;
}
}
// main
// ----------------------------------------------------------------------------------------------------------------
public static void main(final String[] args) throws Exception {
ClassLoader classLoader = new Run().getClass().getClassLoader();
try {
main(args, classLoader);
} finally {
System.exit(0);
}
}
public static void main(final String[] args, ClassLoader classLoader) throws Exception {
final Socket socket = new Socket(InetAddress.getByName(null), Integer.valueOf(args[0]));
final ObjectInputStream is = new ObjectInputStream(socket.getInputStream());
final ObjectOutputStream os = new ObjectOutputStream(socket.getOutputStream());
// Must flush the header that the constructor writes, otherwise the ObjectInputStream on the
// other end may block indefinitely
os.flush();
try {
new Run().run(is, os, classLoader);
} finally {
is.close();
os.close();
}
public static void main(long id, TestInfo info, PrintStream originalOut, ClassLoader classLoader)
throws Exception {
new Run(originalOut, id).run(info, classLoader);
}
// ----------------------------------------------------------------------------------------------------------------
private static final class Run {
public static final class Run {
final PrintStream originalOut;
final long id;
final Gson gson;
private void run(
final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader) {
Run(PrintStream originalOut, long id) {
this.originalOut = originalOut;
this.id = id;
this.gson = WorkerMain.mkGson();
}
private void run(TestInfo info, ClassLoader classLoader) {
try {
runTests(is, os, classLoader);
runTests(info, classLoader);
} catch (final RunAborted e) {
internalError(e);
} catch (final Throwable t) {
try {
logError(os, "Uncaught exception when running tests: " + t.toString());
write(os, new ForkError(t));
logError("Uncaught exception when running tests: " + t.toString());
writeError(new ForkError(t));
} catch (final Throwable t2) {
internalError(t2);
}
@ -223,106 +233,118 @@ public final class ForkMain {
}
}
private synchronized void write(final ObjectOutputStream os, final Object obj) {
try {
os.writeObject(obj);
os.flush();
} catch (final IOException e) {
throw new RunAborted(e);
}
private void writeError(ForkError error) {
ForkErrorInfo info = new ForkErrorInfo(this.id, error);
String params = this.gson.toJson(info, ForkErrorInfo.class);
String notification =
String.format(
"{ \"jsonrpc\": \"2.0\", \"method\": \"forkError\", \"params\": %s }", params);
this.originalOut.println(notification);
this.originalOut.flush();
}
private void log(final ObjectOutputStream os, final String message, final ForkTags level) {
write(os, new Object[] {level, message});
private void log(final String message, final ForkTags level) {
TestLogInfo info = new TestLogInfo(this.id, level, message);
String params = this.gson.toJson(info, TestLogInfo.class);
String notification =
String.format(
"{ \"jsonrpc\": \"2.0\", \"method\": \"testLog\", \"params\": %s }", params);
this.originalOut.println(notification);
this.originalOut.flush();
}
private void logDebug(final ObjectOutputStream os, final String message) {
log(os, message, ForkTags.Debug);
private void logDebug(final String message) {
log(message, ForkTags.Debug);
}
private void logInfo(final ObjectOutputStream os, final String message) {
log(os, message, ForkTags.Info);
private void logInfo(final String message) {
log(message, ForkTags.Info);
}
private void logWarn(final ObjectOutputStream os, final String message) {
log(os, message, ForkTags.Warn);
private void logWarn(final String message) {
log(message, ForkTags.Warn);
}
private void logError(final ObjectOutputStream os, final String message) {
log(os, message, ForkTags.Error);
private void logError(final String message) {
log(message, ForkTags.Error);
}
private Logger remoteLogger(final boolean ansiCodesSupported, final ObjectOutputStream os) {
private Logger remoteLogger(final boolean ansiCodesSupported) {
return new Logger() {
public boolean ansiCodesSupported() {
return ansiCodesSupported;
}
public void error(final String s) {
logError(os, s);
logError(s);
}
public void warn(final String s) {
logWarn(os, s);
logWarn(s);
}
public void info(final String s) {
logInfo(os, s);
logInfo(s);
}
public void debug(final String s) {
logDebug(os, s);
logDebug(s);
}
public void trace(final Throwable t) {
write(os, new ForkError(t));
writeError(new ForkError(t));
}
};
}
private void writeEvents(
final ObjectOutputStream os, final TaskDef taskDef, final ForkEvent[] events) {
write(os, new Object[] {taskDef.fullyQualifiedName(), events});
private void writeEvents(final TaskDef taskDef, final ForkEvent[] events) {
ForkEventsInfo info =
new ForkEventsInfo(
this.id,
taskDef.fullyQualifiedName(),
new ArrayList<ForkEvent>(Arrays.asList(events)));
String params = this.gson.toJson(info, ForkEventsInfo.class);
String notification =
String.format(
"{ \"jsonrpc\": \"2.0\", \"method\": \"testEvents\", \"params\": %s }", params);
this.originalOut.println(notification);
this.originalOut.flush();
}
private ExecutorService executorService(
final ForkConfiguration config, final ObjectOutputStream os) {
if (config.isParallel()) {
private ExecutorService executorService(final boolean parallel) {
if (parallel) {
final int nbThreads = Runtime.getRuntime().availableProcessors();
logDebug(os, "Create a test executor with a thread pool of " + nbThreads + " threads.");
logDebug("Create a test executor with a thread pool of " + nbThreads + " threads.");
// more options later...
// TODO we might want to configure the blocking queue with size #proc
return Executors.newFixedThreadPool(nbThreads);
} else {
logDebug(os, "Create a single-thread test executor");
logDebug("Create a single-thread test executor");
return Executors.newSingleThreadExecutor();
}
}
private void runTests(
final ObjectInputStream is, final ObjectOutputStream os, ClassLoader classLoader)
throws Exception {
final ForkConfiguration config = (ForkConfiguration) is.readObject();
final ExecutorService executor = executorService(config, os);
final TaskDef[] tests = (TaskDef[]) is.readObject();
final int nFrameworks = is.readInt();
final Logger[] loggers = {remoteLogger(config.isAnsiCodesSupported(), os)};
private void runTests(TestInfo info, ClassLoader classLoader) throws Exception {
final ExecutorService executor = executorService(info.parallel);
final TaskDef[] tests = info.taskDefs.toArray(new TaskDef[] {});
final int nFrameworks = info.testRunners.size();
final Logger[] loggers = {remoteLogger(info.ansiCodesSupported)};
for (int i = 0; i < nFrameworks; i++) {
final String[] implClassNames = (String[]) is.readObject();
final String[] frameworkArgs = (String[]) is.readObject();
final String[] remoteFrameworkArgs = (String[]) is.readObject();
for (TestInfo.TestRunner testRunner : info.testRunners) {
final String[] frameworkArgs = testRunner.mainRunnerArgs.toArray(new String[] {});
final String[] remoteFrameworkArgs =
testRunner.mainRunnerRemoteArgs.toArray(new String[] {});
Framework framework = null;
for (final String implClassName : implClassNames) {
for (final String implClassName : testRunner.implClassNames) {
try {
final Object rawFramework =
Class.forName(implClassName).getDeclaredConstructor().newInstance();
classLoader.loadClass(implClassName).getDeclaredConstructor().newInstance();
if (rawFramework instanceof Framework) framework = (Framework) rawFramework;
else framework = new FrameworkWrapper((org.scalatools.testing.Framework) rawFramework);
break;
} catch (final ClassNotFoundException e) {
logDebug(os, "Framework implementation '" + implClassName + "' not present.");
logError("Framework implementation '" + implClassName + "' not present.");
}
}
@ -344,7 +366,6 @@ public final class ForkMain {
final Runner runner = framework.runner(frameworkArgs, remoteFrameworkArgs, classLoader);
final Task[] tasks = runner.tasks(filteredTests.toArray(new TaskDef[filteredTests.size()]));
logDebug(
os,
"Runner for "
+ framework.getClass().getName()
+ " produced "
@ -356,25 +377,20 @@ public final class ForkMain {
Thread callDoneOnShutdown = new Thread(() -> runner.done());
Runtime.getRuntime().addShutdownHook(callDoneOnShutdown);
runTestTasks(executor, tasks, loggers, os);
runTestTasks(executor, tasks, loggers);
runner.done();
Runtime.getRuntime().removeShutdownHook(callDoneOnShutdown);
}
write(os, ForkTags.Done);
is.readObject();
}
private void runTestTasks(
final ExecutorService executor,
final Task[] tasks,
final Logger[] loggers,
final ObjectOutputStream os) {
final ExecutorService executor, final Task[] tasks, final Logger[] loggers) {
if (tasks.length > 0) {
final List<Future<Task[]>> futureNestedTasks = new ArrayList<>();
for (final Task task : tasks) {
futureNestedTasks.add(runTest(executor, task, loggers, os));
futureNestedTasks.add(runTest(executor, task, loggers));
}
// Note: this could be optimized further, we could have a callback once a test finishes that
@ -385,18 +401,15 @@ public final class ForkMain {
try {
nestedTasks.addAll(Arrays.asList(futureNestedTask.get()));
} catch (final Exception e) {
logError(os, "Failed to execute task " + futureNestedTask);
logError("Failed to execute task " + futureNestedTask);
}
}
runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers, os);
runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers);
}
}
private Future<Task[]> runTest(
final ExecutorService executor,
final Task task,
final Logger[] loggers,
final ObjectOutputStream os) {
final ExecutorService executor, final Task task, final Logger[] loggers) {
return executor.submit(
() -> {
ForkEvent[] events;
@ -410,11 +423,10 @@ public final class ForkMain {
eventList.add(new ForkEvent(e));
}
};
logDebug(os, " Running " + taskDef);
logDebug(" Running " + taskDef);
nestedTasks = task.execute(handler, loggers);
if (nestedTasks.length > 0 || eventList.size() > 0)
logDebug(
os,
" Produced "
+ nestedTasks.length
+ " nested tasks and "
@ -426,7 +438,6 @@ public final class ForkMain {
events =
new ForkEvent[] {
testError(
os,
taskDef,
"Uncaught exception when running "
+ taskDef.fullyQualifiedName()
@ -435,7 +446,7 @@ public final class ForkMain {
t)
};
}
writeEvents(os, taskDef, events);
writeEvents(taskDef, events);
return nestedTasks;
});
}
@ -482,14 +493,10 @@ public final class ForkMain {
});
}
private ForkEvent testError(
final ObjectOutputStream os,
final TaskDef taskDef,
final String message,
final Throwable t) {
logError(os, message);
private ForkEvent testError(final TaskDef taskDef, final String message, final Throwable t) {
logError(message);
final ForkError fe = new ForkError(t);
write(os, fe);
writeError(fe);
return testEvent(
taskDef.fullyQualifiedName(),
taskDef.fingerprint(),

View File

@ -6,7 +6,7 @@
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt;
package sbt.internal.worker1;
import sbt.testing.*;
@ -14,11 +14,11 @@ import sbt.testing.*;
* Adapts the old {@link org.scalatools.testing.Framework} interface into the new {@link
* sbt.testing.Framework}
*/
final class FrameworkWrapper implements Framework {
public final class FrameworkWrapper implements Framework {
private final org.scalatools.testing.Framework oldFramework;
FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) {
public FrameworkWrapper(final org.scalatools.testing.Framework oldFramework) {
this.oldFramework = oldFramework;
}

View File

@ -0,0 +1,44 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.worker1;
import java.io.Serializable;
import java.util.ArrayList;
public class RunInfo implements Serializable {
public static class JvmRunInfo implements Serializable {
public ArrayList<String> args;
public ArrayList<FilePath> classpath;
public String mainClass;
public boolean connectInput;
public JvmRunInfo(
ArrayList<String> args,
ArrayList<FilePath> classpath,
String mainClass,
boolean connectInput) {
this.args = args;
this.classpath = classpath;
this.mainClass = mainClass;
this.connectInput = connectInput;
}
}
public static class NativeRunInfo implements Serializable {}
public boolean jvm;
public JvmRunInfo jvmRunInfo;
public NativeRunInfo nativeRunInfo;
public RunInfo(boolean jvm, JvmRunInfo jvmRunInfo, NativeRunInfo nativeRunInfo) {
this.jvm = jvm;
this.jvmRunInfo = jvmRunInfo;
this.nativeRunInfo = nativeRunInfo;
}
}

View File

@ -0,0 +1,55 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.worker1;
import java.io.Serializable;
import java.util.ArrayList;
import sbt.testing.TaskDef;
public class TestInfo implements Serializable {
public static class TestRunner implements Serializable {
public final ArrayList<String> implClassNames;
public final ArrayList<String> mainRunnerArgs;
public final ArrayList<String> mainRunnerRemoteArgs;
public TestRunner(
ArrayList<String> implClassNames,
ArrayList<String> mainRunnerArgs,
ArrayList<String> mainRunnerRemoteArgs) {
this.implClassNames = implClassNames;
this.mainRunnerArgs = mainRunnerArgs;
this.mainRunnerRemoteArgs = mainRunnerRemoteArgs;
}
}
public final boolean jvm;
public final RunInfo.JvmRunInfo jvmRunInfo;
public final RunInfo.NativeRunInfo nativeRunInfo;
public final boolean ansiCodesSupported;
public final boolean parallel;
public final ArrayList<TaskDef> taskDefs;
public final ArrayList<TestRunner> testRunners;
public TestInfo(
boolean jvm,
RunInfo.JvmRunInfo jvmRunInfo,
RunInfo.NativeRunInfo nativeRunInfo,
boolean ansiCodesSupported,
boolean parallel,
ArrayList<TaskDef> taskDefs,
ArrayList<TestRunner> testRunners) {
this.jvm = jvm;
this.jvmRunInfo = jvmRunInfo;
this.nativeRunInfo = nativeRunInfo;
this.ansiCodesSupported = ansiCodesSupported;
this.parallel = parallel;
this.taskDefs = taskDefs;
this.testRunners = testRunners;
}
}

View File

@ -0,0 +1,25 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.worker1;
import java.io.Serializable;
import java.util.ArrayList;
import sbt.testing.TaskDef;
public class TestLogInfo implements Serializable {
public final long id;
public final ForkTags tag;
public final String message;
public TestLogInfo(long id, ForkTags tag, String message) {
this.id = id;
this.tag = tag;
this.message = message;
}
}

View File

@ -0,0 +1,13 @@
package sbt.internal.worker1;
import java.io.Serializable;
public final class WorkerError implements Serializable {
public final int code;
public final String message;
public WorkerError(int code, String message) {
this.code = code;
this.message = message;
}
}

View File

@ -0,0 +1,175 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.worker1;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import com.google.gson.typeadapters.RuntimeTypeAdapterFactory;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Scanner;
import sbt.testing.*;
/**
* WorkerMain that communicates via the stdin and stdout using JSON-RPC
* (https://www.jsonrpc.org/specification).
*/
public final class WorkerMain {
private PrintStream originalOut;
private InputStream originalIn;
private Scanner inScanner;
public static Gson mkGson() {
RuntimeTypeAdapterFactory<Fingerprint> fingerprintFac =
RuntimeTypeAdapterFactory.of(Fingerprint.class, "type");
fingerprintFac.registerSubtype(ForkTestMain.SubclassFingerscan.class, "SubclassFingerscan");
fingerprintFac.registerSubtype(ForkTestMain.AnnotatedFingerscan.class, "AnnotatedFingerscan");
RuntimeTypeAdapterFactory<Selector> selectorFac =
RuntimeTypeAdapterFactory.of(Selector.class, "type");
selectorFac.registerSubtype(SuiteSelector.class, "SuiteSelector");
selectorFac.registerSubtype(TestSelector.class, "TestSelector");
selectorFac.registerSubtype(NestedSuiteSelector.class, "NestedSuiteSelector");
selectorFac.registerSubtype(NestedTestSelector.class, "NestedTestSelector");
selectorFac.registerSubtype(TestWildcardSelector.class, "TestWildcardSelector");
return new GsonBuilder()
.registerTypeAdapterFactory(fingerprintFac)
.registerTypeAdapterFactory(selectorFac)
.create();
}
public static void main(final String[] args) throws Exception {
try {
if (args.length == 0) {
WorkerMain app = new WorkerMain();
app.consoleWork();
System.exit(0);
} else {
System.err.println("missing args");
System.exit(1);
}
} catch (Throwable e) {
e.printStackTrace();
System.exit(1);
}
}
WorkerMain() {
this.originalOut = System.out;
ByteArrayOutputStream baos = new ByteArrayOutputStream();
System.setOut(new PrintStream(baos));
this.originalIn = System.in;
this.inScanner = new Scanner(this.originalIn);
}
void consoleWork() throws Exception {
if (this.inScanner.hasNextLine()) {
String line = this.inScanner.nextLine();
process(line);
}
}
/** This processes single request of supposed JSON line. */
void process(String json) {
JsonElement elem = JsonParser.parseString(json);
JsonObject o = elem.getAsJsonObject();
if (!o.has("jsonrpc")) {
return;
}
Gson g = WorkerMain.mkGson();
long id = o.getAsJsonPrimitive("id").getAsLong();
try {
String method = o.getAsJsonPrimitive("method").getAsString();
JsonObject params = o.getAsJsonObject("params");
switch (method) {
case "run":
RunInfo info = g.fromJson(params, RunInfo.class);
run(info);
break;
case "test":
TestInfo testInfo = g.fromJson(params, TestInfo.class);
test(id, testInfo);
break;
}
String response = String.format("{ \"jsonrpc\": \"2.0\", \"result\": 0, \"id\": %d }", id);
this.originalOut.println(response);
this.originalOut.flush();
} catch (Throwable e) {
WorkerError err = new WorkerError(1, e.getMessage());
String errMessage = g.toJson(err, err.getClass());
String errJson =
String.format("{ \"jsonrpc\": \"2.0\", \"error\": %s, \"id\": %d }", errMessage, id);
this.originalOut.println(errJson);
this.originalOut.flush();
}
}
void run(RunInfo info) throws Exception {
if (info.jvm) {
if (info.jvmRunInfo == null) {
throw new RuntimeException("missing jvmRunInfo element");
}
RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo;
URLClassLoader cl = createClassLoader(jvmRunInfo, ClassLoader.getSystemClassLoader());
try {
Class<?> mainClass = cl.loadClass(jvmRunInfo.mainClass);
Method mainMethod = mainClass.getMethod("main", String[].class);
String[] mainArgs = jvmRunInfo.args.stream().toArray(String[]::new);
mainMethod.invoke(null, (Object) mainArgs);
} finally {
cl.close();
}
} else {
throw new RuntimeException("only jvm is supported");
}
}
void test(long id, TestInfo info) throws Exception {
if (info.jvm) {
RunInfo.JvmRunInfo jvmRunInfo = info.jvmRunInfo;
ClassLoader parent = new ForkTestMain().getClass().getClassLoader();
ClassLoader cl = createClassLoader(jvmRunInfo, parent);
try {
ForkTestMain.main(id, info, this.originalOut, cl);
} finally {
if (cl instanceof URLClassLoader) {
((URLClassLoader) cl).close();
}
}
} else {
throw new RuntimeException("only jvm is supported");
}
}
private URLClassLoader createClassLoader(RunInfo.JvmRunInfo info, ClassLoader parent) {
URL[] urls =
info.classpath
.stream()
.map(
filePath -> {
try {
return filePath.path.toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
})
.toArray(URL[]::new);
return new URLClassLoader(urls, parent);
}
}

View File

@ -0,0 +1,19 @@
package sbt.internal.worker1
import sbt.io.IO
object WorkerTest extends verify.BasicTestSuite:
val main = WorkerMain()
test("process") {
val u0 = IO.classLocationPath(classOf[example.Hello]).toUri()
val u1 = IO.classLocationPath(classOf[scala.quoted.Quotes]).toUri()
val u2 = IO.classLocationPath(classOf[scala.collection.immutable.List[?]]).toUri()
val cp =
s"""[{ "path": "${u0}", "digest": "" }, { "path": "${u1}", "digest": "" }, { "path": "${u2}", "digest": "" }]"""
val runInfo =
s"""{ "jvm": true, "jvmRunInfo": { "args": ["hi"], "classpath": $cp, "mainClass": "example.Hello" } }"""
val json = s"""{ "jsonrpc": "2.0", "id": 1, "method": "run", "params": $runInfo }"""
main.process(json)
}
end WorkerTest