From 1a122d380f045ee3877a2f5f243c4db300c73ac3 Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Mon, 13 Sep 2010 19:43:37 -0400 Subject: [PATCH] allow setup, cleanup functions to access ClassLoader used for testing --- sbt/src/main/scala/sbt/ScalaProject.scala | 10 +++++++--- testing/TestFramework.scala | 18 +++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sbt/src/main/scala/sbt/ScalaProject.scala b/sbt/src/main/scala/sbt/ScalaProject.scala index e2eacad59..9a3287aa9 100644 --- a/sbt/src/main/scala/sbt/ScalaProject.scala +++ b/sbt/src/main/scala/sbt/ScalaProject.scala @@ -69,8 +69,12 @@ trait ScalaProject extends SimpleScalaProject with FileTasks with MultiTaskProje trait PackageOption extends ActionOption trait TestOption extends ActionOption - case class TestSetup(setup: () => Option[String]) extends TestOption - case class TestCleanup(cleanup: () => Option[String]) extends TestOption + case class TestSetup(setup: ClassLoader => Unit) extends TestOption { + def this(setup: () => Unit) = this(_ => setup()) + } + case class TestCleanup(cleanup: ClassLoader => Unit) extends TestOption { + def this(setup: () => Unit) = this(_ => setup()) + } case class ExcludeTests(tests: Iterable[String]) extends TestOption case class TestListeners(listeners: Iterable[TestReportListener]) extends TestOption case class TestFilter(filterTest: String => Boolean) extends TestOption @@ -275,7 +279,7 @@ trait ScalaProject extends SimpleScalaProject with FileTasks with MultiTaskProje val testFilters = new ListBuffer[String => Boolean] val excludeTestsSet = new HashSet[String] - val setup, cleanup = new ListBuffer[() => Option[String]] + val setup, cleanup = new ListBuffer[ClassLoader => Unit] val testListeners = new ListBuffer[TestReportListener] val testArgsByFramework = Map[TestFramework, ListBuffer[String]]() def frameworkArgs(framework: TestFramework): ListBuffer[String] = diff --git a/testing/TestFramework.scala b/testing/TestFramework.scala index 5e0932e1a..e313f2001 100644 --- a/testing/TestFramework.scala +++ b/testing/TestFramework.scala @@ -126,8 +126,8 @@ object TestFramework log: Logger, listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, - setup: Iterable[() => Unit], - cleanup: Iterable[() => Unit], + setup: Iterable[ClassLoader => Unit], + cleanup: Iterable[ClassLoader => Unit], testArgsByFramework: Map[TestFramework, Seq[String]]): (Iterable[NamedTestTask], Iterable[NamedTestTask], Iterable[NamedTestTask]) = { @@ -135,7 +135,7 @@ object TestFramework val arguments = immutable.Map() ++ ( for(framework <- frameworks; created <- framework.create(loader, log)) yield (created, testArgsByFramework.getOrElse(framework, Nil)) ) - val cleanTmp = () => IO.delete(tempDir) + val cleanTmp = (_: ClassLoader) => IO.delete(tempDir) val mappedTests = testMap(arguments.keys.toList, tests, arguments) if(mappedTests.isEmpty) @@ -162,12 +162,12 @@ object TestFramework assignTests() (immutable.Map() ++ map) transform { (framework, tests) => (tests, args(framework)) } } - private def createTasks(work: Iterable[() => Unit], baseName: String) = - work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work()) } + private def createTasks[T](work: Iterable[T => Unit], baseName: String, input: T) = + work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work(input)) } private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger, - listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[() => Unit], - cleanup: Iterable[() => Unit]) = + listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[ClassLoader => Unit], + cleanup: Iterable[ClassLoader => Unit]) = { val testsListeners = listeners.filter(_.isInstanceOf[TestsListener]).map(_.asInstanceOf[TestsListener]) def foreachListenerSafe(f: TestsListener => Unit): Unit = safeForeach(testsListeners, log)(f) @@ -179,7 +179,7 @@ object TestFramework def apply() = synchronized { value } def update(v: Result.Value): Unit = synchronized { if(value != Error) value = v } } - val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup") + val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup", loader) val testTasks = tests flatMap { case (framework, (testDefinitions, testArgs)) => @@ -220,7 +220,7 @@ object TestFramework } } } - val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup") + val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup", loader) (startTask, testTasks, endTask) } def createTestLoader(classpath: Iterable[Path], scalaInstance: ScalaInstance): (ClassLoader, Path) =