diff --git a/main/Defaults.scala b/main/Defaults.scala index 7aa2cdb96..ce6cb96b1 100644 --- a/main/Defaults.scala +++ b/main/Defaults.scala @@ -51,6 +51,7 @@ object Defaults extends BuildCommon buildDependencies <<= buildDependencies or Classpaths.constructBuildDependencies, taskTemporaryDirectory := IO.createTemporaryDirectory, onComplete <<= taskTemporaryDirectory { dir => () => IO.delete(dir); IO.createDirectory(dir) }, + concurrentRestrictions <<= concurrentRestrictions or defaultRestrictions, parallelExecution :== true, sbtVersion in GlobalScope <<= appConfiguration { _.provider.id.version }, sbtResolver in GlobalScope <<= sbtVersion { sbtV => if(sbtV endsWith "-SNAPSHOT") Classpaths.typesafeSnapshots else Classpaths.typesafeResolver }, @@ -58,8 +59,11 @@ object Defaults extends BuildCommon logBuffered :== false, connectInput :== false, cancelable :== false, + cancelable :== false, autoScalaLibrary :== true, onLoad <<= onLoad ?? idFun[State], + tags in test := Seq(Tags.Test -> 1), + tags in testOnly <<= tags in test, onUnload <<= (onUnload ?? idFun[State]), onUnload <<= (onUnload, taskTemporaryDirectory) { (f, dir) => s => { try f(s) finally IO.delete(dir) } }, watchingMessage <<= watchingMessage ?? Watched.defaultWatchingMessage, @@ -196,7 +200,7 @@ object Defaults extends BuildCommon lazy val configTasks = docSetting(doc) ++ compileInputsSettings ++ Seq( initialCommands in GlobalScope :== "", cleanupCommands in GlobalScope :== "", - compile <<= compileTask, + compile <<= compileTask tag(Tags.Compile, Tags.CPU), compileIncSetup <<= compileIncSetupTask, console <<= consoleTask, consoleQuick <<= consoleQuickTask, @@ -270,10 +274,12 @@ object Defaults extends BuildCommon definedTestNames <<= definedTests map ( _.map(_.name).distinct) storeAs definedTestNames triggeredBy compile, testListeners in GlobalScope :== Nil, testOptions in GlobalScope :== Nil, - executeTests <<= (streams in test, loadedTestFrameworks, parallelExecution in test, testOptions in test, testLoader, definedTests, resolvedScoped, state) flatMap { - (s, frameworkMap, par, options, loader, discovered, scoped, st) => + testExecution in test <<= testExecutionTask(test), + testExecution in testOnly <<= testExecutionTask(testOnly), + executeTests <<= (streams in test, loadedTestFrameworks, testExecution in test, testLoader, definedTests, resolvedScoped, state) flatMap { + (s, frameworkMap, config, loader, discovered, scoped, st) => implicit val display = Project.showContextKey(st) - Tests(frameworkMap, loader, discovered, options, par, noTestsMessage(ScopedKey(scoped.scope, test.key)), s.log) + Tests(frameworkMap, loader, discovered, config, noTestsMessage(ScopedKey(scoped.scope, test.key)), s.log) }, test <<= (executeTests, streams) map { (results, s) => Tests.showResults(s.log, results) }, testOnly <<= testOnlyTask @@ -303,14 +309,18 @@ object Defaults extends BuildCommon extra.put(name.key, tdef.name).put(isModule, mod) } + def testExecutionTask(task: Scoped): Initialize[Task[Tests.Execution]] = + (testOptions in task, parallelExecution in task, tags in task) map { (opts, par, ts) => new Tests.Execution(opts, par, ts) } + def testOnlyTask = - InputTask( loadForParser(definedTestNames)( (s, i) => testOnlyParser(s, i getOrElse Nil) ) ) { result => - (streams, loadedTestFrameworks, parallelExecution in testOnly, testOptions in testOnly, testLoader, definedTests, resolvedScoped, result, state) flatMap { - case (s, frameworks, par, opts, loader, discovered, scoped, (tests, frameworkOptions), st) => + InputTask( loadForParser(definedTestNames)( (s, i) => testOnlyParser(s, i getOrElse Nil) ) ) { result => + (streams, loadedTestFrameworks, testExecution in testOnly, testLoader, definedTests, resolvedScoped, result, state) flatMap { + case (s, frameworks, config, loader, discovered, scoped, (tests, frameworkOptions), st) => val filter = selectedFilter(tests) - val modifiedOpts = Tests.Filter(filter) +: Tests.Argument(frameworkOptions : _*) +: opts + val modifiedOpts = Tests.Filter(filter) +: Tests.Argument(frameworkOptions : _*) +: config.options + val newConfig = new Tests.Execution(modifiedOpts, config.parallel, config.tags) implicit val display = Project.showContextKey(st) - Tests(frameworks, loader, discovered, modifiedOpts, par, noTestsMessage(scoped), s.log) map { results => + Tests(frameworks, loader, discovered, newConfig, noTestsMessage(scoped), s.log) map { results => Tests.showResults(s.log, results) } } @@ -323,6 +333,10 @@ object Defaults extends BuildCommon def detectTests: Initialize[Task[Seq[TestDefinition]]] = (loadedTestFrameworks, compile, streams) map { (frameworkMap, analysis, s) => Tests.discover(frameworkMap.values.toSeq, analysis, s.log)._1 } + def defaultRestrictions: Initialize[Seq[Tags.Rule]] = parallelExecution { par => + val max = EvaluateTask.SystemProcessors + Tags.limitAll(if(par) max else 1) :: Nil + } lazy val packageBase: Seq[Setting[_]] = Seq( artifact <<= moduleName(n => Artifact(n)), @@ -739,7 +753,7 @@ object Classpaths update <<= (ivyModule, thisProjectRef, updateConfiguration, cacheDirectory, scalaInstance, transitiveUpdate, streams) map { (module, ref, config, cacheDirectory, si, reports, s) => val depsUpdated = reports.exists(!_.stats.cached) cachedUpdate(cacheDirectory / "update", Project.display(ref), module, config, Some(si), depsUpdated, s.log) - }, + } tag(Tags.Update, Tags.Network), update <<= (conflictWarning, update, streams) map { (config, report, s) => ConflictWarning(config, report, s.log); report }, transitiveClassifiers in GlobalScope :== Seq(SourceClassifier, DocClassifier), classifiersModule in updateClassifiers <<= (projectID, update, transitiveClassifiers in updateClassifiers, ivyConfigurations in updateClassifiers) map { ( pid, up, classifiers, confs) => @@ -749,7 +763,7 @@ object Classpaths withExcludes(out, mod.classifiers, lock(app)) { excludes => IvyActions.updateClassifiers(is, GetClassifiersConfiguration(mod, excludes, c, ivyScala), s.log) } - }, + } tag(Tags.Update, Tags.Network), sbtDependency in GlobalScope <<= appConfiguration { app => val id = app.provider.id val base = ModuleID(id.groupID, id.name, id.version, crossVersion = id.crossVersioned) @@ -787,7 +801,7 @@ object Classpaths withExcludes(out, mod.classifiers, lock(app)) { excludes => IvyActions.transitiveScratch(is, "sbt", GetClassifiersConfiguration(mod, excludes, c, ivyScala), s.log) } - } + } tag(Tags.Update, Tags.Network) )) def deliverTask(config: TaskKey[DeliverConfiguration]): Initialize[Task[File]] = @@ -795,7 +809,7 @@ object Classpaths def publishTask(config: TaskKey[PublishConfiguration], deliverKey: TaskKey[_]): Initialize[Task[Unit]] = (ivyModule, config, streams) map { (module, config, s) => IvyActions.publish(module, config, s.log) - } + } tag(Tags.Publish, Tags.Network) import Cache._ import CacheIvy.{classpathFormat, /*publishIC,*/ updateIC, updateReportF, excludeMap} diff --git a/main/EvaluateTask.scala b/main/EvaluateTask.scala index 01078302e..9510787ee 100644 --- a/main/EvaluateTask.scala +++ b/main/EvaluateTask.scala @@ -11,7 +11,7 @@ package sbt import Types.const import scala.Console.{RED, RESET} -final case class EvaluateConfig(cancelable: Boolean, checkCycles: Boolean = false, maxWorkers: Int = EvaluateTask.SystemProcessors) +final case class EvaluateConfig(cancelable: Boolean, restrictions: Seq[Tags.Rule], checkCycles: Boolean = false) object EvaluateTask { import Load.BuildStructure @@ -21,23 +21,38 @@ object EvaluateTask import Keys.state val SystemProcessors = Runtime.getRuntime.availableProcessors - def defaultConfig = EvaluateConfig(false) + def defaultConfig(state: State): EvaluateConfig = + EvaluateConfig(false, restrictions(state)) + def defaultConfig(extracted: Extracted, structure: Load.BuildStructure) = + EvaluateConfig(false, restrictions(extracted, structure)) + def extractedConfig(extracted: Extracted, structure: BuildStructure): EvaluateConfig = { - val workers = maxWorkers(extracted, structure) + val workers = restrictions(extracted, structure) val canCancel = cancelable(extracted, structure) - EvaluateConfig(cancelable = canCancel, maxWorkers = workers) + EvaluateConfig(cancelable = canCancel, restrictions = workers) } + def defaultRestrictions(maxWorkers: Int) = Tags.limitAll(maxWorkers) :: Nil + def defaultRestrictions(extracted: Extracted, structure: Load.BuildStructure): Seq[Tags.Rule] = + Tags.limitAll(maxWorkers(extracted, structure)) :: Nil + + def restrictions(state: State): Seq[Tags.Rule] = + { + val extracted = Project.extract(state) + restrictions(extracted, extracted.structure) + } + def restrictions(extracted: Extracted, structure: Load.BuildStructure): Seq[Tags.Rule] = + getSetting(Keys.concurrentRestrictions, defaultRestrictions(extracted, structure), extracted, structure) def maxWorkers(extracted: Extracted, structure: Load.BuildStructure): Int = - if(getBoolean(Keys.parallelExecution, true, extracted, structure)) - EvaluateTask.SystemProcessors + if(getSetting(Keys.parallelExecution, true, extracted, structure)) + SystemProcessors else 1 def cancelable(extracted: Extracted, structure: Load.BuildStructure): Boolean = - getBoolean(Keys.cancelable, false, extracted, structure) - def getBoolean(key: SettingKey[Boolean], default: Boolean, extracted: Extracted, structure: Load.BuildStructure): Boolean = - (key in extracted.currentRef get structure.data) getOrElse default + getSetting(Keys.cancelable, false, extracted, structure) + def getSetting[T](key: SettingKey[T], default: T, extracted: Extracted, structure: Load.BuildStructure): T = + key in extracted.currentRef get structure.data getOrElse default def injectSettings: Seq[Setting[_]] = Seq( (state in GlobalScope) ::= dummyState, @@ -48,16 +63,20 @@ object EvaluateTask { val root = ProjectRef(pluginDef.root, Load.getRootProject(pluginDef.units)(pluginDef.root)) val pluginKey = Keys.fullClasspath in Configurations.Runtime - val evaluated = apply(pluginDef, ScopedKey(pluginKey.scope, pluginKey.key), state, root, defaultConfig) + val config = defaultConfig(Project.extract(state), pluginDef) + val evaluated = apply(pluginDef, ScopedKey(pluginKey.scope, pluginKey.key), state, root, config) val (newS, result) = evaluated getOrElse error("Plugin classpath does not exist for plugin definition at " + pluginDef.root) Project.runUnloadHooks(newS) // discard states processResult(result, log) } - @deprecated("This method does not apply state changes requested during task execution. Use 'apply' instead, which does.", "0.11.1") + @deprecated("This method does not apply state changes requested during task execution and does not honor concurrent execution restrictions. Use 'apply' instead.", "0.11.1") def evaluateTask[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, checkCycles: Boolean = false, maxWorkers: Int = SystemProcessors): Option[Result[T]] = - apply(structure, taskKey, state, ref, EvaluateConfig(false, checkCycles, maxWorkers)).map(_._2) - def apply[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, config: EvaluateConfig = defaultConfig): Option[(State, Result[T])] = + apply(structure, taskKey, state, ref, EvaluateConfig(false, defaultRestrictions(maxWorkers), checkCycles)).map(_._2) + + def apply[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef): Option[(State, Result[T])] = + apply[T](structure, taskKey, state, ref, defaultConfig(Project.extract(state), structure)) + def apply[T](structure: BuildStructure, taskKey: ScopedKey[Task[T]], state: State, ref: ProjectRef, config: EvaluateConfig): Option[(State, Result[T])] = withStreams(structure, state) { str => for( (task, toNode) <- getTask(structure, taskKey, state, str, ref) ) yield runTask(task, state, str, structure.index.triggers, config)(toNode) @@ -97,11 +116,14 @@ object EvaluateTask def nodeView[HL <: HList](state: State, streams: Streams, extraDummies: KList[Task, HL] = KNil, extraValues: HL = HNil): Execute.NodeView[Task] = Transform(dummyStreamsManager :^: KCons(dummyState, extraDummies), streams :+: HCons(state, extraValues)) - def runTask[T](root: Task[T], state: State, streams: Streams, triggers: Triggers[Task], config: EvaluateConfig = defaultConfig)(implicit taskToNode: Execute.NodeView[Task]): (State, Result[T]) = + def runTask[T](root: Task[T], state: State, streams: Streams, triggers: Triggers[Task], config: EvaluateConfig)(implicit taskToNode: Execute.NodeView[Task]): (State, Result[T]) = { + import ConcurrentRestrictions.{completionService, TagMap, Tag, tagged, tagsKey} + val log = state.log - log.debug("Running task... Cancelable: " + config.cancelable + ", max worker threads: " + config.maxWorkers + ", check cycles: " + config.checkCycles) - val (service, shutdown) = CompletionService[Task[_], Completed](config.maxWorkers) + log.debug("Running task... Cancelable: " + config.cancelable + ", check cycles: " + config.checkCycles) + val tags = tagged[Task[_]](_.info get tagsKey getOrElse Map.empty, Tags.predicate(config.restrictions)) + val (service, shutdown) = completionService[Task[_], Completed](tags, (s: String) => log.warn(s)) def run() = { val x = new Execute[Task](config.checkCycles, triggers)(taskToNode) diff --git a/main/GlobalPlugin.scala b/main/GlobalPlugin.scala index 66f3b8371..7cc554c69 100644 --- a/main/GlobalPlugin.scala +++ b/main/GlobalPlugin.scala @@ -53,7 +53,7 @@ object GlobalPlugin import EvaluateTask._ withStreams(structure, state) { str => val nv = nodeView(state, str) - val config = EvaluateTask.defaultConfig + val config = EvaluateTask.defaultConfig(Project.extract(state), structure) val (newS, result) = runTask(t, state, str, structure.index.triggers, config)(nv) (newS, processResult(result, newS.log)) } diff --git a/main/Keys.scala b/main/Keys.scala index 91c7db153..9cd4d3c09 100644 --- a/main/Keys.scala +++ b/main/Keys.scala @@ -191,6 +191,7 @@ object Keys val testOptions = TaskKey[Seq[TestOption]]("test-options", "Options for running tests.") val testFrameworks = SettingKey[Seq[TestFramework]]("test-frameworks", "Registered, although not necessarily present, test frameworks.") val testListeners = TaskKey[Seq[TestReportListener]]("test-listeners", "Defines test listeners.") + val testExecution = TaskKey[Tests.Execution]("test-execution", "Settings controlling test execution") val isModule = AttributeKey[Boolean]("is-module", "True if the target is a module.") // Classpath/Dependency Management Keys @@ -300,6 +301,8 @@ object Keys // special val sessionVars = AttributeKey[SessionVar.Map]("session-vars", "Bindings that exist for the duration of the session.") val parallelExecution = SettingKey[Boolean]("parallel-execution", "Enables (true) or disables (false) parallel execution of tasks.") + val tags = SettingKey[Seq[(Tags.Tag,Int)]]("tags", ConcurrentRestrictions.tagsKey.label) + val concurrentRestrictions = SettingKey[Seq[Tags.Rule]]("concurrent-restrictions", "Rules describing restrictions on concurrent task execution.") val cancelable = SettingKey[Boolean]("cancelable", "Enables (true) or disables (false) the ability to interrupt task execution with CTRL+C.") val settings = TaskKey[Settings[Scope]]("settings", "Provides access to the project data for the build.") val streams = TaskKey[TaskStreams]("streams", "Provides streams for logging and persisting data.") diff --git a/main/Project.scala b/main/Project.scala index 4122c9cfd..4cbb2685f 100644 --- a/main/Project.scala +++ b/main/Project.scala @@ -343,11 +343,12 @@ object Project extends Init[Scope] with ProjectExtra } @deprecated("This method does not apply state changes requested during task execution. Use 'runTask' instead, which does.", "0.11.1") def evaluateTask[T](taskKey: ScopedKey[Task[T]], state: State, checkCycles: Boolean = false, maxWorkers: Int = EvaluateTask.SystemProcessors): Option[Result[T]] = - runTask(taskKey, state, checkCycles, maxWorkers).map(_._2) - def runTask[T](taskKey: ScopedKey[Task[T]], state: State, checkCycles: Boolean = false, maxWorkers: Int = EvaluateTask.SystemProcessors): Option[(State, Result[T])] = + runTask(taskKey, state, EvaluateConfig(true, EvaluateTask.defaultRestrictions(maxWorkers), checkCycles)).map(_._2) + def runTask[T](taskKey: ScopedKey[Task[T]], state: State, checkCycles: Boolean = false): Option[(State, Result[T])] = + runTask(taskKey, state, EvaluateConfig(true, EvaluateTask.restrictions(state), checkCycles)) + def runTask[T](taskKey: ScopedKey[Task[T]], state: State, config: EvaluateConfig): Option[(State, Result[T])] = { val extracted = Project.extract(state) - val config = EvaluateConfig(true, checkCycles, maxWorkers) EvaluateTask(extracted.structure, taskKey, state, extracted.currentRef, config) } // this is here instead of Scoped so that it is considered without need for import (because of Project.Initialize) diff --git a/main/Structure.scala b/main/Structure.scala index dde869d86..c8f3b6ba4 100644 --- a/main/Structure.scala +++ b/main/Structure.scala @@ -205,6 +205,9 @@ object Scoped def dependsOn(tasks: AnyInitTask*): Initialize[Task[S]] = (i, Initialize.joinAny(tasks)) { (thisTask, deps) => thisTask.dependsOn(deps : _*) } + def tag(tags: Tags.Tag*): Initialize[Task[S]] = i(_.tag(tags: _*)) + def tagw(tags: (Tags.Tag, Int)*): Initialize[Task[S]] = i(_.tagw(tags : _*)) + import SessionVar.{persistAndSet, resolveContext, set, transform} def updateState(f: (State, S) => State): Initialize[Task[S]] = i(t => transform(t, f)) diff --git a/main/Tags.scala b/main/Tags.scala new file mode 100644 index 000000000..346f09275 --- /dev/null +++ b/main/Tags.scala @@ -0,0 +1,58 @@ +package sbt + + import annotation.tailrec + +object Tags +{ + type Tag = ConcurrentRestrictions.Tag + type TagMap = ConcurrentRestrictions.TagMap + def Tag(s: String): Tag = ConcurrentRestrictions.Tag(s) + + val All = ConcurrentRestrictions.All + val Untagged = ConcurrentRestrictions.Untagged + val Compile = Tag("compile") + val Test = Tag("test") + val Update = Tag("update") + val Publish = Tag("publish") + + val CPU = Tag("cpu") + val Network = Tag("network") + val Disk = Tag("disk") + + /** Describes a restriction on concurrently executing tasks. + * A Rule is constructed using one of the Tags.limit* methods. */ + sealed trait Rule { + def apply(m: TagMap): Boolean + } + private[this] final class Custom(f: TagMap => Boolean) extends Rule { + def apply(m: TagMap) = f(m) + } + private[this] final class Single(tag: Tag, max: Int) extends Rule { + checkMax(max) + def apply(m: TagMap) = getInt(m, tag) <= max + } + private[this] final class Sum(tags: Seq[Tag], max: Int) extends Rule { + checkMax(max) + def apply(m: TagMap) = (0 /: tags)((sum, t) => sum + getInt(m, t)) <= max + } + private[this] def checkMax(max: Int): Unit = assert(max >= 1, "Limit must be at least 1.") + + /** Converts a sequence of rules into a function that identifies whether a set of tasks are allowed to execute concurrently based on their merged tags. */ + def predicate(rules: Seq[Rule]): TagMap => Boolean = m => { + @tailrec def loop(rules: List[Rule]): Boolean = + rules match + { + case x :: xs => x(m) && loop(xs) + case Nil => true + } + loop(rules.toList) + } + + def getInt(m: TagMap, tag: Tag): Int = m.getOrElse(tag, 0) + + def customLimit(f: TagMap => Boolean): Rule = new Custom(f) + def limitAll(max: Int): Rule = limit(All, max) + def limitUntagged(max: Int): Rule = limit(Untagged, max) + def limit(tag: Tag, max: Int): Rule = new Single(tag, max) + def limitSum(max: Int, tags: Tag*): Rule = new Sum(tags, max) +} \ No newline at end of file diff --git a/main/actions/Tests.scala b/main/actions/Tests.scala index 56c3e31dd..c6e19fec0 100644 --- a/main/actions/Tests.scala +++ b/main/actions/Tests.scala @@ -9,6 +9,7 @@ package sbt import TaskExtra._ import Types._ import xsbti.api.Definition + import ConcurrentRestrictions.Tag import org.scalatools.testing.{AnnotatedFingerprint, Fingerprint, Framework, SubclassFingerprint} @@ -39,8 +40,12 @@ object Tests // None means apply to all, Some(tf) means apply to a particular framework only. final case class Argument(framework: Option[TestFramework], args: List[String]) extends TestOption - - def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], parallel: Boolean, noTestsMessage: => String, log: Logger): Task[Output] = + final class Execution(val options: Seq[TestOption], val parallel: Boolean, val tags: Seq[(Tag, Int)]) + + def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], parallel: Boolean, noTestsMessage: => String, log: Logger): Task[Output] = + apply(frameworks, testLoader, discovered, new Execution(options, parallel, Nil), noTestsMessage, log) + + def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], config: Execution, noTestsMessage: => String, log: Logger): Task[Output] = { import mutable.{HashSet, ListBuffer, Map, Set} val testFilters = new ListBuffer[String => Boolean] @@ -57,7 +62,7 @@ object Tests case None => undefinedFrameworks += framework.implClassName } - for(option <- options) + for(option <- config.options) { option match { @@ -88,12 +93,12 @@ object Tests def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.name) && testFilters.forall(filter => filter(test.name)) val tests = discovered.filter(includeTest).toSet.toSeq val arguments = testArgsByFramework.map { case (k,v) => (k, v.toList) } toMap; - testTask(frameworks.values.toSeq, testLoader, tests, noTestsMessage, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments, parallel) + testTask(frameworks.values.toSeq, testLoader, tests, noTestsMessage, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments, config) } def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition], noTestsMessage: => String, userSetup: Iterable[ClassLoader => Unit], userCleanup: Iterable[ClassLoader => Unit], - log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]], parallel: Boolean): Task[Output] = + log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]], config: Execution): Task[Output] = { def fj(actions: Iterable[() => Unit]): Task[Unit] = nop.dependsOn( actions.toSeq.fork( _() ) : _*) def partApp(actions: Iterable[ClassLoader => Unit]) = actions.toSeq map {a => () => a(loader) } @@ -102,16 +107,21 @@ object Tests TestFramework.testTasks(frameworks, loader, tests, noTestsMessage, log, testListeners, arguments) val setupTasks = fj(partApp(userSetup) :+ frameworkSetup) - val mainTasks = if(parallel) makeParallel(runnables, setupTasks).toSeq.join else makeSerial(runnables, setupTasks) - mainTasks map processResults flatMap { results => + val mainTasks = + if(config.parallel) + makeParallel(runnables, setupTasks, config.tags).toSeq.join + else + makeSerial(runnables, setupTasks, config.tags) + val taggedMainTasks = mainTasks.tagw(config.tags : _*) + taggedMainTasks map processResults flatMap { results => val cleanupTasks = fj(partApp(userCleanup) :+ frameworkCleanup(results._1)) cleanupTasks map { _ => results } } } type TestRunnable = (String, () => TestResult.Value) - def makeParallel(runnables: Iterable[TestRunnable], setupTasks: Task[Unit]) = - runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name } - def makeSerial(runnables: Iterable[TestRunnable], setupTasks: Task[Unit]) = + def makeParallel(runnables: Iterable[TestRunnable], setupTasks: Task[Unit], tags: Seq[(Tag,Int)]) = + runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name tagw(tags : _*) } + def makeSerial(runnables: Iterable[TestRunnable], setupTasks: Task[Unit], tags: Seq[(Tag,Int)]) = task { runnables map { case (name, test) => (name, test()) } } dependsOn(setupTasks) def processResults(results: Iterable[(String, TestResult.Value)]): (TestResult.Value, Map[String, TestResult.Value]) = diff --git a/project/Sbt.scala b/project/Sbt.scala index 21ee3f43a..b7ed34e87 100644 --- a/project/Sbt.scala +++ b/project/Sbt.scala @@ -20,6 +20,7 @@ object Sbt extends Build scalaVersion := "2.9.1", publishMavenStyle := false, componentID := None, + testOptions += Tests.Argument(TestFrameworks.ScalaCheck, "-w", "1"), javacOptions in Compile ++= Seq("-target", "6", "-source", "6") ) diff --git a/project/Util.scala b/project/Util.scala index 9c0bffd8f..7b90e0626 100644 --- a/project/Util.scala +++ b/project/Util.scala @@ -28,8 +28,8 @@ object Util lazy val base: Seq[Setting[_]] = Seq(scalacOptions ++= Seq("-Xelide-below", "0"), projectComponent) ++ Licensed.settings def testDependencies = libraryDependencies ++= Seq( - "org.scala-tools.testing" % "scalacheck_2.9.0-1" % "1.9" % "test", - "org.scala-tools.testing" % "specs_2.9.0-1" % "1.6.8" % "test" + "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test", + "org.scala-tools.testing" %% "specs" % "1.6.9" % "test" ) lazy val minimalSettings: Seq[Setting[_]] = Defaults.paths ++ Seq[Setting[_]](crossTarget <<= target.identity, name <<= thisProject(_.id)) diff --git a/tasks/CompletionService.scala b/tasks/CompletionService.scala index 799245ef1..5f1e95758 100644 --- a/tasks/CompletionService.scala +++ b/tasks/CompletionService.scala @@ -5,7 +5,7 @@ package sbt trait CompletionService[A, R] { - def submit(node: A, work: () => R): () => R + def submit(node: A, work: () => R): Unit def take(): R } @@ -22,13 +22,14 @@ object CompletionService apply(new ExecutorCompletionService[T](x)) def apply[A, T](completion: JCompletionService[T]): CompletionService[A,T] = new CompletionService[A, T] { - def submit(node: A, work: () => T) = { - val future = completion.submit { new Callable[T] { def call = work() } } - () => future.get() - } + def submit(node: A, work: () => T) = CompletionService.submit(work, completion) def take() = completion.take().get() } - + def submit[T](work: () => T, completion: JCompletionService[T]): () => T = + { + val future = completion.submit { new Callable[T] { def call = work() } } + () => future.get() + } def manage[A, T](service: CompletionService[A,T])(setup: A => Unit, cleanup: A => Unit): CompletionService[A,T] = wrap(service) { (node, work) => () => setup(node) diff --git a/tasks/ConcurrentRestrictions.scala b/tasks/ConcurrentRestrictions.scala new file mode 100644 index 000000000..678740598 --- /dev/null +++ b/tasks/ConcurrentRestrictions.scala @@ -0,0 +1,198 @@ +package sbt + +/** Describes restrictions on concurrent execution for a set of tasks. +* +* @tparam A the type of a task +*/ +trait ConcurrentRestrictions[A] +{ + /** Internal state type used to describe a set of tasks. */ + type G + + /** Representation of zero tasks.*/ + def empty: G + + /** Updates the description `g` to include a new task `a`.*/ + def add(g: G, a: A): G + + /** Updates the description `g` to remove a previously added task `a`.*/ + def remove(g: G, a: A): G + + /** + * Returns true if the tasks described by `g` are allowed to execute concurrently. + * The methods in this class must obey the following laws: + * + * 1. forall g: G, a: A; valid(g) => valid(remove(g,a)) + * 2. forall a: A; valid(add(empty, a)) + * 3. forall g: G, a: A; valid(g) <=> valid(remove(add(g, a), a)) + * 4. (implied by 1,2,3) valid(empty) + * 5. forall g: G, a: A, b: A; !valid(add(g,a)) => !valid(add(add(g,b), a)) + */ + def valid(g: G): Boolean +} + + import java.util.{LinkedList,Queue} + import java.util.concurrent.{Executor, Executors, ExecutorCompletionService} + import annotation.tailrec + +object ConcurrentRestrictions +{ + /** A ConcurrentRestrictions instance that places no restrictions on concurrently executing tasks. + * @param zero the constant placeholder used for t */ + def unrestricted[A]: ConcurrentRestrictions[A] = + new ConcurrentRestrictions[A] + { + type G = Unit + def empty = () + def add(g: G, a: A) = () + def remove(g: G, a: A) = () + def valid(g: G) = true + } + + def limitTotal[A](i: Int): ConcurrentRestrictions[A] = + { + assert(i >= 1, "Maximum must be at least 1 (was " + i + ")") + new ConcurrentRestrictions[A] + { + type G = Int + def empty = 0 + def add(g: Int, a: A) = g + 1 + def remove(g: Int, a: A) = g - 1 + def valid(g: Int) = g <= i + } + } + + /** A key object used for associating information with a task.*/ + final case class Tag(name: String) + + val tagsKey = AttributeKey[TagMap]("tags", "Attributes restricting concurrent execution of tasks.") + + /** A standard tag describing the number of tasks that do not otherwise have any tags.*/ + val Untagged = Tag("untagged") + + /** A standard tag describing the total number of tasks. */ + val All = Tag("all") + + type TagMap = Map[Tag, Int] + + /** Implements concurrency restrictions on tasks based on Tags. + * @tparma A type of a task + * @param get extracts tags from a task + * @param validF defines whether a set of tasks are allowed to execute concurrently based on their merged tags*/ + def tagged[A](get: A => TagMap, validF: TagMap => Boolean): ConcurrentRestrictions[A] = + new ConcurrentRestrictions[A] + { + type G = TagMap + def empty = Map.empty + def add(g: TagMap, a: A) = merge(g, a, get)(_ + _) + def remove(g: TagMap, a: A) = merge(g, a, get)(_ - _) + def valid(g: TagMap) = validF(g) + } + + private[this] def merge[A](m: TagMap, a: A, get: A => TagMap)(f: (Int,Int) => Int): TagMap = + { + val base = merge(m, get(a))(f) + val un = if(base.isEmpty) update(base, Untagged, 1)(f) else base + update(un, All, 1)(f) + } + + private[this] def update[A,B](m: Map[A,B], a: A, b: B)(f: (B,B) => B): Map[A,B] = + { + val newb = + (m get a) match { + case Some(bv) => f(bv,b) + case None => b + } + m.updated(a,newb) + } + private[this] def merge[A,B](m: Map[A,B], n: Map[A,B])(f: (B,B) => B): Map[A,B] = + (m /: n) { case (acc, (a,b)) => update(acc, a, b)(f) } + + /** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution. + * @return a pair, with _1 being the CompletionService and _2 a function to shutdown the service. + * @tparam A the task type + * @tparam G describes a set of tasks + * @tparam R the type of data that will be computed by the CompletionService. */ + def completionService[A,R](tags: ConcurrentRestrictions[A], warn: String => Unit): (CompletionService[A,R], () => Unit) = + { + val pool = Executors.newCachedThreadPool() + (completionService[A,R](pool, tags, warn), () => pool.shutdownNow() ) + } + + /** Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution + * and using the provided Executor to manage execution on threads. */ + def completionService[A,R](backing: Executor, tags: ConcurrentRestrictions[A], warn: String => Unit): CompletionService[A,R] = + { + /** Represents submitted work for a task.*/ + final class Enqueue(val node: A, val work: () => R) + + new CompletionService[A,R] + { + /** Backing service used to manage execution on threads once all constraints are satisfied. */ + private[this] val jservice = new ExecutorCompletionService[R](backing) + /** The description of the currently running tasks, used by `tags` to manage restrictions.*/ + private[this] var tagState = tags.empty + /** The number of running tasks. */ + private[this] var running = 0 + /** Tasks that cannot be run yet because they cannot execute concurrently with the currently running tasks.*/ + private[this] val pending = new LinkedList[Enqueue] + + def submit(node: A, work: () => R): Unit = synchronized + { + val newState = tags.add(tagState, node) + // if the new task is allowed to run concurrently with the currently running tasks, + // submit it to be run by the backing j.u.c.CompletionService + if(tags valid newState) + { + tagState = newState + submitValid( node, work ) + } + else + { + if(running == 0) errorAddingToIdle() + pending.add( new Enqueue(node, work) ) + } + } + private[this] def submitValid(node: A, work: () => R) = + { + running += 1 + val wrappedWork = () => try work() finally cleanup(node) + CompletionService.submit(wrappedWork, jservice) + } + private[this] def cleanup(node: A): Unit = synchronized + { + running -= 1 + tagState = tags.remove(tagState, node) + if(!tags.valid(tagState)) warn("Invalid restriction: removing a completed node from a valid system must result in a valid system.") + submitValid(new LinkedList) + } + private[this] def errorAddingToIdle() = warn("Invalid restriction: adding a node to an idle system must be allowed.") + + /** Submits pending tasks that are now allowed to executed. */ + @tailrec private[this] def submitValid(tried: Queue[Enqueue]): Unit = + if(pending.isEmpty) + { + if(!tried.isEmpty) + { + if(running == 0) errorAddingToIdle() + pending.addAll(tried) + } + } + else + { + val next = pending.remove() + val newState = tags.add(tagState, next.node) + if(tags.valid(newState)) + { + tagState = newState + submitValid(next.node, next.work) + } + else + tried.add(next) + submitValid(tried) + } + + def take(): R = jservice.take().get() + } + } +} \ No newline at end of file diff --git a/tasks/standard/Action.scala b/tasks/standard/Action.scala index f0dd91253..c79186838 100644 --- a/tasks/standard/Action.scala +++ b/tasks/standard/Action.scala @@ -5,6 +5,7 @@ package sbt import Types._ import Task._ + import ConcurrentRestrictions.{Tag, TagMap, tagsKey} // Action, Task, and Info are intentionally invariant in their type parameter. // Various natural transformations used, such as PMap, require invariant type constructors for correctness @@ -34,6 +35,10 @@ final case class Task[T](info: Info[T], work: Action[T]) { override def toString = info.name getOrElse ("Task(" + info + ")") override def hashCode = info.hashCode + + def tag(tags: Tag*): Task[T] = tagw(tags.map(t => (t, 1)) : _*) + def tagw(tags: (Tag, Int)*): Task[T] = copy(info = info.set(tagsKey, info.get(tagsKey).getOrElse(Map.empty) ++ tags )) + def tags: TagMap = info get tagsKey getOrElse Map.empty } /** Used to provide information about a task, such as the name, description, and tags for controlling concurrent execution. * @param attributes Arbitrary user-defined key/value pairs describing this task @@ -46,6 +51,7 @@ final case class Info[T](attributes: AttributeMap = AttributeMap.empty, post: T def setName(n: String) = set(Name, n) def setDescription(d: String) = set(Description, d) def set[T](key: AttributeKey[T], value: T) = copy(attributes = this.attributes.put(key, value)) + def get[T](key: AttributeKey[T]) = attributes.get(key) def postTransform[A](f: (T, AttributeMap) => AttributeMap) = copy(post = (t: T) => f(t, post(t)) ) override def toString = if(attributes.isEmpty) "_" else attributes.toString diff --git a/tasks/standard/src/test/scala/TaskSerial.scala b/tasks/standard/src/test/scala/TaskSerial.scala index db1564e54..668c26241 100644 --- a/tasks/standard/src/test/scala/TaskSerial.scala +++ b/tasks/standard/src/test/scala/TaskSerial.scala @@ -4,17 +4,21 @@ package std import Types._ import TaskExtra._ import TaskTest.tryRun + import TaskGen.{MaxWorkers,MaxWorkersGen} import org.scalacheck._ import Prop.forAll import Transform.taskToNode + import ConcurrentRestrictions.{All, completionService, limitTotal, tagged => tagged0, TagMap, unrestricted} + + import java.util.concurrent.{CountDownLatch, TimeUnit} object TaskSerial extends Properties("task serial") { val checkCycles = true - val maxWorkers = 100 + val Timeout = 100 // in milliseconds - def eval[T](t: Task[T]): T = tryRun(t, checkCycles, maxWorkers) + def eval[T](t: Task[T]): T = tryRun(t, checkCycles, limitTotal(MaxWorkers)) property("Evaluates basic") = forAll { (i: Int) => checkResult( eval( task(i) ), i ) @@ -24,20 +28,50 @@ object TaskSerial extends Properties("task serial") checkResult( eval( () => i ), i ) } - + // verifies that all tasks get scheduled simultaneously (1-3) or do not (4) + property("Allows arbitrary task limit") = forAll(MaxWorkersGen) { (sze: Int) => + val size = math.max(1, sze) + val halfSize = size / 2 + 1 + val all = + checkArbitrary(size, tagged(_ => true), true ) && + checkArbitrary(size, unrestricted[Task[_]], true ) && + checkArbitrary(size, limitTotal[Task[_]](size), true ) && + checkArbitrary(size, limitTotal[Task[_]](halfSize), size <= halfSize ) + all :| ("Size: " + size) :| ("Half size: " + halfSize) + } + + def checkArbitrary(size: Int, restrictions: ConcurrentRestrictions[Task[_]], shouldSucceed: Boolean) = + { + val latch = task { new CountDownLatch(size) } + def mktask = latch map { l => + l.countDown() + l.await(Timeout, TimeUnit.MILLISECONDS) + } + val tasks = (0 until size).map(_ => mktask).toList.join.map { results => + val success = results.forall(idFun[Boolean]) + assert( success == shouldSucceed, if(shouldSucceed) unschedulableMsg else scheduledMsg) + } + checkResult( evalRestricted( tasks )( restrictions ), () ) + } + def unschedulableMsg = "Some tasks were unschedulable: verify this is an actual failure by extending the timeout to several seconds." + def scheduledMsg = "All tasks were unexpectedly scheduled." + + def tagged(f: TagMap => Boolean) = tagged0[Task[_]](_.tags, f) + def evalRestricted[T](t: Task[T])(restrictions: ConcurrentRestrictions[Task[_]]): T = + tryRun[T](t, checkCycles, restrictions) } object TaskTest { - def run[T](root: Task[T], checkCycles: Boolean, maxWorkers: Int): Result[T] = + def run[T](root: Task[T], checkCycles: Boolean, restrictions: ConcurrentRestrictions[Task[_]]): Result[T] = { - val (service, shutdown) = CompletionService[Task[_], Completed](maxWorkers) + val (service, shutdown) = completionService[Task[_],Completed](restrictions, (x: String) => System.err.println(x)) val x = new Execute[Task](checkCycles, Execute.noTriggers)(taskToNode) try { x.run(root)(service) } finally { shutdown() } } - def tryRun[T](root: Task[T], checkCycles: Boolean, maxWorkers: Int): T = - run(root, checkCycles, maxWorkers) match { + def tryRun[T](root: Task[T], checkCycles: Boolean, restrictions: ConcurrentRestrictions[Task[_]]): T = + run(root, checkCycles, restrictions) match { case Value(v) => v case Inc(i) => throw i }