Merge pull request #1163 from Duhemm/macro-arg-deps

Record dependencies on macro arguments
This commit is contained in:
eugene yokota 2014-03-21 17:58:06 -04:00
commit 9533887797
17 changed files with 291 additions and 41 deletions

View File

@ -91,4 +91,42 @@ abstract class Compat
private[this] def sourceCompatibilityOnly: Nothing = throw new RuntimeException("For source compatibility only: should not get here.")
private[this] final implicit def miscCompat(n: AnyRef): MiscCompat = new MiscCompat
object MacroExpansionOf {
def unapply(tree: Tree): Option[Tree] = {
// MacroExpansionAttachment (MEA) compatibility for 2.8.x and 2.9.x
object Compat {
class MacroExpansionAttachment(val original: Tree)
// Trees have no attachments in 2.8.x and 2.9.x
implicit def withAttachments(tree: Tree): WithAttachments = new WithAttachments(tree)
class WithAttachments(val tree: Tree) {
object EmptyAttachments {
def all = Set.empty[Any]
}
val attachments = EmptyAttachments
}
}
import Compat._
locally {
// Wildcard imports are necessary since 2.8.x and 2.9.x don't have `MacroExpansionAttachment` at all
import global._ // this is where MEA lives in 2.10.x
// `original` has been renamed to `expandee` in 2.11.x
implicit def withExpandee(att: MacroExpansionAttachment): WithExpandee = new WithExpandee(att)
class WithExpandee(att: MacroExpansionAttachment) {
def expandee: Tree = att.original
}
locally {
import analyzer._ // this is where MEA lives in 2.11.x
tree.attachments.all.collect {
case att: MacroExpansionAttachment => att.expandee
} headOption
}
}
}
}
}

View File

@ -146,6 +146,8 @@ final class Dependency(val global: CallbackGlobal) extends LocateClassFile
deps.foreach(addDependency)
case Template(parents, self, body) =>
traverseTrees(body)
case MacroExpansionOf(original) =>
this.traverse(original)
case other => ()
}
super.traverse(tree)

View File

