Add fully-fledged macro check for value inside if

`.value` inside the if of a regular task is unsafe. The wrapping task
will always execute the value, no matter what the if predicate yields.

This commit adds the infrastructure to lint code for every sbt DSL
macro. It also adds example of neg tests that check that the DSL checks
are in place.

The sbt checks yield error for this specific case because we may want to
explore changing this behaviour in the future. The solutions to this are
straightforward and explained in the error message, that looks like
this:

```
EXPECTED: The evaluation of `fooNeg` happens always inside a regular task.

PROBLEM: `fooNeg` is inside the if expression of a regular task.
  Regular tasks always evaluate task inside the bodies of if expressions.

SOLUTION:
  1. If you only want to evaluate it when the if predicate is true, use a dynamic task.
  2. Otherwise, make the static evaluation explicit by evaluating `fooNeg` outside the if expression.
```

Aside from those solutions, this commit also adds a way to disable any
DSL check by using the new `sbt.unchecked` annotation. This annotation,
similar to `scala.annotation.unchecked` disables compiler output. In our
case, it will disable any task dsl check, making it silent.

Examples of positive checks have also been added.

There have been only two places in `Defaults.scala` where this check has
made compilation fail.

The first one is inside `allDependencies`. To ensure that we still have
static dependencies for `allDependencies`, I have hoisted up the value
invocation outside the if expression. We may want to explore adding a
dynamic task in the future, though. We are doing unnecessary work there.

The second one is inside `update` and is not important because it's not
exposed to the user. We use a `taskDyn`.
This commit is contained in:
jvican 2017-05-25 10:54:51 +02:00
parent b4299e7f34
commit 2b12721a68
No known key found for this signature in database
GPG Key ID: 42DAFA0F112E8050
14 changed files with 234 additions and 67 deletions

View File

