diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 7f955bb2a..b68e9e2c3 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -647,7 +647,7 @@ object Defaults extends BuildCommon def consoleTask(classpath: TaskKey[Classpath], task: TaskKey[_]): Initialize[Task[Unit]] = (compilers in task, classpath in task, scalacOptions in task, initialCommands in task, cleanupCommands in task, taskTemporaryDirectory in task, scalaInstance in task, streams) map { (cs, cp, options, initCommands, cleanup, temp, si, s) => - val loader = sbt.classpath.ClasspathUtilities.makeLoader(data(cp), si.loader, si, IO.createUniqueDirectory(temp)) + val loader = sbt.classpath.ClasspathUtilities.makeLoader(data(cp), si, IO.createUniqueDirectory(temp)) (new Console(cs.scalac))(data(cp), options, loader, initCommands, cleanup)()(s.log).foreach(msg => error(msg)) println() } diff --git a/run/src/main/scala/sbt/Run.scala b/run/src/main/scala/sbt/Run.scala index a11c4cc00..29de50e7b 100644 --- a/run/src/main/scala/sbt/Run.scala +++ b/run/src/main/scala/sbt/Run.scala @@ -59,7 +59,7 @@ class Run(instance: ScalaInstance, trapExit: Boolean, nativeTmp: File) extends S private def run0(mainClassName: String, classpath: Seq[File], options: Seq[String], log: Logger) { log.debug(" Classpath:\n\t" + classpath.mkString("\n\t")) - val loader = ClasspathUtilities.makeLoader(classpath, instance.loader, instance, nativeTmp) + val loader = ClasspathUtilities.makeLoader(classpath, instance, nativeTmp) val main = getMainMethod(mainClassName, loader) invokeMain(loader, main, options) } diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index 13c268b40..7831e8f24 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -204,11 +204,11 @@ object TestFramework } def createTestLoader(classpath: Seq[File], scalaInstance: ScalaInstance, tempDir: File): ClassLoader = { - val declaresCompiler = classpath.exists(_.getName contains "scala-compiler") - val filterCompilerLoader = if(declaresCompiler) scalaInstance.loader else new FilteredLoader(scalaInstance.loader, ScalaCompilerJarPackages) + val interfaceJar = IO.classLocationFile(classOf[org.scalatools.testing.Framework]) val interfaceFilter = (name: String) => name.startsWith("org.scalatools.testing.") val notInterfaceFilter = (name: String) => !interfaceFilter(name) - val dual = new DualLoader(filterCompilerLoader, notInterfaceFilter, x => true, getClass.getClassLoader, interfaceFilter, x => false) - ClasspathUtilities.makeLoader(classpath, dual, scalaInstance, tempDir) + val dual = new DualLoader(scalaInstance.loader, notInterfaceFilter, x => true, getClass.getClassLoader, interfaceFilter, x => false) + val main = ClasspathUtilities.makeLoader(classpath, dual, scalaInstance, tempDir) + ClasspathUtilities.filterByClasspath(interfaceJar +: classpath, main) } } diff --git a/util/classpath/src/main/scala/sbt/classpath/ClassLoaders.scala b/util/classpath/src/main/scala/sbt/classpath/ClassLoaders.scala index 427250153..507b506a2 100644 --- a/util/classpath/src/main/scala/sbt/classpath/ClassLoaders.scala +++ b/util/classpath/src/main/scala/sbt/classpath/ClassLoaders.scala @@ -5,7 +5,8 @@ package sbt package classpath import java.io.File -import java.net.{URI, URL, URLClassLoader} +import java.net.{URL, URLClassLoader} +import annotation.tailrec /** This is a starting point for defining a custom ClassLoader. Override 'doLoadClass' to define * loading a class that has not yet been loaded.*/ @@ -33,7 +34,7 @@ abstract class LoaderBase(urls: Seq[URL], parent: ClassLoader) extends URLClassL } /** Searches self first before delegating to the parent.*/ -class SelfFirstLoader(classpath: Seq[URL], parent: ClassLoader) extends LoaderBase(classpath, parent) +final class SelfFirstLoader(classpath: Seq[URL], parent: ClassLoader) extends LoaderBase(classpath, parent) { @throws(classOf[ClassNotFoundException]) override final def doLoadClass(className: String): Class[_] = @@ -43,10 +44,51 @@ class SelfFirstLoader(classpath: Seq[URL], parent: ClassLoader) extends LoaderBa } } +/** Doesn't load any classes itself, but instead verifies that all classes loaded through `parent` either come from `root` or `classpath`.*/ +final class ClasspathFilter(parent: ClassLoader, root: ClassLoader, classpath: Set[File]) extends ClassLoader(parent) +{ + override def loadClass(className: String, resolve: Boolean): Class[_] = + { + val c = super.loadClass(className, resolve) + if(includeLoader(c.getClassLoader, root) || fromClasspath(c)) + c + else + throw new ClassNotFoundException(className) + } + private[this] def fromClasspath(c: Class[_]): Boolean = + try { onClasspath(IO.classLocation(c)) } + catch { case e: RuntimeException => false } + + private[this] def onClasspath(src: URL): Boolean = + (src eq null) || ( + IO.urlAsFile(src) match { + case Some(f) => classpath(f) + case None => false + } + ) + + override def getResource(name: String): URL = { + val u = super.getResource(name) + if(onClasspath(u)) u else null + } + + override def getResources(name: String): java.util.Enumeration[URL] = + { + import collection.convert.WrapAsScala.{enumerationAsScalaIterator => asIt} + import collection.convert.WrapAsJava.{asJavaEnumeration => asEn} + val us = super.getResources(name) + if(us ne null) asEn(asIt(us).filter(onClasspath)) else null + } + + @tailrec private[this] def includeLoader(c: ClassLoader, base: ClassLoader): Boolean = + (base ne null) && ( + (c eq base) || includeLoader(c, base.getParent) + ) +} /** Delegates class loading to `parent` for all classes included by `filter`. An attempt to load classes excluded by `filter` * results in a `ClassNotFoundException`.*/ -class FilteredLoader(parent: ClassLoader, filter: ClassFilter) extends ClassLoader(parent) +final class FilteredLoader(parent: ClassLoader, filter: ClassFilter) extends ClassLoader(parent) { require(parent != null) // included because a null parent is legitimate in Java def this(parent: ClassLoader, excludePackages: Iterable[String]) = this(parent, new ExcludePackagesFilter(excludePackages)) @@ -69,11 +111,11 @@ abstract class PackageFilter(packages: Iterable[String]) extends ClassFilter require(packages.forall(_.endsWith("."))) protected final def matches(className: String): Boolean = packages.exists(className.startsWith) } -class ExcludePackagesFilter(exclude: Iterable[String]) extends PackageFilter(exclude) +final class ExcludePackagesFilter(exclude: Iterable[String]) extends PackageFilter(exclude) { def include(className: String): Boolean = !matches(className) } -class IncludePackagesFilter(include: Iterable[String]) extends PackageFilter(include) +final class IncludePackagesFilter(include: Iterable[String]) extends PackageFilter(include) { def include(className: String): Boolean = matches(className) } diff --git a/util/classpath/src/main/scala/sbt/classpath/ClasspathUtilities.scala b/util/classpath/src/main/scala/sbt/classpath/ClasspathUtilities.scala index 2c055da21..c23b9f414 100644 --- a/util/classpath/src/main/scala/sbt/classpath/ClasspathUtilities.scala +++ b/util/classpath/src/main/scala/sbt/classpath/ClasspathUtilities.scala @@ -42,12 +42,13 @@ object ClasspathUtilities if (systemLoader ne null) parent(systemLoader) else parent(getClass.getClassLoader) } + lazy val xsbtiLoader = classOf[xsbti.Launcher].getClassLoader final val AppClassPath = "app.class.path" final val BootClassPath = "boot.class.path" def createClasspathResources(classpath: Seq[File], instance: ScalaInstance): Map[String,String] = - createClasspathResources(classpath ++ instance.jars, instance.jars) + createClasspathResources(classpath, instance.jars) def createClasspathResources(appPaths: Seq[File], bootPaths: Seq[File]): Map[String, String] = { @@ -55,13 +56,19 @@ object ClasspathUtilities Map( make(AppClassPath, appPaths), make(BootClassPath, bootPaths) ) } - def makeLoader[T](classpath: Seq[File], instance: ScalaInstance): ClassLoader = - makeLoader(classpath, instance.loader, instance) + private[sbt] def filterByClasspath(classpath: Seq[File], loader: ClassLoader): ClassLoader = + new ClasspathFilter(loader, xsbtiLoader, classpath.toSet) - def makeLoader[T](classpath: Seq[File], parent: ClassLoader, instance: ScalaInstance): ClassLoader = + def makeLoader(classpath: Seq[File], instance: ScalaInstance): ClassLoader = + filterByClasspath(classpath, makeLoader(classpath, instance.loader, instance)) + + def makeLoader(classpath: Seq[File], instance: ScalaInstance, nativeTemp: File): ClassLoader = + filterByClasspath(classpath, makeLoader(classpath, instance.loader, instance, nativeTemp)) + + def makeLoader(classpath: Seq[File], parent: ClassLoader, instance: ScalaInstance): ClassLoader = toLoader(classpath, parent, createClasspathResources(classpath, instance)) - def makeLoader[T](classpath: Seq[File], parent: ClassLoader, instance: ScalaInstance, nativeTemp: File): ClassLoader = + def makeLoader(classpath: Seq[File], parent: ClassLoader, instance: ScalaInstance, nativeTemp: File): ClassLoader = toLoader(classpath, parent, createClasspathResources(classpath, instance), nativeTemp) private[sbt] def printSource(c: Class[_]) =