[2.x] fix: ByteStream chunked upload/download (#9298)

**Problem**

BatchUpdateBlobsRequest suffers from gRPC's message size limitation.

**Solution**

For larger files, we switch to using the ByteStream API, chunked to 1MB at a time.
This commit is contained in:
eugene yokota 2026-06-07 15:17:20 -04:00 committed by GitHub
parent 3b097cc3cc
commit ac320c7fd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 209 additions and 19 deletions

View File

@ -11,7 +11,7 @@ import com.eed3si9n.jarjarabrams.ModuleCoordinate
// ThisBuild settings take lower precedence,
// but can be shared across the multi projects.
ThisBuild / version := {
val v = "2.0.0-RC13-bin-SNAPSHOT"
val v = "2.0.0-RC15-bin-SNAPSHOT"
nightlyVersion.getOrElse(v)
}
// update sbt.sh at root
@ -558,7 +558,7 @@ lazy val remoteCacheProj = (project in file("sbt-remote-cache"))
pluginCrossBuild / sbtVersion := version.value,
publishMavenStyle := true,
mimaSettings,
libraryDependencies += remoteapis,
libraryDependencies ++= Seq(remoteapis, scalaVerify % Test),
)
// Implementation and support code for defining actions.

View File

@ -96,7 +96,7 @@ object Dependencies {
val scalaVerify = "com.eed3si9n.verify" %% "verify" % "1.0.0"
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
val remoteapis =
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08"
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-4ee33449fff18243c019f799b636cdb8c8a18f6c"
val gson = "org.scala-sbt.gson" % "shaded-gson" % "2.13.1"
val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3

View File

@ -27,8 +27,16 @@ import com.eed3si9n.remoteapis.shaded.io.grpc.{
Status,
TlsChannelCredentials,
}
import com.eed3si9n.remoteapis.shaded.io.grpc.stub.StreamObserver
import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamGrpc
import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamProto
import ByteStreamProto.{ ReadRequest, WriteRequest }
import java.io.InputStream
import java.net.URI
import java.nio.file.Path
import java.util.UUID
import java.util.concurrent.{ Executors, TimeUnit }
import sbt.io.syntax.*
import sbt.util.{
AbstractActionCacheStore,
ActionResult,
@ -37,11 +45,19 @@ import sbt.util.{
GetActionResultRequest,
UpdateActionResultRequest,
}
import scala.concurrent.{ Await, ExecutionContext, Future, Promise, TimeoutException }
import scala.concurrent.duration.*
import scala.util.Using
import scala.util.control.NonFatal
import scala.jdk.CollectionConverters.*
import xsbti.{ HashedVirtualFileRef, VirtualFile }
object GrpcActionCacheStore:
// chunk uploads to 1MB
val chunkSizeBytes = 1024 * 1024
val remoteTimeoutInSec = 60
val remoteTimeout = (remoteTimeoutInSec + 2).second
def apply(
uri: URI,
rootCerts: Option[Path],
@ -102,10 +118,20 @@ object GrpcActionCacheStore:
applier.apply(headers)
catch case NonFatal(e) => applier.fail(Status.UNAUTHENTICATED.withCause(e))
end AuthCallCredentials
def chunkBytes(sizeBytes: Long): List[Long] =
if sizeBytes <= 0 then Nil
else
val full = sizeBytes / chunkSizeBytes
val remainder = sizeBytes % chunkSizeBytes
val xs = List.fill(full.toInt)(chunkSizeBytes.toLong)
if remainder > 0 then xs ::: List(remainder)
else xs
end GrpcActionCacheStore
/*
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto
* https://github.com/googleapis/googleapis/blob/ff15be54722218705740b9fc6223d264c4cdb6dd/google/bytestream/bytestream.proto
*/
class GrpcActionCacheStore(
channel: ManagedChannel,
@ -113,6 +139,8 @@ class GrpcActionCacheStore(
remoteHeaders: List[String],
disk: DiskActionCacheStore,
) extends AbstractActionCacheStore:
import GrpcActionCacheStore.*
lazy val creds = GrpcActionCacheStore.AuthCallCredentials(remoteHeaders)
lazy val acStub0 = ActionCacheGrpc.newBlockingStub(channel)
lazy val acStub = remoteHeaders match
@ -122,9 +150,20 @@ class GrpcActionCacheStore(
lazy val casStub = remoteHeaders match
case x :: xs => casStub0.withCallCredentials(creds)
case _ => casStub0
lazy val byteStreamStub0 = ByteStreamGrpc.newStub(channel)
lazy val byteStreamStub = remoteHeaders match
case x :: xs =>
byteStreamStub0
.withCallCredentials(creds)
.withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS)
case _ =>
byteStreamStub0.withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS)
override def storeName: String = "remote"
val fixedThreadPool = Executors.newFixedThreadPool(100)
given ExecutionContext = ExecutionContext.fromExecutor(fixedThreadPool)
/**
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L170
*/
@ -157,10 +196,49 @@ class GrpcActionCacheStore(
Right(toActionResult(result))
catch case NonFatal(e) => Left(e)
/**
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403
*/
override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] =
val digests = refs.map(Digest(_))
val totalBytes = digests.map(_.sizeBytes).sum
if refs.isEmpty then Nil
else if totalBytes <= chunkSizeBytes then
val result = batchReadBlobs(refs)
val blobs = result.getResponsesList().asScala.toList
val allOk = blobs.forall(_.getStatus().getCode() == 0)
if allOk then
// do not assume the responses to come in order
val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] =
Map(blobs.map(res => toDigest(res.getDigest) -> res)*)
refs.map: r =>
val digest = Digest(r)
val blob = lookupResponse(digest)
val casFile = disk.putBlob(blob.getData().newInput(), digest)
disk.syncFile(r, casFile, outputDirectory)
else Nil
else
val paths = Await.result(downloadBlobs(digests, outputDirectory), remoteTimeout)
refs
.zip(digests)
.zip(paths)
.map { case ((r, digest), p) =>
val casFile = disk.putBlobInternal(p, digest)
disk.syncFile(r, casFile, outputDirectory)
}
/**
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L379
*/
override def putBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] =
val totalBytes = blobs.map(_.sizeBytes).sum
if blobs.isEmpty then Nil
else if totalBytes <= chunkSizeBytes then batchUpdateBlobs(blobs)
else
try Await.result(uploadBlobs(blobs).recover(_ => Nil), remoteTimeout)
catch case _: TimeoutException => Nil
def batchUpdateBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] =
val b = BatchUpdateBlobsRequest.newBuilder()
b.setInstanceName(instanceName)
b.setDigestFunction(DigestFunction.Value.SHA256)
@ -182,23 +260,89 @@ class GrpcActionCacheStore(
Some(HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes))
else None
def uploadBlobs(blobs: Seq[VirtualFile]): Future[Seq[HashedVirtualFileRef]] =
Future.sequence(blobs.map(uploadBlob))
def uploadBlob(blob: VirtualFile): Future[HashedVirtualFileRef] =
val d = Digest(blob)
withSingleResponse[ByteStreamProto.WriteResponse, HashedVirtualFileRef]: (p, resObs) =>
val reqObs = byteStreamStub.write(resObs)
val un = uploadName(d, UUID.randomUUID())
var off: Long = 0L
try
Using.resource(blob.input()): input =>
val chunks = chunkBytes(d.sizeBytes)
chunks.zipWithIndex.foreach: (chunk, idx) =>
val b = WriteRequest.newBuilder()
if idx == 0 then b.setResourceName(un)
else ()
b.setWriteOffset(off)
b.setData(toByteString(input, chunk))
if idx == chunks.size - 1 then b.setFinishWrite(true)
else ()
val req = b.build()
off = off + chunk
reqObs.onNext(req)
catch
case NonFatal(e) =>
reqObs.onError(e)
p.failure(e)
reqObs.onCompleted()
p.future.map: _ =>
HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes)
private def downloadBlobs(digests: Seq[Digest], outputDirectory: Path): Future[Seq[Path]] =
Future.sequence(digests.map: x =>
downloadBlob(x, outputDirectory))
private def downloadBlob(digest: Digest, outputDirectory: Path): Future[Path] =
val p = Promise[Path]()
val uuid = UUID.randomUUID()
val tempFile = outputDirectory.toFile() / s"$uuid.part"
sbt.io.Using.fileOutputStream(false)(tempFile): out =>
val resObs = new StreamObserver[ByteStreamProto.ReadResponse]:
override def onCompleted(): Unit =
p.success(tempFile.toPath())
override def onError(e: Throwable): Unit = p.failure(e)
override def onNext(res: ByteStreamProto.ReadResponse): Unit =
IO.transfer(res.getData().newInput(), out)
val b = ReadRequest.newBuilder()
val dn = downloadName(digest)
b.setResourceName(dn)
b.setReadOffset(0L)
val req = b.build()
byteStreamStub.read(req, resObs)
p.future
// helper function for many-to-one gRPC streaming
// https://grpc.io/docs/languages/java/basics/#client-side-streaming-rpc-1
private def withSingleResponse[A1, A2](
f: (Promise[A1], StreamObserver[A1]) => Future[A2]
): Future[A2] =
val p = Promise[A1]()
val observer = new StreamObserver[A1]:
var o: Option[A1] = None
override def onCompleted(): Unit =
if o.isDefined then p.success(o.get)
else p.failure(new RuntimeException("unexpected onCompleted"))
override def onError(e: Throwable): Unit = p.failure(e)
override def onNext(res: A1): Unit =
o = Some(res)
f(p, observer)
/**
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403
* resource name is load-bearing.
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L219-L220
*/
override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] =
val result = doGetBlobs(refs)
val blobs = result.getResponsesList().asScala.toList
val allOk = blobs.forall(_.getStatus().getCode() == 0)
if allOk then
// do not assume the responses to come in order
val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] =
Map(blobs.map(res => toDigest(res.getDigest) -> res)*)
refs.map: r =>
val digest = Digest(r)
val blob = lookupResponse(digest)
val casFile = disk.putBlob(blob.getData().newInput(), digest)
disk.syncFile(r, casFile, outputDirectory)
else Nil
private def uploadName(d: Digest, uuid: UUID): String =
s"$instanceName/uploads/$uuid/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}"
/**
* resource name is load-bearing.
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L294-L295
*/
private def downloadName(d: Digest): String =
s"$instanceName/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}"
/**
* https://github.com/bazelbuild/remote-apis/blob/96942a2107c702ed3ca4a664f7eeb7c85ba8dc77/build/bazel/remote/execution/v2/remote_execution.proto#L1629
@ -216,7 +360,7 @@ class GrpcActionCacheStore(
if missing(Digest(r)) then None
else Some(r)
private def doGetBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse =
private def batchReadBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse =
val b = BatchReadBlobsRequest.newBuilder()
b.setInstanceName(instanceName)
refs.foreach: ref =>
@ -268,4 +412,24 @@ class GrpcActionCacheStore(
val out = ByteString.newOutput()
IO.transfer(blob.input(), out)
out.toByteString()
private def toByteString(input: InputStream, size: Long): ByteString =
val BufferSize = 8192
val out = ByteString.newOutput()
if size <= 0 then out.toByteString()
else
var buf = new Array[Byte](BufferSize)
var remaining = size
while
if remaining >= BufferSize then
if buf.size != BufferSize then buf = new Array[Byte](BufferSize)
else ()
else buf = new Array[Byte](remaining.toInt)
val readBytes = input.read(buf)
if readBytes > 0 then out.write(buf, 0, readBytes)
else ()
remaining = remaining - readBytes
readBytes > 0
do ()
out.toByteString()
end GrpcActionCacheStore

View File

@ -0,0 +1,18 @@
package sbt
package internal
object GrpcActionCacheStoreTest extends verify.BasicTestSuite:
test("chunkBytes"):
val actual = GrpcActionCacheStore.chunkBytes(0L)
assert(actual == Nil)
val actual2 = GrpcActionCacheStore.chunkBytes(1L)
assert(actual2 == List(1L))
val meg = 1024L * 1024L
val actual3 = GrpcActionCacheStore.chunkBytes(meg)
assert(actual3 == List(meg))
val actual4 = GrpcActionCacheStore.chunkBytes(meg + 1)
assert(actual4 == List(meg, 1L))
end GrpcActionCacheStoreTest

View File

@ -239,6 +239,14 @@ case class DiskActionCacheStore(base: Path, converter: FileConverter)
putBlob(in, digest)
}
/** Move blob directly to CAS. Internal use only. */
private[sbt] def putBlobInternal(blob: Path, digest: Digest): Path =
val casFile = toCasFile(digest)
if isCompleteBlob(casFile, digest) then casFile
else
IO.move(blob.toFile(), casFile.toFile())
casFile
def putBlob(input: InputStream, digest: Digest): Path =
val casFile = toCasFile(digest)
if isCompleteBlob(casFile, digest) then casFile