diff --git a/build.sbt b/build.sbt index 80fad059e..c1c894578 100644 --- a/build.sbt +++ b/build.sbt @@ -11,7 +11,7 @@ import com.eed3si9n.jarjarabrams.ModuleCoordinate // ThisBuild settings take lower precedence, // but can be shared across the multi projects. ThisBuild / version := { - val v = "2.0.0-RC13-bin-SNAPSHOT" + val v = "2.0.0-RC15-bin-SNAPSHOT" nightlyVersion.getOrElse(v) } // update sbt.sh at root @@ -558,7 +558,7 @@ lazy val remoteCacheProj = (project in file("sbt-remote-cache")) pluginCrossBuild / sbtVersion := version.value, publishMavenStyle := true, mimaSettings, - libraryDependencies += remoteapis, + libraryDependencies ++= Seq(remoteapis, scalaVerify % Test), ) // Implementation and support code for defining actions. diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 6a7450b12..c89309736 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -96,7 +96,7 @@ object Dependencies { val scalaVerify = "com.eed3si9n.verify" %% "verify" % "1.0.0" val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1" val remoteapis = - "com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08" + "com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-4ee33449fff18243c019f799b636cdb8c8a18f6c" val gson = "org.scala-sbt.gson" % "shaded-gson" % "2.13.1" val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3 diff --git a/sbt-remote-cache/src/main/scala/sbt/internal/GrpcActionCacheStore.scala b/sbt-remote-cache/src/main/scala/sbt/internal/GrpcActionCacheStore.scala index 53a47412b..e5338abe7 100644 --- a/sbt-remote-cache/src/main/scala/sbt/internal/GrpcActionCacheStore.scala +++ b/sbt-remote-cache/src/main/scala/sbt/internal/GrpcActionCacheStore.scala @@ -27,8 +27,16 @@ import com.eed3si9n.remoteapis.shaded.io.grpc.{ Status, TlsChannelCredentials, } +import com.eed3si9n.remoteapis.shaded.io.grpc.stub.StreamObserver +import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamGrpc +import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamProto +import ByteStreamProto.{ ReadRequest, WriteRequest } +import java.io.InputStream import java.net.URI import java.nio.file.Path +import java.util.UUID +import java.util.concurrent.{ Executors, TimeUnit } +import sbt.io.syntax.* import sbt.util.{ AbstractActionCacheStore, ActionResult, @@ -37,11 +45,19 @@ import sbt.util.{ GetActionResultRequest, UpdateActionResultRequest, } +import scala.concurrent.{ Await, ExecutionContext, Future, Promise, TimeoutException } +import scala.concurrent.duration.* +import scala.util.Using import scala.util.control.NonFatal import scala.jdk.CollectionConverters.* import xsbti.{ HashedVirtualFileRef, VirtualFile } object GrpcActionCacheStore: + // chunk uploads to 1MB + val chunkSizeBytes = 1024 * 1024 + val remoteTimeoutInSec = 60 + val remoteTimeout = (remoteTimeoutInSec + 2).second + def apply( uri: URI, rootCerts: Option[Path], @@ -102,10 +118,20 @@ object GrpcActionCacheStore: applier.apply(headers) catch case NonFatal(e) => applier.fail(Status.UNAUTHENTICATED.withCause(e)) end AuthCallCredentials + + def chunkBytes(sizeBytes: Long): List[Long] = + if sizeBytes <= 0 then Nil + else + val full = sizeBytes / chunkSizeBytes + val remainder = sizeBytes % chunkSizeBytes + val xs = List.fill(full.toInt)(chunkSizeBytes.toLong) + if remainder > 0 then xs ::: List(remainder) + else xs end GrpcActionCacheStore /* * https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto + * https://github.com/googleapis/googleapis/blob/ff15be54722218705740b9fc6223d264c4cdb6dd/google/bytestream/bytestream.proto */ class GrpcActionCacheStore( channel: ManagedChannel, @@ -113,6 +139,8 @@ class GrpcActionCacheStore( remoteHeaders: List[String], disk: DiskActionCacheStore, ) extends AbstractActionCacheStore: + import GrpcActionCacheStore.* + lazy val creds = GrpcActionCacheStore.AuthCallCredentials(remoteHeaders) lazy val acStub0 = ActionCacheGrpc.newBlockingStub(channel) lazy val acStub = remoteHeaders match @@ -122,9 +150,20 @@ class GrpcActionCacheStore( lazy val casStub = remoteHeaders match case x :: xs => casStub0.withCallCredentials(creds) case _ => casStub0 + lazy val byteStreamStub0 = ByteStreamGrpc.newStub(channel) + lazy val byteStreamStub = remoteHeaders match + case x :: xs => + byteStreamStub0 + .withCallCredentials(creds) + .withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS) + case _ => + byteStreamStub0.withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS) override def storeName: String = "remote" + val fixedThreadPool = Executors.newFixedThreadPool(100) + given ExecutionContext = ExecutionContext.fromExecutor(fixedThreadPool) + /** * https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L170 */ @@ -157,10 +196,49 @@ class GrpcActionCacheStore( Right(toActionResult(result)) catch case NonFatal(e) => Left(e) + /** + * https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403 + */ + override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] = + val digests = refs.map(Digest(_)) + val totalBytes = digests.map(_.sizeBytes).sum + if refs.isEmpty then Nil + else if totalBytes <= chunkSizeBytes then + val result = batchReadBlobs(refs) + val blobs = result.getResponsesList().asScala.toList + val allOk = blobs.forall(_.getStatus().getCode() == 0) + if allOk then + // do not assume the responses to come in order + val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] = + Map(blobs.map(res => toDigest(res.getDigest) -> res)*) + refs.map: r => + val digest = Digest(r) + val blob = lookupResponse(digest) + val casFile = disk.putBlob(blob.getData().newInput(), digest) + disk.syncFile(r, casFile, outputDirectory) + else Nil + else + val paths = Await.result(downloadBlobs(digests, outputDirectory), remoteTimeout) + refs + .zip(digests) + .zip(paths) + .map { case ((r, digest), p) => + val casFile = disk.putBlobInternal(p, digest) + disk.syncFile(r, casFile, outputDirectory) + } + /** * https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L379 */ override def putBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] = + val totalBytes = blobs.map(_.sizeBytes).sum + if blobs.isEmpty then Nil + else if totalBytes <= chunkSizeBytes then batchUpdateBlobs(blobs) + else + try Await.result(uploadBlobs(blobs).recover(_ => Nil), remoteTimeout) + catch case _: TimeoutException => Nil + + def batchUpdateBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] = val b = BatchUpdateBlobsRequest.newBuilder() b.setInstanceName(instanceName) b.setDigestFunction(DigestFunction.Value.SHA256) @@ -182,23 +260,89 @@ class GrpcActionCacheStore( Some(HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes)) else None + def uploadBlobs(blobs: Seq[VirtualFile]): Future[Seq[HashedVirtualFileRef]] = + Future.sequence(blobs.map(uploadBlob)) + + def uploadBlob(blob: VirtualFile): Future[HashedVirtualFileRef] = + val d = Digest(blob) + withSingleResponse[ByteStreamProto.WriteResponse, HashedVirtualFileRef]: (p, resObs) => + val reqObs = byteStreamStub.write(resObs) + val un = uploadName(d, UUID.randomUUID()) + var off: Long = 0L + try + Using.resource(blob.input()): input => + val chunks = chunkBytes(d.sizeBytes) + chunks.zipWithIndex.foreach: (chunk, idx) => + val b = WriteRequest.newBuilder() + if idx == 0 then b.setResourceName(un) + else () + b.setWriteOffset(off) + b.setData(toByteString(input, chunk)) + if idx == chunks.size - 1 then b.setFinishWrite(true) + else () + val req = b.build() + off = off + chunk + reqObs.onNext(req) + catch + case NonFatal(e) => + reqObs.onError(e) + p.failure(e) + reqObs.onCompleted() + p.future.map: _ => + HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes) + + private def downloadBlobs(digests: Seq[Digest], outputDirectory: Path): Future[Seq[Path]] = + Future.sequence(digests.map: x => + downloadBlob(x, outputDirectory)) + + private def downloadBlob(digest: Digest, outputDirectory: Path): Future[Path] = + val p = Promise[Path]() + val uuid = UUID.randomUUID() + val tempFile = outputDirectory.toFile() / s"$uuid.part" + sbt.io.Using.fileOutputStream(false)(tempFile): out => + val resObs = new StreamObserver[ByteStreamProto.ReadResponse]: + override def onCompleted(): Unit = + p.success(tempFile.toPath()) + override def onError(e: Throwable): Unit = p.failure(e) + override def onNext(res: ByteStreamProto.ReadResponse): Unit = + IO.transfer(res.getData().newInput(), out) + val b = ReadRequest.newBuilder() + val dn = downloadName(digest) + b.setResourceName(dn) + b.setReadOffset(0L) + val req = b.build() + byteStreamStub.read(req, resObs) + p.future + + // helper function for many-to-one gRPC streaming + // https://grpc.io/docs/languages/java/basics/#client-side-streaming-rpc-1 + private def withSingleResponse[A1, A2]( + f: (Promise[A1], StreamObserver[A1]) => Future[A2] + ): Future[A2] = + val p = Promise[A1]() + val observer = new StreamObserver[A1]: + var o: Option[A1] = None + override def onCompleted(): Unit = + if o.isDefined then p.success(o.get) + else p.failure(new RuntimeException("unexpected onCompleted")) + override def onError(e: Throwable): Unit = p.failure(e) + override def onNext(res: A1): Unit = + o = Some(res) + f(p, observer) + /** - * https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403 + * resource name is load-bearing. + * https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L219-L220 */ - override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] = - val result = doGetBlobs(refs) - val blobs = result.getResponsesList().asScala.toList - val allOk = blobs.forall(_.getStatus().getCode() == 0) - if allOk then - // do not assume the responses to come in order - val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] = - Map(blobs.map(res => toDigest(res.getDigest) -> res)*) - refs.map: r => - val digest = Digest(r) - val blob = lookupResponse(digest) - val casFile = disk.putBlob(blob.getData().newInput(), digest) - disk.syncFile(r, casFile, outputDirectory) - else Nil + private def uploadName(d: Digest, uuid: UUID): String = + s"$instanceName/uploads/$uuid/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}" + + /** + * resource name is load-bearing. + * https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L294-L295 + */ + private def downloadName(d: Digest): String = + s"$instanceName/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}" /** * https://github.com/bazelbuild/remote-apis/blob/96942a2107c702ed3ca4a664f7eeb7c85ba8dc77/build/bazel/remote/execution/v2/remote_execution.proto#L1629 @@ -216,7 +360,7 @@ class GrpcActionCacheStore( if missing(Digest(r)) then None else Some(r) - private def doGetBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse = + private def batchReadBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse = val b = BatchReadBlobsRequest.newBuilder() b.setInstanceName(instanceName) refs.foreach: ref => @@ -268,4 +412,24 @@ class GrpcActionCacheStore( val out = ByteString.newOutput() IO.transfer(blob.input(), out) out.toByteString() + + private def toByteString(input: InputStream, size: Long): ByteString = + val BufferSize = 8192 + val out = ByteString.newOutput() + if size <= 0 then out.toByteString() + else + var buf = new Array[Byte](BufferSize) + var remaining = size + while + if remaining >= BufferSize then + if buf.size != BufferSize then buf = new Array[Byte](BufferSize) + else () + else buf = new Array[Byte](remaining.toInt) + val readBytes = input.read(buf) + if readBytes > 0 then out.write(buf, 0, readBytes) + else () + remaining = remaining - readBytes + readBytes > 0 + do () + out.toByteString() end GrpcActionCacheStore diff --git a/sbt-remote-cache/src/test/scala/sbt/internal/GrpcActionCacheStoreTest.scala b/sbt-remote-cache/src/test/scala/sbt/internal/GrpcActionCacheStoreTest.scala new file mode 100644 index 000000000..241ccabca --- /dev/null +++ b/sbt-remote-cache/src/test/scala/sbt/internal/GrpcActionCacheStoreTest.scala @@ -0,0 +1,18 @@ +package sbt +package internal + +object GrpcActionCacheStoreTest extends verify.BasicTestSuite: + test("chunkBytes"): + val actual = GrpcActionCacheStore.chunkBytes(0L) + assert(actual == Nil) + + val actual2 = GrpcActionCacheStore.chunkBytes(1L) + assert(actual2 == List(1L)) + + val meg = 1024L * 1024L + val actual3 = GrpcActionCacheStore.chunkBytes(meg) + assert(actual3 == List(meg)) + + val actual4 = GrpcActionCacheStore.chunkBytes(meg + 1) + assert(actual4 == List(meg, 1L)) +end GrpcActionCacheStoreTest diff --git a/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala b/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala index 9c8475ddc..4eaf41961 100644 --- a/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala +++ b/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala @@ -239,6 +239,14 @@ case class DiskActionCacheStore(base: Path, converter: FileConverter) putBlob(in, digest) } + /** Move blob directly to CAS. Internal use only. */ + private[sbt] def putBlobInternal(blob: Path, digest: Digest): Path = + val casFile = toCasFile(digest) + if isCompleteBlob(casFile, digest) then casFile + else + IO.move(blob.toFile(), casFile.toFile()) + casFile + def putBlob(input: InputStream, digest: Digest): Path = val casFile = toCasFile(digest) if isCompleteBlob(casFile, digest) then casFile