mirror of https://github.com/sbt/sbt.git
Add check of task invocation inside `if`
This commit adds the first version of the checker that will tell users if they are doing something wrong. The first version warns any user that write `.value` inside an if expression. As the test integration is not yet working correctly and messages are swallowed, we have to println to get info from the test.
This commit is contained in:
parent
5b7180cfa7
commit
b4299e7f34
|
|
@ -109,12 +109,14 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) {
|
|||
def illegalReference(defs: collection.Set[Symbol], sym: Symbol): Boolean =
|
||||
sym != null && sym != NoSymbol && defs.contains(sym)
|
||||
|
||||
type PropertyChecker = (String, Type, Tree) => Boolean
|
||||
|
||||
/**
|
||||
* A function that checks the provided tree for illegal references to M instances defined in the
|
||||
* expression passed to the macro and for illegal dereferencing of M instances.
|
||||
*/
|
||||
def checkReferences(defs: collection.Set[Symbol],
|
||||
isWrapper: (String, Type, Tree) => Boolean): Tree => Unit = {
|
||||
isWrapper: PropertyChecker): Tree => Unit = {
|
||||
case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) =>
|
||||
if (isWrapper(nme.decodedName.toString, tpe.tpe, qual))
|
||||
ctx.error(s.pos, DynamicDependencyError)
|
||||
|
|
@ -123,6 +125,28 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) {
|
|||
case _ => ()
|
||||
}
|
||||
|
||||
class TaskMacroViolationDiscovery(isInvalid: PropertyChecker) extends Traverser {
|
||||
var insideIf: Boolean = false
|
||||
override def traverse(tree: ctx.universe.Tree): Unit = {
|
||||
tree match {
|
||||
case If(condition, thenp, elsep) =>
|
||||
super.traverse(condition)
|
||||
insideIf = true
|
||||
super.traverse(thenp)
|
||||
super.traverse(elsep)
|
||||
insideIf = false
|
||||
case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) =>
|
||||
if (insideIf && isInvalid(nme.decodedName.toString, tpe.tpe, qual)) {
|
||||
ctx.error(s.pos, "DSL error: a task value cannot be obtained inside if.")
|
||||
}
|
||||
case _ => super.traverse(tree)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def checkMacroViolations(tree: Tree, isInvalid: PropertyChecker) =
|
||||
new TaskMacroViolationDiscovery(isInvalid).traverse(tree)
|
||||
|
||||
/** Constructs a ValDef with a parameter modifier, a unique name, with the provided Type and with an empty rhs. */
|
||||
def freshMethodParameter(tpe: Type): ValDef =
|
||||
ValDef(parameterModifiers, freshTermName("p"), TypeTree(tpe), EmptyTree)
|
||||
|
|
|
|||
|
|
@ -86,6 +86,8 @@ object Instance {
|
|||
nt: c.WeakTypeTag[N[T]],
|
||||
it: c.TypeTag[i.type]
|
||||
): c.Expr[i.M[N[T]]] = {
|
||||
println("STARTING CONTIMPL")
|
||||
println("=================")
|
||||
import c.universe.{ Apply => ApplyTree, _ }
|
||||
|
||||
val util = ContextUtil[c.type](c)
|
||||
|
|
@ -97,6 +99,8 @@ object Instance {
|
|||
case Left(l) => (l.tree, nt.tpe.dealias)
|
||||
case Right(r) => (r.tree, mttpe)
|
||||
}
|
||||
println("TREE TYPE")
|
||||
println(treeType)
|
||||
// the Symbol for the anonymous function passed to the appropriate Instance.map/flatMap/pure method
|
||||
// this Symbol needs to be known up front so that it can be used as the owner of synthetic vals
|
||||
val functionSym = util.functionSymbol(tree.pos)
|
||||
|
|
@ -183,9 +187,17 @@ object Instance {
|
|||
}
|
||||
|
||||
// applies the transformation
|
||||
println("TREE")
|
||||
println(tree)
|
||||
import util.PropertyChecker
|
||||
val isInvalid: PropertyChecker =
|
||||
(s: String, tpe: Type, tree: Tree) => t.isLeft && convert.asPredicate(c).apply(s, tpe, tree)
|
||||
util.checkMacroViolations(tree, isInvalid)
|
||||
val tx = util.transformWrappers(tree, (n, tpe, t, replace) => sub(n, tpe, t, replace))
|
||||
// resetting attributes must be: a) local b) done here and not wider or else there are obscure errors
|
||||
val tr = makeApp(inner(tx))
|
||||
println("FINAL TREE")
|
||||
println(tr)
|
||||
c.Expr[i.M[N[T]]](tr)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
package sbt.std
|
||||
|
||||
class TaskPosSpec {
|
||||
// Dynamic tasks can have task invocations inside if branches
|
||||
locally {
|
||||
import sbt._
|
||||
import sbt.Def._
|
||||
val foo = taskKey[String]("")
|
||||
val bar = taskKey[String]("")
|
||||
var condition = true
|
||||
val baz = Def.taskDyn[String] {
|
||||
if (condition) foo
|
||||
else bar
|
||||
}
|
||||
}
|
||||
|
||||
// Dynamic settings can have setting invocations inside if branches
|
||||
locally {
|
||||
import sbt._
|
||||
import sbt.Def._
|
||||
val foo = settingKey[String]("")
|
||||
val bar = settingKey[String]("")
|
||||
var condition = true
|
||||
val baz = Def.settingDyn[String] {
|
||||
if (condition) foo
|
||||
else bar
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,18 +8,7 @@ package sbt.std
|
|||
import scala.reflect._
|
||||
|
||||
object TestUtil {
|
||||
import tools.reflect.{ ToolBox, ToolBoxError }
|
||||
|
||||
def intercept[T <: Throwable: ClassTag](test: => Any): T = {
|
||||
try {
|
||||
test
|
||||
throw new Exception(s"Expected exception ${classTag[T]}")
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
if (classTag[T].runtimeClass != t.getClass) throw t
|
||||
else t.asInstanceOf[T]
|
||||
}
|
||||
}
|
||||
import tools.reflect.ToolBox
|
||||
|
||||
def eval(code: String, compileOptions: String = ""): Any = {
|
||||
val tb = mkToolbox(compileOptions)
|
||||
|
|
@ -38,18 +27,4 @@ object TestUtil {
|
|||
val completeSporesCoreClasspath = classpathFile.getLines.mkString
|
||||
completeSporesCoreClasspath
|
||||
}
|
||||
|
||||
def expectError(errorSnippet: String,
|
||||
compileOptions: String = "-Xmacro-settings:debug-spores",
|
||||
baseCompileOptions: String = s"-cp $toolboxClasspath")(code: String): Unit = {
|
||||
val errorMessage = intercept[ToolBoxError] {
|
||||
eval(code, s"$compileOptions $baseCompileOptions")
|
||||
}.getMessage
|
||||
val userMessage =
|
||||
s"""
|
||||
|FOUND: $errorMessage
|
||||
|EXPECTED: $errorSnippet
|
||||
""".stripMargin
|
||||
assert(errorMessage.contains(errorSnippet), userMessage)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
package sbt.std.neg
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import sbt.std.TestUtil._
|
||||
|
||||
class TaskNegSpec extends FunSuite {
|
||||
import tools.reflect.ToolBoxError
|
||||
def expectError(errorSnippet: String,
|
||||
compileOptions: String = "-Xmacro-settings:debug-spores",
|
||||
baseCompileOptions: String = s"-cp $toolboxClasspath")(code: String) = {
|
||||
val errorMessage = intercept[ToolBoxError] {
|
||||
eval(code, s"$compileOptions $baseCompileOptions")
|
||||
println("ERROR: The snippet compiled successfully.")
|
||||
}.getMessage
|
||||
val userMessage =
|
||||
s"""
|
||||
|FOUND: $errorMessage
|
||||
|EXPECTED: $errorSnippet
|
||||
""".stripMargin
|
||||
println(userMessage)
|
||||
assert(errorMessage.contains(errorSnippet), userMessage)
|
||||
}
|
||||
|
||||
test("Fail on task invocation inside if of regular task") {
|
||||
|
||||
expectError("DSL error: a task value cannot be obtained inside if.") {
|
||||
"""
|
||||
|import sbt._
|
||||
|import sbt.Def._
|
||||
|
|
||||
|val fooNeg = taskKey[String]("")
|
||||
|val barNeg = taskKey[String]("")
|
||||
|var condition = true
|
||||
|
|
||||
|val bazNeg: Initialize[Task[String]] = Def.task[String] {
|
||||
| if (condition) fooNeg.value
|
||||
| else barNeg.value
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -110,6 +110,7 @@ object Dependencies {
|
|||
|
||||
val sjsonNewScalaJson = "com.eed3si9n" %% "sjson-new-scalajson" % "0.7.0"
|
||||
|
||||
val scalatest = "org.scalatest" %% "scalatest" % "3.0.1"
|
||||
val scalaCheck = "org.scalacheck" %% "scalacheck" % "1.13.4"
|
||||
val specs2 = "org.specs2" %% "specs2" % "2.4.17"
|
||||
val junit = "junit" % "junit" % "4.11"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ object NightlyPlugin extends AutoPlugin {
|
|||
val includeTestDependencies = settingKey[Boolean]("Doesn't declare test dependencies.")
|
||||
|
||||
def testDependencies = libraryDependencies ++= (
|
||||
if (includeTestDependencies.value) Seq(scalaCheck % Test, specs2 % Test, junit % Test)
|
||||
if (includeTestDependencies.value) Seq(scalaCheck % Test, specs2 % Test, junit % Test, scalatest % Test)
|
||||
else Seq()
|
||||
)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue