diff --git a/sbt/src/sbt-test/tests/one-class-multi-framework/build.sbt b/sbt/src/sbt-test/tests/one-class-multi-framework/build.sbt new file mode 100644 index 000000000..80a62d3ce --- /dev/null +++ b/sbt/src/sbt-test/tests/one-class-multi-framework/build.sbt @@ -0,0 +1,6 @@ +scalaVersion := "2.10.2" + +libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test" + +libraryDependencies += "org.specs2" %% "specs2" % "2.1.1" % "test" + diff --git a/sbt/src/sbt-test/tests/one-class-multi-framework/src/test/scala/Test.scala b/sbt/src/sbt-test/tests/one-class-multi-framework/src/test/scala/Test.scala new file mode 100644 index 000000000..27620987f --- /dev/null +++ b/sbt/src/sbt-test/tests/one-class-multi-framework/src/test/scala/Test.scala @@ -0,0 +1,30 @@ +import org.junit.runner.RunWith +import org.specs2._ + +@RunWith(classOf[org.specs2.runner.JUnitRunner]) +class B extends Specification +{ + // sequential=true is necessary to get junit in the call stack + // otherwise, junit calls specs, which then runs tests on separate threads + def is = args(sequential=true) ^ s2""" + + This should + fail if 'succeed' file is missing $succeedNeeded + not run via JUnit $noJUnit + """ + + def succeedNeeded = { + val f = new java.io.File("succeed") + f.exists must_== true + } + def noJUnit = { + println("Trace: " + RunBy.trace.mkString("\n\t", "\n\t", "")) + RunBy.junit must_== false + } +} + +object RunBy +{ + def trace = (new Exception).getStackTrace.map(_.getClassName) + def junit = trace.exists(_.contains("org.junit")) +} diff --git a/sbt/src/sbt-test/tests/one-class-multi-framework/test b/sbt/src/sbt-test/tests/one-class-multi-framework/test new file mode 100644 index 000000000..1ae08f6fb --- /dev/null +++ b/sbt/src/sbt-test/tests/one-class-multi-framework/test @@ -0,0 +1,8 @@ +# verify the test runs at all by having it fail +$ absent succeed +-> test + +# indicate to the test it should succeed +# it will fail if run via junit and succeed if run via specs +$ touch succeed +> test \ No newline at end of file diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index 4638e6199..dc1a29405 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -152,13 +152,19 @@ object TestFramework listeners: Seq[TestReportListener]): (() => Unit, Seq[(String, TestFunction)], TestResult.Value => () => Unit) = { - val mappedTests = testMap(frameworks.values.toSeq, tests) + val unique = distinctBy(tests)(_.name) + val mappedTests = testMap(frameworks.values.toSeq, unique) if(mappedTests.isEmpty) (() => (), Nil, _ => () => () ) else - createTestTasks(testLoader, runners.map { case (tf, r) => (frameworks(tf), new TestRunner(r, listeners, log))}, mappedTests, tests, log, listeners) + createTestTasks(testLoader, runners.map { case (tf, r) => (frameworks(tf), new TestRunner(r, listeners, log))}, mappedTests, unique, log, listeners) } + private[this] def distinctBy[T, K](in: Seq[T])(f: T => K): Seq[T] = + { + val seen = new collection.mutable.HashSet[K] + in.filter(t => seen.add(f(t))) + } private[this] def order(mapped: Map[String, TestFunction], inputs: Seq[TestDefinition]): Seq[(String, TestFunction)] = for( d <- inputs; act <- mapped.get(d.name) ) yield (d.name, act) @@ -166,28 +172,15 @@ object TestFramework { import scala.collection.mutable.{HashMap, HashSet, Set} val map = new HashMap[Framework, Set[TestDefinition]] - def assignTests() + def assignTest(test: TestDefinition) { - for(test <- tests if !map.values.exists(_.contains(test))) - { - def isTestForFramework(framework: Framework) = getFingerprints(framework).exists {t => matches(t, test.fingerprint) } - for(framework <- frameworks.find(isTestForFramework)) - map.getOrElseUpdate(framework, new HashSet[TestDefinition]) += test - } + def isTestForFramework(framework: Framework) = getFingerprints(framework).exists {t => matches(t, test.fingerprint) } + for(framework <- frameworks.find(isTestForFramework)) + map.getOrElseUpdate(framework, new HashSet[TestDefinition]) += test } if(!frameworks.isEmpty) - assignTests() - map.toMap transform { (framework, tests) => mergeDuplicates(framework, tests.toSeq) } - } - private[this] def mergeDuplicates(framework: Framework, tests: Seq[TestDefinition]): Set[TestDefinition] = - { - val frameworkPrints = framework.fingerprints.reverse - def pickOne(prints: Seq[Fingerprint]): Fingerprint = - frameworkPrints.find(prints.toSet) getOrElse prints.head - val uniqueDefs = - for( ((name, explicitlySpecified, selectors), defs) <- tests.groupBy(t => (t.name, t.explicitlySpecified, t.selectors)) ) yield - new TestDefinition(name, pickOne(defs.map(_.fingerprint)), explicitlySpecified, selectors) - uniqueDefs.toSet + for(test <- tests) assignTest(test) + map.toMap.mapValues(_.toSet) } private def createTestTasks(loader: ClassLoader, runners: Map[Framework, TestRunner], tests: Map[Framework, Set[TestDefinition]], ordered: Seq[TestDefinition], log: Logger, listeners: Seq[TestReportListener]) =