From 5a88bd2302cde72aaac2f1a9dc3868f246903e1e Mon Sep 17 00:00:00 2001 From: Bruno Bieth Date: Thu, 7 Nov 2013 10:39:58 +0100 Subject: [PATCH] Third draft to execute the forked tests in parallel. This feature is not activated by default. To enable it set `testForkedParallel` to `true`. The test-agent then executes the tests in a thread pool. For now it has a fixed size set to the number of available processors. The concurrent restrictions configuration should be used. --- .../src/main/scala/sbt/ForkTests.scala | 15 +- main/src/main/scala/sbt/Defaults.scala | 15 +- main/src/main/scala/sbt/Keys.scala | 1 + .../project/ForkParallelTest.scala | 19 ++ .../fork-parallel/src/test/scala/tests.scala | 53 +++ sbt/src/sbt-test/tests/fork-parallel/test | 7 + .../tests/fork/project/ForkTestsTest.scala | 2 +- .../src/main/java/sbt/ForkConfiguration.java | 21 ++ testing/agent/src/main/java/sbt/ForkMain.java | 307 ++++++++++-------- .../src/main/java/sbt/FrameworkWrapper.java | 5 +- .../src/main/scala/sbt/TestFramework.scala | 14 +- 11 files changed, 305 insertions(+), 154 deletions(-) create mode 100644 sbt/src/sbt-test/tests/fork-parallel/project/ForkParallelTest.scala create mode 100644 sbt/src/sbt-test/tests/fork-parallel/src/test/scala/tests.scala create mode 100644 sbt/src/sbt-test/tests/fork-parallel/test create mode 100644 testing/agent/src/main/java/sbt/ForkConfiguration.java diff --git a/main/actions/src/main/scala/sbt/ForkTests.scala b/main/actions/src/main/scala/sbt/ForkTests.scala index c6d318ee8..63b8da1d4 100755 --- a/main/actions/src/main/scala/sbt/ForkTests.scala +++ b/main/actions/src/main/scala/sbt/ForkTests.scala @@ -23,13 +23,13 @@ private[sbt] object ForkTests if(opts.tests.isEmpty) constant( TestOutput(TestResult.Passed, Map.empty[String, SuiteResult], Iterable.empty) ) else - mainTestTask(runners, opts, classpath, fork, log).tagw(config.tags: _*) + mainTestTask(runners, opts, classpath, fork, log, config.parallel).tagw(config.tags: _*) main.dependsOn( all(opts.setup) : _*) flatMap { results => all(opts.cleanup).join.map( _ => results) } } - private[this] def mainTestTask(runners: Map[TestFramework, Runner], opts: ProcessedOptions, classpath: Seq[File], fork: ForkOptions, log: Logger): Task[TestOutput] = + private[this] def mainTestTask(runners: Map[TestFramework, Runner], opts: ProcessedOptions, classpath: Seq[File], fork: ForkOptions, log: Logger, parallel: Boolean): Task[TestOutput] = std.TaskExtra.task { val server = new ServerSocket(0) @@ -41,7 +41,8 @@ private[sbt] object ForkTests object Acceptor extends Runnable { val resultsAcc = mutable.Map.empty[String, SuiteResult] lazy val result = TestOutput(overall(resultsAcc.values.map(_.result)), resultsAcc.toMap, Iterable.empty) - def run: Unit = { + + def run() { val socket = try { server.accept() @@ -58,21 +59,21 @@ private[sbt] object ForkTests val is = new ObjectInputStream(socket.getInputStream) try { - os.writeBoolean(log.ansiCodesSupported) + val config = new ForkConfiguration(log.ansiCodesSupported, parallel) + os.writeObject(config) val taskdefs = opts.tests.map(t => new TaskDef(t.name, forkFingerprint(t.fingerprint), t.explicitlySpecified, t.selectors)) os.writeObject(taskdefs.toArray) os.writeInt(runners.size) for ((testFramework, mainRunner) <- runners) { - val remoteArgs = mainRunner.remoteArgs() os.writeObject(testFramework.implClassNames.toArray) os.writeObject(mainRunner.args) - os.writeObject(remoteArgs) + os.writeObject(mainRunner.remoteArgs) } os.flush() - (new React(is, os, log, opts.testListeners, resultsAcc)).react() + new React(is, os, log, opts.testListeners, resultsAcc).react() } finally { is.close(); os.close(); socket.close() } diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index ef3867384..5a4b74a3a 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -103,6 +103,7 @@ object Defaults extends BuildCommon outputStrategy :== None, exportJars :== false, fork :== false, + testForkedParallel :== false, javaOptions :== Nil, sbtPlugin :== false, crossPaths :== true, @@ -358,7 +359,7 @@ object Defaults extends BuildCommon definedTests <<= detectTests, definedTestNames <<= definedTests map ( _.map(_.name).distinct) storeAs definedTestNames triggeredBy compile, testFilter in testQuick <<= testQuickFilter, - executeTests <<= (streams in test, loadedTestFrameworks, testLoader, testGrouping in test, testExecution in test, fullClasspath in test, javaHome in test) flatMap allTestGroupsTask, + executeTests <<= (streams in test, loadedTestFrameworks, testLoader, testGrouping in test, testExecution in test, fullClasspath in test, javaHome in test, testForkedParallel) flatMap allTestGroupsTask, test := { implicit val display = Project.showContextKey(state.value) Tests.showResults(streams.value.log, executeTests.value, noTestsMessage(resolvedScoped.value)) @@ -468,7 +469,7 @@ object Defaults extends BuildCommon implicit val display = Project.showContextKey(state.value) val modifiedOpts = Tests.Filters(filter(selected)) +: Tests.Argument(frameworkOptions : _*) +: config.options val newConfig = config.copy(options = modifiedOpts) - val output = allTestGroupsTask(s, loadedTestFrameworks.value, testLoader.value, testGrouping.value, newConfig, fullClasspath.value, javaHome.value) + val output = allTestGroupsTask(s, loadedTestFrameworks.value, testLoader.value, testGrouping.value, newConfig, fullClasspath.value, javaHome.value, testForkedParallel.value) val processed = for(out <- output) yield Tests.showResults(s.log, out, noTestsMessage(resolvedScoped.value)) @@ -476,7 +477,7 @@ object Defaults extends BuildCommon } } - def createTestRunners(frameworks: Map[TestFramework,Framework], loader: ClassLoader, config: Tests.Execution) = { + def createTestRunners(frameworks: Map[TestFramework,Framework], loader: ClassLoader, config: Tests.Execution) : Map[TestFramework, Runner] = { import Tests.Argument val opts = config.options.toList frameworks.map { case (tf, f) => @@ -490,12 +491,18 @@ object Defaults extends BuildCommon } def allTestGroupsTask(s: TaskStreams, frameworks: Map[TestFramework,Framework], loader: ClassLoader, groups: Seq[Tests.Group], config: Tests.Execution, cp: Classpath, javaHome: Option[File]): Task[Tests.Output] = { + allTestGroupsTask(s,frameworks,loader, groups, config, cp, javaHome, forkedParallelExecution = false) + } + + def allTestGroupsTask(s: TaskStreams, frameworks: Map[TestFramework,Framework], loader: ClassLoader, groups: Seq[Tests.Group], config: Tests.Execution, cp: Classpath, javaHome: Option[File], forkedParallelExecution: Boolean): Task[Tests.Output] = { val runners = createTestRunners(frameworks, loader, config) val groupTasks = groups map { case Tests.Group(name, tests, runPolicy) => runPolicy match { case Tests.SubProcess(opts) => - ForkTests(runners, tests.toList, config, cp.files, opts, s.log) tag Tags.ForkedTestGroup + val forkedConfig = config.copy(parallel = config.parallel && forkedParallelExecution) + s.log.debug(s"Forking tests - parallelism = ${forkedConfig.parallel}") + ForkTests(runners, tests.toList, forkedConfig, cp.files, opts, s.log) tag Tags.ForkedTestGroup case Tests.InProcess => Tests(frameworks, loader, runners, tests, config, s.log) } diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 7d2369351..99902d8fa 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -195,6 +195,7 @@ object Keys val testOptions = TaskKey[Seq[TestOption]]("test-options", "Options for running tests.", BPlusTask) val testFrameworks = SettingKey[Seq[TestFramework]]("test-frameworks", "Registered, although not necessarily present, test frameworks.", CTask) val testListeners = TaskKey[Seq[TestReportListener]]("test-listeners", "Defines test listeners.", DTask) + val testForkedParallel = SettingKey[Boolean]("test-forked-parallel", "Whether forked tests should be executed in parallel", CTask) val testExecution = TaskKey[Tests.Execution]("test-execution", "Settings controlling test execution", DTask) val testFilter = TaskKey[Seq[String] => Seq[String => Boolean]]("test-filter", "Filter controlling whether the test is executed", DTask) val testGrouping = TaskKey[Seq[Tests.Group]]("test-grouping", "Collects discovered tests into groups. Whether to fork and the options for forking are configurable on a per-group basis.", BMinusTask) diff --git a/sbt/src/sbt-test/tests/fork-parallel/project/ForkParallelTest.scala b/sbt/src/sbt-test/tests/fork-parallel/project/ForkParallelTest.scala new file mode 100644 index 000000000..c54748b65 --- /dev/null +++ b/sbt/src/sbt-test/tests/fork-parallel/project/ForkParallelTest.scala @@ -0,0 +1,19 @@ +import sbt._ +import Keys._ +import Tests._ +import Defaults._ + +object ForkParallelTest extends Build { + val check = taskKey[Unit]("Check that tests are executed in parallel") + + lazy val root = Project("root", file("."), settings = defaultSettings ++ Seq( + scalaVersion := "2.9.2", + libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test", + fork in Test := true, + check := { + if( ! (file("max-concurrent-tests_3").exists() || file("max-concurrent-tests_4").exists() )) { + sys.error("Forked tests were not executed in parallel!") + } + } + )) +} \ No newline at end of file diff --git a/sbt/src/sbt-test/tests/fork-parallel/src/test/scala/tests.scala b/sbt/src/sbt-test/tests/fork-parallel/src/test/scala/tests.scala new file mode 100644 index 000000000..4ebd2ff25 --- /dev/null +++ b/sbt/src/sbt-test/tests/fork-parallel/src/test/scala/tests.scala @@ -0,0 +1,53 @@ + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import org.junit.Test +import scala.annotation.tailrec + +object ParallelTest { + val nbConcurrentTests = new AtomicInteger(0) + val maxConcurrentTests = new AtomicInteger(0) + + private def updateMaxConcurrentTests(currentMax: Int, newMax: Int) : Boolean = { + if( maxConcurrentTests.compareAndSet(currentMax, newMax) ) { + val f = new File("max-concurrent-tests_" + newMax) + f.createNewFile + true + } else { + false + } + } + + @tailrec + def execute(f : => Unit) { + val nb = nbConcurrentTests.incrementAndGet() + val max = maxConcurrentTests.get() + if( nb <= max || updateMaxConcurrentTests(max, nb)) { + f + nbConcurrentTests.getAndDecrement + } else { + nbConcurrentTests.getAndDecrement + execute(f) + } + } +} + +class Test1 { + @Test + def slow() { ParallelTest.execute { Thread.sleep(1000) } } +} + +class Test2 { + @Test + def slow() { ParallelTest.execute { Thread.sleep(1000) } } +} + +class Test3 { + @Test + def slow() { ParallelTest.execute { Thread.sleep(1000) } } +} + +class Test4 { + @Test + def slow() { ParallelTest.execute { Thread.sleep(1000) } } +} \ No newline at end of file diff --git a/sbt/src/sbt-test/tests/fork-parallel/test b/sbt/src/sbt-test/tests/fork-parallel/test new file mode 100644 index 000000000..5226713d3 --- /dev/null +++ b/sbt/src/sbt-test/tests/fork-parallel/test @@ -0,0 +1,7 @@ +> test +-> check + +> clean +> set testForkedParallel := true +> test +> check \ No newline at end of file diff --git a/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala b/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala index f8ae9ec08..d8624f7e9 100755 --- a/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala +++ b/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala @@ -26,7 +26,7 @@ object ForkTestsTest extends Build { val (exist, absent) = files.partition(_.exists) exist.foreach(_.delete()) if(absent.nonEmpty) - error("Files were not created:\n\t" + absent.mkString("\n\t")) + sys.error("Files were not created:\n\t" + absent.mkString("\n\t")) }, concurrentRestrictions := Tags.limit(Tags.ForkedTestGroup, 2) :: Nil, libraryDependencies += "org.scalatest" %% "scalatest" % "1.8" % "test" diff --git a/testing/agent/src/main/java/sbt/ForkConfiguration.java b/testing/agent/src/main/java/sbt/ForkConfiguration.java new file mode 100644 index 000000000..397d1aaa9 --- /dev/null +++ b/testing/agent/src/main/java/sbt/ForkConfiguration.java @@ -0,0 +1,21 @@ +package sbt; + +import java.io.Serializable; + +public final class ForkConfiguration implements Serializable { + private boolean ansiCodesSupported; + private boolean parallel; + + public ForkConfiguration(boolean ansiCodesSupported, boolean parallel) { + this.ansiCodesSupported = ansiCodesSupported; + this.parallel = parallel; + } + + public boolean isAnsiCodesSupported() { + return ansiCodesSupported; + } + + public boolean isParallel() { + return parallel; + } +} diff --git a/testing/agent/src/main/java/sbt/ForkMain.java b/testing/agent/src/main/java/sbt/ForkMain.java index a494adbc4..a56783fcd 100755 --- a/testing/agent/src/main/java/sbt/ForkMain.java +++ b/testing/agent/src/main/java/sbt/ForkMain.java @@ -12,10 +12,15 @@ import java.io.Serializable; import java.net.Socket; import java.net.InetAddress; import java.util.ArrayList; -import java.util.List; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.*; public class ForkMain { + + // serializables + // ----------------------------------------------------------------------------- + static class SubclassFingerscan implements SubclassFingerprint, Serializable { private boolean isModule; private String superclassName; @@ -29,6 +34,7 @@ public class ForkMain { public String superclassName() { return superclassName; } public boolean requireNoArgConstructor() { return requireNoArgConstructor; } } + static class AnnotatedFingerscan implements AnnotatedFingerprint, Serializable { private boolean isModule; private String annotationName; @@ -39,6 +45,54 @@ public class ForkMain { public boolean isModule() { return isModule; } public String annotationName() { return annotationName; } } + + static class ForkEvent implements Event, Serializable { + private String fullyQualifiedName; + private Fingerprint fingerprint; + private Selector selector; + private Status status; + private OptionalThrowable throwable; + private long duration; + + ForkEvent(Event e) { + fullyQualifiedName = e.fullyQualifiedName(); + Fingerprint rawFingerprint = e.fingerprint(); + + if (rawFingerprint instanceof SubclassFingerprint) + this.fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint); + else + this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint); + + selector = e.selector(); + checkSerializableSelector(selector); + status = e.status(); + OptionalThrowable originalThrowable = e.throwable(); + + if (originalThrowable.isDefined()) + this.throwable = new OptionalThrowable(new ForkError(originalThrowable.get())); + else + this.throwable = originalThrowable; + + this.duration = e.duration(); + } + + public String fullyQualifiedName() { return fullyQualifiedName; } + public Fingerprint fingerprint() { return fingerprint; } + public Selector selector() { return selector; } + public Status status() { return status; } + public OptionalThrowable throwable() { return throwable; } + public long duration() { return duration; } + + static void checkSerializableSelector(Selector selector) { + if (! (selector instanceof Serializable)) { + throw new UnsupportedOperationException("Selector implementation must be Serializable, but " + selector.getClass().getName() + " is not."); + } + } + } + + // ----------------------------------------------------------------------------- + + static class ForkError extends Exception { private String originalMessage; private ForkError cause; @@ -50,62 +104,50 @@ public class ForkMain { public String getMessage() { return originalMessage; } public Exception getCause() { return cause; } } - - static Selector forkSelector(Selector selector) { - if (selector instanceof Serializable) - return selector; - else - throw new UnsupportedOperationException("Selector implementation must be Serializable, but " + selector.getClass().getName() + " is not."); - } - - static class ForkEvent implements Event, Serializable { - private String fullyQualifiedName; - private Fingerprint fingerprint; - private Selector selector; - private Status status; - private OptionalThrowable throwable; - private long duration; - ForkEvent(Event e) { - fullyQualifiedName = e.fullyQualifiedName(); - Fingerprint rawFingerprint = e.fingerprint(); - if (rawFingerprint instanceof SubclassFingerprint) - this.fingerprint = new SubclassFingerscan((SubclassFingerprint) rawFingerprint); - else - this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) rawFingerprint); - selector = forkSelector(e.selector()); - status = e.status(); - OptionalThrowable originalThrowable = e.throwable(); - if (originalThrowable.isDefined()) - this.throwable = new OptionalThrowable(new ForkError(originalThrowable.get())); - else - this.throwable = originalThrowable; - this.duration = e.duration(); - } - public String fullyQualifiedName() { return fullyQualifiedName; } - public Fingerprint fingerprint() { return fingerprint; } - public Selector selector() { return selector; } - public Status status() { return status; } - public OptionalThrowable throwable() { return throwable; } - public long duration() { return duration; } - } + + + // main + // ---------------------------------------------------------------------------------------------------------------- + public static void main(String[] args) throws Exception { Socket socket = new Socket(InetAddress.getByName(null), Integer.valueOf(args[0])); final ObjectInputStream is = new ObjectInputStream(socket.getInputStream()); final ObjectOutputStream os = new ObjectOutputStream(socket.getOutputStream()); // Must flush the header that the constructor writes, otherwise the ObjectInputStream on the other end may block indefinitely os.flush(); + try { + new Run().run(is, os); + } finally { try { - new Run().run(is, os); - } finally { is.close(); os.close(); + } finally { + System.exit(0); } - } finally { - System.exit(0); } } + + // ---------------------------------------------------------------------------------------------------------------- + + private static class Run { + + void run(ObjectInputStream is, ObjectOutputStream os) throws Exception { + try { + runTests(is, os); + } catch (RunAborted e) { + internalError(e); + } catch (Throwable t) { + try { + logError(os, "Uncaught exception when running tests: " + t.toString()); + write(os, new ForkError(t)); + } catch (Throwable t2) { + internalError(t2); + } + } + } + boolean matches(Fingerprint f1, Fingerprint f2) { if (f1 instanceof SubclassFingerprint && f2 instanceof SubclassFingerprint) { final SubclassFingerprint sf1 = (SubclassFingerprint) f1; @@ -118,9 +160,11 @@ public class ForkMain { } return false; } + class RunAborted extends RuntimeException { RunAborted(Exception e) { super(e); } } + synchronized void write(ObjectOutputStream os, Object obj) { try { os.writeObject(obj); @@ -129,29 +173,50 @@ public class ForkMain { throw new RunAborted(e); } } - void logError(ObjectOutputStream os, String message) { - write(os, new Object[]{ForkTags.Error, message}); + + void log(ObjectOutputStream os, String message, ForkTags level) { + write(os, new Object[]{level, message}); } - void logDebug(ObjectOutputStream os, String message) { - write(os, new Object[]{ForkTags.Debug, message}); + + void logDebug(ObjectOutputStream os, String message) { log(os, message, ForkTags.Debug); } + void logInfo(ObjectOutputStream os, String message) { log(os, message, ForkTags.Info); } + void logWarn(ObjectOutputStream os, String message) { log(os, message, ForkTags.Warn); } + void logError(ObjectOutputStream os, String message) { log(os, message, ForkTags.Error); } + + Logger remoteLogger(final boolean ansiCodesSupported, final ObjectOutputStream os) { + return new Logger() { + public boolean ansiCodesSupported() { return ansiCodesSupported; } + public void error(String s) { logError(os, s); } + public void warn(String s) { logWarn(os, s); } + public void info(String s) { logInfo(os, s); } + public void debug(String s) { logDebug(os, s); } + public void trace(Throwable t) { write(os, new ForkError(t)); } + }; } + void writeEvents(ObjectOutputStream os, TaskDef taskDef, ForkEvent[] events) { write(os, new Object[]{taskDef.fullyQualifiedName(), events}); } + + ExecutorService executorService(ForkConfiguration config, ObjectOutputStream os) { + if(config.isParallel()) { + int nbThreads = Runtime.getRuntime().availableProcessors(); + logDebug(os, "Create a test executor with a thread pool of " + nbThreads + " threads."); + // more options later... + // TODO we might want to configure the blocking queue with size #proc + return Executors.newFixedThreadPool(nbThreads); + } else { + logDebug(os, "Create a single-thread test executor"); + return Executors.newSingleThreadExecutor(); + } + } + void runTests(ObjectInputStream is, final ObjectOutputStream os) throws Exception { - final boolean ansiCodesSupported = is.readBoolean(); + final ForkConfiguration config = (ForkConfiguration) is.readObject(); + ExecutorService executor = executorService(config, os); final TaskDef[] tests = (TaskDef[]) is.readObject(); int nFrameworks = is.readInt(); - Logger[] loggers = { - new Logger() { - public boolean ansiCodesSupported() { return ansiCodesSupported; } - public void error(String s) { logError(os, s); } - public void warn(String s) { write(os, new Object[]{ForkTags.Warn, s}); } - public void info(String s) { write(os, new Object[]{ForkTags.Info, s}); } - public void debug(String s) { write(os, new Object[]{ForkTags.Debug, s}); } - public void trace(Throwable t) { write(os, new ForkError(t)); } - } - }; + Logger[] loggers = { remoteLogger(config.isAnsiCodesSupported(), os) }; for (int i = 0; i < nFrameworks; i++) { final String[] implClassNames = (String[]) is.readObject(); @@ -186,89 +251,66 @@ public class ForkMain { final Runner runner = framework.runner(frameworkArgs, remoteFrameworkArgs, getClass().getClassLoader()); Task[] tasks = runner.tasks(filteredTests.toArray(new TaskDef[filteredTests.size()])); logDebug(os, "Runner for " + framework.getClass().getName() + " produced " + tasks.length + " initial tasks for " + filteredTests.size() + " tests."); - for (Task task : tasks) - runTestSafe(task, runner, loggers, os); + + runTestTasks(executor, tasks, loggers, os); + runner.done(); } write(os, ForkTags.Done); is.readObject(); } - class NestedTask { - private String parentName; - private Task task; - NestedTask(String parentName, Task task) { - this.parentName = parentName; - this.task = task; - } - public String getParentName() { - return parentName; - } - public Task getTask() { - return task; - } - } - void runTestSafe(Task task, Runner runner, Logger[] loggers, ObjectOutputStream os) { - TaskDef taskDef = task.taskDef(); - try { - List nestedTasks = new ArrayList(); - for (Task nt : runTest(taskDef, task, loggers, os)) - nestedTasks.add(new NestedTask(taskDef.fullyQualifiedName(), nt)); - while (true) { - List newNestedTasks = new ArrayList(); - int nestedTasksLength = nestedTasks.size(); - for (int i = 0; i < nestedTasksLength; i++) { - NestedTask nestedTask = nestedTasks.get(i); - String nestedParentName = nestedTask.getParentName() + "-" + i; - for (Task nt : runTest(nestedTask.getTask().taskDef(), nestedTask.getTask(), loggers, os)) { - newNestedTasks.add(new NestedTask(nestedParentName, nt)); - } - } - if (newNestedTasks.size() == 0) - break; - else { - nestedTasks = newNestedTasks; + + void runTestTasks(ExecutorService executor, Task[] tasks, Logger[] loggers, ObjectOutputStream os) { + if( tasks.length > 0 ) { + List> futureNestedTasks = new ArrayList>(); + for( Task task : tasks ) { + futureNestedTasks.add(runTest(executor, task, loggers, os)); + } + + // Note: this could be optimized further, we could have a callback once a test finishes that executes immediately the nested tasks + // At the moment, I'm especially interested in JUnit, which doesn't have nested tasks. + List nestedTasks = new ArrayList(); + for( Future futureNestedTask : futureNestedTasks ) { + try { + nestedTasks.addAll( Arrays.asList(futureNestedTask.get())); + } catch (Exception e) { + logError(os, "Failed to execute task " + futureNestedTask); } } - } catch (Throwable t) { - writeEvents(os, taskDef, new ForkEvent[] { testError(os, taskDef, "Uncaught exception when running " + taskDef.fullyQualifiedName() + ": " + t.toString(), t) }); + runTestTasks(executor, nestedTasks.toArray(new Task[nestedTasks.size()]), loggers, os); } } - Task[] runTest(TaskDef taskDef, Task task, Logger[] loggers, ObjectOutputStream os) { - ForkEvent[] events; - Task[] nestedTasks; - try { - final List eventList = new ArrayList(); - EventHandler handler = new EventHandler() { public void handle(Event e){ eventList.add(new ForkEvent(e)); } }; - logDebug(os, " Running " + taskDef); - nestedTasks = task.execute(handler, loggers); - if(nestedTasks.length > 0 || eventList.size() > 0) - logDebug(os, " Produced " + nestedTasks.length + " nested tasks and " + eventList.size() + " events."); - events = eventList.toArray(new ForkEvent[eventList.size()]); - } - catch (Throwable t) { - nestedTasks = new Task[0]; - events = new ForkEvent[] { testError(os, taskDef, "Uncaught exception when running " + taskDef.fullyQualifiedName() + ": " + t.toString(), t) }; - } - writeEvents(os, taskDef, events); - return nestedTasks; - } - void run(ObjectInputStream is, ObjectOutputStream os) throws Exception { - try { - runTests(is, os); - } catch (RunAborted e) { - internalError(e); - } catch (Throwable t) { - try { - logError(os, "Uncaught exception when running tests: " + t.toString()); - write(os, new ForkError(t)); - } catch (Throwable t2) { - internalError(t2); + + Future runTest(ExecutorService executor, final Task task, final Logger[] loggers, final ObjectOutputStream os) { + return executor.submit(new Callable() { + @Override + public Task[] call() { + ForkEvent[] events; + Task[] nestedTasks; + TaskDef taskDef = task.taskDef(); + try { + final List eventList = new ArrayList(); + EventHandler handler = new EventHandler() { public void handle(Event e){ eventList.add(new ForkEvent(e)); } }; + logDebug(os, " Running " + taskDef); + nestedTasks = task.execute(handler, loggers); + if(nestedTasks.length > 0 || eventList.size() > 0) + logDebug(os, " Produced " + nestedTasks.length + " nested tasks and " + eventList.size() + " events."); + events = eventList.toArray(new ForkEvent[eventList.size()]); + } + catch (Throwable t) { + nestedTasks = new Task[0]; + events = new ForkEvent[] { testError(os, taskDef, "Uncaught exception when running " + taskDef.fullyQualifiedName() + ": " + t.toString(), t) }; + } + writeEvents(os, taskDef, events); + return nestedTasks; } - } + }); } + void internalError(Throwable t) { System.err.println("Internal error when running tests: " + t.toString()); } + ForkEvent testEvent(final String fullyQualifiedName, final Fingerprint fingerprint, final Selector selector, final Status r, final ForkError err, final long duration) { final OptionalThrowable throwable; if (err == null) @@ -280,21 +322,18 @@ public class ForkMain { public Fingerprint fingerprint() { return fingerprint; } public Selector selector() { return selector; } public Status status() { return r; } - public OptionalThrowable throwable() { - return throwable; + public OptionalThrowable throwable() { + return throwable; } public long duration() { return duration; } }); } - ForkEvent testError(ObjectOutputStream os, TaskDef taskDef, String message) { - logError(os, message); - return testEvent(taskDef.fullyQualifiedName(), taskDef.fingerprint(), new SuiteSelector(), Status.Error, null, 0); - } + ForkEvent testError(ObjectOutputStream os, TaskDef taskDef, String message, Throwable t) { logError(os, message); - ForkError fe = new ForkError(t); + ForkError fe = new ForkError(t); write(os, fe); return testEvent(taskDef.fullyQualifiedName(), taskDef.fingerprint(), new SuiteSelector(), Status.Error, fe, 0); } diff --git a/testing/agent/src/main/java/sbt/FrameworkWrapper.java b/testing/agent/src/main/java/sbt/FrameworkWrapper.java index e7b75eaa0..1ad5d0e1f 100644 --- a/testing/agent/src/main/java/sbt/FrameworkWrapper.java +++ b/testing/agent/src/main/java/sbt/FrameworkWrapper.java @@ -1,8 +1,11 @@ package sbt; import sbt.testing.*; -import java.io.Serializable; +/** + * Adapts the old {@link org.scalatools.testing.Framework} interface into the new + * {@link sbt.testing.Framework} + */ final class FrameworkWrapper implements Framework { private org.scalatools.testing.Framework oldFramework; diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index 2307ef7ae..47b1d69e3 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -65,15 +65,15 @@ final class TestDefinition(val name: String, val fingerprint: Fingerprint, val e final class TestRunner(delegate: Runner, listeners: Seq[TestReportListener], log: Logger) { - final def tasks(testDefs: Set[TestDefinition]): Array[TestTask] = - delegate.tasks(testDefs.map(df => new TaskDef(df.name, df.fingerprint, df.explicitlySpecified, df.selectors)).toArray) + final def tasks(testDefs: Set[TestDefinition]): Array[TestTask] = + delegate.tasks(testDefs.map(df => new TaskDef(df.name, df.fingerprint, df.explicitlySpecified, df.selectors)).toArray) final def run(taskDef: TaskDef, testTask: TestTask): (SuiteResult, Seq[TestTask]) = { - val testDefinition = new TestDefinition(taskDef.fullyQualifiedName, taskDef.fingerprint, taskDef.explicitlySpecified, taskDef.selectors) + val testDefinition = new TestDefinition(taskDef.fullyQualifiedName, taskDef.fingerprint, taskDef.explicitlySpecified, taskDef.selectors) log.debug("Running " + taskDef) val name = testDefinition.name - + def runTest() = { // here we get the results! here is where we'd pass in the event listener @@ -167,7 +167,7 @@ object TestFramework { import scala.collection.mutable.{HashMap, HashSet, Set} val map = new HashMap[Framework, Set[TestDefinition]] - def assignTest(test: TestDefinition) + def assignTest(test: TestDefinition) { def isTestForFramework(framework: Framework) = getFingerprints(framework).exists {t => matches(t, test.fingerprint) } for(framework <- frameworks.find(isTestForFramework)) @@ -192,8 +192,8 @@ object TestFramework val runner = runners(framework) val testTasks = withContextLoader(loader) { runner.tasks(testDefinitions) } for (testTask <- testTasks) yield { - val taskDef = testTask.taskDef - (taskDef.fullyQualifiedName, createTestFunction(loader, taskDef, runner, testTask)) + val taskDef = testTask.taskDef + (taskDef.fullyQualifiedName, createTestFunction(loader, taskDef, runner, testTask)) } }