Support java reflection with layered classloaders

Jave reflection did not work with layered classloaders if a dependency
attempted to load a class that was below the dependency layer in the
layered classloader hierarchy. The underlying problem was (in general) a
call to Class.forName somewhere. If the classloader parameter is not
specified, then Class.forName locates the ClassLoader for the caller
using reflection. It ultimately delegates to that ClassLoader's
loadClass method. With the previous LayeredClassLoader class, there was
no way for that classloader to access a URL that was below it in the
class loading hierarchy. I reworked LayeredClassLoader so that if it
fails to load the class, it will check the Thread's context classloader
and see if there are other LayeredClassLoader instances below it. If so,
it will then check if any of those classloaders would be able to load
the class by using findResource. If the descendant loader can load the
class, then we manually load it with findClass.
This commit is contained in:
Ethan Atkins 2019-06-03 09:42:33 -07:00
parent a3cde88db4
commit cc8c66c66d
8 changed files with 118 additions and 6 deletions

View File

@ -855,9 +855,12 @@ object Defaults extends BuildCommon {
}
val trl = (testResultLogger in (Test, test)).value
val taskName = Project.showContextKey(state.value).show(resolvedScoped.value)
val currentLoader = Thread.currentThread.getContextClassLoader
try {
Thread.currentThread.setContextClassLoader(testLoader.value)
trl.run(streams.value.log, executeTests.value, taskName)
} finally {
Thread.currentThread.setContextClassLoader(currentLoader)
close.foreach(_.apply())
}
},
@ -1022,8 +1025,13 @@ object Defaults extends BuildCommon {
)
val taskName = display.show(resolvedScoped.value)
val trl = testResultLogger.value
val processed = output.map(out => trl.run(s.log, out, taskName))
processed
val currentLoader = Thread.currentThread.getContextClassLoader
try {
Thread.currentThread.setContextClassLoader(testLoader.value)
output.map(out => trl.run(s.log, out, taskName))
} finally {
Thread.currentThread.setContextClassLoader(currentLoader)
}
}
}

View File

@ -206,6 +206,7 @@ private[sbt] object ClassLoaders {
parent: ClassLoader,
resources: Map[String, String]
) extends LayeredClassLoader(classpath, parent, resources, new File("/dev/null")) {
override def findClass(name: String): Class[_] = throw new ClassNotFoundException(name)
override def loadClass(name: String, resolve: Boolean): Class[_] = {
val clazz = parent.loadClass(name)
if (resolve) resolveClass(clazz)

View File

@ -9,12 +9,14 @@ package sbt.internal
import java.io.File
import java.net.URLClassLoader
import java.{ util => jutil }
import scala.collection.JavaConverters._
import java.util.concurrent.ConcurrentHashMap
import sbt.internal.inc.classpath._
import sbt.io.IO
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
private[sbt] class LayeredClassLoader(
classpath: Seq[File],
parent: ClassLoader,
@ -24,7 +26,7 @@ private[sbt] class LayeredClassLoader(
with RawResources
with NativeCopyLoader
with AutoCloseable {
private[this] val nativeLibs = new jutil.HashSet[File]().asScala
private[this] val nativeLibs = new java.util.HashSet[File]().asScala
override protected val config = new NativeCopyConfig(
tempDir,
classpath,
@ -38,6 +40,50 @@ private[sbt] class LayeredClassLoader(
l
}
}
private[this] val loaded = new ConcurrentHashMap[String, Class[_]]
/*
* Override findClass to memoize its result. We need to do this because in loadClass we will
* delegate to findClass if the current LayeredClassLoader cannot load a class but it is a
* descendant of the thread's context class loader and a class loader below it in the layering
* hierarchy is able to load the required class. Unlike loadClass, findClass does not cache
* the result which would make it possible to return multiple versions of the same class.
*/
override def findClass(name: String): Class[_] = loaded.get(name) match {
case null =>
val res = super.findClass(name)
loaded.putIfAbsent(name, res) match {
case null => res
case clazz => clazz
}
case c => c
}
override def loadClass(name: String, resolve: Boolean): Class[_] = {
try super.loadClass(name, resolve)
catch {
case e: ClassNotFoundException =>
val loaders = new ListBuffer[LayeredClassLoader]
var currentLoader: ClassLoader = Thread.currentThread.getContextClassLoader
do {
currentLoader match {
case cl: LayeredClassLoader if cl != this => loaders.prepend(cl)
case _ =>
}
currentLoader = currentLoader.getParent
} while (currentLoader != null && currentLoader != this)
if (currentLoader == this) {
val resourceName = name.replace('.', '/').concat(".class")
loaders
.collectFirst {
case l if l.findResource(resourceName) != null =>
val res = l.findClass(name)
if (resolve) l.resolveClass(res)
res
}
.getOrElse(throw e)
} else throw e
}
}
override def close(): Unit = nativeLibs.foreach(NativeLibs.delete)
override def toString: String = s"""LayeredClassLoader(
| classpath =
@ -48,7 +94,7 @@ private[sbt] class LayeredClassLoader(
}
private[internal] object NativeLibs {
private[this] val nativeLibs = new jutil.HashSet[File].asScala
private[this] val nativeLibs = new java.util.HashSet[File].asScala
ShutdownHooks.add(() => {
nativeLibs.foreach(IO.delete)
IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet)

View File

@ -0,0 +1,4 @@
val dependency = project.settings(exportJars := true)
val descendant = project.dependsOn(dependency).settings(
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test"
)

View File

@ -0,0 +1,17 @@
package reflection
import java.io._
import scala.util.control.NonFatal
object Reflection {
def roundTrip[A](a: A): A = {
val baos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(baos)
oos.writeObject(a)
oos.close()
val bais = new ByteArrayInputStream(baos.toByteArray())
val ois = new ObjectInputStream(bais)
try ois.readObject().asInstanceOf[A]
finally ois.close()
}
}

View File

@ -0,0 +1,12 @@
package test
class Foo extends Serializable {
private[this] var value: Int = 0
def getValue(): Int = value
def setValue(newValue: Int): Unit = value = newValue
override def equals(o: Any): Boolean = o match {
case that: Foo => this.getValue() == that.getValue()
case _ => false
}
override def hashCode: Int = value
}

View File

@ -0,0 +1,12 @@
package test
import org.scalatest._
class ReflectionTest extends FlatSpec {
val foo = new Foo
foo.setValue(3)
val newFoo = reflection.Reflection.roundTrip(foo)
assert(newFoo == foo)
assert(System.identityHashCode(newFoo) != System.identityHashCode(foo))
}

View File

@ -0,0 +1,12 @@
> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.AllLibraryJars
> test
> testOnly
> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.ScalaLibrary
> test
> testOnly