diff --git a/main/actions/src/main/scala/sbt/ForkTests.scala b/main/actions/src/main/scala/sbt/ForkTests.scala index dcab26ade..0748797dc 100755 --- a/main/actions/src/main/scala/sbt/ForkTests.scala +++ b/main/actions/src/main/scala/sbt/ForkTests.scala @@ -53,7 +53,7 @@ private[sbt] object ForkTests { os.writeInt(runners.size) for ((testFramework, mainRunner) <- runners) { val remoteArgs = mainRunner.remoteArgs() - os.writeObject(testFramework.implClassName) + os.writeObject(testFramework.implClassNames.toArray) os.writeObject(mainRunner.args) os.writeObject(remoteArgs) } diff --git a/main/actions/src/main/scala/sbt/Tests.scala b/main/actions/src/main/scala/sbt/Tests.scala index 23c0c7d63..a368876c6 100644 --- a/main/actions/src/main/scala/sbt/Tests.scala +++ b/main/actions/src/main/scala/sbt/Tests.scala @@ -58,7 +58,7 @@ object Tests def frameworkArguments(framework: TestFramework, args: Seq[String]): Unit = (frameworks get framework) match { case Some(f) => frameworkArgs(f, args) - case None => undefinedFrameworks += framework.implClassName + case None => undefinedFrameworks ++= framework.implClassNames } for(option <- config.options) diff --git a/testing/agent/src/main/java/sbt/ForkMain.java b/testing/agent/src/main/java/sbt/ForkMain.java index 9cdf1e744..43950dafd 100755 --- a/testing/agent/src/main/java/sbt/ForkMain.java +++ b/testing/agent/src/main/java/sbt/ForkMain.java @@ -168,6 +168,9 @@ public class ForkMain { void logError(ObjectOutputStream os, String message) { write(os, new Object[]{ForkTags.Error, message}); } + void logDebug(ObjectOutputStream os, String message) { + write(os, new Object[]{ForkTags.Debug, message}); + } void writeEvents(ObjectOutputStream os, ForkTestDefinition test, ForkEvent[] events) { write(os, new Object[]{test.name, events}); } @@ -187,22 +190,27 @@ public class ForkMain { }; for (int i = 0; i < nFrameworks; i++) { - final String implClassName = (String) is.readObject(); + final String[] implClassNames = (String[]) is.readObject(); final String[] frameworkArgs = (String[]) is.readObject(); final String[] remoteFrameworkArgs = (String[]) is.readObject(); - final Framework framework; - try { - Object rawFramework = Class.forName(implClassName).newInstance(); - if (rawFramework instanceof Framework) - framework = (Framework) rawFramework; - else - framework = new FrameworkWrapper((org.scalatools.testing.Framework) rawFramework); - } catch (ClassNotFoundException e) { - logError(os, "Framework implementation '" + implClassName + "' not present."); - continue; + Framework framework = null; + for (String implClassName : implClassNames) { + try { + Object rawFramework = Class.forName(implClassName).newInstance(); + if (rawFramework instanceof Framework) + framework = (Framework) rawFramework; + else + framework = new FrameworkWrapper((org.scalatools.testing.Framework) rawFramework); + break; + } catch (ClassNotFoundException e) { + logDebug(os, "Framework implementation '" + implClassName + "' not present."); + } } + if (framework == null) + continue; + ArrayList filteredTests = new ArrayList(); for (Fingerprint testFingerprint : framework.fingerprints()) { for (ForkTestDefinition test : tests) { diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index bd85f5741..b59475983 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -8,6 +8,7 @@ package sbt import testing.{Logger=>TLogger, _} import org.scalatools.testing.{Framework => OldFramework} import classpath.{ClasspathUtilities, DualLoader, FilteredLoader} + import scala.annotation.tailrec object TestResult extends Enumeration { @@ -17,30 +18,38 @@ object TestResult extends Enumeration object TestFrameworks { val ScalaCheck = new TestFramework("org.scalacheck.ScalaCheckFramework") - val ScalaTest = new TestFramework("org.scalatest.tools.Framework") + val ScalaTest = new TestFramework("org.scalatest.tools.Framework", "org.scalatest.tools.ScalaTestFramework") val Specs = new TestFramework("org.specs.runner.SpecsFramework") val Specs2 = new TestFramework("org.specs2.runner.SpecsFramework") val JUnit = new TestFramework("com.novocode.junit.JUnitFramework") } -case class TestFramework(val implClassName: String) +case class TestFramework(val implClassNames: String*) { - def create(loader: ClassLoader, log: Logger): Option[Framework] = - { - try - { - Some( - Class.forName(implClassName, true, loader).newInstance match { - case newFramework: Framework => newFramework - case oldFramework: OldFramework => new FrameworkWrapper(oldFramework) - } - ) - } - catch - { - case e: ClassNotFoundException => log.debug("Framework implementation '" + implClassName + "' not present."); None + @tailrec + private def createFramework(loader: ClassLoader, log: Logger, frameworkClassNames: List[String]): Option[Framework] = { + frameworkClassNames match { + case head :: tail => + try + { + Some(Class.forName(head, true, loader).newInstance match { + case newFramework: Framework => newFramework + case oldFramework: OldFramework => new FrameworkWrapper(oldFramework) + }) + } + catch + { + case e: ClassNotFoundException => + log.debug("Framework implementation '" + head + "' not present."); + createFramework(loader, log, tail) + } + case Nil => + None } } + + def create(loader: ClassLoader, log: Logger): Option[Framework] = + createFramework(loader, log, implClassNames.toList) } final class TestDefinition(val name: String, val fingerprint: Fingerprint) {