From 9439a737b8594b0479bf5135bc4e48ec76c4c168 Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Sat, 21 May 2011 13:51:13 -0400 Subject: [PATCH] make parallel execution configurable, fixes #22 --- main/Aggregation.scala | 12 +++++++++--- main/Defaults.scala | 11 ++++++----- main/Keys.scala | 1 + main/actions/Tests.scala | 16 +++++++++++----- tasks/CompletionService.scala | 2 +- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/main/Aggregation.scala b/main/Aggregation.scala index 06c7b4b8a..ccba7ad3d 100644 --- a/main/Aggregation.scala +++ b/main/Aggregation.scala @@ -78,17 +78,23 @@ final object Aggregation { import EvaluateTask._ import std.TaskExtra._ + val extracted = Project extract s val toRun = ts map { case KeyValue(k,t) => t.map(v => KeyValue(k,v)) } join; + val workers = maxWorkers(extracted, structure) val start = System.currentTimeMillis - val result = withStreams(structure){ str => runTask(toRun, str, structure.index.triggers)(nodeView(s, str, extra.tasks, extra.values)) } + val result = withStreams(structure){ str => runTask(toRun, str, structure.index.triggers, maxWorkers = workers)(nodeView(s, str, extra.tasks, extra.values)) } val stop = System.currentTimeMillis val log = logger(s) - lazy val extracted = Project.extract(s) val success = result match { case Value(_) => true; case Inc(_) => false } try { onResult(result, log) { results => if(show) printSettings(results, log) } } - finally { printSuccess(start, stop, Project.extract(s), success, log) } + finally { printSuccess(start, stop, extracted, success, log) } } + def maxWorkers(extracted: Extracted, structure: Load.BuildStructure): Int = + (Keys.parallelExecution in extracted.currentRef get structure.data) match { + case Some(true) | None => EvaluateTask.SystemProcessors + case Some(false) => 1 + } def printSuccess(start: Long, stop: Long, extracted: Extracted, success: Boolean, log: Logger) { import extracted._ diff --git a/main/Defaults.scala b/main/Defaults.scala index 2083e0652..29147b9a3 100644 --- a/main/Defaults.scala +++ b/main/Defaults.scala @@ -43,6 +43,7 @@ object Defaults extends BuildCommon managedDirectory <<= baseDirectory(_ / "lib_managed") )) def globalCore: Seq[Setting[_]] = inScope(GlobalScope)(Seq( + parallelExecution :== true, pollInterval :== 500, logBuffered :== false, trapExit :== false, @@ -227,8 +228,8 @@ object Defaults extends BuildCommon definedTests <<= TaskData.writeRelated(detectTests)(_.map(_.name).distinct) triggeredBy compile, testListeners :== Nil, testOptions :== Nil, - executeTests <<= (streams in test, loadedTestFrameworks, testOptions in test, testLoader, definedTests) flatMap { - (s, frameworkMap, options, loader, discovered) => Tests(frameworkMap, loader, discovered, options, s.log) + executeTests <<= (streams in test, loadedTestFrameworks, parallelExecution in test, testOptions in test, testLoader, definedTests) flatMap { + (s, frameworkMap, par, options, loader, discovered) => Tests(frameworkMap, loader, discovered, options, par, s.log) }, test <<= (executeTests, streams) map { (results, s) => Tests.showResults(s.log, results) }, testOnly <<= testOnlyTask @@ -256,10 +257,10 @@ object Defaults extends BuildCommon def testOnlyTask = InputTask( TaskData(definedTests)(testOnlyParser)(Nil) ) { result => - (streams, loadedTestFrameworks, testOptions in testOnly, testLoader, definedTests, result) flatMap { - case (s, frameworks, opts, loader, discovered, (tests, frameworkOptions)) => + (streams, loadedTestFrameworks, parallelExecution in testOnly, testOptions in testOnly, testLoader, definedTests, result) flatMap { + case (s, frameworks, par, opts, loader, discovered, (tests, frameworkOptions)) => val modifiedOpts = Tests.Filter(if(tests.isEmpty) _ => true else tests.toSet ) +: Tests.Argument(frameworkOptions : _*) +: opts - Tests(frameworks, loader, discovered, modifiedOpts, s.log) map { results => + Tests(frameworks, loader, discovered, modifiedOpts, par, s.log) map { results => Tests.showResults(s.log, results) } } diff --git a/main/Keys.scala b/main/Keys.scala index 997027707..3f9b77206 100644 --- a/main/Keys.scala +++ b/main/Keys.scala @@ -249,6 +249,7 @@ object Keys val sbtDependency = SettingKey[ModuleID]("sbt-dependency", "Provides a definition for declaring the current version of sbt.") // special + val parallelExecution = SettingKey[Boolean]("parallel-execution", "Enables (true) or disables (false) parallel execution of tasks.") val settings = TaskKey[Settings[Scope]]("settings", "Provides access to the project data for the build.") val streams = TaskKey[TaskStreams]("streams", "Provides streams for logging and persisting data.") val isDummyTask = AttributeKey[Boolean]("is-dummy-task", "Internal: used to identify dummy tasks. sbt injects values for these tasks at the start of task execution.") diff --git a/main/actions/Tests.scala b/main/actions/Tests.scala index 53a5739df..74245326a 100644 --- a/main/actions/Tests.scala +++ b/main/actions/Tests.scala @@ -40,7 +40,7 @@ object Tests final case class Argument(framework: Option[TestFramework], args: List[String]) extends TestOption - def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], log: Logger): Task[Output] = + def apply(frameworks: Map[TestFramework, Framework], testLoader: ClassLoader, discovered: Seq[TestDefinition], options: Seq[TestOption], parallel: Boolean, log: Logger): Task[Output] = { import mutable.{HashSet, ListBuffer, Map, Set} val testFilters = new ListBuffer[String => Boolean] @@ -80,12 +80,12 @@ object Tests def includeTest(test: TestDefinition) = !excludeTestsSet.contains(test.name) && testFilters.forall(filter => filter(test.name)) val tests = discovered.filter(includeTest).toSet.toSeq val arguments = testArgsByFramework.map { case (k,v) => (k, v.toList) } toMap; - testTask(frameworks.values.toSeq, testLoader, tests, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments) + testTask(frameworks.values.toSeq, testLoader, tests, setup.readOnly, cleanup.readOnly, log, testListeners.readOnly, arguments, parallel) } def testTask(frameworks: Seq[Framework], loader: ClassLoader, tests: Seq[TestDefinition], userSetup: Iterable[ClassLoader => Unit], userCleanup: Iterable[ClassLoader => Unit], - log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]]): Task[Output] = + log: Logger, testListeners: Seq[TestReportListener], arguments: Map[Framework, Seq[String]], parallel: Boolean): Task[Output] = { def fj(actions: Iterable[() => Unit]): Task[Unit] = nop.dependsOn( actions.toSeq.fork( _() ) : _*) def partApp(actions: Iterable[ClassLoader => Unit]) = actions.toSeq map {a => () => a(loader) } @@ -94,12 +94,18 @@ object Tests TestFramework.testTasks(frameworks, loader, tests, log, testListeners, arguments) val setupTasks = fj(partApp(userSetup) :+ frameworkSetup) - val mainTasks = runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name } - mainTasks.toSeq.join map processResults flatMap { results => + val mainTasks = if(parallel) makeParallel(runnables, setupTasks).toSeq.join else makeSerial(runnables, setupTasks) + mainTasks map processResults flatMap { results => val cleanupTasks = fj(partApp(userCleanup) :+ frameworkCleanup(results._1)) cleanupTasks map { _ => results } } } + type TestRunnable = (String, () => TestResult.Value) + def makeParallel(runnables: Iterable[TestRunnable], setupTasks: Task[Unit]) = + runnables map { case (name, test) => task { (name, test()) } dependsOn setupTasks named name } + def makeSerial(runnables: Iterable[TestRunnable], setupTasks: Task[Unit]) = + task { runnables map { case (name, test) => (name, test()) } } dependsOn(setupTasks) + def processResults(results: Iterable[(String, TestResult.Value)]): (TestResult.Value, Map[String, TestResult.Value]) = (overall(results.map(_._2)), results.toMap) def overall(results: Iterable[TestResult.Value]): TestResult.Value = diff --git a/tasks/CompletionService.scala b/tasks/CompletionService.scala index 9dc924df4..799245ef1 100644 --- a/tasks/CompletionService.scala +++ b/tasks/CompletionService.scala @@ -15,7 +15,7 @@ object CompletionService { def apply[A, T](poolSize: Int): (CompletionService[A,T], () => Unit) = { - val pool = Executors.newFixedThreadPool(2) + val pool = Executors.newFixedThreadPool(poolSize) (apply[A,T]( pool ), () => pool.shutdownNow() ) } def apply[A, T](x: Executor): CompletionService[A,T] =