diff --git a/core-macros/src/main/scala/sbt/internal/util/appmacro/ContextUtil.scala b/core-macros/src/main/scala/sbt/internal/util/appmacro/ContextUtil.scala index f6dbf5894..993072ab9 100644 --- a/core-macros/src/main/scala/sbt/internal/util/appmacro/ContextUtil.scala +++ b/core-macros/src/main/scala/sbt/internal/util/appmacro/ContextUtil.scala @@ -109,12 +109,14 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) { def illegalReference(defs: collection.Set[Symbol], sym: Symbol): Boolean = sym != null && sym != NoSymbol && defs.contains(sym) + type PropertyChecker = (String, Type, Tree) => Boolean + /** * A function that checks the provided tree for illegal references to M instances defined in the * expression passed to the macro and for illegal dereferencing of M instances. */ def checkReferences(defs: collection.Set[Symbol], - isWrapper: (String, Type, Tree) => Boolean): Tree => Unit = { + isWrapper: PropertyChecker): Tree => Unit = { case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) => if (isWrapper(nme.decodedName.toString, tpe.tpe, qual)) ctx.error(s.pos, DynamicDependencyError) @@ -123,6 +125,28 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) { case _ => () } + class TaskMacroViolationDiscovery(isInvalid: PropertyChecker) extends Traverser { + var insideIf: Boolean = false + override def traverse(tree: ctx.universe.Tree): Unit = { + tree match { + case If(condition, thenp, elsep) => + super.traverse(condition) + insideIf = true + super.traverse(thenp) + super.traverse(elsep) + insideIf = false + case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) => + if (insideIf && isInvalid(nme.decodedName.toString, tpe.tpe, qual)) { + ctx.error(s.pos, "DSL error: a task value cannot be obtained inside if.") + } + case _ => super.traverse(tree) + } + } + } + + def checkMacroViolations(tree: Tree, isInvalid: PropertyChecker) = + new TaskMacroViolationDiscovery(isInvalid).traverse(tree) + /** Constructs a ValDef with a parameter modifier, a unique name, with the provided Type and with an empty rhs. */ def freshMethodParameter(tpe: Type): ValDef = ValDef(parameterModifiers, freshTermName("p"), TypeTree(tpe), EmptyTree) diff --git a/core-macros/src/main/scala/sbt/internal/util/appmacro/Instance.scala b/core-macros/src/main/scala/sbt/internal/util/appmacro/Instance.scala index e5a2fcbf3..9a36e3ff8 100644 --- a/core-macros/src/main/scala/sbt/internal/util/appmacro/Instance.scala +++ b/core-macros/src/main/scala/sbt/internal/util/appmacro/Instance.scala @@ -86,6 +86,8 @@ object Instance { nt: c.WeakTypeTag[N[T]], it: c.TypeTag[i.type] ): c.Expr[i.M[N[T]]] = { + println("STARTING CONTIMPL") + println("=================") import c.universe.{ Apply => ApplyTree, _ } val util = ContextUtil[c.type](c) @@ -97,6 +99,8 @@ object Instance { case Left(l) => (l.tree, nt.tpe.dealias) case Right(r) => (r.tree, mttpe) } + println("TREE TYPE") + println(treeType) // the Symbol for the anonymous function passed to the appropriate Instance.map/flatMap/pure method // this Symbol needs to be known up front so that it can be used as the owner of synthetic vals val functionSym = util.functionSymbol(tree.pos) @@ -183,9 +187,17 @@ object Instance { } // applies the transformation + println("TREE") + println(tree) + import util.PropertyChecker + val isInvalid: PropertyChecker = + (s: String, tpe: Type, tree: Tree) => t.isLeft && convert.asPredicate(c).apply(s, tpe, tree) + util.checkMacroViolations(tree, isInvalid) val tx = util.transformWrappers(tree, (n, tpe, t, replace) => sub(n, tpe, t, replace)) // resetting attributes must be: a) local b) done here and not wider or else there are obscure errors val tr = makeApp(inner(tx)) + println("FINAL TREE") + println(tr) c.Expr[i.M[N[T]]](tr) } diff --git a/main-settings/src/test/scala/sbt/std/TaskPosSpec.scala b/main-settings/src/test/scala/sbt/std/TaskPosSpec.scala new file mode 100644 index 000000000..c97ce92ba --- /dev/null +++ b/main-settings/src/test/scala/sbt/std/TaskPosSpec.scala @@ -0,0 +1,29 @@ +package sbt.std + +class TaskPosSpec { + // Dynamic tasks can have task invocations inside if branches + locally { + import sbt._ + import sbt.Def._ + val foo = taskKey[String]("") + val bar = taskKey[String]("") + var condition = true + val baz = Def.taskDyn[String] { + if (condition) foo + else bar + } + } + + // Dynamic settings can have setting invocations inside if branches + locally { + import sbt._ + import sbt.Def._ + val foo = settingKey[String]("") + val bar = settingKey[String]("") + var condition = true + val baz = Def.settingDyn[String] { + if (condition) foo + else bar + } + } +} diff --git a/main-settings/src/test/scala/sbt/std/TestUtil.scala b/main-settings/src/test/scala/sbt/std/TestUtil.scala index 09365ba52..34b6a1b3b 100644 --- a/main-settings/src/test/scala/sbt/std/TestUtil.scala +++ b/main-settings/src/test/scala/sbt/std/TestUtil.scala @@ -8,18 +8,7 @@ package sbt.std import scala.reflect._ object TestUtil { - import tools.reflect.{ ToolBox, ToolBoxError } - - def intercept[T <: Throwable: ClassTag](test: => Any): T = { - try { - test - throw new Exception(s"Expected exception ${classTag[T]}") - } catch { - case t: Throwable => - if (classTag[T].runtimeClass != t.getClass) throw t - else t.asInstanceOf[T] - } - } + import tools.reflect.ToolBox def eval(code: String, compileOptions: String = ""): Any = { val tb = mkToolbox(compileOptions) @@ -38,18 +27,4 @@ object TestUtil { val completeSporesCoreClasspath = classpathFile.getLines.mkString completeSporesCoreClasspath } - - def expectError(errorSnippet: String, - compileOptions: String = "-Xmacro-settings:debug-spores", - baseCompileOptions: String = s"-cp $toolboxClasspath")(code: String): Unit = { - val errorMessage = intercept[ToolBoxError] { - eval(code, s"$compileOptions $baseCompileOptions") - }.getMessage - val userMessage = - s""" - |FOUND: $errorMessage - |EXPECTED: $errorSnippet - """.stripMargin - assert(errorMessage.contains(errorSnippet), userMessage) - } } diff --git a/main-settings/src/test/scala/sbt/std/neg/TaskNegSpec.scala b/main-settings/src/test/scala/sbt/std/neg/TaskNegSpec.scala new file mode 100644 index 000000000..17f16d3f8 --- /dev/null +++ b/main-settings/src/test/scala/sbt/std/neg/TaskNegSpec.scala @@ -0,0 +1,42 @@ +package sbt.std.neg + +import org.scalatest.FunSuite +import sbt.std.TestUtil._ + +class TaskNegSpec extends FunSuite { + import tools.reflect.ToolBoxError + def expectError(errorSnippet: String, + compileOptions: String = "-Xmacro-settings:debug-spores", + baseCompileOptions: String = s"-cp $toolboxClasspath")(code: String) = { + val errorMessage = intercept[ToolBoxError] { + eval(code, s"$compileOptions $baseCompileOptions") + println("ERROR: The snippet compiled successfully.") + }.getMessage + val userMessage = + s""" + |FOUND: $errorMessage + |EXPECTED: $errorSnippet + """.stripMargin + println(userMessage) + assert(errorMessage.contains(errorSnippet), userMessage) + } + + test("Fail on task invocation inside if of regular task") { + + expectError("DSL error: a task value cannot be obtained inside if.") { + """ + |import sbt._ + |import sbt.Def._ + | + |val fooNeg = taskKey[String]("") + |val barNeg = taskKey[String]("") + |var condition = true + | + |val bazNeg: Initialize[Task[String]] = Def.task[String] { + | if (condition) fooNeg.value + | else barNeg.value + |} + """.stripMargin + } + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 7d918eaaf..468402eb1 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -110,6 +110,7 @@ object Dependencies { val sjsonNewScalaJson = "com.eed3si9n" %% "sjson-new-scalajson" % "0.7.0" + val scalatest = "org.scalatest" %% "scalatest" % "3.0.1" val scalaCheck = "org.scalacheck" %% "scalacheck" % "1.13.4" val specs2 = "org.specs2" %% "specs2" % "2.4.17" val junit = "junit" % "junit" % "4.11" diff --git a/project/NightlyPlugin.scala b/project/NightlyPlugin.scala index 893410d0b..58c1917dc 100644 --- a/project/NightlyPlugin.scala +++ b/project/NightlyPlugin.scala @@ -10,7 +10,7 @@ object NightlyPlugin extends AutoPlugin { val includeTestDependencies = settingKey[Boolean]("Doesn't declare test dependencies.") def testDependencies = libraryDependencies ++= ( - if (includeTestDependencies.value) Seq(scalaCheck % Test, specs2 % Test, junit % Test) + if (includeTestDependencies.value) Seq(scalaCheck % Test, specs2 % Test, junit % Test, scalatest % Test) else Seq() ) }