From be04ad552d44fe5565326d98d661a6f7458d5d7e Mon Sep 17 00:00:00 2001 From: Adrien Piquerez Date: Mon, 30 Sep 2024 14:01:46 +0200 Subject: [PATCH] Use projectMatrix caller to resolve plugins --- .../src/main/scala/sbt/std/KeyMacro.scala | 13 +++++++++- main/src/main/scala/sbt/ProjectMatrix.scala | 25 ++++++++++++------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/main-settings/src/main/scala/sbt/std/KeyMacro.scala b/main-settings/src/main/scala/sbt/std/KeyMacro.scala index b59a4c534..653388433 100644 --- a/main-settings/src/main/scala/sbt/std/KeyMacro.scala +++ b/main-settings/src/main/scala/sbt/std/KeyMacro.scala @@ -62,7 +62,11 @@ private[sbt] object KeyMacro: if term.isValDef then Expr(term.name) else errorAndAbort(errorMsg) - def enclosingTerm(using qctx: Quotes) = + private[sbt] def callerThis(using Quotes): Expr[Any] = + import quotes.reflect.* + This(enclosingClass).asExpr + + private def enclosingTerm(using qctx: Quotes) = import qctx.reflect._ def enclosingTerm0(sym: Symbol): Symbol = sym match @@ -70,4 +74,11 @@ private[sbt] object KeyMacro: case sym if !sym.isTerm => enclosingTerm0(sym.owner) case _ => sym enclosingTerm0(Symbol.spliceOwner) + + private def enclosingClass(using Quotes) = + import quotes.reflect.* + def rec(sym: Symbol): Symbol = + if sym.isClassDef then sym + else rec(sym.owner) + rec(Symbol.spliceOwner) end KeyMacro diff --git a/main/src/main/scala/sbt/ProjectMatrix.scala b/main/src/main/scala/sbt/ProjectMatrix.scala index 1b834ded7..dd7bb7d12 100644 --- a/main/src/main/scala/sbt/ProjectMatrix.scala +++ b/main/src/main/scala/sbt/ProjectMatrix.scala @@ -253,6 +253,7 @@ object ProjectMatrix { val plugins: Plugins, val transforms: Seq[Project => Project], val defAxes: Seq[VirtualAxis], + val pluginClassLoader: ClassLoader ) extends ProjectMatrix { self => lazy val resolvedMappings: ListMap[ProjectRow, Project] = resolveMappings private def resolveProjectIds: Map[ProjectRow, String] = { @@ -470,7 +471,7 @@ object ProjectMatrix { private def enableScalaJSPlugin(project: Project): Project = project.enablePlugins( - scalajsPlugin(this.getClass.getClassLoader).getOrElse( + scalajsPlugin.getOrElse( sys.error( """Scala.js plugin was not found. Add the sbt-scalajs plugin into project/plugins.sbt: | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "x.y.z") @@ -517,9 +518,9 @@ object ProjectMatrix { override def defaultAxes(axes: VirtualAxis*): ProjectMatrix = copy(defAxes = axes.toSeq) - def scalajsPlugin(classLoader: ClassLoader): Try[AutoPlugin] = { + def scalajsPlugin: Try[AutoPlugin] = { import ReflectionUtil.* - withContextClassloader(classLoader) { loader => + withContextClassloader(pluginClassLoader) { loader => getSingletonObject[AutoPlugin](loader, "org.scalajs.sbtplugin.ScalaJSPlugin$") } } @@ -533,7 +534,7 @@ object ProjectMatrix { private def enableScalaNativePlugin(project: Project): Project = project.enablePlugins( - nativePlugin(this.getClass.getClassLoader).getOrElse( + nativePlugin.getOrElse( sys.error( """Scala Native plugin was not found. Add the sbt-scala-native plugin into project/plugins.sbt: | addSbtPlugin("org.scala-native" % "sbt-scala-native" % "x.y.z") @@ -577,9 +578,9 @@ object ProjectMatrix { project => configure(enableScalaNativePlugin(project)) ) - def nativePlugin(classLoader: ClassLoader): Try[AutoPlugin] = { + def nativePlugin: Try[AutoPlugin] = { import ReflectionUtil.* - withContextClassloader(classLoader) { loader => + withContextClassloader(pluginClassLoader) { loader => getSingletonObject[AutoPlugin](loader, "scala.scalanative.sbtplugin.ScalaNativePlugin$") } } @@ -678,6 +679,7 @@ object ProjectMatrix { plugins: Plugins = plugins, transforms: Seq[Project => Project] = transforms, defAxes: Seq[VirtualAxis] = defAxes, + pluginClassLoader: ClassLoader = pluginClassLoader, ): ProjectMatrix = { val matrix = unresolved( id, @@ -693,6 +695,7 @@ object ProjectMatrix { plugins, transforms, defAxes, + pluginClassLoader ) allMatrices(id) = matrix matrix @@ -700,7 +703,7 @@ object ProjectMatrix { } // called by macro - def apply(id: String, base: File): ProjectMatrix = { + def apply(id: String, base: File, pluginClassLoader: ClassLoader): ProjectMatrix = { val defaultDefAxes = Seq(VirtualAxis.jvm, VirtualAxis.scalaABIVersion("3.3.3")) val matrix = unresolved( id, @@ -715,7 +718,8 @@ object ProjectMatrix { Nil, Plugins.Empty, Nil, - defaultDefAxes + defaultDefAxes, + pluginClassLoader ) allMatrices(id) = matrix matrix @@ -735,6 +739,7 @@ object ProjectMatrix { plugins: Plugins, transforms: Seq[Project => Project], defAxes: Seq[VirtualAxis], + pluginClassLoader: ClassLoader ): ProjectMatrix = new ProjectMatrixDef( id, @@ -750,6 +755,7 @@ object ProjectMatrix { plugins, transforms, defAxes, + pluginClassLoader ) def lookupMatrix(local: LocalProjectMatrix): ProjectMatrix = { @@ -799,7 +805,8 @@ object ProjectMatrix { val name = std.KeyMacro.definingValName( "projectMatrix must be directly assigned to a val, such as `val x = projectMatrix`. Alternatively, you can use `sbt.ProjectMatrix.apply`" ) - '{ ProjectMatrix($name, new File($name)) } + val callerThis = std.KeyMacro.callerThis + '{ ProjectMatrix($name, new File($name), $callerThis.getClass.getClassLoader) } } }