diff --git a/main/DefaultProject.scala b/main/DefaultProject.scala index 86317c396..706419ddf 100644 --- a/main/DefaultProject.scala +++ b/main/DefaultProject.scala @@ -12,6 +12,7 @@ package sbt import Types._ import xsbti.api.Definition + import org.scalatools.testing.Framework import java.io.File class DefaultProject(val info: ProjectInfo) extends BasicProject @@ -73,6 +74,43 @@ abstract class BasicProject extends TestProject with MultiClasspathProject with val classpath = cp.map(x => Path.fromFile(x.data)) r.run(mainClass, classpath, in.splitArgs, s.log) foreach error } + + lazy val testFrameworks: Task[Seq[TestFramework]] = task { + import TestFrameworks._ + Seq(ScalaCheck, Specs, ScalaTest, ScalaCheckCompat, ScalaTestCompat, SpecsCompat, JUnit) + } + + lazy val testLoader: Task[ClassLoader] = + fullClasspath(TestConfig) :^: buildScalaInstance :^: KNil map { case classpath :+: instance :+: HNil => + TestFramework.createTestLoader(data(classpath), instance) + } + + lazy val loadedTestFrameworks: Task[Map[TestFramework,Framework]] = + testFrameworks :^: streams :^: testLoader :^: KNil map { case frameworks :+: s :+: loader :+: HNil => + frameworks.flatMap( f => f.create(loader, s.log).map( x => (f, x)).toIterable ).toMap + } + + lazy val discoverTest: Task[(Seq[TestDefinition], Set[String])] = + loadedTestFrameworks :^: testCompile :^: KNil map { case frameworkMap :+: analysis :+: HNil => + Test.discover(frameworkMap.values.toSeq, analysis) + } + lazy val definedTests: Task[Seq[TestDefinition]] = discoverTest.map(_._1) + + lazy val testOptions: Task[Seq[TestOption]] = task { Nil } + + lazy val test = (loadedTestFrameworks :^: testOptions :^: testLoader :^: definedTests :^: streams :^: KNil) flatMap { + case frameworkMap :+: options :+: loader :+: discovered :+: s :+: HNil => + + val toTask = (testTask: NamedTestTask) => task { testTask.run() } named(testTask.name) + def dependsOn(on: Iterable[Task[_]]): Task[Unit] = task { () } dependsOn(on.toSeq : _*) + + val (begin, work, end) = Test(frameworkMap, loader, discovered, options, s.log) + val beginTasks = dependsOn( begin.map(toTask) ) // test setup tasks + val workTasks = work.map(w => toTask(w) dependsOn(beginTasks) ) // the actual tests + val endTasks = dependsOn( end.map(toTask) ) // tasks that perform test cleanup and are run regardless of success of tests + dependsOn( workTasks ) doFinally { endTasks } + } + lazy val clean = task { IO.delete(outputDirectory) } @@ -107,4 +145,7 @@ abstract class BasicProject extends TestProject with MultiClasspathProject with lazy val compileInputs: Task[Compile.Inputs] = compileInputsTask(Configurations.Compile, "src" / "main", buildScalaInstance) named(name + "/compile-inputs") lazy val compile: Task[Analysis] = compileTask(compileInputs) named(name + "/compile") + + lazy val testCompileInputs: Task[Compile.Inputs] = compileInputsTask(Configurations.Test, "src" / "test", buildScalaInstance) named(name + "/test-inputs") + lazy val testCompile: Task[Analysis] = compileTask(testCompileInputs) named(name + "/test") } diff --git a/main/Test.scala b/main/Test.scala new file mode 100644 index 000000000..3ffbe668e --- /dev/null +++ b/main/Test.scala @@ -0,0 +1,154 @@ +/* sbt -- Simple Build Tool + * Copyright 2010 Mark Harrah + */ +package sbt + + import std._ + import compile.{Discovered,Discovery} + import inc.Analysis + import TaskExtra._ + import Types._ + import xsbti.api.Definition + + import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, Framework, SubclassFingerprint} + + import collection.mutable + import java.io.File + +sealed trait TestOption +object Test +{ + type Output = (TestResult.Value, Map[String,TestResult.Value]) + + final case class Setup(setup: ClassLoader => Unit) extends TestOption + def Setup(setup: () => Unit) = new Setup(_ => setup()) + + final case class Cleanup(cleanup: ClassLoader => Unit) extends TestOption + def Cleanup(setup: () => Unit) = new Cleanup(_ => setup()) + + final case class Exclude(tests: Iterable[String]) extends TestOption + final case class Listeners(listeners: Iterable[TestReportListener]) extends TestOption + final case class Filter(filterTest: String => Boolean) extends TestOption + + // args for all frameworks + def Argument(args: String*): Argument = Argument(None, args.toList) + // args for a particular test framework + def Argument(tf: TestFramework, args: String*): Argument = Argument(Some(tf), args.toList) + + // None means apply to all, Some(tf) means apply to a particular framework only. + final case class Argument(framework: Option[TestFramework], args: List[String]) extends TestOption + + + def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], log: Logger) = + { + import mutable.{HashSet, ListBuffer, Map, Set} + val testFilters = new ListBuffer[String => Boolean] + val excludeTestsSet = new HashSet[String] + val setup, cleanup = new ListBuffer[ClassLoader => Unit] + val testListeners = new ListBuffer[TestReportListener] + val testArgsByFramework = Map[Framework, ListBuffer[String]]() + def frameworkArgs(framework: Framework, args: Seq[String]): Unit = + testArgsByFramework.getOrElseUpdate(framework, new ListBuffer[String]) ++= args + + for(option <- options) + { + option match + { + case Filter(include) => testFilters += include + case Exclude(exclude) => excludeTestsSet ++= exclude + case Listeners(listeners) => testListeners ++= listeners + case Setup(setupFunction) => setup += setupFunction + case Cleanup(cleanupFunction) => cleanup += cleanupFunction + /** + * There are two cases here. + * The first handles TestArguments in the project file, which + * might have a TestFramework specified. + * The second handles arguments to be applied to all test frameworks. + * -- arguments from the project file that didnt have a framework specified + * -- command line arguments (ex: test-only someClass -- someArg) + * (currently, command line args must be passed to all frameworks) + */ + case Argument(Some(framework), args) => frameworkArgs(frameworks(framework), args) + case Argument(None, args) => frameworks.values.foreach { f => frameworkArgs(f, args) } + } + } + + if(excludeTestsSet.size > 0) + log.debug(excludeTestsSet.mkString("Excluding tests: \n\t", "\n\t", "")) + + def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.name) && testFilters.forall(filter => filter(test.name)) + val tests = discovered.filter(includeTest).toSet.toSeq + val arguments = testArgsByFramework.map { case (k,v) => (k, v.toList) } toMap; + testTask(frameworks.values.toSeq, testLoader, tests, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments) + } + + // (overall result, individual results) + def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition], + userSetup: Iterable[ClassLoader => Unit], userCleanup: Iterable[ClassLoader => Unit], + log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]]): Task[Output] = + { + def fj(actions: Iterable[() => Unit]): Task[Unit] = nop.dependsOn( actions.toSeq.fork( _() ) : _*) + def partApp(actions: Iterable[ClassLoader => Unit]) = actions.toSeq map {a => () => a(loader) } + + val (frameworkSetup, runnables, frameworkCleanup) = + TestFramework.testTasks(frameworks, loader, tests, log, testListeners, arguments) + + val setupTasks = fj(partApp(userSetup) :+ frameworkSetup) + val mainTasks = runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name } + mainTasks.toSeq.join map processResults flatMap { results => + val cleanupTasks = fj(partApp(userCleanup) :+ frameworkCleanup(results._1)) + cleanupTasks map { _ => results } + } + } + def processResults(results: Iterable[(String, TestResult.Value)]): (TestResult.Value, Map[String, TestResult.Value]) = + (overall(results.map(_._2)), results.toMap) + def overall(results: Iterable[TestResult.Value]): TestResult.Value = + (TestResult.Passed /: results) { (acc, result) => if(acc.id < result.id) result else acc } + def discover(frameworks: Seq[Framework], analysis: Analysis): (Seq[TestDefinition], Set[String]) = + discover(frameworks.flatMap(_.tests), allDefs(analysis)) + + def allDefs(analysis: Analysis) = analysis.apis.internal.values.flatMap(_.definitions).toSeq + def discover(fingerprints: Seq[Fingerprint], definitions: Seq[Definition]): (Seq[TestDefinition], Set[String]) = + { + val subclasses = fingerprints collect { case sub: SubclassFingerprint => (sub.superClassName, sub.isModule, sub) }; + val annotations = fingerprints collect { case ann: AnnotatedFingerprint => (ann.annotationName, ann.isModule, ann) }; + + def firsts[A,B,C](s: Seq[(A,B,C)]): Set[A] = s.map(_._1).toSet + def defined(in: Seq[(String,Boolean,Fingerprint)], names: Set[String], isModule: Boolean): Seq[Fingerprint] = + in collect { case (name, mod, print) if names(name) && isModule == mod => print } + + def toFingerprints(d: Discovered): Seq[Fingerprint] = + defined(subclasses, d.baseClasses, d.isModule) ++ + defined(annotations, d.annotations, d.isModule) + + val discovered = Discovery(firsts(subclasses), firsts(annotations))(definitions) + val tests = for( (df, di) <- discovered; fingerprint <- toFingerprints(di) ) yield new TestDefinition(df.name, fingerprint) + val mains = discovered collect { case (df, di) if di.hasMain => df.name } + (tests, mains.toSet) + } + + def showResults(log: Logger, results: (TestResult.Value, Map[String, TestResult.Value])) = + { + import TestResult.{Error, Failed, Passed} + + def select(Tpe: TestResult.Value) = results._2 collect { case (name, Tpe) => name } + + val failures = select(Failed) + val errors = select(Error) + val passed = select(Passed) + + def show(label: String, level: Level.Value, tests: Iterable[String]): Unit = + if(!tests.isEmpty) + { + log.log(level, label) + log.log(level, tests.mkString("\t", "\n\t", "")) + } + + show("Passed tests:", Level.Debug, passed ) + show("Failed tests:", Level.Error, failures) + show("Error during tests:", Level.Error, errors) + + if(!failures.isEmpty || !errors.isEmpty) + error("Tests unsuccessful") + } +} \ No newline at end of file diff --git a/project/build/XSbt.scala b/project/build/XSbt.scala index 264accc9b..41288310e 100644 --- a/project/build/XSbt.scala +++ b/project/build/XSbt.scala @@ -87,7 +87,8 @@ class XSbt(info: ProjectInfo) extends ParentProject(info) with NoCrossPaths val stdTaskSub = testedBase(tasksPath / "standard", "Task System", taskSub, collectionSub, logSub, ioSub, processSub) // The main integration project for sbt. It brings all of the subsystems together, configures them, and provides for overriding conventions. val mainSub = baseProject("main", "Main", - buildSub, compileIncrementalSub, compilerSub, completeSub, discoverySub, ioSub, logSub, processSub, taskSub, stdTaskSub, runSub, trackingSub) + buildSub, compileIncrementalSub, compilerSub, completeSub, discoverySub, + ioSub, logSub, processSub, taskSub, stdTaskSub, runSub, trackingSub, testingSub) // Strictly for bringing implicits and aliases from subsystems into the top-level sbt namespace through a single package object val sbtSub = project(sbtPath, "Simple Build Tool", new Sbt(_), mainSub) // technically, we need a dependency on all of mainSub's dependencies, but we don't do that since this is strictly an integration project diff --git a/tasks/standard/TaskExtra.scala b/tasks/standard/TaskExtra.scala index 5cff9b427..10b0bf861 100644 --- a/tasks/standard/TaskExtra.scala +++ b/tasks/standard/TaskExtra.scala @@ -81,6 +81,9 @@ sealed trait ProcessPipe trait TaskExtra { + final def nop: Task[Unit] = const( () ) + final def const[T](t: T): Task[T] = task(t) + final def cross[T](key: AttributeKey[T])(values: T*): Task[T] = CrossAction( for(v <- values) yield ( AttributeMap.empty put (key, v), task(v) ) ) diff --git a/testing/TestFramework.scala b/testing/TestFramework.scala index e4e692acd..d8a993dbc 100644 --- a/testing/TestFramework.scala +++ b/testing/TestFramework.scala @@ -3,14 +3,15 @@ */ package sbt + import java.io.File import java.net.URLClassLoader import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, SubclassFingerprint, TestFingerprint} import org.scalatools.testing.{Event, EventHandler, Framework, Runner, Runner2, Logger=>TLogger} import classpath.{ClasspathUtilities, DualLoader, FilteredLoader} -object Result extends Enumeration +object TestResult extends Enumeration { - val Error, Passed, Failed = Value + val Passed, Failed, Error = Value } object TestFrameworks @@ -55,7 +56,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq case _ => error("Framework '" + framework + "' does not support test '" + testDefinition + "'") } - final def run(testDefinition: TestDefinition, args: Seq[String]): Result.Value = + final def run(testDefinition: TestDefinition, args: Seq[String]): TestResult.Value = { log.debug("Running " + testDefinition + " with arguments " + args.mkString(", ")) val name = testDefinition.name @@ -73,7 +74,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq safeListenersCall(_.startGroup(name)) try { - val result = runTest().getOrElse(Result.Passed) + val result = runTest().getOrElse(TestResult.Passed) safeListenersCall(_.endGroup(name, result)) result } @@ -81,7 +82,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq { case e => safeListenersCall(_.endGroup(name, e)) - Result.Error + TestResult.Error } } @@ -89,8 +90,6 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq TestFramework.safeForeach(listeners, log)(call) } -final class NamedTestTask(val name: String, action: => Unit) { def run() = action } - object TestFramework { def getTests(framework: Framework): Seq[Fingerprint] = @@ -117,39 +116,28 @@ object TestFramework case _ => false } - import scala.collection.{immutable, Map, Set} - - def testTasks(frameworks: Seq[TestFramework], - classpath: Iterable[Path], - scalaInstance: ScalaInstance, + def testTasks(frameworks: Seq[Framework], + testLoader: ClassLoader, tests: Seq[TestDefinition], log: Logger, listeners: Seq[TestReportListener], - endErrorsEnabled: Boolean, - setup: Iterable[ClassLoader => Unit], - cleanup: Iterable[ClassLoader => Unit], - testArgsByFramework: Map[TestFramework, Seq[String]]): - (Iterable[NamedTestTask], Iterable[NamedTestTask], Iterable[NamedTestTask]) = + testArgsByFramework: Map[Framework, Seq[String]]): + (() => Unit, Iterable[(String, () => TestResult.Value)], TestResult.Value => () => Unit) = { - val (loader, tempDir) = createTestLoader(classpath, scalaInstance) - val arguments = immutable.Map() ++ - ( for(framework <- frameworks; created <- framework.create(loader, log)) yield - (created, testArgsByFramework.getOrElse(framework, Nil)) ) - val cleanTmp = (_: ClassLoader) => IO.delete(tempDir) - - val mappedTests = testMap(arguments.keys.toList, tests, arguments) + val arguments = testArgsByFramework withDefaultValue Nil + val mappedTests = testMap(frameworks, tests, arguments) if(mappedTests.isEmpty) - (new NamedTestTask(TestStartName, None) :: Nil, Nil, new NamedTestTask(TestFinishName, { log.info("No tests to run."); cleanTmp(loader) }) :: Nil ) + (() => (), Nil, _ => () => log.info("No tests to run.") ) else - createTestTasks(loader, mappedTests, log, listeners, endErrorsEnabled, setup, Seq(cleanTmp) ++ cleanup) + createTestTasks(testLoader, mappedTests, log, listeners) } private def testMap(frameworks: Seq[Framework], tests: Seq[TestDefinition], args: Map[Framework, Seq[String]]): - immutable.Map[Framework, (Set[TestDefinition], Seq[String])] = + Map[Framework, (Set[TestDefinition], Seq[String])] = { import scala.collection.mutable.{HashMap, HashSet, Set} val map = new HashMap[Framework, Set[TestDefinition]] - def assignTests(): Unit = + def assignTests() { for(test <- tests if !map.values.exists(_.contains(test))) { @@ -160,70 +148,38 @@ object TestFramework } if(!frameworks.isEmpty) assignTests() - (immutable.Map() ++ map) transform { (framework, tests) => (tests, args(framework)) } + map.toMap transform { (framework, tests) => (tests.toSet, args(framework)) }; } - private def createTasks[T](work: Iterable[T => Unit], baseName: String, input: T) = - work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work(input)) } - private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger, - listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[ClassLoader => Unit], - cleanup: Iterable[ClassLoader => Unit]) = + private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger, listeners: Seq[TestReportListener]) = { - val testsListeners = listeners.filter(_.isInstanceOf[TestsListener]).map(_.asInstanceOf[TestsListener]) - def foreachListenerSafe(f: TestsListener => Unit): Unit = safeForeach(testsListeners, log)(f) + val testsListeners = listeners collect { case tl: TestsListener => tl } + def foreachListenerSafe(f: TestsListener => Unit): () => Unit = () => safeForeach(testsListeners, log)(f) - import Result.{Error,Passed,Failed} - object result - { - private[this] var value: Result.Value = Passed - def apply() = synchronized { value } - def update(v: Result.Value): Unit = synchronized { if(value != Error) value = v } - } - val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup", loader) + import TestResult.{Error,Passed,Failed} + + val startTask = foreachListenerSafe(_.doInit) val testTasks = - tests flatMap { case (framework, (testDefinitions, testArgs)) => + tests.view flatMap { case (framework, (testDefinitions, testArgs)) => val runner = new TestRunner(framework, loader, listeners, log) for(testDefinition <- testDefinitions) yield { - def runTest() = - { - val oldLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(loader) - try { - runner.run(testDefinition, testArgs) match - { - case Error => result() = Error; Some("ERROR occurred during testing.") - case Failed => result() = Failed; Some("Test FAILED") - case _ => None - } - } - finally { - Thread.currentThread.setContextClassLoader(oldLoader) - } - } - new NamedTestTask(testDefinition.name, runTest()) + val runTest = () => withContextLoader(loader) { runner.run(testDefinition, testArgs) } + (testDefinition.name, runTest) } } - def end() = - { - foreachListenerSafe(_.doComplete(result())) - result() match - { - case Error => if(endErrorsEnabled) Some("ERROR occurred during testing.") else None - case Failed => if(endErrorsEnabled) Some("One or more tests FAILED.") else None - case Passed => - { - log.info(" ") - log.info("All tests PASSED.") - None - } - } - } - val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup", loader) - (startTask, testTasks, endTask) + + val endTask = (result: TestResult.Value) => foreachListenerSafe(_.doComplete(result)) + (startTask, testTasks.toList, endTask) } - def createTestLoader(classpath: Iterable[Path], scalaInstance: ScalaInstance): (ClassLoader, Path) = + private[this] def withContextLoader[T](loader: ClassLoader)(eval: => T): T = + { + val oldLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(loader) + try { eval } finally { Thread.currentThread.setContextClassLoader(oldLoader) } + } + def createTestLoader(classpath: Seq[File], scalaInstance: ScalaInstance): ClassLoader = { val filterCompilerLoader = new FilteredLoader(scalaInstance.loader, ScalaCompilerJarPackages) val interfaceFilter = (name: String) => name.startsWith("org.scalatools.testing.") diff --git a/testing/TestReportListener.scala b/testing/TestReportListener.scala index bab4a0cda..0bf36b7b7 100644 --- a/testing/TestReportListener.scala +++ b/testing/TestReportListener.scala @@ -15,7 +15,7 @@ trait TestReportListener /** called if there was an error during test */ def endGroup(name: String, t: Throwable) /** called if test completed */ - def endGroup(name: String, result: Result.Value) + def endGroup(name: String, result: TestResult.Value) /** Used by the test framework for logging test results*/ def contentLogger: Option[TLogger] = None } @@ -25,23 +25,23 @@ trait TestsListener extends TestReportListener /** called once, at beginning. */ def doInit /** called once, at end. */ - def doComplete(finalResult: Result.Value) + def doComplete(finalResult: TestResult.Value) } abstract class TestEvent extends NotNull { - def result: Option[Result.Value] + def result: Option[TestResult.Value] def detail: Seq[TEvent] = Nil } object TestEvent { def apply(events: Seq[TEvent]): TestEvent = { - val overallResult = (Result.Passed /: events) { (sum, event) => + val overallResult = (TestResult.Passed /: events) { (sum, event) => val result = event.result - if(sum == Result.Error || result == TResult.Error) Result.Error - else if(sum == Result.Failed || result == TResult.Failure) Result.Failed - else Result.Passed + if(sum == TestResult.Error || result == TResult.Error) TestResult.Error + else if(sum == TestResult.Failed || result == TResult.Failure) TestResult.Failed + else TestResult.Passed } new TestEvent { val result = Some(overallResult) @@ -76,7 +76,7 @@ class TestLogger(val log: TLogger) extends TestsListener log.trace(t) log.error("Could not run test " + name + ": " + t.toString) } - def endGroup(name: String, result: Result.Value) {} + def endGroup(name: String, result: TestResult.Value) {} protected def count(event: TEvent): Unit = { event.result match @@ -95,15 +95,15 @@ class TestLogger(val log: TLogger) extends TestsListener skipped = 0 } /** called once, at end. */ - def doComplete(finalResult: Result.Value): Unit = + def doComplete(finalResult: TestResult.Value): Unit = { val totalCount = failures + errors + skipped + passed val postfix = ": Total " + totalCount + ", Failed " + failures + ", Errors " + errors + ", Passed " + passed + ", Skipped " + skipped finalResult match { - case Result.Error => log.error("Error" + postfix) - case Result.Passed => log.info("Passed: " + postfix) - case Result.Failed => log.error("Failed: " + postfix) + case TestResult.Error => log.error("Error" + postfix) + case TestResult.Passed => log.info("Passed: " + postfix) + case TestResult.Failed => log.error("Failed: " + postfix) } } override def contentLogger: Option[TLogger] = Some(log)