'test' task

This commit is contained in:
Mark Harrah 2010-11-24 14:03:26 -05:00
parent 9a8c62517f
commit 46a6a1af16
6 changed files with 248 additions and 93 deletions

View File

@ -12,6 +12,7 @@ package sbt
import Types._
import xsbti.api.Definition
import org.scalatools.testing.Framework
import java.io.File
class DefaultProject(val info: ProjectInfo) extends BasicProject
@ -73,6 +74,43 @@ abstract class BasicProject extends TestProject with MultiClasspathProject with
val classpath = cp.map(x => Path.fromFile(x.data))
r.run(mainClass, classpath, in.splitArgs, s.log) foreach error
}
lazy val testFrameworks: Task[Seq[TestFramework]] = task {
import TestFrameworks._
Seq(ScalaCheck, Specs, ScalaTest, ScalaCheckCompat, ScalaTestCompat, SpecsCompat, JUnit)
}
lazy val testLoader: Task[ClassLoader] =
fullClasspath(TestConfig) :^: buildScalaInstance :^: KNil map { case classpath :+: instance :+: HNil =>
TestFramework.createTestLoader(data(classpath), instance)
}
lazy val loadedTestFrameworks: Task[Map[TestFramework,Framework]] =
testFrameworks :^: streams :^: testLoader :^: KNil map { case frameworks :+: s :+: loader :+: HNil =>
frameworks.flatMap( f => f.create(loader, s.log).map( x => (f, x)).toIterable ).toMap
}
lazy val discoverTest: Task[(Seq[TestDefinition], Set[String])] =
loadedTestFrameworks :^: testCompile :^: KNil map { case frameworkMap :+: analysis :+: HNil =>
Test.discover(frameworkMap.values.toSeq, analysis)
}
lazy val definedTests: Task[Seq[TestDefinition]] = discoverTest.map(_._1)
lazy val testOptions: Task[Seq[TestOption]] = task { Nil }
lazy val test = (loadedTestFrameworks :^: testOptions :^: testLoader :^: definedTests :^: streams :^: KNil) flatMap {
case frameworkMap :+: options :+: loader :+: discovered :+: s :+: HNil =>
val toTask = (testTask: NamedTestTask) => task { testTask.run() } named(testTask.name)
def dependsOn(on: Iterable[Task[_]]): Task[Unit] = task { () } dependsOn(on.toSeq : _*)
val (begin, work, end) = Test(frameworkMap, loader, discovered, options, s.log)
val beginTasks = dependsOn( begin.map(toTask) ) // test setup tasks
val workTasks = work.map(w => toTask(w) dependsOn(beginTasks) ) // the actual tests
val endTasks = dependsOn( end.map(toTask) ) // tasks that perform test cleanup and are run regardless of success of tests
dependsOn( workTasks ) doFinally { endTasks }
}
lazy val clean = task {
IO.delete(outputDirectory)
}
@ -107,4 +145,7 @@ abstract class BasicProject extends TestProject with MultiClasspathProject with
lazy val compileInputs: Task[Compile.Inputs] = compileInputsTask(Configurations.Compile, "src" / "main", buildScalaInstance) named(name + "/compile-inputs")
lazy val compile: Task[Analysis] = compileTask(compileInputs) named(name + "/compile")
lazy val testCompileInputs: Task[Compile.Inputs] = compileInputsTask(Configurations.Test, "src" / "test", buildScalaInstance) named(name + "/test-inputs")
lazy val testCompile: Task[Analysis] = compileTask(testCompileInputs) named(name + "/test")
}

154
main/Test.scala Normal file
View File

