From 1e30cb820363b0876eb72c50626dfe24ce5ade99 Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Mon, 9 Nov 2009 23:29:19 -0500 Subject: [PATCH] Use correct ClassLoader for tests (build Scala version not definition Scala version) --- src/main/scala/sbt/DefaultProject.scala | 2 +- src/main/scala/sbt/ScalaProject.scala | 2 +- src/main/scala/sbt/TestFramework.scala | 17 ++++++++++------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/main/scala/sbt/DefaultProject.scala b/src/main/scala/sbt/DefaultProject.scala index 21d7be00f..d0f9b9e93 100644 --- a/src/main/scala/sbt/DefaultProject.scala +++ b/src/main/scala/sbt/DefaultProject.scala @@ -193,7 +193,7 @@ abstract class BasicScalaProject extends ScalaProject with BasicDependencyProjec def compileOrder = BasicScalaProject.this.compileOrder protected def testClassNames(frameworks: Seq[TestFramework]) = { - val loader = new URLClassLoader(classpath.get.map(_.asURL).toSeq.toArray, getClass.getClassLoader) + val loader = TestFramework.createTestLoader(classpath.get, buildScalaInstance.loader) def getTestNames(framework: TestFramework): Seq[String] = framework.create(loader, log).toList.flatMap(_.tests.map(_.superClassName)) frameworks.flatMap(getTestNames) diff --git a/src/main/scala/sbt/ScalaProject.scala b/src/main/scala/sbt/ScalaProject.scala index dee6d3215..31d7101bb 100644 --- a/src/main/scala/sbt/ScalaProject.scala +++ b/src/main/scala/sbt/ScalaProject.scala @@ -276,7 +276,7 @@ trait ScalaProject extends SimpleScalaProject with FileTasks with MultiTaskProje } def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.testClassName) && testFilters.forall(filter => filter(test.testClassName)) val tests = HashSet.empty[TestDefinition] ++ analysis.allTests.filter(includeTest) - TestFramework.testTasks(frameworks, classpath.get, tests.toSeq, log, testListeners.readOnly, false, setup.readOnly, cleanup.readOnly) + TestFramework.testTasks(frameworks, classpath.get, buildScalaInstance.loader, tests.toSeq, log, testListeners.readOnly, false, setup.readOnly, cleanup.readOnly) } private def flatten[T](i: Iterable[Iterable[T]]) = i.flatMap(x => x) diff --git a/src/main/scala/sbt/TestFramework.scala b/src/main/scala/sbt/TestFramework.scala index 84368661f..8c36acc8a 100644 --- a/src/main/scala/sbt/TestFramework.scala +++ b/src/main/scala/sbt/TestFramework.scala @@ -61,10 +61,10 @@ final class TestRunner(framework: Framework, loader: ClassLoader, listeners: Seq final class NamedTestTask(val name: String, action: => Option[String]) extends NotNull { def run() = action } object TestFramework { - def runTests(frameworks: Seq[TestFramework], classpath: Iterable[Path], tests: Seq[TestDefinition], log: Logger, + def runTests(frameworks: Seq[TestFramework], classpath: Iterable[Path], scalaLoader: ClassLoader, tests: Seq[TestDefinition], log: Logger, listeners: Seq[TestReportListener]) = { - val (start, runTests, end) = testTasks(frameworks, classpath, tests, log, listeners, true, Nil, Nil) + val (start, runTests, end) = testTasks(frameworks, classpath, scalaLoader, tests, log, listeners, true, Nil, Nil) def run(tasks: Iterable[NamedTestTask]) = tasks.foreach(_.run()) run(start) run(runTests) @@ -81,11 +81,11 @@ object TestFramework import scala.collection.{Map, Set} - def testTasks(frameworks: Seq[TestFramework], classpath: Iterable[Path], tests: Seq[TestDefinition], log: Logger, + def testTasks(frameworks: Seq[TestFramework], classpath: Iterable[Path], scalaLoader: ClassLoader, tests: Seq[TestDefinition], log: Logger, listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[() => Option[String]], cleanup: Iterable[() => Option[String]]): (Iterable[NamedTestTask], Iterable[NamedTestTask], Iterable[NamedTestTask]) = { - val loader = createTestLoader(classpath) + val loader = createTestLoader(classpath, scalaLoader) val rawFrameworks = frameworks.flatMap(_.create(loader, log)) val mappedTests = testMap(rawFrameworks, tests) if(mappedTests.isEmpty) @@ -175,9 +175,12 @@ object TestFramework val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup") (startTask, testTasks, endTask) } - private def createTestLoader(classpath: Iterable[Path]): ClassLoader = + def createTestLoader(classpath: Iterable[Path], scalaLoader: ClassLoader): ClassLoader = { - val filterCompilerLoader = new FilteredLoader(getClass.getClassLoader, ScalaCompilerJarPackages) - new URLClassLoader(classpath.map(_.asURL).toSeq.toArray, filterCompilerLoader) + val filterCompilerLoader = new FilteredLoader(scalaLoader, ScalaCompilerJarPackages) + val interfaceFilter = (name: String) => name.startsWith("org.scalatools.testing.") + val notInterfaceFilter = (name: String) => !interfaceFilter(name) + val dual = new xsbt.DualLoader(filterCompilerLoader, notInterfaceFilter, x => true, getClass.getClassLoader, interfaceFilter, x => false) + new URLClassLoader(classpath.map(_.asURL).toSeq.toArray, dual) } }