diff --git a/main-settings/src/main/scala/sbt/Structure.scala b/main-settings/src/main/scala/sbt/Structure.scala index c8fb5995b..ff361d604 100644 --- a/main-settings/src/main/scala/sbt/Structure.scala +++ b/main-settings/src/main/scala/sbt/Structure.scala @@ -347,8 +347,6 @@ object Scoped: protected def onTask[A2](f: Task[A1] => Task[A2]): Initialize[Task[A2]] = init.apply(f) - def dependsOn(tasks: Initialize[? <: Task[?]]*): Initialize[Task[A1]] = - init.zipWith(tasks.asInstanceOf[Seq[Initialize[Task[?]]]].join)(_.dependsOn(_*)) def flatMapTaskValue[T](f: A1 => Task[T]): Initialize[Task[T]] = onTask(_.result flatMap (f compose successM)) def map[A2](f: A1 => A2): Initialize[Task[A2]] = @@ -363,6 +361,8 @@ object Scoped: def tagw(tags: (Tag, Int)*): Initialize[Task[A1]] = onTask(_.tagw(tags: _*)) // Task-specific extensions + def dependsOn(tasks: Initialize[? <: Task[?]]*): Initialize[Task[A1]] = + init.zipWith(tasks.asInstanceOf[Seq[Initialize[Task[?]]]].join)(_.dependsOn(_*)) def dependsOnTask[A2](task1: Initialize[Task[A2]]): Initialize[Task[A1]] = dependsOnSeq(Seq[AnyInitTask](task1.asInstanceOf[AnyInitTask])) def dependsOnSeq(tasks: Seq[AnyInitTask]): Initialize[Task[A1]] = @@ -408,6 +408,11 @@ object Scoped: def tagw(tags: (Tag, Int)*): Initialize[InputTask[A1]] = onTask(_.tagw(tags: _*)) // InputTask specific extensions + @targetName("dependsOnInitializeInputTask") + def dependsOn(tasks: Initialize[? <: Task[?]]*): Initialize[InputTask[A1]] = + init.zipWith(tasks.asInstanceOf[Seq[Initialize[Task[?]]]].join)((thisTask, deps) => + thisTask.mapTask(_.dependsOn(deps*)) + ) @targetName("dependsOnTaskInitializeInputTask") def dependsOnTask[B1](task1: Initialize[Task[B1]]): Initialize[InputTask[A1]] = dependsOnSeq(Seq[AnyInitTask](task1.asInstanceOf[AnyInitTask])) diff --git a/main-settings/src/main/scala/sbt/std/KeyMacro.scala b/main-settings/src/main/scala/sbt/std/KeyMacro.scala index b59a4c534..653388433 100644 --- a/main-settings/src/main/scala/sbt/std/KeyMacro.scala +++ b/main-settings/src/main/scala/sbt/std/KeyMacro.scala @@ -62,7 +62,11 @@ private[sbt] object KeyMacro: if term.isValDef then Expr(term.name) else errorAndAbort(errorMsg) - def enclosingTerm(using qctx: Quotes) = + private[sbt] def callerThis(using Quotes): Expr[Any] = + import quotes.reflect.* + This(enclosingClass).asExpr + + private def enclosingTerm(using qctx: Quotes) = import qctx.reflect._ def enclosingTerm0(sym: Symbol): Symbol = sym match @@ -70,4 +74,11 @@ private[sbt] object KeyMacro: case sym if !sym.isTerm => enclosingTerm0(sym.owner) case _ => sym enclosingTerm0(Symbol.spliceOwner) + + private def enclosingClass(using Quotes) = + import quotes.reflect.* + def rec(sym: Symbol): Symbol = + if sym.isClassDef then sym + else rec(sym.owner) + rec(Symbol.spliceOwner) end KeyMacro diff --git a/main/src/main/scala/sbt/ProjectMatrix.scala b/main/src/main/scala/sbt/ProjectMatrix.scala index 1b834ded7..dd7bb7d12 100644 --- a/main/src/main/scala/sbt/ProjectMatrix.scala +++ b/main/src/main/scala/sbt/ProjectMatrix.scala @@ -253,6 +253,7 @@ object ProjectMatrix { val plugins: Plugins, val transforms: Seq[Project => Project], val defAxes: Seq[VirtualAxis], + val pluginClassLoader: ClassLoader ) extends ProjectMatrix { self => lazy val resolvedMappings: ListMap[ProjectRow, Project] = resolveMappings private def resolveProjectIds: Map[ProjectRow, String] = { @@ -470,7 +471,7 @@ object ProjectMatrix { private def enableScalaJSPlugin(project: Project): Project = project.enablePlugins( - scalajsPlugin(this.getClass.getClassLoader).getOrElse( + scalajsPlugin.getOrElse( sys.error( """Scala.js plugin was not found. Add the sbt-scalajs plugin into project/plugins.sbt: | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "x.y.z") @@ -517,9 +518,9 @@ object ProjectMatrix { override def defaultAxes(axes: VirtualAxis*): ProjectMatrix = copy(defAxes = axes.toSeq) - def scalajsPlugin(classLoader: ClassLoader): Try[AutoPlugin] = { + def scalajsPlugin: Try[AutoPlugin] = { import ReflectionUtil.* - withContextClassloader(classLoader) { loader => + withContextClassloader(pluginClassLoader) { loader => getSingletonObject[AutoPlugin](loader, "org.scalajs.sbtplugin.ScalaJSPlugin$") } } @@ -533,7 +534,7 @@ object ProjectMatrix { private def enableScalaNativePlugin(project: Project): Project = project.enablePlugins( - nativePlugin(this.getClass.getClassLoader).getOrElse( + nativePlugin.getOrElse( sys.error( """Scala Native plugin was not found. Add the sbt-scala-native plugin into project/plugins.sbt: | addSbtPlugin("org.scala-native" % "sbt-scala-native" % "x.y.z") @@ -577,9 +578,9 @@ object ProjectMatrix { project => configure(enableScalaNativePlugin(project)) ) - def nativePlugin(classLoader: ClassLoader): Try[AutoPlugin] = { + def nativePlugin: Try[AutoPlugin] = { import ReflectionUtil.* - withContextClassloader(classLoader) { loader => + withContextClassloader(pluginClassLoader) { loader => getSingletonObject[AutoPlugin](loader, "scala.scalanative.sbtplugin.ScalaNativePlugin$") } } @@ -678,6 +679,7 @@ object ProjectMatrix { plugins: Plugins = plugins, transforms: Seq[Project => Project] = transforms, defAxes: Seq[VirtualAxis] = defAxes, + pluginClassLoader: ClassLoader = pluginClassLoader, ): ProjectMatrix = { val matrix = unresolved( id, @@ -693,6 +695,7 @@ object ProjectMatrix { plugins, transforms, defAxes, + pluginClassLoader ) allMatrices(id) = matrix matrix @@ -700,7 +703,7 @@ object ProjectMatrix { } // called by macro - def apply(id: String, base: File): ProjectMatrix = { + def apply(id: String, base: File, pluginClassLoader: ClassLoader): ProjectMatrix = { val defaultDefAxes = Seq(VirtualAxis.jvm, VirtualAxis.scalaABIVersion("3.3.3")) val matrix = unresolved( id, @@ -715,7 +718,8 @@ object ProjectMatrix { Nil, Plugins.Empty, Nil, - defaultDefAxes + defaultDefAxes, + pluginClassLoader ) allMatrices(id) = matrix matrix @@ -735,6 +739,7 @@ object ProjectMatrix { plugins: Plugins, transforms: Seq[Project => Project], defAxes: Seq[VirtualAxis], + pluginClassLoader: ClassLoader ): ProjectMatrix = new ProjectMatrixDef( id, @@ -750,6 +755,7 @@ object ProjectMatrix { plugins, transforms, defAxes, + pluginClassLoader ) def lookupMatrix(local: LocalProjectMatrix): ProjectMatrix = { @@ -799,7 +805,8 @@ object ProjectMatrix { val name = std.KeyMacro.definingValName( "projectMatrix must be directly assigned to a val, such as `val x = projectMatrix`. Alternatively, you can use `sbt.ProjectMatrix.apply`" ) - '{ ProjectMatrix($name, new File($name)) } + val callerThis = std.KeyMacro.callerThis + '{ ProjectMatrix($name, new File($name), $callerThis.getClass.getClassLoader) } } }