diff --git a/main/Defaults.scala b/main/Defaults.scala index 7d835ddd2..185aa3c9c 100755 --- a/main/Defaults.scala +++ b/main/Defaults.scala @@ -117,10 +117,10 @@ object Defaults extends BuildCommon excludeFilter :== (".*" - ".") || HiddenFileFilter, pomIncludeRepository :== Classpaths.defaultRepositoryFilter )) - def defaultTestTasks(key: Scoped): Seq[Setting[_]] = Seq( - tags in key := Seq(Tags.Test -> 1), - logBuffered in key := true - ) + def defaultTestTasks(key: Scoped): Seq[Setting[_]] = inTask(key)(Seq( + tags := Seq(Tags.Test -> 1), + logBuffered := true + )) def projectCore: Seq[Setting[_]] = Seq( name <<= thisProject(_.id), logManager <<= extraLoggers(LogManager.defaults), @@ -287,12 +287,24 @@ object Defaults extends BuildCommon testOptions in GlobalScope :== Nil, testFilter in testOnly :== (selectedFilter _), testFilter in testQuick <<= testQuickFilter, - executeTests <<= (streams in test, loadedTestFrameworks, testExecution in test, testLoader, definedTests, resolvedScoped, state) flatMap { - (s, frameworkMap, config, loader, discovered, scoped, st) => - implicit val display = Project.showContextKey(st) - Tests(frameworkMap, loader, discovered, config, noTestsMessage(ScopedKey(scoped.scope, test.key)), s.log) + executeTests <<= (streams in test, loadedTestFrameworks, testLoader, testGrouping in test, testExecution in test, fullClasspath in test, javaHome in test) flatMap { + (s, frameworkMap, loader, groups, config, cp, javaHome) => + val tasks = groups map { + case Tests.Group(name, tests, runPolicy) => + runPolicy match { + case Tests.SubProcess(javaOpts) => + ForkTests(frameworkMap.keys.toSeq, tests.toList, config, cp.files, javaHome, javaOpts, s.log) tag Tags.ForkedTestGroup + case Tests.InProcess => + Tests(frameworkMap, loader, tests, config, s.log) + } + } + Tests.foldTasks(tasks) + }, + test <<= (executeTests, streams, resolvedScoped, state) map { + (results, s, scoped, st) => + implicit val display = Project.showContextKey(st) + Tests.showResults(s.log, results, noTestsMessage(scoped)) }, - test <<= (executeTests, streams) map { (results, s) => Tests.showResults(s.log, results) }, testOnly <<= inputTests(testOnly), testQuick <<= inputTests(testQuick) ) @@ -306,7 +318,8 @@ object Defaults extends BuildCommon TestLogger(s.log, testLogger(sm, test in sco.scope), buff) +: new TestStatusReporter(succeededFile(dir)) +: ls }, testOptions <<= (testOptions in TaskGlobal, testListeners) map { (options, ls) => Tests.Listeners(ls) +: options }, - testExecution <<= testExecutionTask(key) + testExecution <<= testExecutionTask(key), + testGrouping <<= testGrouping or singleTestGroup(key) ) ) def testLogger(manager: Streams, baseKey: Scoped)(tdef: TestDefinition): Logger = { @@ -321,9 +334,16 @@ object Defaults extends BuildCommon val mod = tdef.fingerprint match { case f: SubclassFingerprint => f.isModule; case f: AnnotatedFingerprint => f.isModule; case _ => false } extra.put(name.key, tdef.name).put(isModule, mod) } + def singleTestGroup(key: Scoped): Initialize[Task[Seq[Tests.Group]]] = + ((definedTests in key, fork in key, javaOptions in key) map { + (tests, fork, javaOpts) => Seq(new Tests.Group("", tests, if (fork) Tests.SubProcess(javaOpts) else Tests.InProcess)) + }) def testExecutionTask(task: Scoped): Initialize[Task[Tests.Execution]] = - (testOptions in task, parallelExecution in task, tags in task) map { (opts, par, ts) => new Tests.Execution(opts, par, ts) } + (testOptions in task, parallelExecution in task, fork in task, tags in task) map { + (opts, par, fork, ts) => + new Tests.Execution(opts, par, ts) + } def testQuickFilter: Initialize[Task[Seq[String] => String => Boolean]] = (fullClasspath in test, cacheDirectory) map { @@ -355,18 +375,24 @@ object Defaults extends BuildCommon def inputTests(key: InputKey[_]): Initialize[InputTask[Unit]] = InputTask( loadForParser(definedTestNames)( (s, i) => testOnlyParser(s, i getOrElse Nil) ) ) { result => - (streams, loadedTestFrameworks, testFilter in key, testExecution in key, testLoader, definedTests, resolvedScoped, result, state) flatMap { - case (s, frameworks, filter, config, loader, discovered, scoped, (tests, frameworkOptions), st) => - val modifiedOpts = Tests.Filter(filter(tests)) +: Tests.Argument(frameworkOptions : _*) +: config.options - val newConfig = new Tests.Execution(modifiedOpts, config.parallel, config.tags) + (streams, loadedTestFrameworks, testFilter in key, testGrouping in key, testExecution in key, testLoader, resolvedScoped, result, fullClasspath in key, javaHome in key, state) flatMap { + case (s, frameworks, filter, groups, config, loader, scoped, (selected, frameworkOptions), cp, javaHome, st) => implicit val display = Project.showContextKey(st) - Tests(frameworks, loader, discovered, newConfig, noTestsMessage(scoped), s.log) map { results => - Tests.showResults(s.log, results) - } + val tasks = groups map { + case Tests.Group(name, tests, runPolicy) => + val modifiedOpts = Tests.Filter(filter(selected)) +: Tests.Argument(frameworkOptions : _*) +: config.options + val newConfig = config.copy(options = modifiedOpts) + runPolicy match { + case Tests.SubProcess(javaOpts) => + ForkTests(frameworks.keys.toSeq, tests.toList, newConfig, cp.files, javaHome, javaOpts, s.log) tag Tags.ForkedTestGroup + case Tests.InProcess => + Tests(frameworks, loader, tests, newConfig, s.log) + } + } + Tests.foldTasks(tasks) map (Tests.showResults(s.log, _, noTestsMessage(scoped))) } } - def selectedFilter(args: Seq[String]): String => Boolean = { val filters = args map GlobFilter.apply @@ -377,7 +403,7 @@ object Defaults extends BuildCommon } def defaultRestrictions: Initialize[Seq[Tags.Rule]] = parallelExecution { par => val max = EvaluateTask.SystemProcessors - Tags.limitAll(if(par) max else 1) :: Nil + Tags.limitAll(if(par) max else 1) :: Tags.limit(Tags.ForkedTestGroup, 1) :: Nil } lazy val packageBase: Seq[Setting[_]] = Seq( @@ -504,12 +530,11 @@ object Defaults extends BuildCommon def runnerTask = runner <<= runnerInit def runnerInit: Initialize[Task[ScalaRun]] = (taskTemporaryDirectory, scalaInstance, baseDirectory, javaOptions, outputStrategy, fork, javaHome, trapExit, connectInput) map { - (tmp, si, base, options, strategy, forkRun, javaHomeDir, trap, connectIn) => - if(forkRun) { - new ForkRun( ForkOptions(scalaJars = si.jars, javaHome = javaHomeDir, connectInput = connectIn, outputStrategy = strategy, - runJVMOptions = options, workingDirectory = Some(base)) ) - } else - new Run(si, trap, tmp) + (tmp, si, base, options, strategy, forkRun, javaHomeDir, trap, connectIn) => + if(forkRun) { + new ForkRun( ForkOptions(scalaJars = si.jars, javaHome = javaHomeDir, connectInput = connectIn, outputStrategy = strategy, runJVMOptions = options, workingDirectory = Some(base)) ) + } else + new Run(si, trap, tmp) } @deprecated("Use `docTaskSettings` instead", "0.12.0") diff --git a/main/Keys.scala b/main/Keys.scala index 52a79757c..1e3fc5374 100644 --- a/main/Keys.scala +++ b/main/Keys.scala @@ -200,6 +200,7 @@ object Keys val testListeners = TaskKey[Seq[TestReportListener]]("test-listeners", "Defines test listeners.", DTask) val testExecution = TaskKey[Tests.Execution]("test-execution", "Settings controlling test execution", DTask) val testFilter = TaskKey[Seq[String] => String => Boolean]("test-filter", "Filter controlling whether the test is executed", DTask) + val testGrouping = TaskKey[Seq[Tests.Group]]("test-grouping", "Collects discovered tests into groups. Whether to fork and the options for forking are configurable on a per-group basis.", BMinusTask) val isModule = AttributeKey[Boolean]("is-module", "True if the target is a module.", DSetting) // Classpath/Dependency Management Keys diff --git a/main/Tags.scala b/main/Tags.scala index 346f09275..9d32dde25 100644 --- a/main/Tags.scala +++ b/main/Tags.scala @@ -19,6 +19,8 @@ object Tags val Network = Tag("network") val Disk = Tag("disk") + val ForkedTestGroup = Tag("forked-test-group") + /** Describes a restriction on concurrently executing tasks. * A Rule is constructed using one of the Tags.limit* methods. */ sealed trait Rule { @@ -55,4 +57,4 @@ object Tags def limitUntagged(max: Int): Rule = limit(Untagged, max) def limit(tag: Tag, max: Int): Rule = new Single(tag, max) def limitSum(max: Int, tags: Tag*): Rule = new Sum(tags, max) -} \ No newline at end of file +} diff --git a/main/actions/ForkTests.scala b/main/actions/ForkTests.scala new file mode 100755 index 000000000..cf29f4e34 --- /dev/null +++ b/main/actions/ForkTests.scala @@ -0,0 +1,102 @@ +/* sbt -- Simple Build Tool + * Copyright 2012 Eugene Vigdorchik + */ +package sbt + +import org.scalatools.testing._ +import java.net.ServerSocket +import java.io._ +import Tests._ +import ForkMain._ + +private[sbt] object ForkTests { + def apply(frameworks: Seq[TestFramework], tests: List[TestDefinition], config: Execution, classpath: Seq[File], javaHome: Option[File], javaOpts: Seq[String], log: Logger): Task[Output] = { + val opts = config.options.toList + val listeners = opts flatMap { + case Listeners(ls) => ls + case _ => Nil + } + val testListeners = listeners flatMap { + case tl: TestsListener => Some(tl) + case _ => None + } + val filters = opts flatMap { + case Filter(f) => Some(f) + case _ => None + } + val argMap = frameworks.map { + f => f.implClassName -> opts.flatMap { + case Argument(None, args) => args + case Argument(Some(`f`), args) => args + case _ => Nil + } + }.toMap + + std.TaskExtra.task { + val server = new ServerSocket(0) + object Acceptor extends Runnable { + val results = collection.mutable.Map.empty[String, TestResult.Value] + def output = (overall(results.values), results.toMap) + def run = { + val socket = server.accept() + val os = new ObjectOutputStream(socket.getOutputStream) + val is = new ObjectInputStream(socket.getInputStream) + + import Tags._ + @annotation.tailrec def react: Unit = is.readObject match { + case `Done` => os.writeObject(Done); + case Array(`Error`, s: String) => log.error(s); react + case Array(`Warn`, s: String) => log.warn(s); react + case Array(`Info`, s: String) => log.info(s); react + case Array(`Debug`, s: String) => log.debug(s); react + case t: Throwable => log.trace(t); react + case tEvents: Array[Event] => + for (first <- tEvents.headOption) listeners.foreach(_ startGroup first.testName) + val event = TestEvent(tEvents) + listeners.foreach(_ testEvent event) + for (first <- tEvents.headOption) { + val result = event.result getOrElse TestResult.Passed + results += first.testName -> result + listeners.foreach(_ endGroup (first.testName, result)) + } + react + } + + try { + os.writeBoolean(log.ansiCodesSupported) + + val testsFiltered = tests.filter(test => filters.forall(_(test.name))).map{ + t => new ForkTestDefinition(t.name, t.fingerprint) + }.toArray + os.writeObject(testsFiltered) + + os.writeInt(frameworks.size) + for ((clazz, args) <- argMap) { + os.writeObject(clazz) + os.writeObject(args.toArray) + } + + react + } finally { + is.close(); os.close(); socket.close() + } + } + } + + try { + testListeners.foreach(_.doInit()) + new Thread(Acceptor).start() + + val fullCp = classpath ++: Seq(IO.classLocationFile[ForkMain], IO.classLocationFile[Framework]) + val options = javaOpts ++: Seq("-classpath", fullCp mkString File.pathSeparator, classOf[ForkMain].getCanonicalName, server.getLocalPort.toString) + val ec = Fork.java(javaHome, options, LoggedOutput(log)) + if (ec != 0) log.error("Running java with options " + options.mkString(" ") + " failed with exit code " + ec) + } finally { + server.close() + } + val result = Acceptor.output + testListeners.foreach(_.doComplete(result._1)) + result + } tagw (config.tags: _*) + } +} diff --git a/main/actions/Tests.scala b/main/actions/Tests.scala index c6e19fec0..4487f7139 100644 --- a/main/actions/Tests.scala +++ b/main/actions/Tests.scala @@ -13,7 +13,6 @@ package sbt import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, Framework, SubclassFingerprint} - import collection.mutable import java.io.File sealed trait TestOption @@ -40,14 +39,11 @@ object Tests // None means apply to all, Some(tf) means apply to a particular framework only. final case class Argument(framework: Option[TestFramework], args: List[String]) extends TestOption - final class Execution(val options: Seq[TestOption], val parallel: Boolean, val tags: Seq[(Tag, Int)]) + final case class Execution(options: Seq[TestOption], parallel: Boolean, tags: Seq[(Tag, Int)]) - def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], parallel: Boolean, noTestsMessage: => String, log: Logger): Task[Output] = - apply(frameworks, testLoader, discovered, new Execution(options, parallel, Nil), noTestsMessage, log) - - def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], config: Execution, noTestsMessage: => String, log: Logger): Task[Output] = + def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], config: Execution, log: Logger): Task[Output] = { - import mutable.{HashSet, ListBuffer, Map, Set} + import collection.mutable.{HashSet, ListBuffer, Map, Set} val testFilters = new ListBuffer[String => Boolean] val excludeTestsSet = new HashSet[String] val setup, cleanup = new ListBuffer[ClassLoader => Unit] @@ -93,10 +89,10 @@ object Tests def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.name) && testFilters.forall(filter => filter(test.name)) val tests = discovered.filter(includeTest).toSet.toSeq val arguments = testArgsByFramework.map { case (k,v) => (k, v.toList) } toMap; - testTask(frameworks.values.toSeq, testLoader, tests, noTestsMessage, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments, config) + testTask(frameworks.values.toSeq, testLoader, tests, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments, config) } - def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition], noTestsMessage: => String, + def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition], userSetup: Iterable[ClassLoader => Unit], userCleanup: Iterable[ClassLoader => Unit], log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]], config: Execution): Task[Output] = { @@ -104,7 +100,7 @@ object Tests def partApp(actions: Iterable[ClassLoader => Unit]) = actions.toSeq map {a => () => a(loader) } val (frameworkSetup, runnables, frameworkCleanup) = - TestFramework.testTasks(frameworks, loader, tests, noTestsMessage, log, testListeners, arguments) + TestFramework.testTasks(frameworks, loader, tests, log, testListeners, arguments) val setupTasks = fj(partApp(userSetup) :+ frameworkSetup) val mainTasks = @@ -126,6 +122,10 @@ object Tests def processResults(results: Iterable[(String, TestResult.Value)]): (TestResult.Value, Map[String, TestResult.Value]) = (overall(results.map(_._2)), results.toMap) + def foldTasks(results: Seq[Task[Output]]): Task[Output] = + reduced(results.toIndexedSeq, { + case ((v1, m1), (v2, m2)) => (if (v1.id < v2.id) v2 else v1, m1 ++ m2) + }) def overall(results: Iterable[TestResult.Value]): TestResult.Value = (TestResult.Passed /: results) { (acc, result) => if(acc.id < result.id) result else acc } def discover(frameworks: Seq[Framework], analysis: Analysis, log: Logger): (Seq[TestDefinition], Set[String]) = @@ -153,28 +153,39 @@ object Tests (tests, mains.toSet) } - def showResults(log: Logger, results: (TestResult.Value, Map[String, TestResult.Value])): Unit = + def showResults(log: Logger, results: (TestResult.Value, Map[String, TestResult.Value]), noTestsMessage: =>String): Unit = { + if (results._2.isEmpty) + log.info(noTestsMessage) + else { import TestResult.{Error, Failed, Passed} - def select(Tpe: TestResult.Value) = results._2 collect { case (name, Tpe) => name } + def select(Tpe: TestResult.Value) = results._2 collect { case (name, Tpe) => name } - val failures = select(Failed) - val errors = select(Error) - val passed = select(Passed) + val failures = select(Failed) + val errors = select(Error) + val passed = select(Passed) - def show(label: String, level: Level.Value, tests: Iterable[String]): Unit = - if(!tests.isEmpty) - { - log.log(level, label) - log.log(level, tests.mkString("\t", "\n\t", "")) - } + def show(label: String, level: Level.Value, tests: Iterable[String]): Unit = + if(!tests.isEmpty) + { + log.log(level, label) + log.log(level, tests.mkString("\t", "\n\t", "")) + } - show("Passed tests:", Level.Debug, passed ) - show("Failed tests:", Level.Error, failures) - show("Error during tests:", Level.Error, errors) + show("Passed tests:", Level.Debug, passed ) + show("Failed tests:", Level.Error, failures) + show("Error during tests:", Level.Error, errors) - if(!failures.isEmpty || !errors.isEmpty) - error("Tests unsuccessful") + if(!failures.isEmpty || !errors.isEmpty) + error("Tests unsuccessful") + } } -} \ No newline at end of file + + sealed trait TestRunPolicy + case object InProcess extends TestRunPolicy + final case class SubProcess(javaOptions: Seq[String]) extends TestRunPolicy + + final case class Group(name: String, tests: Seq[TestDefinition], runPolicy: TestRunPolicy) +} + diff --git a/project/Sbt.scala b/project/Sbt.scala index 3cee103b0..a428a09e8 100644 --- a/project/Sbt.scala +++ b/project/Sbt.scala @@ -72,8 +72,12 @@ object Sbt extends Build // Apache Ivy integration lazy val ivySub = baseProject(file("ivy"), "Ivy") dependsOn(interfaceSub, launchInterfaceSub, logSub % "compile;test->test", ioSub % "compile;test->test", launchSub % "test->test") settings(ivy, jsch, httpclient) - // Runner for uniform test interface - lazy val testingSub = baseProject(file("testing"), "Testing") dependsOn(ioSub, classpathSub, logSub) settings(libraryDependencies += "org.scala-tools.testing" % "test-interface" % "0.5") + // Runner for uniform test interface + lazy val testingSub = baseProject(file("testing"), "Testing") dependsOn(ioSub, classpathSub, logSub, launchInterfaceSub, testAgentSub) settings(libraryDependencies += "org.scala-tools.testing" % "test-interface" % "0.5") + // Testing agent for running tests in a separate process. + lazy val testAgentSub = project(file("testing/agent"), "Test Agent") settings( + libraryDependencies += "org.scala-tools.testing" % "test-interface" % "0.5" + ) // Basic task engine lazy val taskSub = testedBaseProject(tasksPath, "Tasks") dependsOn(controlSub, collectionSub) diff --git a/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala b/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala new file mode 100755 index 000000000..4cfa42fef --- /dev/null +++ b/sbt/src/sbt-test/tests/fork/project/ForkTestsTest.scala @@ -0,0 +1,36 @@ +import sbt._ +import Keys._ +import Tests._ +import Defaults._ + +object ForkTestsTest extends Build { + val totalFiles = 9 + val groupSize = 3 + + val check = TaskKey[Unit]("check", "Check all files were created and remove them.") + + def groupId(idx: Int) = "group_" + (idx + 1) + def groupPrefix(idx: Int) = groupId(idx) + "_file_" + + lazy val root = Project("root", file("."), settings = defaultSettings ++ Seq( + testGrouping <<= definedTests in Test map { tests => + assert(tests.size == 1) + val groups = Stream const tests(0) take totalFiles grouped groupSize + for ((ts, idx) <- groups.toSeq.zipWithIndex) yield { + new Group(groupId(idx), ts, SubProcess(Seq("-Dgroup.prefix=" + groupPrefix(idx), "-Dgroup.size=" + ts.size))) + } + }, + check := { + for (i <- 0 until totalFiles/groupSize) + for (j <- 1 to groupSize) { + val f = file(groupPrefix(i) + j) + if (!f.exists) + error("File " + f.getName + " was not created.") + else + f.delete() + } + }, + concurrentRestrictions := Tags.limit(Tags.ForkedTestGroup, 2) :: Nil, + libraryDependencies += "org.scalatest" % "scalatest_2.9.0" % "1.6.1" % "test" + )) +} diff --git a/sbt/src/sbt-test/tests/fork/src/test/scala/Ensemble.scala b/sbt/src/sbt-test/tests/fork/src/test/scala/Ensemble.scala new file mode 100755 index 000000000..fc009691f --- /dev/null +++ b/sbt/src/sbt-test/tests/fork/src/test/scala/Ensemble.scala @@ -0,0 +1,20 @@ +import org.scalatest.FlatSpec +import org.scalatest.matchers.MustMatchers +import java.io.File + +class Ensemble extends FlatSpec with MustMatchers { + val prefix = System.getProperty("group.prefix") + val countTo = System.getProperty("group.size").toInt + + "an ensemble" must "create all files" in { + @annotation.tailrec + def step(i: Int): Unit = { + val f = new File(prefix + i) + if (!f.createNewFile) + step(if (f.exists) i + 1 else i) + else + i must be <= (countTo) + } + step(1) + } +} diff --git a/sbt/src/sbt-test/tests/fork/test b/sbt/src/sbt-test/tests/fork/test new file mode 100755 index 000000000..de5041ba0 --- /dev/null +++ b/sbt/src/sbt-test/tests/fork/test @@ -0,0 +1,2 @@ +> test +> check \ No newline at end of file diff --git a/testing/TestFramework.scala b/testing/TestFramework.scala index 8c3dd6d1a..2261f497e 100644 --- a/testing/TestFramework.scala +++ b/testing/TestFramework.scala @@ -23,14 +23,13 @@ object TestFrameworks val JUnit = new TestFramework("com.novocode.junit.JUnitFramework") } -class TestFramework(val implClassName: String) +case class TestFramework(val implClassName: String) { def create(loader: ClassLoader, log: Logger): Option[Framework] = { try { Some(Class.forName(implClassName, true, loader).newInstance.asInstanceOf[Framework]) } catch { case e: ClassNotFoundException => log.debug("Framework implementation '" + implClassName + "' not present."); None } } - override def toString = "TestFramework(" + implClassName + ")" } final class TestDefinition(val name: String, val fingerprint: Fingerprint) { @@ -130,7 +129,6 @@ object TestFramework def testTasks(frameworks: Seq[Framework], testLoader: ClassLoader, tests: Seq[TestDefinition], - noTestsMessage: => String, log: Logger, listeners: Seq[TestReportListener], testArgsByFramework: Map[Framework, Seq[String]]): @@ -139,7 +137,7 @@ object TestFramework val arguments = testArgsByFramework withDefaultValue Nil val mappedTests = testMap(frameworks, tests, arguments) if(mappedTests.isEmpty) - (() => (), Nil, _ => () => log.info(noTestsMessage) ) + (() => (), Nil, _ => () => () ) else createTestTasks(testLoader, mappedTests, log, listeners) } diff --git a/testing/TestReportListener.scala b/testing/TestReportListener.scala index 542648cd8..d20dbbb76 100644 --- a/testing/TestReportListener.scala +++ b/testing/TestReportListener.scala @@ -23,7 +23,7 @@ trait TestReportListener trait TestsListener extends TestReportListener { /** called once, at beginning. */ - def doInit + def doInit() /** called once, at end. */ def doComplete(finalResult: TestResult.Value) } diff --git a/testing/agent/src/main/java/sbt/ForkMain.java b/testing/agent/src/main/java/sbt/ForkMain.java new file mode 100755 index 000000000..416592b46 --- /dev/null +++ b/testing/agent/src/main/java/sbt/ForkMain.java @@ -0,0 +1,151 @@ +/* sbt -- Simple Build Tool + * Copyright 2012 Eugene Vigdorchik + */ +package sbt; + +import org.scalatools.testing.*; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.Socket; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.List; + +public class ForkMain { + public static enum Tags { + Error, Warn, Info, Debug, Done; + } + + static class SubclassFingerscan implements TestFingerprint, Serializable { + private boolean isModule; + private String superClassName; + SubclassFingerscan(SubclassFingerprint print) { + isModule = print.isModule(); + superClassName = print.superClassName(); + } + public boolean isModule() { return isModule; } + public String superClassName() { return superClassName; } + } + static class AnnotatedFingerscan implements AnnotatedFingerprint, Serializable { + private boolean isModule; + private String annotationName; + AnnotatedFingerscan(AnnotatedFingerprint print) { + isModule = print.isModule(); + annotationName = print.annotationName(); + } + public boolean isModule() { return isModule; } + public String annotationName() { return annotationName; } + } + public static class ForkTestDefinition implements Serializable { + public String name; + public Fingerprint fingerprint; + + public ForkTestDefinition(String name, Fingerprint fingerprint) { + this.name = name; + if (fingerprint instanceof SubclassFingerprint) { + this.fingerprint = new SubclassFingerscan((SubclassFingerprint) fingerprint); + } else { + this.fingerprint = new AnnotatedFingerscan((AnnotatedFingerprint) fingerprint); + } + } + } + static class ForkEvent implements Event, Serializable { + private String testName; + private String description; + private Result result; + ForkEvent(Event e) { + testName = e.testName(); + description = e.description(); + result = e.result(); + } + public String testName() { return testName; } + public String description() { return description; } + public Result result() { return result;} + public Throwable error() { return null; } + } + public static void main(String[] args) throws Exception { + Socket socket = new Socket(InetAddress.getByName(null), Integer.valueOf(args[0])); + final ObjectInputStream is = new ObjectInputStream(socket.getInputStream()); + final ObjectOutputStream os = new ObjectOutputStream(socket.getOutputStream()); + try { + new Run().run(is, os); + } finally { + is.close(); + os.close(); + } + } + private static class Run { + boolean matches(Fingerprint f1, Fingerprint f2) { + if (f1 instanceof SubclassFingerprint && f2 instanceof SubclassFingerprint) { + final SubclassFingerprint sf1 = (SubclassFingerprint) f1; + final SubclassFingerprint sf2 = (SubclassFingerprint) f2; + return sf1.isModule() == sf2.isModule() && sf1.superClassName().equals(sf2.superClassName()); + } else if (f1 instanceof AnnotatedFingerprint && f2 instanceof AnnotatedFingerprint) { + AnnotatedFingerprint af1 = (AnnotatedFingerprint) f1; + AnnotatedFingerprint af2 = (AnnotatedFingerprint) f2; + return af1.isModule() == af2.isModule() && af1.annotationName().equals(af2.annotationName()); + } + return false; + } + void write(ObjectOutputStream os, Object obj) { + try { + os.writeObject(obj); + } catch (IOException e) { + System.err.println("Cannot write to socket"); + } + } + void run(ObjectInputStream is, final ObjectOutputStream os) throws Exception { + final boolean ansiCodesSupported = is.readBoolean(); + Logger[] loggers = { + new Logger() { + public boolean ansiCodesSupported() { return ansiCodesSupported; } + public void error(String s) { write(os, new Object[]{Tags.Error, s}); } + public void warn(String s) { write(os, new Object[]{Tags.Warn, s}); } + public void info(String s) { write(os, new Object[]{Tags.Info, s}); } + public void debug(String s) { write(os, new Object[]{Tags.Debug, s}); } + public void trace(Throwable t) { write(os, t); } + } + }; + + final ForkTestDefinition[] tests = (ForkTestDefinition[]) is.readObject(); + int nFrameworks = is.readInt(); + for (int i = 0; i < nFrameworks; i++) { + final Framework framework; + final String implClassName = (String) is.readObject(); + try { + framework = (Framework) Class.forName(implClassName).newInstance(); + } catch (ClassNotFoundException e) { + write(os, new Object[]{Tags.Error, "Framework implementation '" + implClassName + "' not present."}); + continue; + } + + final String[] frameworkArgs = (String[]) is.readObject(); + + ArrayList filteredTests = new ArrayList(); + for (Fingerprint testFingerprint : framework.tests()) { + for (ForkTestDefinition test : tests) { + if (matches(testFingerprint, test.fingerprint)) filteredTests.add(test); + } + } + final org.scalatools.testing.Runner runner = framework.testRunner(getClass().getClassLoader(), loggers); + for (ForkTestDefinition test : filteredTests) { + final List events = new ArrayList(); + EventHandler handler = new EventHandler() { public void handle(Event e){ events.add(new ForkEvent(e)); } }; + if (runner instanceof Runner2) { + ((Runner2) runner).run(test.name, test.fingerprint, handler, frameworkArgs); + } else if (test.fingerprint instanceof TestFingerprint) { + runner.run(test.name, (TestFingerprint) test.fingerprint, handler, frameworkArgs); + } else { + write(os, new Object[]{Tags.Error, "Framework '" + framework + "' does not support test '" + test.name + "'"}); + } + write(os, events.toArray(new ForkEvent[events.size()])); + } + } + write(os, Tags.Done); + is.readObject(); + } + } +} diff --git a/util/io/IPC.scala b/util/io/IPC.scala index d0ccf7b03..2cad6106f 100644 --- a/util/io/IPC.scala +++ b/util/io/IPC.scala @@ -69,4 +69,4 @@ final class IPC private(s: Socket) extends NotNull def send(s: String) = { out.write(s); out.newLine(); out.flush() } def receive: String = in.readLine() -} \ No newline at end of file +}