From cc8c66c66d2b67d4491f07cabf862b06f8381b7d Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Mon, 3 Jun 2019 09:42:33 -0700 Subject: [PATCH] Support java reflection with layered classloaders Jave reflection did not work with layered classloaders if a dependency attempted to load a class that was below the dependency layer in the layered classloader hierarchy. The underlying problem was (in general) a call to Class.forName somewhere. If the classloader parameter is not specified, then Class.forName locates the ClassLoader for the caller using reflection. It ultimately delegates to that ClassLoader's loadClass method. With the previous LayeredClassLoader class, there was no way for that classloader to access a URL that was below it in the class loading hierarchy. I reworked LayeredClassLoader so that if it fails to load the class, it will check the Thread's context classloader and see if there are other LayeredClassLoader instances below it. If so, it will then check if any of those classloaders would be able to load the class by using findResource. If the descendant loader can load the class, then we manually load it with findClass. --- main/src/main/scala/sbt/Defaults.scala | 12 ++++- .../scala/sbt/internal/ClassLoaders.scala | 1 + .../sbt/internal/LayeredClassLoader.scala | 54 +++++++++++++++++-- .../java-serialization/build.sbt | 4 ++ .../main/scala/reflection/Reflection.scala | 17 ++++++ .../descendant/src/test/scala/test/Foo.scala | 12 +++++ .../src/test/scala/test/ReflectionTest.scala | 12 +++++ .../classloader-cache/java-serialization/test | 12 +++++ 8 files changed, 118 insertions(+), 6 deletions(-) create mode 100644 sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt create mode 100644 sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala create mode 100644 sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala create mode 100644 sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala create mode 100644 sbt/src/sbt-test/classloader-cache/java-serialization/test diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 4b6e11db9..79ad043c7 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -855,9 +855,12 @@ object Defaults extends BuildCommon { } val trl = (testResultLogger in (Test, test)).value val taskName = Project.showContextKey(state.value).show(resolvedScoped.value) + val currentLoader = Thread.currentThread.getContextClassLoader try { + Thread.currentThread.setContextClassLoader(testLoader.value) trl.run(streams.value.log, executeTests.value, taskName) } finally { + Thread.currentThread.setContextClassLoader(currentLoader) close.foreach(_.apply()) } }, @@ -1022,8 +1025,13 @@ object Defaults extends BuildCommon { ) val taskName = display.show(resolvedScoped.value) val trl = testResultLogger.value - val processed = output.map(out => trl.run(s.log, out, taskName)) - processed + val currentLoader = Thread.currentThread.getContextClassLoader + try { + Thread.currentThread.setContextClassLoader(testLoader.value) + output.map(out => trl.run(s.log, out, taskName)) + } finally { + Thread.currentThread.setContextClassLoader(currentLoader) + } } } diff --git a/main/src/main/scala/sbt/internal/ClassLoaders.scala b/main/src/main/scala/sbt/internal/ClassLoaders.scala index 937516a0c..54e3895ad 100644 --- a/main/src/main/scala/sbt/internal/ClassLoaders.scala +++ b/main/src/main/scala/sbt/internal/ClassLoaders.scala @@ -206,6 +206,7 @@ private[sbt] object ClassLoaders { parent: ClassLoader, resources: Map[String, String] ) extends LayeredClassLoader(classpath, parent, resources, new File("/dev/null")) { + override def findClass(name: String): Class[_] = throw new ClassNotFoundException(name) override def loadClass(name: String, resolve: Boolean): Class[_] = { val clazz = parent.loadClass(name) if (resolve) resolveClass(clazz) diff --git a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala index 6ca12c6a7..095fb42f1 100644 --- a/main/src/main/scala/sbt/internal/LayeredClassLoader.scala +++ b/main/src/main/scala/sbt/internal/LayeredClassLoader.scala @@ -9,12 +9,14 @@ package sbt.internal import java.io.File import java.net.URLClassLoader -import java.{ util => jutil } -import scala.collection.JavaConverters._ +import java.util.concurrent.ConcurrentHashMap import sbt.internal.inc.classpath._ import sbt.io.IO +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + private[sbt] class LayeredClassLoader( classpath: Seq[File], parent: ClassLoader, @@ -24,7 +26,7 @@ private[sbt] class LayeredClassLoader( with RawResources with NativeCopyLoader with AutoCloseable { - private[this] val nativeLibs = new jutil.HashSet[File]().asScala + private[this] val nativeLibs = new java.util.HashSet[File]().asScala override protected val config = new NativeCopyConfig( tempDir, classpath, @@ -38,6 +40,50 @@ private[sbt] class LayeredClassLoader( l } } + + private[this] val loaded = new ConcurrentHashMap[String, Class[_]] + /* + * Override findClass to memoize its result. We need to do this because in loadClass we will + * delegate to findClass if the current LayeredClassLoader cannot load a class but it is a + * descendant of the thread's context class loader and a class loader below it in the layering + * hierarchy is able to load the required class. Unlike loadClass, findClass does not cache + * the result which would make it possible to return multiple versions of the same class. + */ + override def findClass(name: String): Class[_] = loaded.get(name) match { + case null => + val res = super.findClass(name) + loaded.putIfAbsent(name, res) match { + case null => res + case clazz => clazz + } + case c => c + } + override def loadClass(name: String, resolve: Boolean): Class[_] = { + try super.loadClass(name, resolve) + catch { + case e: ClassNotFoundException => + val loaders = new ListBuffer[LayeredClassLoader] + var currentLoader: ClassLoader = Thread.currentThread.getContextClassLoader + do { + currentLoader match { + case cl: LayeredClassLoader if cl != this => loaders.prepend(cl) + case _ => + } + currentLoader = currentLoader.getParent + } while (currentLoader != null && currentLoader != this) + if (currentLoader == this) { + val resourceName = name.replace('.', '/').concat(".class") + loaders + .collectFirst { + case l if l.findResource(resourceName) != null => + val res = l.findClass(name) + if (resolve) l.resolveClass(res) + res + } + .getOrElse(throw e) + } else throw e + } + } override def close(): Unit = nativeLibs.foreach(NativeLibs.delete) override def toString: String = s"""LayeredClassLoader( | classpath = @@ -48,7 +94,7 @@ private[sbt] class LayeredClassLoader( } private[internal] object NativeLibs { - private[this] val nativeLibs = new jutil.HashSet[File].asScala + private[this] val nativeLibs = new java.util.HashSet[File].asScala ShutdownHooks.add(() => { nativeLibs.foreach(IO.delete) IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet) diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt b/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt new file mode 100644 index 000000000..ecc5e2814 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/build.sbt @@ -0,0 +1,4 @@ +val dependency = project.settings(exportJars := true) +val descendant = project.dependsOn(dependency).settings( + libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test" +) diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala new file mode 100644 index 000000000..09853ea44 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/dependency/src/main/scala/reflection/Reflection.scala @@ -0,0 +1,17 @@ +package reflection + +import java.io._ +import scala.util.control.NonFatal + +object Reflection { + def roundTrip[A](a: A): A = { + val baos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(baos) + oos.writeObject(a) + oos.close() + val bais = new ByteArrayInputStream(baos.toByteArray()) + val ois = new ObjectInputStream(bais) + try ois.readObject().asInstanceOf[A] + finally ois.close() + } +} diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala new file mode 100644 index 000000000..ea880687e --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/Foo.scala @@ -0,0 +1,12 @@ +package test + +class Foo extends Serializable { + private[this] var value: Int = 0 + def getValue(): Int = value + def setValue(newValue: Int): Unit = value = newValue + override def equals(o: Any): Boolean = o match { + case that: Foo => this.getValue() == that.getValue() + case _ => false + } + override def hashCode: Int = value +} \ No newline at end of file diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala new file mode 100644 index 000000000..a1988034d --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/descendant/src/test/scala/test/ReflectionTest.scala @@ -0,0 +1,12 @@ +package test + +import org.scalatest._ + +class ReflectionTest extends FlatSpec { + val foo = new Foo + foo.setValue(3) + val newFoo = reflection.Reflection.roundTrip(foo) + assert(newFoo == foo) + assert(System.identityHashCode(newFoo) != System.identityHashCode(foo)) +} + diff --git a/sbt/src/sbt-test/classloader-cache/java-serialization/test b/sbt/src/sbt-test/classloader-cache/java-serialization/test new file mode 100644 index 000000000..78ba31f24 --- /dev/null +++ b/sbt/src/sbt-test/classloader-cache/java-serialization/test @@ -0,0 +1,12 @@ +> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.AllLibraryJars + +> test + +> testOnly + +> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.ScalaLibrary + +> test + +> testOnly +