Merge pull request #5552 from eed3si9n/wip/promise

implement Def.promise
This commit is contained in:
eugene yokota 2020-06-10 17:38:23 -04:00 committed by GitHub
commit a83be809ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 156 additions and 13 deletions

View File

@ -235,6 +235,12 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits {
def inputTaskDyn[T](t: Def.Initialize[Task[T]]): Def.Initialize[InputTask[T]] =
macro inputTaskDynMacroImpl[T]
/** Returns `PromiseWrap[A]`, which is a wrapper around `scala.concurrent.Promise`.
* When a task is typed promise (e.g. `Def.Initialize[Task[PromiseWrap[A]]]`),an implicit
* method called `await` is injected which will run in a thread outside of concurrent restriction budget.
*/
def promise[A]: PromiseWrap[A] = new PromiseWrap[A]()
// The following conversions enable the types Initialize[T], Initialize[Task[T]], and Task[T] to
// be used in task and setting macros as inputs with an ultimate result of type T

View File

@ -0,0 +1,21 @@
/*
* sbt
* Copyright 2011 - 2018, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
import scala.concurrent.{ Promise => XPromise }
final class PromiseWrap[A] {
private[sbt] val underlying: XPromise[A] = XPromise()
def complete(result: Result[A]): Unit =
result match {
case Inc(cause) => underlying.failure(cause)
case Value(value) => underlying.success(value)
}
def success(value: A): Unit = underlying.success(value)
def failure(cause: Throwable): Unit = underlying.failure(cause)
}

View File

@ -444,10 +444,16 @@ object EvaluateTask {
log.debug(
s"Running task... Cancel: ${config.cancelStrategy}, check cycles: ${config.checkCycles}, forcegc: ${config.forceGarbageCollection}"
)
def tagMap(t: Task[_]): Tags.TagMap =
t.info.get(tagsKey).getOrElse(Map.empty)
val tags =
tagged[Task[_]](_.info get tagsKey getOrElse Map.empty, Tags.predicate(config.restrictions))
tagged[Task[_]](tagMap, Tags.predicate(config.restrictions))
val (service, shutdownThreads) =
completionService[Task[_], Completed](tags, (s: String) => log.warn(s))
completionService[Task[_], Completed](
tags,
(s: String) => log.warn(s),
(t: Task[_]) => tagMap(t).contains(Tags.Sentinel)
)
def shutdown(): Unit = {
// First ensure that all threads are stopped for task execution.

View File

@ -865,6 +865,24 @@ object Project extends ProjectExtra {
}
}
/** implicitly injected to tasks that return PromiseWrap.
*/
final class RichTaskPromise[A](i: Def.Initialize[Task[PromiseWrap[A]]]) {
import scala.concurrent.Await
import scala.concurrent.duration._
def await: Def.Initialize[Task[A]] = await(Duration.Inf)
def await(atMost: Duration): Def.Initialize[Task[A]] =
(Def
.task {
val p = i.value
val result: A = Await.result(p.underlying.future, atMost)
result
})
.tag(Tags.Sentinel)
}
import scala.reflect.macros._
def projectMacroImpl(c: blackbox.Context): c.Expr[Project] = {
@ -907,6 +925,11 @@ trait ProjectExtra {
implicit def richTaskSessionVar[T](init: Initialize[Task[T]]): Project.RichTaskSessionVar[T] =
new Project.RichTaskSessionVar(init)
implicit def sbtRichTaskPromise[A](
i: Initialize[Task[PromiseWrap[A]]]
): Project.RichTaskPromise[A] =
new Project.RichTaskPromise(i)
def inThisBuild(ss: Seq[Setting[_]]): Seq[Setting[_]] =
inScope(ThisScope.copy(project = Select(ThisBuild)))(ss)

View File

@ -21,6 +21,8 @@ object Tags {
val Update = Tag("update")
val Publish = Tag("publish")
val Clean = Tag("clean")
// special tag for waiting on a promise
val Sentinel = Tag("sentinel")
val CPU = Tag("cpu")
val Network = Tag("network")

View File

@ -0,0 +1,40 @@
val midpoint = taskKey[PromiseWrap[Int]]("")
val longRunning = taskKey[Unit]("")
val midTask = taskKey[Unit]("")
val joinTwo = taskKey[Unit]("")
val output = settingKey[File]("")
lazy val root = (project in file("."))
.settings(
name := "promise",
output := baseDirectory.value / "output.txt",
midpoint := Def.promise[Int],
longRunning := {
val p = midpoint.value
val st = streams.value
IO.write(output.value, "start\n", append = true)
Thread.sleep(100)
p.success(5)
Thread.sleep(100)
IO.write(output.value, "end\n", append = true)
},
midTask := {
val st = streams.value
val x = midpoint.await.value
IO.write(output.value, s"$x in the middle\n", append = true)
},
joinTwo := {
val x = longRunning.value
val y = midTask.value
},
TaskKey[Unit]("check") := {
val lines = IO.read(output.value).linesIterator.toList
assert(lines == List("start", "5 in the middle", "end"))
()
},
TaskKey[Unit]("check2") := {
val lines = IO.read(output.value).linesIterator.toList
assert(lines == List("start", "end", "5 in the middle"))
()
},
)

View File

@ -0,0 +1,8 @@
> joinTwo
> check
$ delete output.txt
# check that we won't thread starve
> set Global / concurrentRestrictions := Seq(Tags.limitAll(1))
> joinTwo
> check2

View File

@ -8,7 +8,18 @@
package sbt
trait CompletionService[A, R] {
/**
* Submits a work node A with work that returns R.
* In Execute this is used for tasks returning sbt.Completed.
*/
def submit(node: A, work: () => R): Unit
/**
* Retrieves and removes the result from the next completed task,
* waiting if none are yet present.
* In Execute this is used for tasks returning sbt.Completed.
*/
def take(): R
}

View File

@ -136,6 +136,26 @@ object ConcurrentRestrictions {
})
}
def completionService[A, R](
tags: ConcurrentRestrictions[A],
warn: String => Unit,
isSentinel: A => Boolean
): (CompletionService[A, R], () => Unit) = {
val pool = Executors.newCachedThreadPool()
(completionService[A, R](pool, tags, warn, isSentinel), () => {
pool.shutdownNow()
()
})
}
def completionService[A, R](
backing: Executor,
tags: ConcurrentRestrictions[A],
warn: String => Unit
): CompletionService[A, R] = {
completionService[A, R](backing, tags, warn, (_: A) => false)
}
/**
* Constructs a CompletionService suitable for backing task execution based on the provided restrictions on concurrent task execution
* and using the provided Executor to manage execution on threads.
@ -143,7 +163,8 @@ object ConcurrentRestrictions {
def completionService[A, R](
backing: Executor,
tags: ConcurrentRestrictions[A],
warn: String => Unit
warn: String => Unit,
isSentinel: A => Boolean,
): CompletionService[A, R] = {
// Represents submitted work for a task.
@ -164,17 +185,22 @@ object ConcurrentRestrictions {
private[this] val pending = new LinkedList[Enqueue]
def submit(node: A, work: () => R): Unit = synchronized {
val newState = tags.add(tagState, node)
// if the new task is allowed to run concurrently with the currently running tasks,
// submit it to be run by the backing j.u.c.CompletionService
if (tags valid newState) {
tagState = newState
submitValid(node, work)
()
if (isSentinel(node)) {
// skip all checks for sentinels
CompletionService.submit(work, jservice)
} else {
if (running == 0) errorAddingToIdle()
pending.add(new Enqueue(node, work))
()
val newState = tags.add(tagState, node)
// if the new task is allowed to run concurrently with the currently running tasks,
// submit it to be run by the backing j.u.c.CompletionService
if (tags valid newState) {
tagState = newState
submitValid(node, work)
()
} else {
if (running == 0) errorAddingToIdle()
pending.add(new Enqueue(node, work))
()
}
}
()
}