@ -0,0 +1,154 @@
/* sbt -- Simple Build Tool
* Copyright 2010 Mark Harrah
*/
package sbt
import std._
import compile.{Discovered,Discovery}
import inc.Analysis
import TaskExtra._
import Types._
import xsbti.api.Definition
import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, Framework, SubclassFingerprint}
import collection.mutable
import java.io.File
sealed trait TestOption
object Test
{
type Output = (TestResult.Value, Map[String,TestResult.Value])
final case class Setup(setup: ClassLoader => Unit) extends TestOption
def Setup(setup: () => Unit) = new Setup(_ => setup())
final case class Cleanup(cleanup: ClassLoader => Unit) extends TestOption
def Cleanup(setup: () => Unit) = new Cleanup(_ => setup())
final case class Exclude(tests: Iterable[String]) extends TestOption
final case class Listeners(listeners: Iterable[TestReportListener]) extends TestOption
final case class Filter(filterTest: String => Boolean) extends TestOption
// args for all frameworks
def Argument(args: String*): Argument = Argument(None, args.toList)
// args for a particular test framework
def Argument(tf: TestFramework, args: String*): Argument = Argument(Some(tf), args.toList)
// None means apply to all, Some(tf) means apply to a particular framework only.
final case class Argument(framework: Option[TestFramework], args: List[String]) extends TestOption
def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], log: Logger) =
{
import mutable.{HashSet, ListBuffer, Map, Set}
val testFilters = new ListBuffer[String => Boolean]
val excludeTestsSet = new HashSet[String]
val setup, cleanup = new ListBuffer[ClassLoader => Unit]
val testListeners = new ListBuffer[TestReportListener]
val testArgsByFramework = Map[Framework, ListBuffer[String]]()
def frameworkArgs(framework: Framework, args: Seq[String]): Unit =
testArgsByFramework.getOrElseUpdate(framework, new ListBuffer[String]) ++= args
for(option <- options)
{
option match
{
case Filter(include) => testFilters += include
case Exclude(exclude) => excludeTestsSet ++= exclude
case Listeners(listeners) => testListeners ++= listeners
case Setup(setupFunction) => setup += setupFunction
case Cleanup(cleanupFunction) => cleanup += cleanupFunction
/**
* There are two cases here.
* The first handles TestArguments in the project file, which
* might have a TestFramework specified.
* The second handles arguments to be applied to all test frameworks.
* -- arguments from the project file that didnt have a framework specified
* -- command line arguments (ex: test-only someClass -- someArg)
* (currently, command line args must be passed to all frameworks)
*/
case Argument(Some(framework), args) => frameworkArgs(frameworks(framework), args)
case Argument(None, args) => frameworks.values.foreach { f => frameworkArgs(f, args) }
}
}
if(excludeTestsSet.size > 0)
log.debug(excludeTestsSet.mkString("Excluding tests: \n\t", "\n\t", ""))
def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.name) && testFilters.forall(filter => filter(test.name))
val tests = discovered.filter(includeTest).toSet.toSeq
val arguments = testArgsByFramework.map { case (k,v) => (k, v.toList) } toMap;
testTask(frameworks.values.toSeq, testLoader, tests, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments)
}
// (overall result, individual results)
def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition],
userSetup: Iterable[ClassLoader => Unit], userCleanup: Iterable[ClassLoader => Unit],
log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]]): Task[Output] =
{
def fj(actions: Iterable[() => Unit]): Task[Unit] = nop.dependsOn( actions.toSeq.fork( _() ) : _*)
def partApp(actions: Iterable[ClassLoader => Unit]) = actions.toSeq map {a => () => a(loader) }
val (frameworkSetup, runnables, frameworkCleanup) =
TestFramework.testTasks(frameworks, loader, tests, log, testListeners, arguments)
val setupTasks = fj(partApp(userSetup) :+ frameworkSetup)
val mainTasks = runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name }
mainTasks.toSeq.join map processResults flatMap { results =>
val cleanupTasks = fj(partApp(userCleanup) :+ frameworkCleanup(results._1))
cleanupTasks map { _ => results }
}
}
def processResults(results: Iterable[(String, TestResult.Value)]): (TestResult.Value, Map[String, TestResult.Value]) =
(overall(results.map(_._2)), results.toMap)
def overall(results: Iterable[TestResult.Value]): TestResult.Value =
(TestResult.Passed /: results) { (acc, result) => if(acc.id < result.id) result else acc }
def discover(frameworks: Seq[Framework], analysis: Analysis): (Seq[TestDefinition], Set[String]) =
discover(frameworks.flatMap(_.tests), allDefs(analysis))
def allDefs(analysis: Analysis) = analysis.apis.internal.values.flatMap(_.definitions).toSeq
def discover(fingerprints: Seq[Fingerprint], definitions: Seq[Definition]): (Seq[TestDefinition], Set[String]) =
{
val subclasses = fingerprints collect { case sub: SubclassFingerprint => (sub.superClassName, sub.isModule, sub) };
val annotations = fingerprints collect { case ann: AnnotatedFingerprint => (ann.annotationName, ann.isModule, ann) };
def firsts[A,B,C](s: Seq[(A,B,C)]): Set[A] = s.map(_._1).toSet
def defined(in: Seq[(String,Boolean,Fingerprint)], names: Set[String], isModule: Boolean): Seq[Fingerprint] =
in collect { case (name, mod, print) if names(name) && isModule == mod => print }
def toFingerprints(d: Discovered): Seq[Fingerprint] =
defined(subclasses, d.baseClasses, d.isModule) ++
defined(annotations, d.annotations, d.isModule)
val discovered = Discovery(firsts(subclasses), firsts(annotations))(definitions)
val tests = for( (df, di) <- discovered; fingerprint <- toFingerprints(di) ) yield new TestDefinition(df.name, fingerprint)
val mains = discovered collect { case (df, di) if di.hasMain => df.name }
(tests, mains.toSet)
}
def showResults(log: Logger, results: (TestResult.Value, Map[String, TestResult.Value])) =
{
import TestResult.{Error, Failed, Passed}
def select(Tpe: TestResult.Value) = results._2 collect { case (name, Tpe) => name }
val failures = select(Failed)
val errors = select(Error)
val passed = select(Passed)
def show(label: String, level: Level.Value, tests: Iterable[String]): Unit =
if(!tests.isEmpty)
{
log.log(level, label)
log.log(level, tests.mkString("\t", "\n\t", ""))
}
show("Passed tests:", Level.Debug, passed )
show("Failed tests:", Level.Error, failures)
show("Error during tests:", Level.Error, errors)
if(!failures.isEmpty || !errors.isEmpty)
error("Tests unsuccessful")
}
}

