From 2f99797bac4891f7347fbaab068522898078245e Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Sat, 11 Jan 2020 15:26:51 -0800 Subject: [PATCH] Fix RunFromSourceMain sbt.Package$ bug The main reason for having both the RunFromSourceMain and LauncherBased scripted tests was that RunFromSourceMain would fail for any test that ended up accessing the sbt.Package$ object. This commit fixes this bug by reworking the classloader generated by RunFromSourceMain to invoke sbt, switching from the classpath to jar classpath (by setting exportJars = true) and entering sbt by calling `new xMain().run` rather than `xMain.run`. The reason for switching to the jar classpath is that the jvm seems to have issues when there are two classes provided in different directories that have the same case insensitive name, e.g. `sbt.package$` and `sbt.Package$`. If those classes are instead provided in different jars, the jvm seems to be able to handle it. Exporting the jars is not enough though, I had to rework the ClassLoader created in the launch method to have a layout that was recognized by xMainConfiguration. I reimplemented the AppConfiguration in java so that it could bootstrap itself in a single jar classloader (the only needed jar is the Scripted. If we export the jars in the build, then the NoClassDefErrors for `sbt.Package$` go away during scripted tests using RunSourceFromMain. This might make running tests in subprojects slightly slower but I think its a worthy tradeoff. --- .../java/sbt/internal/ClassLoaderClose.java | 14 + .../sbt/internal/XMainConfiguration.scala | 11 +- project/Scripted.scala | 2 - .../test/scala/sbt/RunFromSourceMain.scala | 128 +---- .../scriptedtest/ScriptedLauncher.java | 464 ++++++++++++++++++ 5 files changed, 504 insertions(+), 115 deletions(-) create mode 100644 main/src/main/java/sbt/internal/ClassLoaderClose.java create mode 100644 sbt/src/test/scala/sbt/internal/scriptedtest/ScriptedLauncher.java diff --git a/main/src/main/java/sbt/internal/ClassLoaderClose.java b/main/src/main/java/sbt/internal/ClassLoaderClose.java new file mode 100644 index 000000000..70c7cf16c --- /dev/null +++ b/main/src/main/java/sbt/internal/ClassLoaderClose.java @@ -0,0 +1,14 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +public class ClassLoaderClose { + public static void close(ClassLoader classLoader) throws Exception { + if (classLoader instanceof AutoCloseable) ((AutoCloseable) classLoader).close(); + } +} diff --git a/main/src/main/scala/sbt/internal/XMainConfiguration.scala b/main/src/main/scala/sbt/internal/XMainConfiguration.scala index 86a11ecd4..b0eb06751 100644 --- a/main/src/main/scala/sbt/internal/XMainConfiguration.scala +++ b/main/src/main/scala/sbt/internal/XMainConfiguration.scala @@ -11,6 +11,7 @@ import java.io.File import java.lang.reflect.InvocationTargetException import java.net.URL import java.util.concurrent.{ ExecutorService, Executors } +import ClassLoaderClose.close import sbt.plugins.{ CorePlugin, IvyPlugin, JvmPlugin } import sbt.util.LogExchange @@ -56,10 +57,6 @@ private[internal] object ClassLoaderWarmup { * in this file. */ private[sbt] class XMainConfiguration { - private def close(classLoader: ClassLoader): Unit = classLoader match { - case a: AutoCloseable => a.close() - case _ => - } def run(moduleName: String, configuration: xsbti.AppConfiguration): xsbti.MainResult = { val updatedConfiguration = if (configuration.provider.scalaProvider.launcher.topLoader.getClass.getCanonicalName @@ -86,9 +83,11 @@ private[sbt] class XMainConfiguration { private def makeConfiguration(configuration: xsbti.AppConfiguration): xsbti.AppConfiguration = { val baseLoader = classOf[XMainConfiguration].getClassLoader - val url = baseLoader.getResource("sbt/internal/XMainConfiguration.class") + val className = "sbt/internal/XMainConfiguration.class" + val url = baseLoader.getResource(className) + val path = url.toString.replaceAll(s"$className$$", "") val urlArray = new Array[URL](1) - urlArray(0) = new URL(url.getPath.replaceAll("[!][^!]*class", "")) + urlArray(0) = new URL(path) val topLoader = configuration.provider.scalaProvider.launcher.topLoader // This loader doesn't have the scala library in it so it's critical that none of the code // in this file use the scala library. diff --git a/project/Scripted.scala b/project/Scripted.scala index 04acff63e..e55949f4c 100644 --- a/project/Scripted.scala +++ b/project/Scripted.scala @@ -5,8 +5,6 @@ import java.lang.reflect.InvocationTargetException import sbt._ import sbt.internal.inc.ScalaInstance import sbt.internal.inc.classpath.{ ClasspathUtilities, FilteredLoader } -import sbt.ScriptedPlugin.autoImport._ -import sbt.util.Level object LocalScriptedPlugin extends AutoPlugin { override def requires = plugins.JvmPlugin diff --git a/sbt/src/test/scala/sbt/RunFromSourceMain.scala b/sbt/src/test/scala/sbt/RunFromSourceMain.scala index ec7b168aa..ba3e68746 100644 --- a/sbt/src/test/scala/sbt/RunFromSourceMain.scala +++ b/sbt/src/test/scala/sbt/RunFromSourceMain.scala @@ -7,15 +7,16 @@ package sbt -import sbt.util.LogExchange -import scala.annotation.tailrec import buildinfo.TestBuildInfo -import xsbti._ +import sbt.internal.scriptedtest.ScriptedLauncher +import sbt.util.LogExchange + +import scala.annotation.tailrec import scala.sys.process.Process object RunFromSourceMain { private val sbtVersion = TestBuildInfo.version - private val scalaVersion = "2.12.6" + private val scalaVersion = "2.12.10" def fork(workingDirectory: File): Process = { val fo = ForkOptions() @@ -53,22 +54,28 @@ object RunFromSourceMain { } private def runImpl(baseDir: File, args: Seq[String]): Option[(File, Seq[String])] = - try launch(getConf(baseDir, args)) map exit + try launch(baseDir, args) map exit catch { case r: xsbti.FullReload => Some((baseDir, r.arguments())) case scala.util.control.NonFatal(e) => e.printStackTrace(); errorAndExit(e.toString) } - @tailrec private def launch(conf: AppConfiguration): Option[Int] = - xMain.run(conf) match { - case e: xsbti.Exit => Some(e.code) - case _: xsbti.Continue => None - case r: xsbti.Reboot => launch(getConf(conf.baseDirectory(), r.arguments())) - case x => handleUnknownMainResult(x) + private def launch(baseDirectory: File, arguments: Seq[String]): Option[Int] = { + ScriptedLauncher + .launch( + scalaHome, + sbtVersion, + scalaVersion, + bootDirectory, + baseDirectory, + buildinfo.TestBuildInfo.fullClasspath.toArray, + arguments.toArray + ) + .orElse(null) match { + case null => None + case i if i == Int.MaxValue => None + case i => Some(i) } - - private val noGlobalLock = new GlobalLock { - def apply[T](lockFile: File, run: java.util.concurrent.Callable[T]) = run.call() } private lazy val bootDirectory: File = file(sys.props("user.home")) / ".sbt" / "boot" @@ -106,99 +113,6 @@ object RunFromSourceMain { } } - private def getConf(baseDir: File, args: Seq[String]): AppConfiguration = new AppConfiguration { - def baseDirectory = baseDir - def arguments = args.toArray - - def provider = new AppProvider { appProvider => - def scalaProvider = new ScalaProvider { scalaProvider => - def scalaOrg = "org.scala-lang" - def launcher = new Launcher { - def getScala(version: String) = getScala(version, "") - def getScala(version: String, reason: String) = getScala(version, reason, scalaOrg) - def getScala(version: String, reason: String, scalaOrg: String) = scalaProvider - def app(id: xsbti.ApplicationID, version: String) = appProvider - def topLoader = new java.net.URLClassLoader(Array(), null) - def globalLock = noGlobalLock - def bootDirectory = RunFromSourceMain.bootDirectory - def ivyHome: File = sys.props.get("sbt.ivy.home") match { - case Some(home) => file(home) - case _ => file(sys.props("user.home")) / ".ivy2" - } - case class PredefRepo(id: Predefined) extends PredefinedRepository - import Predefined._ - def ivyRepositories = Array(PredefRepo(Local), PredefRepo(MavenCentral)) - def appRepositories = Array(PredefRepo(Local), PredefRepo(MavenCentral)) - def isOverrideRepositories = false - def checksums = Array("sha1", "md5") - } - def version = scalaVersion - lazy val libDir: File = RunFromSourceMain.scalaHome / "lib" - def jar(name: String): File = libDir / s"$name.jar" - lazy val libraryJar = jar("scala-library") - lazy val compilerJar = jar("scala-compiler") - lazy val jars = { - assert(libDir.exists) - libDir.listFiles(f => !f.isDirectory && f.getName.endsWith(".jar")) - } - def loader = new java.net.URLClassLoader(jars map (_.toURI.toURL), null) - def app(id: xsbti.ApplicationID) = appProvider - } - - def id = ApplicationID( - "org.scala-sbt", - "sbt", - sbtVersion, - "sbt.xMain", - Seq("xsbti", "extra"), - CrossValue.Disabled, - Nil - ) - def appHome: File = scalaHome / id.groupID / id.name / id.version - - def mainClasspath = buildinfo.TestBuildInfo.fullClasspath.toArray - def loader = new java.net.URLClassLoader(mainClasspath map (_.toURI.toURL), null) - def entryPoint = classOf[xMain] - def mainClass = classOf[xMain] - def newMain = new xMain - - def components = new ComponentProvider { - def componentLocation(id: String) = appHome / id - def component(id: String) = IO.listFiles(componentLocation(id), _.isFile) - - def defineComponent(id: String, files: Array[File]) = { - val location = componentLocation(id) - if (location.exists) - sys error s"Cannot redefine component. ID: $id, files: ${files mkString ","}" - else { - copy(files.toList, location) - () - } - } - - def addToComponent(id: String, files: Array[File]) = - copy(files.toList, componentLocation(id)) - - def lockFile = appHome / "sbt.components.lock" - - private def copy(files: List[File], toDirectory: File): Boolean = - files exists (copy(_, toDirectory)) - - private def copy(file: File, toDirectory: File): Boolean = { - val to = toDirectory / file.getName - val missing = !to.exists - IO.copyFile(file, to) - missing - } - } - } - } - - private def handleUnknownMainResult(x: MainResult): Nothing = { - val clazz = if (x eq null) "" else " (class: " + x.getClass + ")" - errorAndExit("Invalid main result: " + x + clazz) - } - private def errorAndExit(msg: String): Nothing = { System.err.println(msg); exit(1) } private def exit(code: Int): Nothing = System.exit(code).asInstanceOf[Nothing] } diff --git a/sbt/src/test/scala/sbt/internal/scriptedtest/ScriptedLauncher.java b/sbt/src/test/scala/sbt/internal/scriptedtest/ScriptedLauncher.java new file mode 100644 index 000000000..f46c40ea3 --- /dev/null +++ b/sbt/src/test/scala/sbt/internal/scriptedtest/ScriptedLauncher.java @@ -0,0 +1,464 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.scriptedtest; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.file.Files; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import xsbti.AppConfiguration; +import xsbti.AppMain; +import xsbti.AppProvider; +import xsbti.ApplicationID; +import xsbti.ComponentProvider; +import xsbti.CrossValue; +import xsbti.Exit; +import xsbti.GlobalLock; +import xsbti.Launcher; +import xsbti.MainResult; +import xsbti.Predefined; +import xsbti.PredefinedRepository; +import xsbti.Reboot; +import xsbti.Repository; +import xsbti.ScalaProvider; + +public class ScriptedLauncher { + private static URL URLForClass(final Class clazz) + throws MalformedURLException, ClassNotFoundException { + final String path = clazz.getCanonicalName().replace('.', '/') + ".class"; + final URL url = clazz.getClassLoader().getResource(path); + if (url == null) throw new ClassNotFoundException(clazz.getCanonicalName()); + return new URL(url.toString().replaceAll(path + "$", "")); + } + + public static Optional launch( + final File scalaHome, + final String sbtVersion, + final String scalaVersion, + final File bootDirectory, + final File baseDir, + final File[] classpath, + String[] args) + throws MalformedURLException, InvocationTargetException, ClassNotFoundException, + IllegalAccessException { + while (true) { + final URL configURL = URLForClass(xsbti.AppConfiguration.class); + final URL mainURL = URLForClass(sbt.xMain.class); + final URL scriptedURL = URLForClass(ScriptedLauncher.class); + final ClassLoader topLoader = + new URLClassLoader(new URL[] {configURL}, ClassLoader.getSystemClassLoader().getParent()); + final URLClassLoader loader = new URLClassLoader(new URL[] {mainURL, scriptedURL}, topLoader); + final ClassLoader previous = Thread.currentThread().getContextClassLoader(); + try { + Thread.currentThread().setContextClassLoader(loader); + final AtomicInteger result = new AtomicInteger(-1); + final AtomicReference newArguments = new AtomicReference<>(); + final Class clazz = loader.loadClass("sbt.internal.scriptedtest.ScriptedLauncher"); + Method method = null; + for (final Method m : clazz.getDeclaredMethods()) { + if (m.getName().equals("launchImpl")) method = m; + } + method.invoke( + null, + topLoader, + loader, + scalaHome, + sbtVersion, + scalaVersion, + bootDirectory, + baseDir, + classpath, + args, + result, + newArguments); + final int res = result.get(); + if (res >= 0) return res == Integer.MAX_VALUE ? Optional.empty() : Optional.of(res); + else args = newArguments.get(); + } finally { + try { + loader.close(); + } catch (final Exception e) { + } + Thread.currentThread().setContextClassLoader(previous); + } + } + } + + private static void copy(final File[] files, final File toDirectory) { + for (final File file : files) { + try { + Files.createDirectories(toDirectory.toPath()); + Files.copy(file.toPath(), toDirectory.toPath().resolve(file.getName())); + } catch (final IOException e) { + e.printStackTrace(System.err); + } + } + } + + @SuppressWarnings("unused") + public static void launchImpl( + final ClassLoader topLoader, + final ClassLoader loader, + final File scalaHome, + final String sbtVersion, + final String scalaVersion, + final File bootDirectory, + final File baseDir, + final File[] classpath, + final String[] args, + final AtomicInteger result, + final AtomicReference newArguments) + throws ClassNotFoundException, InvocationTargetException, IllegalAccessException, + InstantiationException { + final AppConfiguration conf = + getConf( + topLoader, + scalaHome, + sbtVersion, + scalaVersion, + bootDirectory, + baseDir, + classpath, + args); + final Class clazz = loader.loadClass("sbt.xMain"); + final Object instance = clazz.newInstance(); + Method run = null; + for (final Method m : clazz.getDeclaredMethods()) { + if (m.getName().equals("run")) run = m; + } + final Object runResult = run.invoke(instance, conf); + if (runResult instanceof xsbti.Reboot) newArguments.set(((Reboot) runResult).arguments()); + else { + if (runResult instanceof xsbti.Exit) { + result.set(((Exit) runResult).code()); + } else if (runResult instanceof xsbti.Continue) { + result.set(Integer.MAX_VALUE); + } else { + handleUnknownMainResult((MainResult) runResult); + } + } + } + + private static void handleUnknownMainResult(MainResult x) { + final String clazz = x == null ? "" : " (class: " + x.getClass() + ")"; + System.err.println("Invalid main result: " + x + clazz); + System.exit(1); + } + + public static AppConfiguration getConf( + final ClassLoader topLoader, + final File scalaHome, + final String sbtVersion, + final String scalaVersion, + final File bootDirectory, + final File baseDir, + final File[] classpath, + String[] args) { + + final File libDir = new File(scalaHome, "lib"); + final ApplicationID id = + new ApplicationID() { + @Override + public String groupID() { + return "org.scala-sbt"; + } + + @Override + public String name() { + return "sbt"; + } + + @Override + public String version() { + return sbtVersion; + } + + @Override + public String mainClass() { + return "sbt.xMain"; + } + + @Override + public String[] mainComponents() { + return new String[] {"xsbti", "extra"}; + } + + @Deprecated + @Override + public boolean crossVersioned() { + return false; + } + + @Override + public CrossValue crossVersionedValue() { + return CrossValue.Disabled; + } + + @Override + public File[] classpathExtra() { + return new File[0]; + } + }; + final File appHome = + scalaHome.toPath().resolve(id.groupID()).resolve(id.name()).resolve(id.version()).toFile(); + assert (libDir.exists()); + final File[] jars = libDir.listFiles(f -> f.isFile() && f.getName().endsWith(".jar")); + final URL[] urls = new URL[jars.length]; + for (int i = 0; i < jars.length; ++i) { + try { + urls[i] = jars[i].toURI().toURL(); + } catch (final IOException e) { + throw new RuntimeException(e); + } + } + return new AppConfiguration() { + @Override + public String[] arguments() { + return args; + } + + @Override + public File baseDirectory() { + return baseDir; + } + + @Override + public AppProvider provider() { + return new AppProvider() { + final AppProvider self = this; + final ScalaProvider scalaProvider = + new ScalaProvider() { + private final ScalaProvider sp = this; + private final String scalaOrg = "org.scala-lang"; + private final Repository[] repos = + new PredefinedRepository[] { + () -> Predefined.Local, () -> Predefined.MavenCentral + }; + private final Launcher launcher = + new Launcher() { + @Override + public ScalaProvider getScala(String version) { + return getScala(version, ""); + } + + @Override + public ScalaProvider getScala(String version, String reason) { + return getScala(version, reason, scalaOrg); + } + + @Override + public ScalaProvider getScala( + String version, String reason, String scalaOrg) { + return sp; + } + + @Override + public AppProvider app(ApplicationID id, String version) { + return self; + } + + @Override + public ClassLoader topLoader() { + return topLoader; + } + + class foo extends Throwable { + foo(final Exception e) { + super(e.getMessage(), null, true, false); + } + } + + @Override + public GlobalLock globalLock() { + return new GlobalLock() { + @Override + public T apply(File lockFile, Callable run) { + try { + return run.call(); + } catch (final Exception e) { + throw new RuntimeException(new foo(e)) { + @Override + public StackTraceElement[] getStackTrace() { + return new StackTraceElement[0]; + } + }; + } + } + }; + } + + @Override + public File bootDirectory() { + return bootDirectory; + } + + @Override + public Repository[] ivyRepositories() { + return repos; + } + + @Override + public Repository[] appRepositories() { + return repos; + } + + @Override + public boolean isOverrideRepositories() { + return false; + } + + @Override + public File ivyHome() { + final String home = System.getProperty("sbt.ivy.home"); + return home == null + ? new File(System.getProperty("user.home"), ".ivy2") + : new File(home); + } + + @Override + public String[] checksums() { + return new String[] {"sha1", "md5"}; + } + }; + + @Override + public Launcher launcher() { + return launcher; + } + + @Override + public String version() { + return scalaVersion; + } + + @Override + public ClassLoader loader() { + return new URLClassLoader(urls, topLoader); + } + + @Override + public File[] jars() { + return jars; + } + + @Deprecated + @Override + public File libraryJar() { + return new File(libDir, "scala-library.jar"); + } + + @Deprecated + @Override + public File compilerJar() { + return new File(libDir, "scala-compiler.jar"); + } + + @Override + public AppProvider app(ApplicationID id) { + return self; + } + }; + + @Override + public ScalaProvider scalaProvider() { + return scalaProvider; + } + + @Override + public ApplicationID id() { + return id; + } + + @Override + public ClassLoader loader() { + return new URLClassLoader(urls, topLoader); + } + + @Deprecated + @Override + public Class mainClass() { + return AppMain.class; + } + + @Override + public Class entryPoint() { + return AppMain.class; + } + + @Override + public AppMain newMain() { + try { + return (AppMain) loader().loadClass("sbt.xMain").newInstance(); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public File[] mainClasspath() { + return classpath; + } + + @Override + public ComponentProvider components() { + return new ComponentProvider() { + @Override + public File componentLocation(String id) { + return new File(appHome, id); + } + + @Override + public File[] component(String componentID) { + final File dir = componentLocation(componentID); + final File[] files = dir.listFiles(File::isFile); + return files == null ? new File[0] : files; + } + + @Override + public void defineComponent(String componentID, File[] components) { + final File dir = componentLocation(componentID); + if (dir.exists()) { + final StringBuilder files = new StringBuilder(); + for (final File file : components) { + if (files.length() > 0) { + files.append(','); + } + files.append(file.toString()); + } + throw new RuntimeException( + "Cannot redefine component. ID: " + id + ", files: " + files); + } else { + copy(components, dir); + } + } + + @Override + public boolean addToComponent(String componentID, File[] components) { + copy(components, componentLocation(componentID)); + return false; + } + + @Override + public File lockFile() { + return new File(appHome, "sbt.components.lock"); + } + }; + } + }; + } + }; + } +}