mirror of https://github.com/sbt/sbt.git
[2.x] fix: ByteStream chunked upload/download (#9298)
**Problem** BatchUpdateBlobsRequest suffers from gRPC's message size limitation. **Solution** For larger files, we switch to using the ByteStream API, chunked to 1MB at a time.
This commit is contained in:
parent
3b097cc3cc
commit
ac320c7fd0
|
|
@ -11,7 +11,7 @@ import com.eed3si9n.jarjarabrams.ModuleCoordinate
|
|||
// ThisBuild settings take lower precedence,
|
||||
// but can be shared across the multi projects.
|
||||
ThisBuild / version := {
|
||||
val v = "2.0.0-RC13-bin-SNAPSHOT"
|
||||
val v = "2.0.0-RC15-bin-SNAPSHOT"
|
||||
nightlyVersion.getOrElse(v)
|
||||
}
|
||||
// update sbt.sh at root
|
||||
|
|
@ -558,7 +558,7 @@ lazy val remoteCacheProj = (project in file("sbt-remote-cache"))
|
|||
pluginCrossBuild / sbtVersion := version.value,
|
||||
publishMavenStyle := true,
|
||||
mimaSettings,
|
||||
libraryDependencies += remoteapis,
|
||||
libraryDependencies ++= Seq(remoteapis, scalaVerify % Test),
|
||||
)
|
||||
|
||||
// Implementation and support code for defining actions.
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ object Dependencies {
|
|||
val scalaVerify = "com.eed3si9n.verify" %% "verify" % "1.0.0"
|
||||
val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1"
|
||||
val remoteapis =
|
||||
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-52317e00d8d4c37fa778c628485d220fb68a8d08"
|
||||
"com.eed3si9n.remoteapis.shaded" % "shaded-remoteapis-java" % "2.3.0-M1-4ee33449fff18243c019f799b636cdb8c8a18f6c"
|
||||
val gson = "org.scala-sbt.gson" % "shaded-gson" % "2.13.1"
|
||||
|
||||
val scalaCompiler = "org.scala-lang" %% "scala3-compiler" % scala3
|
||||
|
|
|
|||
|
|
@ -27,8 +27,16 @@ import com.eed3si9n.remoteapis.shaded.io.grpc.{
|
|||
Status,
|
||||
TlsChannelCredentials,
|
||||
}
|
||||
import com.eed3si9n.remoteapis.shaded.io.grpc.stub.StreamObserver
|
||||
import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamGrpc
|
||||
import com.eed3si9n.remoteapis.shaded.com.google.bytestream.ByteStreamProto
|
||||
import ByteStreamProto.{ ReadRequest, WriteRequest }
|
||||
import java.io.InputStream
|
||||
import java.net.URI
|
||||
import java.nio.file.Path
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.{ Executors, TimeUnit }
|
||||
import sbt.io.syntax.*
|
||||
import sbt.util.{
|
||||
AbstractActionCacheStore,
|
||||
ActionResult,
|
||||
|
|
@ -37,11 +45,19 @@ import sbt.util.{
|
|||
GetActionResultRequest,
|
||||
UpdateActionResultRequest,
|
||||
}
|
||||
import scala.concurrent.{ Await, ExecutionContext, Future, Promise, TimeoutException }
|
||||
import scala.concurrent.duration.*
|
||||
import scala.util.Using
|
||||
import scala.util.control.NonFatal
|
||||
import scala.jdk.CollectionConverters.*
|
||||
import xsbti.{ HashedVirtualFileRef, VirtualFile }
|
||||
|
||||
object GrpcActionCacheStore:
|
||||
// chunk uploads to 1MB
|
||||
val chunkSizeBytes = 1024 * 1024
|
||||
val remoteTimeoutInSec = 60
|
||||
val remoteTimeout = (remoteTimeoutInSec + 2).second
|
||||
|
||||
def apply(
|
||||
uri: URI,
|
||||
rootCerts: Option[Path],
|
||||
|
|
@ -102,10 +118,20 @@ object GrpcActionCacheStore:
|
|||
applier.apply(headers)
|
||||
catch case NonFatal(e) => applier.fail(Status.UNAUTHENTICATED.withCause(e))
|
||||
end AuthCallCredentials
|
||||
|
||||
def chunkBytes(sizeBytes: Long): List[Long] =
|
||||
if sizeBytes <= 0 then Nil
|
||||
else
|
||||
val full = sizeBytes / chunkSizeBytes
|
||||
val remainder = sizeBytes % chunkSizeBytes
|
||||
val xs = List.fill(full.toInt)(chunkSizeBytes.toLong)
|
||||
if remainder > 0 then xs ::: List(remainder)
|
||||
else xs
|
||||
end GrpcActionCacheStore
|
||||
|
||||
/*
|
||||
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto
|
||||
* https://github.com/googleapis/googleapis/blob/ff15be54722218705740b9fc6223d264c4cdb6dd/google/bytestream/bytestream.proto
|
||||
*/
|
||||
class GrpcActionCacheStore(
|
||||
channel: ManagedChannel,
|
||||
|
|
@ -113,6 +139,8 @@ class GrpcActionCacheStore(
|
|||
remoteHeaders: List[String],
|
||||
disk: DiskActionCacheStore,
|
||||
) extends AbstractActionCacheStore:
|
||||
import GrpcActionCacheStore.*
|
||||
|
||||
lazy val creds = GrpcActionCacheStore.AuthCallCredentials(remoteHeaders)
|
||||
lazy val acStub0 = ActionCacheGrpc.newBlockingStub(channel)
|
||||
lazy val acStub = remoteHeaders match
|
||||
|
|
@ -122,9 +150,20 @@ class GrpcActionCacheStore(
|
|||
lazy val casStub = remoteHeaders match
|
||||
case x :: xs => casStub0.withCallCredentials(creds)
|
||||
case _ => casStub0
|
||||
lazy val byteStreamStub0 = ByteStreamGrpc.newStub(channel)
|
||||
lazy val byteStreamStub = remoteHeaders match
|
||||
case x :: xs =>
|
||||
byteStreamStub0
|
||||
.withCallCredentials(creds)
|
||||
.withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS)
|
||||
case _ =>
|
||||
byteStreamStub0.withDeadlineAfter(remoteTimeoutInSec, TimeUnit.SECONDS)
|
||||
|
||||
override def storeName: String = "remote"
|
||||
|
||||
val fixedThreadPool = Executors.newFixedThreadPool(100)
|
||||
given ExecutionContext = ExecutionContext.fromExecutor(fixedThreadPool)
|
||||
|
||||
/**
|
||||
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L170
|
||||
*/
|
||||
|
|
@ -157,10 +196,49 @@ class GrpcActionCacheStore(
|
|||
Right(toActionResult(result))
|
||||
catch case NonFatal(e) => Left(e)
|
||||
|
||||
/**
|
||||
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403
|
||||
*/
|
||||
override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] =
|
||||
val digests = refs.map(Digest(_))
|
||||
val totalBytes = digests.map(_.sizeBytes).sum
|
||||
if refs.isEmpty then Nil
|
||||
else if totalBytes <= chunkSizeBytes then
|
||||
val result = batchReadBlobs(refs)
|
||||
val blobs = result.getResponsesList().asScala.toList
|
||||
val allOk = blobs.forall(_.getStatus().getCode() == 0)
|
||||
if allOk then
|
||||
// do not assume the responses to come in order
|
||||
val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] =
|
||||
Map(blobs.map(res => toDigest(res.getDigest) -> res)*)
|
||||
refs.map: r =>
|
||||
val digest = Digest(r)
|
||||
val blob = lookupResponse(digest)
|
||||
val casFile = disk.putBlob(blob.getData().newInput(), digest)
|
||||
disk.syncFile(r, casFile, outputDirectory)
|
||||
else Nil
|
||||
else
|
||||
val paths = Await.result(downloadBlobs(digests, outputDirectory), remoteTimeout)
|
||||
refs
|
||||
.zip(digests)
|
||||
.zip(paths)
|
||||
.map { case ((r, digest), p) =>
|
||||
val casFile = disk.putBlobInternal(p, digest)
|
||||
disk.syncFile(r, casFile, outputDirectory)
|
||||
}
|
||||
|
||||
/**
|
||||
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L379
|
||||
*/
|
||||
override def putBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] =
|
||||
val totalBytes = blobs.map(_.sizeBytes).sum
|
||||
if blobs.isEmpty then Nil
|
||||
else if totalBytes <= chunkSizeBytes then batchUpdateBlobs(blobs)
|
||||
else
|
||||
try Await.result(uploadBlobs(blobs).recover(_ => Nil), remoteTimeout)
|
||||
catch case _: TimeoutException => Nil
|
||||
|
||||
def batchUpdateBlobs(blobs: Seq[VirtualFile]): Seq[HashedVirtualFileRef] =
|
||||
val b = BatchUpdateBlobsRequest.newBuilder()
|
||||
b.setInstanceName(instanceName)
|
||||
b.setDigestFunction(DigestFunction.Value.SHA256)
|
||||
|
|
@ -182,23 +260,89 @@ class GrpcActionCacheStore(
|
|||
Some(HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes))
|
||||
else None
|
||||
|
||||
def uploadBlobs(blobs: Seq[VirtualFile]): Future[Seq[HashedVirtualFileRef]] =
|
||||
Future.sequence(blobs.map(uploadBlob))
|
||||
|
||||
def uploadBlob(blob: VirtualFile): Future[HashedVirtualFileRef] =
|
||||
val d = Digest(blob)
|
||||
withSingleResponse[ByteStreamProto.WriteResponse, HashedVirtualFileRef]: (p, resObs) =>
|
||||
val reqObs = byteStreamStub.write(resObs)
|
||||
val un = uploadName(d, UUID.randomUUID())
|
||||
var off: Long = 0L
|
||||
try
|
||||
Using.resource(blob.input()): input =>
|
||||
val chunks = chunkBytes(d.sizeBytes)
|
||||
chunks.zipWithIndex.foreach: (chunk, idx) =>
|
||||
val b = WriteRequest.newBuilder()
|
||||
if idx == 0 then b.setResourceName(un)
|
||||
else ()
|
||||
b.setWriteOffset(off)
|
||||
b.setData(toByteString(input, chunk))
|
||||
if idx == chunks.size - 1 then b.setFinishWrite(true)
|
||||
else ()
|
||||
val req = b.build()
|
||||
off = off + chunk
|
||||
reqObs.onNext(req)
|
||||
catch
|
||||
case NonFatal(e) =>
|
||||
reqObs.onError(e)
|
||||
p.failure(e)
|
||||
reqObs.onCompleted()
|
||||
p.future.map: _ =>
|
||||
HashedVirtualFileRef.of(blob.id(), d.contentHashStr, d.sizeBytes)
|
||||
|
||||
private def downloadBlobs(digests: Seq[Digest], outputDirectory: Path): Future[Seq[Path]] =
|
||||
Future.sequence(digests.map: x =>
|
||||
downloadBlob(x, outputDirectory))
|
||||
|
||||
private def downloadBlob(digest: Digest, outputDirectory: Path): Future[Path] =
|
||||
val p = Promise[Path]()
|
||||
val uuid = UUID.randomUUID()
|
||||
val tempFile = outputDirectory.toFile() / s"$uuid.part"
|
||||
sbt.io.Using.fileOutputStream(false)(tempFile): out =>
|
||||
val resObs = new StreamObserver[ByteStreamProto.ReadResponse]:
|
||||
override def onCompleted(): Unit =
|
||||
p.success(tempFile.toPath())
|
||||
override def onError(e: Throwable): Unit = p.failure(e)
|
||||
override def onNext(res: ByteStreamProto.ReadResponse): Unit =
|
||||
IO.transfer(res.getData().newInput(), out)
|
||||
val b = ReadRequest.newBuilder()
|
||||
val dn = downloadName(digest)
|
||||
b.setResourceName(dn)
|
||||
b.setReadOffset(0L)
|
||||
val req = b.build()
|
||||
byteStreamStub.read(req, resObs)
|
||||
p.future
|
||||
|
||||
// helper function for many-to-one gRPC streaming
|
||||
// https://grpc.io/docs/languages/java/basics/#client-side-streaming-rpc-1
|
||||
private def withSingleResponse[A1, A2](
|
||||
f: (Promise[A1], StreamObserver[A1]) => Future[A2]
|
||||
): Future[A2] =
|
||||
val p = Promise[A1]()
|
||||
val observer = new StreamObserver[A1]:
|
||||
var o: Option[A1] = None
|
||||
override def onCompleted(): Unit =
|
||||
if o.isDefined then p.success(o.get)
|
||||
else p.failure(new RuntimeException("unexpected onCompleted"))
|
||||
override def onError(e: Throwable): Unit = p.failure(e)
|
||||
override def onNext(res: A1): Unit =
|
||||
o = Some(res)
|
||||
f(p, observer)
|
||||
|
||||
/**
|
||||
* https://github.com/bazelbuild/remote-apis/blob/9ff14cecffe5287ba337f857731ceadfc2d80de9/build/bazel/remote/execution/v2/remote_execution.proto#L403
|
||||
* resource name is load-bearing.
|
||||
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L219-L220
|
||||
*/
|
||||
override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] =
|
||||
val result = doGetBlobs(refs)
|
||||
val blobs = result.getResponsesList().asScala.toList
|
||||
val allOk = blobs.forall(_.getStatus().getCode() == 0)
|
||||
if allOk then
|
||||
// do not assume the responses to come in order
|
||||
val lookupResponse: Map[Digest, BatchReadBlobsResponse.Response] =
|
||||
Map(blobs.map(res => toDigest(res.getDigest) -> res)*)
|
||||
refs.map: r =>
|
||||
val digest = Digest(r)
|
||||
val blob = lookupResponse(digest)
|
||||
val casFile = disk.putBlob(blob.getData().newInput(), digest)
|
||||
disk.syncFile(r, casFile, outputDirectory)
|
||||
else Nil
|
||||
private def uploadName(d: Digest, uuid: UUID): String =
|
||||
s"$instanceName/uploads/$uuid/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}"
|
||||
|
||||
/**
|
||||
* resource name is load-bearing.
|
||||
* https://github.com/bazelbuild/remote-apis/blob/main/build/bazel/remote/execution/v2/remote_execution.proto#L294-L295
|
||||
*/
|
||||
private def downloadName(d: Digest): String =
|
||||
s"$instanceName/blobs/${d.algo}/${d.hashHexString}/${d.sizeBytes}"
|
||||
|
||||
/**
|
||||
* https://github.com/bazelbuild/remote-apis/blob/96942a2107c702ed3ca4a664f7eeb7c85ba8dc77/build/bazel/remote/execution/v2/remote_execution.proto#L1629
|
||||
|
|
@ -216,7 +360,7 @@ class GrpcActionCacheStore(
|
|||
if missing(Digest(r)) then None
|
||||
else Some(r)
|
||||
|
||||
private def doGetBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse =
|
||||
private def batchReadBlobs(refs: Seq[HashedVirtualFileRef]): BatchReadBlobsResponse =
|
||||
val b = BatchReadBlobsRequest.newBuilder()
|
||||
b.setInstanceName(instanceName)
|
||||
refs.foreach: ref =>
|
||||
|
|
@ -268,4 +412,24 @@ class GrpcActionCacheStore(
|
|||
val out = ByteString.newOutput()
|
||||
IO.transfer(blob.input(), out)
|
||||
out.toByteString()
|
||||
|
||||
private def toByteString(input: InputStream, size: Long): ByteString =
|
||||
val BufferSize = 8192
|
||||
val out = ByteString.newOutput()
|
||||
if size <= 0 then out.toByteString()
|
||||
else
|
||||
var buf = new Array[Byte](BufferSize)
|
||||
var remaining = size
|
||||
while
|
||||
if remaining >= BufferSize then
|
||||
if buf.size != BufferSize then buf = new Array[Byte](BufferSize)
|
||||
else ()
|
||||
else buf = new Array[Byte](remaining.toInt)
|
||||
val readBytes = input.read(buf)
|
||||
if readBytes > 0 then out.write(buf, 0, readBytes)
|
||||
else ()
|
||||
remaining = remaining - readBytes
|
||||
readBytes > 0
|
||||
do ()
|
||||
out.toByteString()
|
||||
end GrpcActionCacheStore
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
package sbt
|
||||
package internal
|
||||
|
||||
object GrpcActionCacheStoreTest extends verify.BasicTestSuite:
|
||||
test("chunkBytes"):
|
||||
val actual = GrpcActionCacheStore.chunkBytes(0L)
|
||||
assert(actual == Nil)
|
||||
|
||||
val actual2 = GrpcActionCacheStore.chunkBytes(1L)
|
||||
assert(actual2 == List(1L))
|
||||
|
||||
val meg = 1024L * 1024L
|
||||
val actual3 = GrpcActionCacheStore.chunkBytes(meg)
|
||||
assert(actual3 == List(meg))
|
||||
|
||||
val actual4 = GrpcActionCacheStore.chunkBytes(meg + 1)
|
||||
assert(actual4 == List(meg, 1L))
|
||||
end GrpcActionCacheStoreTest
|
||||
|
|
@ -239,6 +239,14 @@ case class DiskActionCacheStore(base: Path, converter: FileConverter)
|
|||
putBlob(in, digest)
|
||||
}
|
||||
|
||||
/** Move blob directly to CAS. Internal use only. */
|
||||
private[sbt] def putBlobInternal(blob: Path, digest: Digest): Path =
|
||||
val casFile = toCasFile(digest)
|
||||
if isCompleteBlob(casFile, digest) then casFile
|
||||
else
|
||||
IO.move(blob.toFile(), casFile.toFile())
|
||||
casFile
|
||||
|
||||
def putBlob(input: InputStream, digest: Digest): Path =
|
||||
val casFile = toCasFile(digest)
|
||||
if isCompleteBlob(casFile, digest) then casFile
|
||||
|
|
|
|||
Loading…
Reference in New Issue