sbt/tasks/TaskScheduler.scala

239 lines
7.8 KiB
Scala

package xsbt
import scala.collection.{immutable,mutable}
final case class WorkFailure[D](work: D, exception: Throwable) extends NotNull
{
def map[C](f: D => C) = WorkFailure(f(work), exception)
}
private final class TaskScheduler[O](root: Task[O], strategy: ScheduleStrategy[Work[_]], newListener: => TaskListener)
extends Scheduler[ Either[ List[WorkFailure[Task[_]]], O ], Work, Result]
{
def run = new Run
{
val listener = newListener
def result =
{
assume(reverseDeps.isEmpty)
assume(forwardDeps.isEmpty)
assume(calls.isEmpty)
assume(!strategyRun.hasReady)
if(failureReports.isEmpty)
Right(completed(root))
else
Left(failureReports.toList)
}
def next(max: Int) =
{
val running = strategyRun.next(max)
running.foreach(r => listener.running(r.source))
running
}
def isComplete = reverseDeps.isEmpty
def hasPending = strategyRun.hasReady || !forwardDeps.isEmpty
def complete[A](work: Work[A], result: Either[Throwable,Result[A]]): Unit =
{
val task = work.source
result match
{
case Left(err) =>
failureReports += WorkFailure(task, err)
listener.failed(task, err)
retire(task, None)
assert(failed.contains(task))
case Right(value) =>
success(task, value)
assert(completed.contains(task) || (calls.isCalling(task) && !reverseDeps.isEmpty) || failed.contains(task))
}
assert(calls.isCalling(task) || !reverseDeps.contains(task))
assert(!forwardDeps.contains(task))
}
private def newDepMap = new mutable.HashMap[Task[_], mutable.Set[Task[_]]]
private val reverseDeps = newDepMap
private val forwardDeps = newDepMap
private val calls = new CalledByMap
private val completed = new ResultMap
private val strategyRun = strategy.run
private val failed = new mutable.HashSet[Task[_]]
private val failureReports = new mutable.ArrayBuffer[WorkFailure[Task[_]]]
{
val initialized = addGraph(root, root) // TODO: replace second root with something better? (it is ignored here anyway)
assert(initialized)
}
private def addReady[O](m: Task[O])
{
assert(!forwardDeps.contains(m), m)
assert(reverseDeps.contains(m), m)
assert(!completed.contains(m), m)
assert(!calls.isCalling(m), m)
assert(m.dependencies.forall(completed.contains), "Could not find result for dependency of ready task " + m)
strategyRun.workReady(new Work(m, completed))
listener.runnable(m)
}
// context called node
private def addGraph(node: Task[_], context: Task[_]): Boolean =
{
if(failed(node)) // node already failed
false
else if(calls.isCalling(node)) // node is waiting for a called task to complete, so we need to check for circular dependencies
{
if(calls.isCallerOf(node, context)) // if node called context, this is a circular dependency and is invalid
{
failureReports += WorkFailure(node, CircularDependency(node, context))
false
}
else
true
}
else if(reverseDeps.contains(node) || completed.contains(node)) // node is either already added and is waiting for dependencies to complete or it has completed
true
else // node has never been added
newAdd(node, context)
}
private def newAdd(node: Task[_], context: Task[_]): Boolean =
{
val deps = node.dependencies.filter(dep => !completed.contains(dep))
def finishAdding() =
{
listener.added(node)
true
}
if(deps.isEmpty) // node is ready to be run
{
reverseDeps(node) = new mutable.HashSet[Task[_]]
addReady(node)
finishAdding()
}
else if(deps.forall(dep => addGraph(dep,context))) // node requires dependencies to be added successfully and will then wait for them to complete before running
{
for(dep <- node.dependencies if !(completed.contains(dep) || reverseDeps.contains(dep) || calls.isCalling(dep)))
error("Invalid dependency state: (completed=" + completed.contains(dep) + ", reverse=" + reverseDeps.contains(dep) + ", calling=" + calls.isCalling(dep) + ") for " + dep)
reverseDeps(node) = new mutable.HashSet[Task[_]]
deps.foreach(dep => reverseDeps(dep) += node) // mark this node as depending on its dependencies
forwardDeps(node) = mutable.HashSet(deps.toSeq : _*)
finishAdding()
}
else // a dependency could not be added, so this node will fail as well.
{
failed += node
false
}
}
private def retire[O](m: Task[O], value: Option[O])
{
value match
{
case Some(v) => completed(m) = v // map the task to its value
case None => failed += m // mark the task as failed. complete has already recorded the error message for the original cause
}
updateCurrentGraph(m, value.isDefined) // update forward and reverse dependency maps and propagate the change to depending tasks
listener.completed(m, value)
calls.remove(m) match // unwind the call stack
{
case Some(c) =>
listener.called(c, m)
retire(c, value)
case None => ()
}
}
private def updateCurrentGraph[O](m: Task[O], success: Boolean)
{
if(!success)
{
// clear m from the forward dependency map
// for each dependency d of m, remove m from the set of tasks that depend on d
for(depSet <- forwardDeps.removeKey(m); dep <- depSet; reverseSet <- reverseDeps.get(dep))
reverseSet -= m
}
// m is complete, so remove its entry from reverseDeps and update all tasks that depend on m
for(mReverseDeps <- reverseDeps.removeKey(m); dependsOnM <- mReverseDeps)
{
if(success)
{
val on = forwardDeps(dependsOnM)
on -= m // m has completed, so remove it from the set of tasks that must complete before 'on' can run
if(on.isEmpty) // m was the last dependency of on, so make it runnable
{
forwardDeps.removeKey(dependsOnM)
addReady(dependsOnM)
}
}
else // cancel dependsOnM because dependency (m) failed
retire(dependsOnM, None)
}
}
private def success[O](task: Task[O], value: Result[O]): Unit =
value match
{
case t: Task[O] =>
if(t eq task)
{
failureReports += WorkFailure(t, CircularDependency(t, task))
retire(task, None)
}
else if(addGraph(t, task))
{
if(completed.contains(t))
retire(task, Some(completed(t)))
else
{
calls(t) = task
listener.calling(task, t)
}
}
else
retire(task, None)
case Value(v) => retire(task, Some(v))
}
}
}
final case class CircularDependency(node: Task[_], context: Task[_])
extends RuntimeException("Task " + context + " provided task " + node + " already in calling stack")
private final class CalledByMap extends NotNull
{
private[this] val calling = new mutable.HashSet[Task[_]]
private[this] val callMap = new mutable.HashMap[Task[_], Task[_]]
def update[O](called: Task[O], by: Task[O])
{
calling += by
callMap(called) = by
}
final def isCallerOf(check: Task[_], frame: Task[_]): Boolean =
{
if(check eq frame) true
else
callMap.get(frame) match
{
case Some(nextFrame) => isCallerOf(check, nextFrame)
case None => false
}
}
def isEmpty = calling.isEmpty && callMap.isEmpty
def isCalled(task: Task[_]): Boolean = callMap.contains(task)
def isCalling(caller: Task[_]): Boolean = calling(caller)
def remove[O](called: Task[O]): Option[Task[O]] =
for(caller <- callMap.removeKey(called)) yield
{
calling -= caller
caller.asInstanceOf[Task[O]]
}
}
import java.util.concurrent.{ConcurrentHashMap => HashMap}
private final class ResultMap(private val map: HashMap[Task[_], Any]) extends Results
{
def this() = this(new HashMap)
def update[O](task: Task[O], value: O) { map.put(task, value) }
def apply[O](task: Task[O]): O = map.get(task).asInstanceOf[O]
def contains(task: Task[_]) = map.containsKey(task)
}
private final class Work[O](val source: Task[O], results: Results) extends Identity with NotNull
{
final def apply = Task.compute(source, results)
}