@ -38,7 +38,7 @@ import scala.tools.nsc._
* The tree walking algorithm walks into TypeTree.original explicitly.
*
*/
class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) {
class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) extends Compat {
import global._
def extract(unit: CompilationUnit): Set[String] = {
@ -53,30 +53,44 @@ class ExtractUsedNames[GlobalType <: CallbackGlobal](val global: GlobalType) {
val symbolNameAsString = symbol.name.decode.trim
namesBuffer += symbolNameAsString
}
def handleTreeNode(node: Tree): Unit = node match {
case _: DefTree | _: Template => ()
// turns out that Import node has a TermSymbol associated with it
// I (Grzegorz) tried to understand why it's there and what does it represent but
// that logic was introduced in 2005 without any justification I'll just ignore the
// import node altogether and just process the selectors in the import node
case Import(_, selectors: List[ImportSelector]) =>
def usedNameInImportSelector(name: Name): Unit =
if ((name != null) && (name != nme.WILDCARD)) namesBuffer += name.toString
selectors foreach { selector =>
usedNameInImportSelector(selector.name)
usedNameInImportSelector(selector.rename)
}
// TODO: figure out whether we should process the original tree or walk the type
// the argument for processing the original tree: we process what user wrote
// the argument for processing the type: we catch all transformations that typer applies
// to types but that might be a bad thing because it might expand aliases eagerly which
// not what we need
case t: TypeTree if t.original != null =>
t.original.foreach(handleTreeNode)
case t if t.hasSymbol && eligibleAsUsedName(t.symbol) =>
addSymbol(t.symbol)
case _ => ()
def handleTreeNode(node: Tree): Unit = {
def handleMacroExpansion(original: Tree): Unit = original.foreach(handleTreeNode)
def handleClassicTreeNode(node: Tree): Unit = node match {
case _: DefTree | _: Template => ()
// turns out that Import node has a TermSymbol associated with it
// I (Grzegorz) tried to understand why it's there and what does it represent but
// that logic was introduced in 2005 without any justification I'll just ignore the
// import node altogether and just process the selectors in the import node
case Import(_, selectors: List[ImportSelector]) =>
def usedNameInImportSelector(name: Name): Unit =
if ((name != null) && (name != nme.WILDCARD)) namesBuffer += name.toString
selectors foreach { selector =>
usedNameInImportSelector(selector.name)
usedNameInImportSelector(selector.rename)
}
// TODO: figure out whether we should process the original tree or walk the type
// the argument for processing the original tree: we process what user wrote
// the argument for processing the type: we catch all transformations that typer applies
// to types but that might be a bad thing because it might expand aliases eagerly which
// not what we need
case t: TypeTree if t.original != null =>
t.original.foreach(handleTreeNode)
case t if t.hasSymbol && eligibleAsUsedName(t.symbol) =>
addSymbol(t.symbol)
case _ => ()
}
node match {
case MacroExpansionOf(original) =>
handleClassicTreeNode(node)
handleMacroExpansion(original)
case _ =>
handleClassicTreeNode(node)
}
}
tree.foreach(handleTreeNode)
namesBuffer.toSet
}

View File

@ -65,6 +65,19 @@ class DependencySpecification extends Specification {
inheritance('D) === Set('A, 'C)
}
"Extracted source dependencies from macro arguments" in {
val sourceDependencies = extractSourceDependenciesFromMacroArgument
val memberRef = sourceDependencies.memberRef
val inheritance = sourceDependencies.inheritance
memberRef('A) === Set('B, 'C)
inheritance('A) === Set.empty
memberRef('B) === Set.empty
inheritance('B) === Set.empty
memberRef('C) === Set.empty
inheritance('C) === Set.empty
}
private def extractSourceDependenciesPublic: ExtractedSourceDependencies = {
val srcA = "class A"
val srcB = "class B extends D[A]"
@ -109,4 +122,25 @@ class DependencySpecification extends Specification {
compilerForTesting.extractDependenciesFromSrcs('A -> srcA, 'B -> srcB, 'C -> srcC, 'D -> srcD)
sourceDependencies
}
private def extractSourceDependenciesFromMacroArgument: ExtractedSourceDependencies = {
val srcA = "class A { println(B.printTree(C.foo)) }"
val srcB = """
|import scala.language.experimental.macros
|import scala.reflect.macros._
|object B {
| def printTree(arg: Any) = macro printTreeImpl
| def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
| val argStr = arg.tree.toString
| val literalStr = c.universe.Literal(c.universe.Constant(argStr))
| c.Expr[String](literalStr)
| }
|}""".stripMargin
val srcC = "object C { val foo = 1 }"
val compilerForTesting = new ScalaCompilerForUnitTesting(nameHashing = true)
val sourceDependencies =
compilerForTesting.extractDependenciesFromSrcs(List(Map('B -> srcB, 'C -> srcC), Map('A -> srcA)))
sourceDependencies
}
}

View File

@ -53,15 +53,19 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
* dependencies between snippets. Source code snippets are identified by symbols. Each symbol should
* be associated with one snippet only.
*
* Snippets can be grouped to be compiled together in the same compiler run. This is
* useful to compile macros, which cannot be used in the same compilation run that
* defines them.
*
* Symbols are used to express extracted dependencies between source code snippets. This way we have
* file system-independent way of testing dependencies between source code "files".
*/
def extractDependenciesFromSrcs(srcs: (Symbol, String)*): ExtractedSourceDependencies = {
val (symbolsForSrcs, rawSrcs) = srcs.unzip
assert(symbolsForSrcs.distinct.size == symbolsForSrcs.size,
s"Duplicate symbols for srcs detected: $symbolsForSrcs")
val (tempSrcFiles, testCallback) = compileSrcs(rawSrcs: _*)
val fileToSymbol = (tempSrcFiles zip symbolsForSrcs).toMap
def extractDependenciesFromSrcs(srcs: List[Map[Symbol, String]]): ExtractedSourceDependencies = {
val rawGroupedSrcs = srcs.map(_.values.toList).toList
val symbols = srcs.map(_.keys).flatten
val (tempSrcFiles, testCallback) = compileSrcs(rawGroupedSrcs)
val fileToSymbol = (tempSrcFiles zip symbols).toMap
val memberRefFileDeps = testCallback.sourceDependencies collect {
// false indicates that those dependencies are not introduced by inheritance
case (target, src, false) => (src, target)
@ -82,40 +86,64 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
// convert all collections to immutable variants
multiMap.toMap.mapValues(_.toSet).withDefaultValue(Set.empty)
}
ExtractedSourceDependencies(pairsToMultiMap(memberRefDeps), pairsToMultiMap(inheritanceDeps))
}
def extractDependenciesFromSrcs(srcs: (Symbol, String)*): ExtractedSourceDependencies = {
val symbols = srcs.map(_._1)
assert(symbols.distinct.size == symbols.size,
s"Duplicate symbols for srcs detected: $symbols")
extractDependenciesFromSrcs(List(srcs.toMap))
}
/**
* Compiles given source code snippets written to a temporary files. Each snippet is
* Compiles given source code snippets written to temporary files. Each snippet is
* written to a separate temporary file.
*
* Snippets can be grouped to be compiled together in the same compiler run. This is
* useful to compile macros, which cannot be used in the same compilation run that
* defines them.
*
* The sequence of temporary files corresponding to passed snippets and analysis
* callback is returned as a result.
*/
private def compileSrcs(srcs: String*): (Seq[File], TestCallback) = {
private def compileSrcs(groupedSrcs: List[List[String]]): (Seq[File], TestCallback) = {
withTemporaryDirectory { temp =>
val analysisCallback = new TestCallback(nameHashing)
val classesDir = new File(temp, "classes")
classesDir.mkdir()
val compiler = prepareCompiler(classesDir, analysisCallback)
val run = new compiler.Run
val srcFiles = srcs.toSeq.zipWithIndex map { case (src, i) =>
val fileName = s"Test_$i.scala"
prepareSrcFile(temp, fileName, src)
val compiler = prepareCompiler(classesDir, analysisCallback, classesDir.toString)
val files = for((compilationUnit, unitId) <- groupedSrcs.zipWithIndex) yield {
val run = new compiler.Run
val srcFiles = compilationUnit.toSeq.zipWithIndex map { case (src, i) =>
val fileName = s"Test-$unitId-$i.scala"
prepareSrcFile(temp, fileName, src)
}
val srcFilePaths = srcFiles.map(srcFile => srcFile.getAbsolutePath).toList
run.compile(srcFilePaths)
srcFilePaths.foreach(f => new File(f).delete)
srcFiles
}
val srcFilePaths = srcFiles.map(srcFile => srcFile.getAbsolutePath).toList
run.compile(srcFilePaths)
(srcFiles, analysisCallback)
(files.flatten.toSeq, analysisCallback)
}
}
private def compileSrcs(srcs: String*): (Seq[File], TestCallback) = {
compileSrcs(List(srcs.toList))
}
private def prepareSrcFile(baseDir: File, fileName: String, src: String): File = {
val srcFile = new File(baseDir, fileName)
sbt.IO.write(srcFile, src)
srcFile
}
private def prepareCompiler(outputDir: File, analysisCallback: AnalysisCallback): CachedCompiler0#Compiler = {
private def prepareCompiler(outputDir: File, analysisCallback: AnalysisCallback, classpath: String = "."): CachedCompiler0#Compiler = {
val args = Array.empty[String]
object output extends SingleOutput {
def outputDirectory: File = outputDir
@ -123,6 +151,7 @@ class ScalaCompilerForUnitTesting(nameHashing: Boolean = false) {
val weakLog = new WeakLog(ConsoleLogger(), ConsoleReporter)
val cachedCompiler = new CachedCompiler0(args, output, weakLog, false)
val settings = cachedCompiler.settings
settings.classpath.value = classpath
settings.usejavacp.value = true
val scalaReporter = new ConsoleReporter(settings)
val delegatingReporter = DelegatingReporter(settings, ConsoleReporter)

View File

@ -0,0 +1,5 @@
package macro
object Client {
Provider.printTree(Provider.printTree(Foo.str))
}

View File

@ -0,0 +1,5 @@
package macro
object Foo {
def str: String = "abc"
}

View File

@ -0,0 +1,3 @@
package macro
object Foo {
}

View File

@ -0,0 +1,12 @@
package macro
import scala.language.experimental.macros
import scala.reflect.macros._
object Provider {
def printTree(arg: Any) = macro printTreeImpl
def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
val argStr = arg.tree.toString
val literalStr = c.universe.Literal(c.universe.Constant(argStr))
c.Expr[String](literalStr)
}
}

View File

@ -0,0 +1,29 @@
import sbt._
import Keys._
object build extends Build {
val defaultSettings = Seq(
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _ ),
incOptions := incOptions.value.withNameHashing(true)
)
lazy val root = Project(
base = file("."),
id = "macro",
aggregate = Seq(macroProvider, macroClient),
settings = Defaults.defaultSettings ++ defaultSettings
)
lazy val macroProvider = Project(
base = file("macro-provider"),
id = "macro-provider",
settings = Defaults.defaultSettings ++ defaultSettings
)
lazy val macroClient = Project(
base = file("macro-client"),
id = "macro-client",
dependencies = Seq(macroProvider),
settings = Defaults.defaultSettings ++ defaultSettings
)
}

View File

@ -0,0 +1,13 @@
> compile
# remove `Foo.str` which is an argument to a macro
# (this macro itself that is an argument to another macro)
$ copy-file macro-client/changes/Foo.scala macro-client/Foo.scala
# we should recompile Foo.scala first and then fail to compile Client.scala due to missing
# `Foo.str`
-> macro-client/compile
> clean
-> compile

View File

@ -0,0 +1,5 @@
package macro
object Client {
Provider.printTree(Foo.str)
}

View File

@ -0,0 +1,5 @@
package macro
object Foo {
def str: String = "abc"
}

View File

@ -0,0 +1,3 @@
package macro
object Foo {
}

View File

@ -0,0 +1,12 @@
package macro
import scala.language.experimental.macros
import scala.reflect.macros._
object Provider {
def printTree(arg: Any) = macro printTreeImpl
def printTreeImpl(c: Context)(arg: c.Expr[Any]): c.Expr[String] = {
val argStr = arg.tree.toString
val literalStr = c.universe.Literal(c.universe.Constant(argStr))
c.Expr[String](literalStr)
}
}

View File

@ -0,0 +1,29 @@
import sbt._
import Keys._
object build extends Build {
val defaultSettings = Seq(
libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-reflect" % _ ),
incOptions := incOptions.value.withNameHashing(true)
)
lazy val root = Project(
base = file("."),
id = "macro",
aggregate = Seq(macroProvider, macroClient),
settings = Defaults.defaultSettings ++ defaultSettings
)
lazy val macroProvider = Project(
base = file("macro-provider"),
id = "macro-provider",
settings = Defaults.defaultSettings ++ defaultSettings
)
lazy val macroClient = Project(
base = file("macro-client"),
id = "macro-client",
dependencies = Seq(macroProvider),
settings = Defaults.defaultSettings ++ defaultSettings
)
}

View File

@ -0,0 +1,12 @@
> compile
# remove `Foo.str` which is an argument to a macro
$ copy-file macro-client/changes/Foo.scala macro-client/Foo.scala
# we should recompile Foo.scala first and then fail to compile Client.scala due to missing
# `Foo.str`
-> macro-client/compile
> clean
-> compile