mirror of https://github.com/sbt/sbt.git
Merge pull request #1163 from Duhemm/macro-arg-deps
Record dependencies on macro arguments
This commit is contained in:
commit
9533887797
|
|
@ -91,4 +91,42 @@ abstract class Compat
|
|||
private[this] def sourceCompatibilityOnly: Nothing = throw new RuntimeException("For source compatibility only: should not get here.")
|
||||
|
||||
private[this] final implicit def miscCompat(n: AnyRef): MiscCompat = new MiscCompat
|
||||
|
||||
object MacroExpansionOf {
|
||||
def unapply(tree: Tree): Option[Tree] = {
|
||||
|
||||
// MacroExpansionAttachment (MEA) compatibility for 2.8.x and 2.9.x
|
||||
object Compat {
|
||||
class MacroExpansionAttachment(val original: Tree)
|
||||
|
||||
// Trees have no attachments in 2.8.x and 2.9.x
|
||||
implicit def withAttachments(tree: Tree): WithAttachments = new WithAttachments(tree)
|
||||
class WithAttachments(val tree: Tree) {
|
||||
object EmptyAttachments {
|
||||
def all = Set.empty[Any]
|
||||
}
|
||||
val attachments = EmptyAttachments
|
||||
}
|
||||
}
|
||||
import Compat._
|
||||
|
||||
locally {
|
||||
// Wildcard imports are necessary since 2.8.x and 2.9.x don't have `MacroExpansionAttachment` at all
|
||||
import global._ // this is where MEA lives in 2.10.x
|
||||
|
||||
// `original` has been renamed to `expandee` in 2.11.x
|
||||
implicit def withExpandee(att: MacroExpansionAttachment): WithExpandee = new WithExpandee(att)
|
||||
class WithExpandee(att: MacroExpansionAttachment) {
|
||||
def expandee: Tree = att.original
|
||||
}
|
||||
|
||||
locally {
|
||||
import analyzer._ // this is where MEA lives in 2.11.x
|
||||
tree.attachments.all.collect {
|
||||
case att: MacroExpansionAttachment => att.expandee
|
||||
} headOption
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,6 +146,8 @@ final class Dependency(val global: CallbackGlobal) extends LocateClassFile
|
|||
deps.foreach(addDependency)
|
||||
case Template(parents, self, body) =>
|
||||
traverseTrees(body)
|
||||
case MacroExpansionOf(original) =>
|
||||
this.traverse(original)
|
||||
case other => ()
|
||||
}
|
||||
super.traverse(tree)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ import scala.tools.nsc._
|
|||
* The tree walking algorithm walks into TypeTree.original explicitly.
|
||||
*
|
||||
*/
|
||||
class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) {
|
||||
class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) extends Compat {
|
||||
import global._
|
||||
|
||||
def extract(unit: CompilationUnit): Set[String] = {
|
||||
|
|
@ -53,30 +53,44 @@ class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) {
|
|||
val symbolNameAsString = symbol.name.decode.trim
|
||||
namesBuffer += symbolNameAsString
|
||||
}
|
||||
def handleTreeNode(node: Tree): Unit = node match {
|
||||
case _: DefTree | _: Template => ()
|
||||
// turns out that Import node has a TermSymbol associated with it
|
||||
// I (Grzegorz) tried to understand why it's there and what does it represent but
|
||||
// that logic was introduced in 2005 without any justification I'll just ignore the
|
||||
// import node altogether and just process the selectors in the import node
|
||||
case Import(_, selectors: List[ImportSelector]) =>
|
||||
def usedNameInImportSelector(name: Name): Unit =
|
||||
if ((name != null) && (name != nme.WILDCARD)) namesBuffer += name.toString
|
||||
selectors foreach { selector =>
|
||||
usedNameInImportSelector(selector.name)
|
||||
usedNameInImportSelector(selector.rename)
|
||||
}
|
||||
// TODO: figure out whether we should process the original tree or walk the type
|
||||
// the argument for processing the original tree: we process what user wrote
|
||||
// the argument for processing the type: we catch all transformations that typer applies
|
||||
// to types but that might be a bad thing because it might expand aliases eagerly which
|
||||
// not what we need
|
||||
case t: TypeTree if t.original != null =>
|
||||
t.original.foreach(handleTreeNode)
|
||||
case t if t.hasSymbol && eligibleAsUsedName(t.symbol) =>
|
||||
addSymbol(t.symbol)
|
||||
case _ => ()
|
||||
|
||||
def handleTreeNode(node: Tree): Unit = {
|
||||
def handleMacroExpansion(original: Tree): Unit = original.foreach(handleTreeNode)
|
||||
|
||||
def handleClassicTreeNode(node: Tree): Unit = node match {
|
||||
case _: DefTree | _: Template => ()
|
||||
// turns out that Import node has a TermSymbol associated with it
|
||||
// I (Grzegorz) tried to understand why it's there and what does it represent but
|
||||
// that logic was introduced in 2005 without any justification I'll just ignore the
|
||||
// import node altogether and just process the selectors in the import node
|
||||
case Import(_, selectors: List[ImportSelector]) =>
|
||||
def usedNameInImportSelector(name: Name): Unit =
|
||||
if ((name != null) && (name != nme.WILDCARD)) namesBuffer += name.toString
|
||||
selectors foreach { selector =>
|
||||
usedNameInImportSelector(selector.name)
|
||||
usedNameInImportSelector(selector.rename)
|
||||
}
|
||||
// TODO: figure out whether we should process the original tree or walk the type
|
||||
// the argument for processing the original tree: we process what user wrote
|
||||
// the argument for processing the type: we catch all transformations that typer applies
|
||||
// to types but that might be a bad thing because it might expand aliases eagerly which
|
||||
// not what we need
|
||||
case t: TypeTree if t.original != null =>
|
||||
t.original.foreach(handleTreeNode)
|
||||
case t if t.hasSymbol && eligibleAsUsedName(t.symbol) =>
|
||||
addSymbol(t.symbol)
|
||||
case _ => ()
|
||||
}
|
||||
|
||||
node match {
|
||||
case MacroExpansionOf(original) =>
|
||||
handleClassicTreeNode(node)
|
||||
handleMacroExpansion(original)
|
||||
case _ =>
|
||||
handleClassicTreeNode(node)
|
||||
}
|
||||
}
|
||||
|
||||
tree.foreach(handleTreeNode)
|
||||
namesBuffer.toSet
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,6 +65,19 @@ class DependencySpecification extends Specification {
|
|||
inheritance('D) === Set('A, 'C)
|
||||
}
|
||||
|
||||
"Extracted source dependencies from macro arguments" in {
|
||||
val sourceDependencies = extractSourceDependenciesFromMacroArgument
|
||||
val memberRef = sourceDependencies.memberRef
|
||||
val inheritance = sourceDependencies.inheritance
|
||||
|
||||
memberRef('A) === Set('B, 'C)
|
||||
inheritance('A) === Set.empty
|
||||
memberRef('B) === Set.empty
|
||||
inheritance('B) === Set.empty
|
||||
memberRef('C) === Set.empty
|
||||
inheritance('C) === Set.empty
|
||||
}
|
||||
|
||||
private def extractSourceDependenciesPublic: ExtractedSourceDependencies = {
|
||||
val srcA = "class A"
|
||||
val srcB = "class B extends D[A]"
|
||||
|
|
@ -109,4 +122,25 @@ class DependencySpecification extends Specification {
|
|||
compilerForTesting.extractDependenciesFromSrcs('A -> srcA, 'B -> srcB, 'C -> srcC, 'D -> srcD)
|
||||
sourceDependencies
|
||||
}
|
||||
|
||||
private def extractSourceDependenciesFromMacroArgument: ExtractedSourceDependencies = {
|
||||
val srcA = "class A { println(B.printTree(C.foo)) }"
|
||||
val srcB = """
|
||||
|import scala.language.experimental.macros
|
||||
|import scala.reflect.macros._
|
||||
|object B {
|
||||
| def printTree(arg: Any) = macro printTreeImpl
|
||||
| def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
|
||||
| val argStr = arg.tree.toString
|
||||
| val literalStr = c.universe.Literal(c.universe.Constant(argStr))
|
||||
| c.Expr[String](literalStr)
|
||||
| }
|
||||
|}""".stripMargin
|
||||
val srcC = "object C { val foo = 1 }"
|
||||
|
||||
val compilerForTesting = new ScalaCompilerForUnitTesting(nameHashing = true)
|
||||
val sourceDependencies =
|
||||
compilerForTesting.extractDependenciesFromSrcs(List(Map('B -> srcB, 'C -> srcC), Map('A -> srcA)))
|
||||
sourceDependencies
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,15 +53,19 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
|
|||
* dependencies between snippets. Source code snippets are identified by symbols. Each symbol should
|
||||
* be associated with one snippet only.
|
||||
*
|
||||
* Snippets can be grouped to be compiled together in the same compiler run. This is
|
||||
* useful to compile macros, which cannot be used in the same compilation run that
|
||||
* defines them.
|
||||
*
|
||||
* Symbols are used to express extracted dependencies between source code snippets. This way we have
|
||||
* file system-independent way of testing dependencies between source code "files".
|
||||
*/
|
||||
def extractDependenciesFromSrcs(srcs: (Symbol, String)*): ExtractedSourceDependencies = {
|
||||
val (symbolsForSrcs, rawSrcs) = srcs.unzip
|
||||
assert(symbolsForSrcs.distinct.size == symbolsForSrcs.size,
|
||||
s"Duplicate symbols for srcs detected: $symbolsForSrcs")
|
||||
val (tempSrcFiles, testCallback) = compileSrcs(rawSrcs: _*)
|
||||
val fileToSymbol = (tempSrcFiles zip symbolsForSrcs).toMap
|
||||
def extractDependenciesFromSrcs(srcs: List[Map[Symbol, String]]): ExtractedSourceDependencies = {
|
||||
val rawGroupedSrcs = srcs.map(_.values.toList).toList
|
||||
val symbols = srcs.map(_.keys).flatten
|
||||
val (tempSrcFiles, testCallback) = compileSrcs(rawGroupedSrcs)
|
||||
val fileToSymbol = (tempSrcFiles zip symbols).toMap
|
||||
|
||||
val memberRefFileDeps = testCallback.sourceDependencies collect {
|
||||
// false indicates that those dependencies are not introduced by inheritance
|
||||
case (target, src, false) => (src, target)
|
||||
|
|
@ -82,40 +86,64 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
|
|||
// convert all collections to immutable variants
|
||||
multiMap.toMap.mapValues(_.toSet).withDefaultValue(Set.empty)
|
||||
}
|
||||
|
||||
ExtractedSourceDependencies(pairsToMultiMap(memberRefDeps), pairsToMultiMap(inheritanceDeps))
|
||||
}
|
||||
|
||||
def extractDependenciesFromSrcs(srcs: (Symbol, String)*): ExtractedSourceDependencies = {
|
||||
val symbols = srcs.map(_._1)
|
||||
assert(symbols.distinct.size == symbols.size,
|
||||
s"Duplicate symbols for srcs detected: $symbols")
|
||||
extractDependenciesFromSrcs(List(srcs.toMap))
|
||||
}
|
||||
|
||||
/**
|
||||
* Compiles given source code snippets written to a temporary files. Each snippet is
|
||||
* Compiles given source code snippets written to temporary files. Each snippet is
|
||||
* written to a separate temporary file.
|
||||
*
|
||||
* Snippets can be grouped to be compiled together in the same compiler run. This is
|
||||
* useful to compile macros, which cannot be used in the same compilation run that
|
||||
* defines them.
|
||||
*
|
||||
* The sequence of temporary files corresponding to passed snippets and analysis
|
||||
* callback is returned as a result.
|
||||
*/
|
||||
private def compileSrcs(srcs: String*): (Seq[File], TestCallback) = {
|
||||
private def compileSrcs(groupedSrcs: List[List[String]]): (Seq[File], TestCallback) = {
|
||||
withTemporaryDirectory { temp =>
|
||||
val analysisCallback = new TestCallback(nameHashing)
|
||||
val classesDir = new File(temp, "classes")
|
||||
classesDir.mkdir()
|
||||
val compiler = prepareCompiler(classesDir, analysisCallback)
|
||||
val run = new compiler.Run
|
||||
val srcFiles = srcs.toSeq.zipWithIndex map { case (src, i) =>
|
||||
val fileName = s"Test_$i.scala"
|
||||
prepareSrcFile(temp, fileName, src)
|
||||
|
||||
val compiler = prepareCompiler(classesDir, analysisCallback, classesDir.toString)
|
||||
|
||||
val files = for((compilationUnit, unitId) <- groupedSrcs.zipWithIndex) yield {
|
||||
val run = new compiler.Run
|
||||
val srcFiles = compilationUnit.toSeq.zipWithIndex map { case (src, i) =>
|
||||
val fileName = s"Test-$unitId-$i.scala"
|
||||
prepareSrcFile(temp, fileName, src)
|
||||
}
|
||||
val srcFilePaths = srcFiles.map(srcFile => srcFile.getAbsolutePath).toList
|
||||
|
||||
run.compile(srcFilePaths)
|
||||
|
||||
srcFilePaths.foreach(f => new File(f).delete)
|
||||
srcFiles
|
||||
}
|
||||
val srcFilePaths = srcFiles.map(srcFile => srcFile.getAbsolutePath).toList
|
||||
run.compile(srcFilePaths)
|
||||
(srcFiles, analysisCallback)
|
||||
(files.flatten.toSeq, analysisCallback)
|
||||
}
|
||||
}
|
||||
|
||||
private def compileSrcs(srcs: String*): (Seq[File], TestCallback) = {
|
||||
compileSrcs(List(srcs.toList))
|
||||
}
|
||||
|
||||
private def prepareSrcFile(baseDir: File, fileName: String, src: String): File = {
|
||||
val srcFile = new File(baseDir, fileName)
|
||||
sbt.IO.write(srcFile, src)
|
||||
srcFile
|
||||
}
|
||||
|
||||
private def prepareCompiler(outputDir: File, analysisCallback: AnalysisCallback): CachedCompiler0#Compiler = {
|
||||
private def prepareCompiler(outputDir: File, analysisCallback: AnalysisCallback, classpath: String = "."): CachedCompiler0#Compiler = {
|
||||
val args = Array.empty[String]
|
||||
object output extends SingleOutput {
|
||||
def outputDirectory: File = outputDir
|
||||
|
|
@ -123,6 +151,7 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
|
|||
val weakLog = new WeakLog(ConsoleLogger(), ConsoleReporter)
|
||||
val cachedCompiler = new CachedCompiler0(args, output, weakLog, false)
|
||||
val settings = cachedCompiler.settings
|
||||
settings.classpath.value = classpath
|
||||
settings.usejavacp.value = true
|
||||
val scalaReporter = new ConsoleReporter(settings)
|
||||
val delegatingReporter = DelegatingReporter(settings, ConsoleReporter)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
package macro
|
||||
|
||||
object Client {
|
||||
Provider.printTree(Provider.printTree(Foo.str))
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package macro
|
||||
|
||||
object Foo {
|
||||
def str: String = "abc"
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
package macro
|
||||
object Foo {
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package macro
|
||||
import scala.language.experimental.macros
|
||||
import scala.reflect.macros._
|
||||
|
||||
object Provider {
|
||||
def printTree(arg: Any) = macro printTreeImpl
|
||||
def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
|
||||
val argStr = arg.tree.toString
|
||||
val literalStr = c.universe.Literal(c.universe.Constant(argStr))
|
||||
c.Expr[String](literalStr)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import sbt._
|
||||
import Keys._
|
||||
|
||||
object build extends Build {
|
||||
val defaultSettings = Seq(
|
||||
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _ ),
|
||||
incOptions := incOptions.value.withNameHashing(true)
|
||||
)
|
||||
|
||||
lazy val root = Project(
|
||||
base = file("."),
|
||||
id = "macro",
|
||||
aggregate = Seq(macroProvider, macroClient),
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
|
||||
lazy val macroProvider = Project(
|
||||
base = file("macro-provider"),
|
||||
id = "macro-provider",
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
|
||||
lazy val macroClient = Project(
|
||||
base = file("macro-client"),
|
||||
id = "macro-client",
|
||||
dependencies = Seq(macroProvider),
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
> compile
|
||||
|
||||
# remove `Foo.str` which is an argument to a macro
|
||||
# (this macro itself that is an argument to another macro)
|
||||
$ copy-file macro-client/changes/Foo.scala macro-client/Foo.scala
|
||||
|
||||
# we should recompile Foo.scala first and then fail to compile Client.scala due to missing
|
||||
# `Foo.str`
|
||||
-> macro-client/compile
|
||||
|
||||
> clean
|
||||
|
||||
-> compile
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package macro
|
||||
|
||||
object Client {
|
||||
Provider.printTree(Foo.str)
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
package macro
|
||||
|
||||
object Foo {
|
||||
def str: String = "abc"
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
package macro
|
||||
object Foo {
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package macro
|
||||
import scala.language.experimental.macros
|
||||
import scala.reflect.macros._
|
||||
|
||||
object Provider {
|
||||
def printTree(arg: Any) = macro printTreeImpl
|
||||
def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
|
||||
val argStr = arg.tree.toString
|
||||
val literalStr = c.universe.Literal(c.universe.Constant(argStr))
|
||||
c.Expr[String](literalStr)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import sbt._
|
||||
import Keys._
|
||||
|
||||
object build extends Build {
|
||||
val defaultSettings = Seq(
|
||||
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _ ),
|
||||
incOptions := incOptions.value.withNameHashing(true)
|
||||
)
|
||||
|
||||
lazy val root = Project(
|
||||
base = file("."),
|
||||
id = "macro",
|
||||
aggregate = Seq(macroProvider, macroClient),
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
|
||||
lazy val macroProvider = Project(
|
||||
base = file("macro-provider"),
|
||||
id = "macro-provider",
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
|
||||
lazy val macroClient = Project(
|
||||
base = file("macro-client"),
|
||||
id = "macro-client",
|
||||
dependencies = Seq(macroProvider),
|
||||
settings = Defaults.defaultSettings ++ defaultSettings
|
||||
)
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
> compile
|
||||
|
||||
# remove `Foo.str` which is an argument to a macro
|
||||
$ copy-file macro-client/changes/Foo.scala macro-client/Foo.scala
|
||||
|
||||
# we should recompile Foo.scala first and then fail to compile Client.scala due to missing
|
||||
# `Foo.str`
|
||||
-> macro-client/compile
|
||||
|
||||
> clean
|
||||
|
||||
-> compile
|
||||
Loading…
Reference in New Issue