diff --git a/main-settings/src/main/scala/sbt/Def.scala b/main-settings/src/main/scala/sbt/Def.scala index b63e58d3b..d67d8dc5b 100644 --- a/main-settings/src/main/scala/sbt/Def.scala +++ b/main-settings/src/main/scala/sbt/Def.scala @@ -235,6 +235,12 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits { def inputTaskDyn[T](t: Def.Initialize[Task[T]]): Def.Initialize[InputTask[T]] = macro inputTaskDynMacroImpl[T] + /** Returns `PromiseWrap[A]`, which is a wrapper around `scala.concurrent.Promise`. + * When a task is typed promise (e.g. `Def.Initialize[Task[PromiseWrap[A]]]`),an implicit + * method called `await` is injected which will run in a thread outside of concurrent restriction budget. + */ + def promise[A]: PromiseWrap[A] = new PromiseWrap[A]() + // The following conversions enable the types Initialize[T], Initialize[Task[T]], and Task[T] to // be used in task and setting macros as inputs with an ultimate result of type T diff --git a/main-settings/src/main/scala/sbt/PromiseWrap.scala b/main-settings/src/main/scala/sbt/PromiseWrap.scala new file mode 100644 index 000000000..2706ac5bc --- /dev/null +++ b/main-settings/src/main/scala/sbt/PromiseWrap.scala @@ -0,0 +1,21 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt + +import scala.concurrent.{ Promise => XPromise } + +final class PromiseWrap[A] { + private[sbt] val underlying: XPromise[A] = XPromise() + def complete(result: Result[A]): Unit = + result match { + case Inc(cause) => underlying.failure(cause) + case Value(value) => underlying.success(value) + } + def success(value: A): Unit = underlying.success(value) + def failure(cause: Throwable): Unit = underlying.failure(cause) +} diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index 2816acfe9..c16ef6b1c 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -444,10 +444,16 @@ object EvaluateTask { log.debug( s"Running task... Cancel: ${config.cancelStrategy}, check cycles: ${config.checkCycles}, forcegc: ${config.forceGarbageCollection}" ) + def tagMap(t: Task[_]): Tags.TagMap = + t.info.get(tagsKey).getOrElse(Map.empty) val tags = - tagged[Task[_]](_.info get tagsKey getOrElse Map.empty, Tags.predicate(config.restrictions)) + tagged[Task[_]](tagMap, Tags.predicate(config.restrictions)) val (service, shutdownThreads) = - completionService[Task[_], Completed](tags, (s: String) => log.warn(s)) + completionService[Task[_], Completed]( + tags, + (s: String) => log.warn(s), + (t: Task[_]) => tagMap(t).contains(Tags.Sentinel) + ) def shutdown(): Unit = { // First ensure that all threads are stopped for task execution. diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index 7c4c1ce85..555737dd5 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -865,6 +865,24 @@ object Project extends ProjectExtra { } } + /** implicitly injected to tasks that return PromiseWrap. + */ + final class RichTaskPromise[A](i: Def.Initialize[Task[PromiseWrap[A]]]) { + import scala.concurrent.Await + import scala.concurrent.duration._ + + def await: Def.Initialize[Task[A]] = await(Duration.Inf) + + def await(atMost: Duration): Def.Initialize[Task[A]] = + (Def + .task { + val p = i.value + val result: A = Await.result(p.underlying.future, atMost) + result + }) + .tag(Tags.Sentinel) + } + import scala.reflect.macros._ def projectMacroImpl(c: blackbox.Context): c.Expr[Project] = { @@ -907,6 +925,11 @@ trait ProjectExtra { implicit def richTaskSessionVar[T](init: Initialize[Task[T]]): Project.RichTaskSessionVar[T] = new Project.RichTaskSessionVar(init) + implicit def sbtRichTaskPromise[A]( + i: Initialize[Task[PromiseWrap[A]]] + ): Project.RichTaskPromise[A] = + new Project.RichTaskPromise(i) + def inThisBuild(ss: Seq[Setting[_]]): Seq[Setting[_]] = inScope(ThisScope.copy(project = Select(ThisBuild)))(ss) diff --git a/main/src/main/scala/sbt/Tags.scala b/main/src/main/scala/sbt/Tags.scala index 7b30bbdd2..ec37b8807 100644 --- a/main/src/main/scala/sbt/Tags.scala +++ b/main/src/main/scala/sbt/Tags.scala @@ -21,6 +21,8 @@ object Tags { val Update = Tag("update") val Publish = Tag("publish") val Clean = Tag("clean") + // special tag for waiting on a promise + val Sentinel = Tag("sentinel") val CPU = Tag("cpu") val Network = Tag("network") diff --git a/sbt/src/sbt-test/actions/promise/build.sbt b/sbt/src/sbt-test/actions/promise/build.sbt new file mode 100644 index 000000000..977218fb9 --- /dev/null +++ b/sbt/src/sbt-test/actions/promise/build.sbt @@ -0,0 +1,40 @@ +val midpoint = taskKey[PromiseWrap[Int]]("") +val longRunning = taskKey[Unit]("") +val midTask = taskKey[Unit]("") +val joinTwo = taskKey[Unit]("") +val output = settingKey[File]("") + +lazy val root = (project in file(".")) + .settings( + name := "promise", + output := baseDirectory.value / "output.txt", + midpoint := Def.promise[Int], + longRunning := { + val p = midpoint.value + val st = streams.value + IO.write(output.value, "start\n", append = true) + Thread.sleep(100) + p.success(5) + Thread.sleep(100) + IO.write(output.value, "end\n", append = true) + }, + midTask := { + val st = streams.value + val x = midpoint.await.value + IO.write(output.value, s"$x in the middle\n", append = true) + }, + joinTwo := { + val x = longRunning.value + val y = midTask.value + }, + TaskKey[Unit]("check") := { + val lines = IO.read(output.value).linesIterator.toList + assert(lines == List("start", "5 in the middle", "end")) + () + }, + TaskKey[Unit]("check2") := { + val lines = IO.read(output.value).linesIterator.toList + assert(lines == List("start", "end", "5 in the middle")) + () + }, + ) diff --git a/sbt/src/sbt-test/actions/promise/test b/sbt/src/sbt-test/actions/promise/test new file mode 100644 index 000000000..24416f93e --- /dev/null +++ b/sbt/src/sbt-test/actions/promise/test @@ -0,0 +1,8 @@ +> joinTwo +> check + +$ delete output.txt +# check that we won't thread starve +> set Global / concurrentRestrictions := Seq(Tags.limitAll(1)) +> joinTwo +> check2 diff --git a/tasks/src/main/scala/sbt/CompletionService.scala b/tasks/src/main/scala/sbt/CompletionService.scala index 2338ca885..54d619d09 100644 --- a/tasks/src/main/scala/sbt/CompletionService.scala +++ b/tasks/src/main/scala/sbt/CompletionService.scala @@ -8,7 +8,18 @@ package sbt trait CompletionService[A, R] { + + /** + * Submits a work node A with work that returns R. + * In Execute this is used for tasks returning sbt.Completed. + */ def submit(node: A, work: () => R): Unit + + /** + * Retrieves and removes the result from the next completed task, + * waiting if none are yet present. + * In Execute this is used for tasks returning sbt.Completed. + */ def take(): R } diff --git a/tasks/src/main/scala/sbt/ConcurrentRestrictions.scala b/tasks/src/main/scala/sbt/ConcurrentRestrictions.scala index e4db63e2b..50fba84dd 100644 --- a/tasks/src/main/scala/sbt/ConcurrentRestrictions.scala +++ b/tasks/src/main/scala/sbt/ConcurrentRestrictions.scala @@ -136,6 +136,26 @@ object ConcurrentRestrictions { }) } + def completionService[A, R]( + tags: ConcurrentRestrictions[A], + warn: String => Unit, + isSentinel: A => Boolean + ): (CompletionService[A, R], () => Unit) = { + val pool = Executors.newCachedThreadPool() + (completionService[A, R](pool, tags, warn, isSentinel), () => { + pool.shutdownNow() + () + }) + } + + def completionService[A, R]( + backing: Executor, + tags: ConcurrentRestrictions[A], + warn: String => Unit + ): CompletionService[A, R] = { + completionService[A, R](backing, tags, warn, (_: A) => false) + } + /** * Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution * and using the provided Executor to manage execution on threads. @@ -143,7 +163,8 @@ object ConcurrentRestrictions { def completionService[A, R]( backing: Executor, tags: ConcurrentRestrictions[A], - warn: String => Unit + warn: String => Unit, + isSentinel: A => Boolean, ): CompletionService[A, R] = { // Represents submitted work for a task. @@ -164,17 +185,22 @@ object ConcurrentRestrictions { private[this] val pending = new LinkedList[Enqueue] def submit(node: A, work: () => R): Unit = synchronized { - val newState = tags.add(tagState, node) - // if the new task is allowed to run concurrently with the currently running tasks, - // submit it to be run by the backing j.u.c.CompletionService - if (tags valid newState) { - tagState = newState - submitValid(node, work) - () + if (isSentinel(node)) { + // skip all checks for sentinels + CompletionService.submit(work, jservice) } else { - if (running == 0) errorAddingToIdle() - pending.add(new Enqueue(node, work)) - () + val newState = tags.add(tagState, node) + // if the new task is allowed to run concurrently with the currently running tasks, + // submit it to be run by the backing j.u.c.CompletionService + if (tags valid newState) { + tagState = newState + submitValid(node, work) + () + } else { + if (running == 0) errorAddingToIdle() + pending.add(new Enqueue(node, work)) + () + } } () }