@ -115,8 +115,7 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) {
* A function that checks the provided tree for illegal references to M instances defined in the
* expression passed to the macro and for illegal dereferencing of M instances.
*/
def checkReferences(defs: collection.Set[Symbol],
isWrapper: PropertyChecker): Tree => Unit = {
def checkReferences(defs: collection.Set[Symbol], isWrapper: PropertyChecker): Tree => Unit = {
case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) =>
if (isWrapper(nme.decodedName.toString, tpe.tpe, qual))
ctx.error(s.pos, DynamicDependencyError)
@ -125,28 +124,6 @@ final class ContextUtil[C <: blackbox.Context](val ctx: C) {
case _ => ()
}
class TaskMacroViolationDiscovery(isInvalid: PropertyChecker) extends Traverser {
var insideIf: Boolean = false
override def traverse(tree: ctx.universe.Tree): Unit = {
tree match {
case If(condition, thenp, elsep) =>
super.traverse(condition)
insideIf = true
super.traverse(thenp)
super.traverse(elsep)
insideIf = false
case s @ ApplyTree(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) =>
if (insideIf && isInvalid(nme.decodedName.toString, tpe.tpe, qual)) {
ctx.error(s.pos, "DSL error: a task value cannot be obtained inside if.")
}
case _ => super.traverse(tree)
}
}
}
def checkMacroViolations(tree: Tree, isInvalid: PropertyChecker) =
new TaskMacroViolationDiscovery(isInvalid).traverse(tree)
/** Constructs a ValDef with a parameter modifier, a unique name, with the provided Type and with an empty rhs. */
def freshMethodParameter(tpe: Type): ValDef =
ValDef(parameterModifiers, freshTermName("p"), TypeTree(tpe), EmptyTree)

View File

@ -81,13 +81,16 @@ object Instance {
c: blackbox.Context,
i: Instance with Singleton,
convert: Convert,
builder: TupleBuilder)(t: Either[c.Expr[T], c.Expr[i.M[T]]], inner: Transform[c.type, N])(
builder: TupleBuilder,
linter: LinterDSL
)(
t: Either[c.Expr[T], c.Expr[i.M[T]]],
inner: Transform[c.type, N]
)(
implicit tt: c.WeakTypeTag[T],
nt: c.WeakTypeTag[N[T]],
it: c.TypeTag[i.type]
): c.Expr[i.M[N[T]]] = {
println("STARTING CONTIMPL")
println("=================")
import c.universe.{ Apply => ApplyTree, _ }
val util = ContextUtil[c.type](c)
@ -99,8 +102,6 @@ object Instance {
case Left(l) => (l.tree, nt.tpe.dealias)
case Right(r) => (r.tree, mttpe)
}
println("TREE TYPE")
println(treeType)
// the Symbol for the anonymous function passed to the appropriate Instance.map/flatMap/pure method
// this Symbol needs to be known up front so that it can be used as the owner of synthetic vals
val functionSym = util.functionSymbol(tree.pos)
@ -187,17 +188,10 @@ object Instance {
}
// applies the transformation
println("TREE")
println(tree)
import util.PropertyChecker
val isInvalid: PropertyChecker =
(s: String, tpe: Type, tree: Tree) => t.isLeft && convert.asPredicate(c).apply(s, tpe, tree)
util.checkMacroViolations(tree, isInvalid)
linter.runLinter(c)(tree, t.isLeft)
val tx = util.transformWrappers(tree, (n, tpe, t, replace) => sub(n, tpe, t, replace))
// resetting attributes must be: a) local b) done here and not wider or else there are obscure errors
val tr = makeApp(inner(tx))
println("FINAL TREE")
println(tr)
c.Expr[i.M[N[T]]](tr)
}

View File

@ -0,0 +1,16 @@
package sbt.internal.util.appmacro
import scala.reflect.macros.blackbox
trait LinterDSL {
def runLinter(ctx: blackbox.Context)(tree: ctx.Tree, isApplicableFromContext: => Boolean): Unit
}
object LinterDSL {
object Empty extends LinterDSL {
override def runLinter(ctx: blackbox.Context)(
tree: ctx.Tree,
isApplicableFromContext: => Boolean
): Unit = ()
}
}

View File

@ -2,9 +2,16 @@ package sbt
package std
import Def.Initialize
import sbt.internal.util.Types.{ idFun, Id }
import sbt.internal.util.Types.{ Id, idFun }
import sbt.internal.util.AList
import sbt.internal.util.appmacro.{ Convert, Converted, Instance, MixedBuilder, MonadInstance }
import sbt.internal.util.appmacro.{
Convert,
Converted,
Instance,
LinterDSL,
MixedBuilder,
MonadInstance
}
object InitializeInstance extends MonadInstance {
type M[x] = Initialize[x]
@ -41,15 +48,16 @@ object InitializeConvert extends Convert {
}
object SettingMacro {
import LinterDSL.{ Empty => EmptyLinter }
def settingMacroImpl[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[T]): c.Expr[Initialize[T]] =
Instance.contImpl[T, Id](c, InitializeInstance, InitializeConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, InitializeInstance, InitializeConvert, MixedBuilder, EmptyLinter)(
Left(t),
Instance.idTransform[c.type])
def settingDynMacroImpl[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[Initialize[T]]): c.Expr[Initialize[T]] =
Instance.contImpl[T, Id](c, InitializeInstance, InitializeConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, InitializeInstance, InitializeConvert, MixedBuilder, EmptyLinter)(
Right(t),
Instance.idTransform[c.type])
}

View File

@ -0,0 +1,83 @@
package sbt.std
import sbt.internal.util.ConsoleAppender
import sbt.internal.util.appmacro.LinterDSL
import scala.collection.mutable.{ HashSet => MutableSet }
import scala.io.AnsiColor
import scala.reflect.macros.blackbox
object TaskLinterDSL extends LinterDSL {
override def runLinter(ctx: blackbox.Context)(
tree: ctx.Tree,
isApplicableFromContext: => Boolean
): Unit = {
import ctx.universe._
val isTask = FullConvert.asPredicate(ctx)
object traverser extends Traverser {
private val unchecked = symbolOf[sbt.unchecked].asClass
private val uncheckedWrappers = MutableSet.empty[Tree]
var insideIf: Boolean = false
override def traverse(tree: ctx.universe.Tree): Unit = {
tree match {
case If(condition, thenp, elsep) =>
super.traverse(condition)
insideIf = true
super.traverse(thenp)
super.traverse(elsep)
insideIf = false
case s @ Apply(TypeApply(Select(_, nme), tpe :: Nil), qual :: Nil) =>
val shouldIgnore = uncheckedWrappers.contains(s)
val wrapperName = nme.decodedName.toString
if (insideIf && !shouldIgnore && isTask(wrapperName, tpe.tpe, qual)) {
val qualName = qual.symbol.name.decodedName.toString
ctx.error(s.pos, TaskLinterDSLFeedback.useOfValueInsideIfExpression(qualName))
}
// Don't remove, makes the whole analysis work
case Typed(expr, tt) =>
super.traverse(expr)
super.traverse(tt)
case Typed(expr, tt: TypeTree) if tt.original != null =>
tt.original match {
case Annotated(annot, arg) =>
annot.tpe match {
case AnnotatedType(annotations, _) =>
val symAnnotations = annotations.map(_.tree.tpe.typeSymbol)
val isUnchecked = symAnnotations.contains(unchecked)
if (isUnchecked) {
val toAdd = arg match {
case Typed(`expr`, _) => `expr`
case `tree` => `tree`
}
uncheckedWrappers.add(toAdd)
}
case _ =>
}
}
super.traverse(expr)
case tree => super.traverse(tree)
}
}
}
if (!isApplicableFromContext) ()
else traverser.traverse(tree)
}
}
object TaskLinterDSLFeedback {
private final val startBold = if (ConsoleAppender.formatEnabled) AnsiColor.BOLD else ""
private final val startRed = if (ConsoleAppender.formatEnabled) AnsiColor.RED else ""
private final val startGreen = if (ConsoleAppender.formatEnabled) AnsiColor.GREEN else ""
private final val reset = if (ConsoleAppender.formatEnabled) AnsiColor.RESET else ""
def useOfValueInsideIfExpression(offendingValue: String) =
s"""${startBold}The evaluation of `${offendingValue}` happens always inside a regular task.$reset
|
|${startRed}PROBLEM${reset}: `${offendingValue}` is inside the if expression of a regular task.
| Regular tasks always evaluate task inside the bodies of if expressions.
|
|${startGreen}SOLUTION${reset}:
| 1. If you only want to evaluate it when the if predicate is true, use a dynamic task.
| 2. Otherwise, make the static evaluation explicit by evaluating `${offendingValue}` outside the if expression.
""".stripMargin
}

View File

@ -2,8 +2,15 @@ package sbt
package std
import Def.{ Initialize, Setting }
import sbt.internal.util.Types.{ const, idFun, Id }
import sbt.internal.util.appmacro.{ ContextUtil, Converted, Instance, MixedBuilder, MonadInstance }
import sbt.internal.util.Types.{ Id, const, idFun }
import sbt.internal.util.appmacro.{
ContextUtil,
Converted,
Instance,
LinterDSL,
MixedBuilder,
MonadInstance
}
import Instance.Transform
import sbt.internal.util.complete.{ DefaultParsers, Parser }
import sbt.internal.util.{ AList, LinePosition, NoPosition, SourcePosition }
@ -82,15 +89,17 @@ object TaskMacro {
"""`<<=` operator is deprecated. Use `key := { x.value }` or `key ~= (old => { newValue })`.
|See http://www.scala-sbt.org/0.13/docs/Migrating-from-sbt-012x.html""".stripMargin
import LinterDSL.{ Empty => EmptyLinter }
def taskMacroImpl[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[T]): c.Expr[Initialize[Task[T]]] =
Instance.contImpl[T, Id](c, FullInstance, FullConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, FullInstance, FullConvert, MixedBuilder, TaskLinterDSL)(
Left(t),
Instance.idTransform[c.type])
def taskDynMacroImpl[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[Initialize[Task[T]]]): c.Expr[Initialize[Task[T]]] =
Instance.contImpl[T, Id](c, FullInstance, FullConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, FullInstance, FullConvert, MixedBuilder, EmptyLinter)(
Right(t),
Instance.idTransform[c.type])
@ -356,7 +365,9 @@ object TaskMacro {
}
val cond = c.Expr[T](conditionInputTaskTree(c)(t.tree))
Instance
.contImpl[T, M](c, InitializeInstance, InputInitConvert, MixedBuilder)(Left(cond), inner)
.contImpl[T, M](c, InitializeInstance, InputInitConvert, MixedBuilder, EmptyLinter)(
Left(cond),
inner)
}
private[this] def conditionInputTaskTree(c: blackbox.Context)(t: c.Tree): c.Tree = {
@ -397,13 +408,17 @@ object TaskMacro {
val inner: Transform[c.type, M] = new Transform[c.type, M] {
def apply(in: c.Tree): c.Tree = f(c.Expr[T](in)).tree
}
Instance.contImpl[T, M](c, ParserInstance, ParserConvert, MixedBuilder)(Left(t), inner)
Instance.contImpl[T, M](c, ParserInstance, ParserConvert, MixedBuilder, LinterDSL.Empty)(
Left(t),
inner)
}
private[this] def iTaskMacro[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[T]): c.Expr[Task[T]] =
Instance
.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder)(Left(t), Instance.idTransform)
.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder, EmptyLinter)(
Left(t),
Instance.idTransform)
private[this] def inputTaskDynMacro0[T: c.WeakTypeTag](c: blackbox.Context)(
t: c.Expr[Initialize[Task[T]]]): c.Expr[Initialize[InputTask[T]]] = {
@ -488,13 +503,13 @@ object TaskMacro {
object PlainTaskMacro {
def task[T](t: T): Task[T] = macro taskImpl[T]
def taskImpl[T: c.WeakTypeTag](c: blackbox.Context)(t: c.Expr[T]): c.Expr[Task[T]] =
Instance.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder, TaskLinterDSL)(
Left(t),
Instance.idTransform[c.type])
def taskDyn[T](t: Task[T]): Task[T] = macro taskDynImpl[T]
def taskDynImpl[T: c.WeakTypeTag](c: blackbox.Context)(t: c.Expr[Task[T]]): c.Expr[Task[T]] =
Instance.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder)(
Instance.contImpl[T, Id](c, TaskInstance, TaskConvert, MixedBuilder, LinterDSL.Empty)(
Right(t),
Instance.idTransform[c.type])
}

View File

@ -0,0 +1,13 @@
package sbt
import scala.annotation.Annotation
/** An annotation to designate that the annotated entity
* should not be considered for additional sbt compiler checks.
* These checks ensure that the DSL is predictable and prevents
* users from doing dangerous things at the cost of a stricter
* code structure.
*
* @since 1.0.0
*/
class unchecked extends Annotation

View File

@ -26,4 +26,16 @@ class TaskPosSpec {
else bar
}
}
locally {
import sbt._
import sbt.Def._
val foo = taskKey[String]("")
val bar = taskKey[String]("")
var condition = true
val baz = Def.task[String] {
if (condition) foo.value: @unchecked
else bar.value: @unchecked
}
}
}

View File

@ -2,7 +2,7 @@ package sbt.std
import sbt.internal.util.complete
import sbt.internal.util.complete.DefaultParsers
import sbt.{Def, InputTask, Task}
import sbt.{ Def, InputTask, Task }
/*object UseTask
{
@ -21,7 +21,7 @@ import sbt.{Def, InputTask, Task}
object Assign {
import java.io.File
import Def.{Initialize, inputKey, macroValueT, parserToInput, settingKey, taskKey}
import Def.{ Initialize, inputKey, macroValueT, parserToInput, settingKey, taskKey }
// import UseTask.{x,y,z,a,set,plain}
val ak = taskKey[Int]("a")

View File

@ -1,29 +1,29 @@
package sbt.std.neg
import org.scalatest.FunSuite
import sbt.std.TaskLinterDSLFeedback
import sbt.std.TestUtil._
class TaskNegSpec extends FunSuite {
import tools.reflect.ToolBoxError
def expectError(errorSnippet: String,
compileOptions: String = "-Xmacro-settings:debug-spores",
compileOptions: String = "-Xfatal-warnings",
baseCompileOptions: String = s"-cp $toolboxClasspath")(code: String) = {
val errorMessage = intercept[ToolBoxError] {
eval(code, s"$compileOptions $baseCompileOptions")
println("ERROR: The snippet compiled successfully.")
}.getMessage
val userMessage =
s"""
|FOUND: $errorMessage
|EXPECTED: $errorSnippet
""".stripMargin
println(userMessage)
assert(errorMessage.contains(errorSnippet), userMessage)
}
test("Fail on task invocation inside if of regular task") {
expectError("DSL error: a task value cannot be obtained inside if.") {
test("Fail on task invocation inside if it is used inside a regular task") {
val fooNegError = TaskLinterDSLFeedback.useOfValueInsideIfExpression("fooNeg")
val barNegError = TaskLinterDSLFeedback.useOfValueInsideIfExpression("barNeg")
expectError(List(fooNegError, barNegError).mkString("\n")) {
"""
|import sbt._
|import sbt.Def._
@ -32,11 +32,56 @@ class TaskNegSpec extends FunSuite {
|val barNeg = taskKey[String]("")
|var condition = true
|
|val bazNeg: Initialize[Task[String]] = Def.task[String] {
|val bazNeg = Def.task[String] {
| if (condition) fooNeg.value
| else barNeg.value
|}
""".stripMargin
}
}
test("Fail on task invocation inside inside if of task returned by dynamic task") {
expectError(TaskLinterDSLFeedback.useOfValueInsideIfExpression("fooNeg")) {
"""
|import sbt._
|import sbt.Def._
|
|val fooNeg = taskKey[String]("")
|val barNeg = taskKey[String]("")
|var condition = true
|
|val bazNeg = Def.taskDyn[String] {
| if (condition) {
| Def.task {
| if (condition) {
| fooNeg.value
| } else ""
| }
| } else Def.task("")
|}
""".stripMargin
}
}
test("Fail on task invocation inside inside else of task returned by dynamic task") {
expectError(TaskLinterDSLFeedback.useOfValueInsideIfExpression("barNeg")) {
"""
|import sbt._
|import sbt.Def._
|
|val fooNeg = taskKey[String]("")
|val barNeg = taskKey[String]("")
|var condition = true
|
|val bazNeg = Def.taskDyn[String] {
| if (condition) {
| Def.task {
| if (condition) ""
| else barNeg.value
| }
| } else Def.task("")
|}
""".stripMargin
}
}
}

View File

@ -1867,8 +1867,9 @@ object Classpaths {
// Override the default to handle mixing in the sbtPlugin + scala dependencies.
allDependencies := {
val base = projectDependencies.value ++ libraryDependencies.value
val dependency = sbtDependency.value
val pluginAdjust =
if (sbtPlugin.value) sbtDependency.value.withConfigurations(Some(Provided.name)) +: base
if (sbtPlugin.value) dependency.withConfigurations(Some(Provided.name)) +: base
else base
if (scalaHome.value.isDefined || ivyScala.value.isEmpty || !managedScalaInstance.value)
pluginAdjust
@ -2124,11 +2125,11 @@ object Classpaths {
}
}
val evictionOptions = {
val evictionOptions = Def.taskDyn {
if (executionRoots.value.exists(_.key == evicted.key))
EvictionWarningOptions.empty
else (evictionWarningOptions in update).value
}
Def.task(EvictionWarningOptions.empty)
else Def.task((evictionWarningOptions in update).value)
}.value
LibraryManagement.cachedUpdate(
s.cacheStoreFactory.sub(updateCacheName.value),

View File

@ -10,7 +10,8 @@ object NightlyPlugin extends AutoPlugin {
val includeTestDependencies = settingKey[Boolean]("Doesn't declare test dependencies.")
def testDependencies = libraryDependencies ++= (
if (includeTestDependencies.value) Seq(scalaCheck % Test, specs2 % Test, junit % Test, scalatest % Test)
if (includeTestDependencies.value)
Seq(scalaCheck % Test, specs2 % Test, junit % Test, scalatest % Test)
else Seq()
)
}

View File

@ -3,11 +3,12 @@ import xsbti.Maybe
import xsbti.compile.{PreviousResult, CompileAnalysis, MiniSetup}
previousCompile in Compile := {
val previous = (previousCompile in Compile).value
if (!CompileState.isNew) {
val res = new PreviousResult(Maybe.nothing[CompileAnalysis], Maybe.nothing[MiniSetup])
CompileState.isNew = true
res
} else (previousCompile in Compile).value
} else previous
}
/* Performs checks related to compilations:

View File

@ -6,11 +6,12 @@ logLevel := Level.Debug
// Reset compile status because scripted tests are run in batch mode
previousCompile in Compile := {
val previous = (previousCompile in Compile).value
if (!CompileState.isNew) {
val res = new PreviousResult(Maybe.nothing[CompileAnalysis], Maybe.nothing[MiniSetup])
CompileState.isNew = true
res
} else (previousCompile in Compile).value
} else previous
}
// disable sbt's heuristic which recompiles everything in case