mirror of https://github.com/sbt/sbt.git
Support java reflection with layered classloaders
Jave reflection did not work with layered classloaders if a dependency attempted to load a class that was below the dependency layer in the layered classloader hierarchy. The underlying problem was (in general) a call to Class.forName somewhere. If the classloader parameter is not specified, then Class.forName locates the ClassLoader for the caller using reflection. It ultimately delegates to that ClassLoader's loadClass method. With the previous LayeredClassLoader class, there was no way for that classloader to access a URL that was below it in the class loading hierarchy. I reworked LayeredClassLoader so that if it fails to load the class, it will check the Thread's context classloader and see if there are other LayeredClassLoader instances below it. If so, it will then check if any of those classloaders would be able to load the class by using findResource. If the descendant loader can load the class, then we manually load it with findClass.
This commit is contained in:
parent
a3cde88db4
commit
cc8c66c66d
|
|
@ -855,9 +855,12 @@ object Defaults extends BuildCommon {
|
|||
}
|
||||
val trl = (testResultLogger in (Test, test)).value
|
||||
val taskName = Project.showContextKey(state.value).show(resolvedScoped.value)
|
||||
val currentLoader = Thread.currentThread.getContextClassLoader
|
||||
try {
|
||||
Thread.currentThread.setContextClassLoader(testLoader.value)
|
||||
trl.run(streams.value.log, executeTests.value, taskName)
|
||||
} finally {
|
||||
Thread.currentThread.setContextClassLoader(currentLoader)
|
||||
close.foreach(_.apply())
|
||||
}
|
||||
},
|
||||
|
|
@ -1022,8 +1025,13 @@ object Defaults extends BuildCommon {
|
|||
)
|
||||
val taskName = display.show(resolvedScoped.value)
|
||||
val trl = testResultLogger.value
|
||||
val processed = output.map(out => trl.run(s.log, out, taskName))
|
||||
processed
|
||||
val currentLoader = Thread.currentThread.getContextClassLoader
|
||||
try {
|
||||
Thread.currentThread.setContextClassLoader(testLoader.value)
|
||||
output.map(out => trl.run(s.log, out, taskName))
|
||||
} finally {
|
||||
Thread.currentThread.setContextClassLoader(currentLoader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ private[sbt] object ClassLoaders {
|
|||
parent: ClassLoader,
|
||||
resources: Map[String, String]
|
||||
) extends LayeredClassLoader(classpath, parent, resources, new File("/dev/null")) {
|
||||
override def findClass(name: String): Class[_] = throw new ClassNotFoundException(name)
|
||||
override def loadClass(name: String, resolve: Boolean): Class[_] = {
|
||||
val clazz = parent.loadClass(name)
|
||||
if (resolve) resolveClass(clazz)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,14 @@ package sbt.internal
|
|||
|
||||
import java.io.File
|
||||
import java.net.URLClassLoader
|
||||
import java.{ util => jutil }
|
||||
import scala.collection.JavaConverters._
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
import sbt.internal.inc.classpath._
|
||||
import sbt.io.IO
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
private[sbt] class LayeredClassLoader(
|
||||
classpath: Seq[File],
|
||||
parent: ClassLoader,
|
||||
|
|
@ -24,7 +26,7 @@ private[sbt] class LayeredClassLoader(
|
|||
with RawResources
|
||||
with NativeCopyLoader
|
||||
with AutoCloseable {
|
||||
private[this] val nativeLibs = new jutil.HashSet[File]().asScala
|
||||
private[this] val nativeLibs = new java.util.HashSet[File]().asScala
|
||||
override protected val config = new NativeCopyConfig(
|
||||
tempDir,
|
||||
classpath,
|
||||
|
|
@ -38,6 +40,50 @@ private[sbt] class LayeredClassLoader(
|
|||
l
|
||||
}
|
||||
}
|
||||
|
||||
private[this] val loaded = new ConcurrentHashMap[String, Class[_]]
|
||||
/*
|
||||
* Override findClass to memoize its result. We need to do this because in loadClass we will
|
||||
* delegate to findClass if the current LayeredClassLoader cannot load a class but it is a
|
||||
* descendant of the thread's context class loader and a class loader below it in the layering
|
||||
* hierarchy is able to load the required class. Unlike loadClass, findClass does not cache
|
||||
* the result which would make it possible to return multiple versions of the same class.
|
||||
*/
|
||||
override def findClass(name: String): Class[_] = loaded.get(name) match {
|
||||
case null =>
|
||||
val res = super.findClass(name)
|
||||
loaded.putIfAbsent(name, res) match {
|
||||
case null => res
|
||||
case clazz => clazz
|
||||
}
|
||||
case c => c
|
||||
}
|
||||
override def loadClass(name: String, resolve: Boolean): Class[_] = {
|
||||
try super.loadClass(name, resolve)
|
||||
catch {
|
||||
case e: ClassNotFoundException =>
|
||||
val loaders = new ListBuffer[LayeredClassLoader]
|
||||
var currentLoader: ClassLoader = Thread.currentThread.getContextClassLoader
|
||||
do {
|
||||
currentLoader match {
|
||||
case cl: LayeredClassLoader if cl != this => loaders.prepend(cl)
|
||||
case _ =>
|
||||
}
|
||||
currentLoader = currentLoader.getParent
|
||||
} while (currentLoader != null && currentLoader != this)
|
||||
if (currentLoader == this) {
|
||||
val resourceName = name.replace('.', '/').concat(".class")
|
||||
loaders
|
||||
.collectFirst {
|
||||
case l if l.findResource(resourceName) != null =>
|
||||
val res = l.findClass(name)
|
||||
if (resolve) l.resolveClass(res)
|
||||
res
|
||||
}
|
||||
.getOrElse(throw e)
|
||||
} else throw e
|
||||
}
|
||||
}
|
||||
override def close(): Unit = nativeLibs.foreach(NativeLibs.delete)
|
||||
override def toString: String = s"""LayeredClassLoader(
|
||||
| classpath =
|
||||
|
|
@ -48,7 +94,7 @@ private[sbt] class LayeredClassLoader(
|
|||
}
|
||||
|
||||
private[internal] object NativeLibs {
|
||||
private[this] val nativeLibs = new jutil.HashSet[File].asScala
|
||||
private[this] val nativeLibs = new java.util.HashSet[File].asScala
|
||||
ShutdownHooks.add(() => {
|
||||
nativeLibs.foreach(IO.delete)
|
||||
IO.deleteIfEmpty(nativeLibs.map(_.getParentFile).toSet)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
val dependency = project.settings(exportJars := true)
|
||||
val descendant = project.dependsOn(dependency).settings(
|
||||
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test"
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
package reflection
|
||||
|
||||
import java.io._
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
object Reflection {
|
||||
def roundTrip[A](a: A): A = {
|
||||
val baos = new ByteArrayOutputStream()
|
||||
val oos = new ObjectOutputStream(baos)
|
||||
oos.writeObject(a)
|
||||
oos.close()
|
||||
val bais = new ByteArrayInputStream(baos.toByteArray())
|
||||
val ois = new ObjectInputStream(bais)
|
||||
try ois.readObject().asInstanceOf[A]
|
||||
finally ois.close()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package test
|
||||
|
||||
class Foo extends Serializable {
|
||||
private[this] var value: Int = 0
|
||||
def getValue(): Int = value
|
||||
def setValue(newValue: Int): Unit = value = newValue
|
||||
override def equals(o: Any): Boolean = o match {
|
||||
case that: Foo => this.getValue() == that.getValue()
|
||||
case _ => false
|
||||
}
|
||||
override def hashCode: Int = value
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package test
|
||||
|
||||
import org.scalatest._
|
||||
|
||||
class ReflectionTest extends FlatSpec {
|
||||
val foo = new Foo
|
||||
foo.setValue(3)
|
||||
val newFoo = reflection.Reflection.roundTrip(foo)
|
||||
assert(newFoo == foo)
|
||||
assert(System.identityHashCode(newFoo) != System.identityHashCode(foo))
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.AllLibraryJars
|
||||
|
||||
> test
|
||||
|
||||
> testOnly
|
||||
|
||||
> set classLoaderLayeringStrategy := ClassLoaderLayeringStrategy.ScalaLibrary
|
||||
|
||||
> test
|
||||
|
||||
> testOnly
|
||||
|
||||
Loading…
Reference in New Issue