diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 4b6e11db9..79ad043c7 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -855,9 +855,12 @@ object Defaults extends BuildCommon { } val trl = (testResultLogger in (Test, test)).value val taskName = Project.showContextKey(state.value).show(resolvedScoped.value) + val currentLoader = Thread.currentThread.getContextClassLoader try { + Thread.currentThread.setContextClassLoader(testLoader.value) trl.run(streams.value.log, executeTests.value, taskName) } finally { + Thread.currentThread.setContextClassLoader(currentLoader) close.foreach(_.apply()) } }, @@ -1022,8 +1025,13 @@ object Defaults extends BuildCommon { ) val taskName = display.show(resolvedScoped.value) val trl = testResultLogger.value - val processed = output.map(out => trl.run(s.log, out, taskName)) - processed + val currentLoader = Thread.currentThread.getContextClassLoader + try { + Thread.currentThread.setContextClassLoader(testLoader.value) + output.map(out => trl.run(s.log, out, taskName)) + } finally { + Thread.currentThread.setContextClassLoader(currentLoader) + } } } diff --git a/main/src/main/scala/sbt/internal/ClassLoaders.scala b/main/src/main/scala/sbt/internal/ClassLoaders.scala index 937516a0c..54e3895ad 100644 --- a/main/src/main/scala/sbt/internal/ClassLoaders.scala +++ b/main/src/main/scala/sbt/internal/ClassLoaders.scala @@ -206,6 +206,7 @@ private[sbt] object ClassLoaders { parent: ClassLoader, resources: Map[String, String] ) extends LayeredClassLoader(classpath, parent, resources, new File("/dev/null")) { + override def findClass(name: String): Class[_] = throw new ClassNotFoundException(name) override def loadClass(name: String, resolve: Boolean): Class[_] = { val clazz = parent.loadClass(name) if (resolve) resolveClass(clazz) diff --git a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala index 6ca12c6a7..095fb42f1 100644 --- a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala +++ b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala @@ -9,12 +9,14 @@ package sbt.internal import java.io.File import java.net.URLClassLoader -import java.{ util => jutil } -import scala.collection.JavaConverters._ +import java.util.concurrent.ConcurrentHashMap import sbt.internal.inc.classpath._ import sbt.io.IO +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + private[sbt] class LayeredClassLoader( classpath: Seq[File], parent: ClassLoader, @@ -24,7 +26,7 @@ private[sbt] class LayeredClassLoader( with RawResources with NativeCopyLoader with AutoCloseable { - private[this] val nativeLibs = new jutil.HashSet[File]().asScala + private[this] val nativeLibs = new java.util.HashSet[File]().asScala override protected val config = new NativeCopyConfig( tempDir, classpath, @@ -38,6 +40,50 @@ private[sbt] class LayeredClassLoader( l } } + + private[this] val loaded = new ConcurrentHashMap[String, Class[_]] + /* + * Override findClass to memoize its result. We need to do this because in loadClass we will + * delegate to findClass if the current LayeredClassLoader cannot load a class but it is a + * descendant of the thread's context class loader and a class loader below it in the layering + * hierarchy is able to load the required class. Unlike loadClass, findClass does not cache + * the result which would make it possible to return multiple versions of the same class. + */ + override def findClass(name: String): Class[_] = loaded.get(name) match { + case null => + val res = super.findClass(name) + loaded.putIfAbsent(name, res) match { + case null => res + case clazz => clazz + } + case c => c + } + override def loadClass(name: String, resolve: Boolean): Class[_] = { + try super.loadClass(name, resolve) + catch { + case e: ClassNotFoundException => + val loaders = new ListBuffer[LayeredClassLoader] + var currentLoader: ClassLoader = Thread.currentThread.getContextClassLoader + do { + currentLoader match { + case cl: LayeredClassLoader if cl != this => loaders.prepend(cl) + case _ => + } + currentLoader = currentLoader.getParent + } while (currentLoader != null && currentLoader != this) + if (currentLoader == this) { + val resourceName = name.replace('.', '/').concat(".class") + loaders + .collectFirst { + case l if l.findResource(resourceName) != null => + val res = l.findClass(name) + if (resolve) l.resolveClass(res) + res + } + .getOrElse(throw e) + } else throw e + } + } override def close(): Unit = nativeLibs.foreach(NativeLibs.delete) override def toString: String = s"""LayeredClassLoader( | classpath = @@ -48,7 +94,7 @@ private[sbt] class LayeredClassLoader( } private[internal] object NativeLibs { - private[this] val nativeLibs = new jutil.HashSet[File].asScala + private[this] val nativeLibs = new java.util.HashSet[File].asScala ShutdownHooks.add(() => { nativeLibs.foreach(IO.delete) IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet) diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt b/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt new file mode 100644 index 000000000..ecc5e2814 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt @@ -0,0 +1,4 @@ +val dependency = project.settings(exportJars := true) +val descendant = project.dependsOn(dependency).settings( + libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test" +) diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala new file mode 100644 index 000000000..09853ea44 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala @@ -0,0 +1,17 @@ +package reflection + +import java.io._ +import scala.util.control.NonFatal + +object Reflection { + def roundTrip[A](a: A): A = { + val baos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(baos) + oos.writeObject(a) + oos.close() + val bais = new ByteArrayInputStream(baos.toByteArray()) + val ois = new ObjectInputStream(bais) + try ois.readObject().asInstanceOf[A] + finally ois.close() + } +} diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala new file mode 100644 index 000000000..ea880687e --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala @@ -0,0 +1,12 @@ +package test + +class Foo extends Serializable { + private[this] var value: Int = 0 + def getValue(): Int = value + def setValue(newValue: Int): Unit = value = newValue + override def equals(o: Any): Boolean = o match { + case that: Foo => this.getValue() == that.getValue() + case _ => false + } + override def hashCode: Int = value +} \ No newline at end of file diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala new file mode 100644 index 000000000..a1988034d --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala @@ -0,0 +1,12 @@ +package test + +import org.scalatest._ + +class ReflectionTest extends FlatSpec { + val foo = new Foo + foo.setValue(3) + val newFoo = reflection.Reflection.roundTrip(foo) + assert(newFoo == foo) + assert(System.identityHashCode(newFoo) != System.identityHashCode(foo)) +} + diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/test b/sbt/src/sbt-test/classloader-cache/java-serialization/test new file mode 100644 index 000000000..78ba31f24 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/test @@ -0,0 +1,12 @@ +> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.AllLibraryJars + +> test + +> testOnly + +> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.ScalaLibrary + +> test + +> testOnly +