View File

@ -87,7 +87,8 @@ class XSbt(info: ProjectInfo) extends ParentProject(info) with NoCrossPaths
val stdTaskSub = testedBase(tasksPath / "standard", "Task System", taskSub, collectionSub, logSub, ioSub, processSub)
// The main integration project for sbt. It brings all of the subsystems together, configures them, and provides for overriding conventions.
val mainSub = baseProject("main", "Main",
buildSub, compileIncrementalSub, compilerSub, completeSub, discoverySub, ioSub, logSub, processSub, taskSub, stdTaskSub, runSub, trackingSub)
buildSub, compileIncrementalSub, compilerSub, completeSub, discoverySub,
ioSub, logSub, processSub, taskSub, stdTaskSub, runSub, trackingSub, testingSub)
// Strictly for bringing implicits and aliases from subsystems into the top-level sbt namespace through a single package object
val sbtSub = project(sbtPath, "Simple Build Tool", new Sbt(_), mainSub) // technically, we need a dependency on all of mainSub's dependencies, but we don't do that since this is strictly an integration project

View File

@ -81,6 +81,9 @@ sealed trait ProcessPipe
trait TaskExtra
{
final def nop: Task[Unit] = const( () )
final def const[T](t: T): Task[T] = task(t)
final def cross[T](key: AttributeKey[T])(values: T*): Task[T] =
CrossAction( for(v <- values) yield ( AttributeMap.empty put (key, v), task(v) ) )

View File

@ -3,14 +3,15 @@
*/
package sbt
import java.io.File
import java.net.URLClassLoader
import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, SubclassFingerprint, TestFingerprint}
import org.scalatools.testing.{Event, EventHandler, Framework, Runner, Runner2, Logger=>TLogger}
import classpath.{ClasspathUtilities, DualLoader, FilteredLoader}
object Result extends Enumeration
object TestResult extends Enumeration
{
val Error, Passed, Failed = Value
val Passed, Failed, Error = Value
}
object TestFrameworks
@ -55,7 +56,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq
case _ => error("Framework '" + framework + "' does not support test '" + testDefinition + "'")
}
final def run(testDefinition: TestDefinition, args: Seq[String]): Result.Value =
final def run(testDefinition: TestDefinition, args: Seq[String]): TestResult.Value =
{
log.debug("Running " + testDefinition + " with arguments " + args.mkString(", "))
val name = testDefinition.name
@ -73,7 +74,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq
safeListenersCall(_.startGroup(name))
try
{
val result = runTest().getOrElse(Result.Passed)
val result = runTest().getOrElse(TestResult.Passed)
safeListenersCall(_.endGroup(name, result))
result
}
@ -81,7 +82,7 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq
{
case e =>
safeListenersCall(_.endGroup(name, e))
Result.Error
TestResult.Error
}
}
@ -89,8 +90,6 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq
TestFramework.safeForeach(listeners, log)(call)
}
final class NamedTestTask(val name: String, action: => Unit) { def run() = action }
object TestFramework
{
def getTests(framework: Framework): Seq[Fingerprint] =
@ -117,39 +116,28 @@ object TestFramework
case _ => false
}
import scala.collection.{immutable, Map, Set}
def testTasks(frameworks: Seq[TestFramework],
classpath: Iterable[Path],
scalaInstance: ScalaInstance,
def testTasks(frameworks: Seq[Framework],
testLoader: ClassLoader,
tests: Seq[TestDefinition],
log: Logger,
listeners: Seq[TestReportListener],
endErrorsEnabled: Boolean,
setup: Iterable[ClassLoader => Unit],
cleanup: Iterable[ClassLoader => Unit],
testArgsByFramework: Map[TestFramework, Seq[String]]):
(Iterable[NamedTestTask], Iterable[NamedTestTask], Iterable[NamedTestTask]) =
testArgsByFramework: Map[Framework, Seq[String]]):
(() => Unit, Iterable[(String, () => TestResult.Value)], TestResult.Value => () => Unit) =
{
val (loader, tempDir) = createTestLoader(classpath, scalaInstance)
val arguments = immutable.Map() ++
( for(framework <- frameworks; created <- framework.create(loader, log)) yield
(created, testArgsByFramework.getOrElse(framework, Nil)) )
val cleanTmp = (_: ClassLoader) => IO.delete(tempDir)
val mappedTests = testMap(arguments.keys.toList, tests, arguments)
val arguments = testArgsByFramework withDefaultValue Nil
val mappedTests = testMap(frameworks, tests, arguments)
if(mappedTests.isEmpty)
(new NamedTestTask(TestStartName, None) :: Nil, Nil, new NamedTestTask(TestFinishName, { log.info("No tests to run."); cleanTmp(loader) }) :: Nil )
(() => (), Nil, _ => () => log.info("No tests to run.") )
else
createTestTasks(loader, mappedTests, log, listeners, endErrorsEnabled, setup, Seq(cleanTmp) ++ cleanup)
createTestTasks(testLoader, mappedTests, log, listeners)
}
private def testMap(frameworks: Seq[Framework], tests: Seq[TestDefinition], args: Map[Framework, Seq[String]]):
immutable.Map[Framework, (Set[TestDefinition], Seq[String])] =
Map[Framework, (Set[TestDefinition], Seq[String])] =
{
import scala.collection.mutable.{HashMap, HashSet, Set}
val map = new HashMap[Framework, Set[TestDefinition]]
def assignTests(): Unit =
def assignTests()
{
for(test <- tests if !map.values.exists(_.contains(test)))
{
@ -160,70 +148,38 @@ object TestFramework
}
if(!frameworks.isEmpty)
assignTests()
(immutable.Map() ++ map) transform { (framework, tests) => (tests, args(framework)) }
map.toMap transform { (framework, tests) => (tests.toSet, args(framework)) };
}
private def createTasks[T](work: Iterable[T => Unit], baseName: String, input: T) =
work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work(input)) }
private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger,
listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[ClassLoader => Unit],
cleanup: Iterable[ClassLoader => Unit]) =
private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger, listeners: Seq[TestReportListener]) =
{
val testsListeners = listeners.filter(_.isInstanceOf[TestsListener]).map(_.asInstanceOf[TestsListener])
def foreachListenerSafe(f: TestsListener => Unit): Unit = safeForeach(testsListeners, log)(f)
val testsListeners = listeners collect { case tl: TestsListener => tl }
def foreachListenerSafe(f: TestsListener => Unit): () => Unit = () => safeForeach(testsListeners, log)(f)
import Result.{Error,Passed,Failed}
object result
{
private[this] var value: Result.Value = Passed
def apply() = synchronized { value }
def update(v: Result.Value): Unit = synchronized { if(value != Error) value = v }
}
val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup", loader)
import TestResult.{Error,Passed,Failed}
val startTask = foreachListenerSafe(_.doInit)
val testTasks =
tests flatMap { case (framework, (testDefinitions, testArgs)) =>
tests.view flatMap { case (framework, (testDefinitions, testArgs)) =>
val runner = new TestRunner(framework, loader, listeners, log)
for(testDefinition <- testDefinitions) yield
{
def runTest() =
{
val oldLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(loader)
try {
runner.run(testDefinition, testArgs) match
{
case Error => result() = Error; Some("ERROR occurred during testing.")
case Failed => result() = Failed; Some("Test FAILED")
case _ => None
}
}
finally {
Thread.currentThread.setContextClassLoader(oldLoader)
}
}
new NamedTestTask(testDefinition.name, runTest())
val runTest = () => withContextLoader(loader) { runner.run(testDefinition, testArgs) }
(testDefinition.name, runTest)
}
}
def end() =
{
foreachListenerSafe(_.doComplete(result()))
result() match
{
case Error => if(endErrorsEnabled) Some("ERROR occurred during testing.") else None
case Failed => if(endErrorsEnabled) Some("One or more tests FAILED.") else None
case Passed =>
{
log.info(" ")
log.info("All tests PASSED.")
None
}
}
}
val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup", loader)
(startTask, testTasks, endTask)
val endTask = (result: TestResult.Value) => foreachListenerSafe(_.doComplete(result))
(startTask, testTasks.toList, endTask)
}
def createTestLoader(classpath: Iterable[Path], scalaInstance: ScalaInstance): (ClassLoader, Path) =
private[this] def withContextLoader[T](loader: ClassLoader)(eval: => T): T =
{
val oldLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(loader)
try { eval } finally { Thread.currentThread.setContextClassLoader(oldLoader) }
}
def createTestLoader(classpath: Seq[File], scalaInstance: ScalaInstance): ClassLoader =
{
val filterCompilerLoader = new FilteredLoader(scalaInstance.loader, ScalaCompilerJarPackages)
val interfaceFilter = (name: String) => name.startsWith("org.scalatools.testing.")

View File

@ -15,7 +15,7 @@ trait TestReportListener
/** called if there was an error during test */
def endGroup(name: String, t: Throwable)
/** called if test completed */
def endGroup(name: String, result: Result.Value)
def endGroup(name: String, result: TestResult.Value)
/** Used by the test framework for logging test results*/
def contentLogger: Option[TLogger] = None
}
@ -25,23 +25,23 @@ trait TestsListener extends TestReportListener
/** called once, at beginning. */
def doInit
/** called once, at end. */
def doComplete(finalResult: Result.Value)
def doComplete(finalResult: TestResult.Value)
}
abstract class TestEvent extends NotNull
{
def result: Option[Result.Value]
def result: Option[TestResult.Value]
def detail: Seq[TEvent] = Nil
}
object TestEvent
{
def apply(events: Seq[TEvent]): TestEvent =
{
val overallResult = (Result.Passed /: events) { (sum, event) =>
val overallResult = (TestResult.Passed /: events) { (sum, event) =>
val result = event.result
if(sum == Result.Error || result == TResult.Error) Result.Error
else if(sum == Result.Failed || result == TResult.Failure) Result.Failed
else Result.Passed
if(sum == TestResult.Error || result == TResult.Error) TestResult.Error
else if(sum == TestResult.Failed || result == TResult.Failure) TestResult.Failed
else TestResult.Passed
}
new TestEvent {
val result = Some(overallResult)
@ -76,7 +76,7 @@ class TestLogger(val log: TLogger) extends TestsListener
log.trace(t)
log.error("Could not run test " + name + ": " + t.toString)
}
def endGroup(name: String, result: Result.Value) {}
def endGroup(name: String, result: TestResult.Value) {}
protected def count(event: TEvent): Unit =
{
event.result match
@ -95,15 +95,15 @@ class TestLogger(val log: TLogger) extends TestsListener
skipped = 0
}
/** called once, at end. */
def doComplete(finalResult: Result.Value): Unit =
def doComplete(finalResult: TestResult.Value): Unit =
{
val totalCount = failures + errors + skipped + passed
val postfix = ": Total " + totalCount + ", Failed " + failures + ", Errors " + errors + ", Passed " + passed + ", Skipped " + skipped
finalResult match
{
case Result.Error => log.error("Error" + postfix)
case Result.Passed => log.info("Passed: " + postfix)
case Result.Failed => log.error("Failed: " + postfix)
case TestResult.Error => log.error("Error" + postfix)
case TestResult.Passed => log.info("Passed: " + postfix)
case TestResult.Failed => log.error("Failed: " + postfix)
}
}
override def contentLogger: Option[TLogger] = Some(log)