From fbd1fb8398d9ca4a1bbff264a043a448f446b294 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Sun, 11 Aug 2024 15:47:51 -0400 Subject: [PATCH] Check the digest during sync **Problem** Currently `syncBlobs` delete the existing files in the out directory when remote cache kicks in. **Solution** 1. This refactors `Digest(...)` and adds support for `Digest.apply(Path)` and `Digest.sameDigest(...)` 2. This uses the `sameDigest` to compare the digest and replace the existing out files only when it needs to --- .../scala/sbt/util/ActionCacheStore.scala | 26 ++++++--- .../src/main/scala/sbt/util/Digest.scala | 56 ++++++++++++++----- .../src/test/scala/sbt/util/DigestTest.scala | 37 ++++++++---- 3 files changed, 88 insertions(+), 31 deletions(-) diff --git a/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala b/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala index 95c56eb41..a693ffa0c 100644 --- a/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala +++ b/util-cache/src/main/scala/sbt/util/ActionCacheStore.scala @@ -243,16 +243,26 @@ class DiskActionCacheStore(base: Path) extends AbstractActionCacheStore: override def syncBlobs(refs: Seq[HashedVirtualFileRef], outputDirectory: Path): Seq[Path] = refs.flatMap: r => val casFile = toCasFile(Digest(r)) - if casFile.toFile().exists then - val shortPath = - if r.id.startsWith("${OUT}/") then r.id.drop(7) - else r.id - val outPath = outputDirectory.resolve(shortPath) - Files.createDirectories(outPath.getParent()) - if outPath.toFile().exists() then IO.delete(outPath.toFile()) - Some(Files.createSymbolicLink(outPath, casFile)) + if casFile.toFile().exists then Some(syncFile(r, casFile, outputDirectory)) else None + def syncFile(ref: HashedVirtualFileRef, casFile: Path, outputDirectory: Path): Path = + val shortPath = + if ref.id.startsWith("${OUT}/") then ref.id.drop(7) + else ref.id + val d = Digest(ref) + def symlinkAndNotify(outPath: Path): Path = + Files.createDirectories(outPath.getParent()) + val result = Files.createSymbolicLink(outPath, casFile) + // after(result) + result + outputDirectory.resolve(shortPath) match + case p if !p.toFile().exists() => symlinkAndNotify(p) + case p if Digest.sameDigest(p, d) => p + case p => + IO.delete(p.toFile()) + symlinkAndNotify(p) + override def findBlobs(refs: Seq[HashedVirtualFileRef]): Seq[HashedVirtualFileRef] = refs.flatMap: r => val casFile = toCasFile(Digest(r)) diff --git a/util-cache/src/main/scala/sbt/util/Digest.scala b/util-cache/src/main/scala/sbt/util/Digest.scala index 1a04b5536..a6fbc8b06 100644 --- a/util-cache/src/main/scala/sbt/util/Digest.scala +++ b/util-cache/src/main/scala/sbt/util/Digest.scala @@ -5,17 +5,24 @@ import sbt.io.Hash import xsbti.HashedVirtualFileRef import java.io.{ BufferedInputStream, InputStream } import java.nio.ByteBuffer +import java.nio.file.{ Files, Path } import java.security.{ DigestInputStream, MessageDigest } opaque type Digest = String object Digest: - private val sha256_upper = "SHA-256" + private[sbt] val Murmur3 = "murmur3" + private[sbt] val Md5 = "md5" + private[sbt] val Sha1 = "sha1" + private[sbt] val Sha256 = "sha256" + private[sbt] val Sha384 = "sha384" + private[sbt] val Sha512 = "sha512" extension (d: Digest) def contentHashStr: String = val tokens = parse(d) s"${tokens._1}-${tokens._2}" + def algo: String = parse(d)._1 def toBytes: Array[Byte] = parse(d)._4 def sizeBytes: Long = parse(d)._3 @@ -29,25 +36,39 @@ object Digest: def apply(ref: HashedVirtualFileRef): Digest = apply(ref.contentHashStr() + "/" + ref.sizeBytes.toString) + def apply(algo: String, path: Path): Digest = + val input = Files.newInputStream(path) + try + apply(algo, hashBytes(algo, input), Files.size(path)) + finally + input.close() + // used to wrap a Long value as a fake Digest, which will // later be hashed using sha256 anyway. def dummy(value: Long): Digest = - apply("murmur3", longsToBytes(Array(0L, value)), 0) + apply(Murmur3, longsToBytes(Array(0L, value)), 0) lazy val zero: Digest = dummy(0L) + def sha256Hash(path: Path): Digest = apply(Sha256, path) + def sha256Hash(bytes: Array[Byte]): Digest = - apply("sha256", hashBytes(sha256_upper, bytes), bytes.length) + apply(Sha256, hashBytes(Sha256, bytes), bytes.length) def sha256Hash(longs: Array[Long]): Digest = - val bytes = hashBytes(sha256_upper, longs) - apply("sha256", bytes, bytes.length) + val bytes = hashBytes(Sha256, longs) + apply(Sha256, bytes, bytes.length) def sha256Hash(digests: Digest*): Digest = sha256Hash(digests.toSeq.map(_.toBytes).flatten.toArray[Byte]) + // first check the file size, then the hash + def sameDigest(path: Path, digest: Digest): Boolean = + if Files.size(path) != digest.sizeBytes then false + else Digest(digest.algo, path) == digest + private def hashBytes(algo: String, bytes: Array[Byte]): Array[Byte] = - val digest = MessageDigest.getInstance(algo) + val digest = MessageDigest.getInstance(jvmAlgo(algo)) digest.digest(bytes) private def hashBytes(algo: String, longs: Array[Long]): Array[Byte] = @@ -56,7 +77,7 @@ object Digest: private def hashBytes(algo: String, input: InputStream): Array[Byte] = val BufferSize = 8192 val bis = BufferedInputStream(input) - val digest = MessageDigest.getInstance(algo) + val digest = MessageDigest.getInstance(jvmAlgo(algo)) try val dis = DigestInputStream(bis, digest) val buffer = new Array[Byte](BufferSize) @@ -75,21 +96,30 @@ object Digest: case head :: rest :: Nil => val subtokens = head :: rest.split("/").toList subtokens match - case (a @ "murmur3") :: value :: sizeBytes :: Nil => + case (a @ Murmur3) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 128)) - case (a @ "md5") :: value :: sizeBytes :: Nil => + case (a @ Md5) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 128)) - case (a @ "sha1") :: value :: sizeBytes :: Nil => + case (a @ Sha1) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 160)) - case (a @ "sha256") :: value :: sizeBytes :: Nil => + case (a @ Sha256) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 256)) - case (a @ "sha384") :: value :: sizeBytes :: Nil => + case (a @ Sha384) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 384)) - case (a @ "sha512") :: value :: sizeBytes :: Nil => + case (a @ Sha512) :: value :: sizeBytes :: Nil => (a, value, sizeBytes.toLong, parseHex(value, 512)) case _ => throw IllegalArgumentException(s"unexpected digest: $s") case _ => throw IllegalArgumentException(s"unexpected digest: $s") + private def jvmAlgo(algo: String): String = + algo match + case Md5 => "MD5" + case Sha1 => "SHA-1" + case Sha256 => "SHA-256" + case Sha384 => "SHA-384" + case Sha512 => "SHA-512" + case a => a + private def parseHex(value: String, expectedBytes: Int): Array[Byte] = val bs = Hash.fromHex(value) require(bs.length == expectedBytes / 8, s"expected $expectedBytes, but found a digest $value") diff --git a/util-cache/src/test/scala/sbt/util/DigestTest.scala b/util-cache/src/test/scala/sbt/util/DigestTest.scala index 4843136c8..15c5de7d2 100644 --- a/util-cache/src/test/scala/sbt/util/DigestTest.scala +++ b/util-cache/src/test/scala/sbt/util/DigestTest.scala @@ -1,36 +1,45 @@ package sbt.util +import sbt.io.IO +import sbt.io.syntax.* + object DigestTest extends verify.BasicTestSuite: - test("murmur3") { + test("parse murmur3") { val d = Digest("murmur3-00000000000000000000000000000000/0") val dummy = Digest.dummy(0L) assert(d == dummy) } - test("md5") { - val d = Digest("md5-d41d8cd98f00b204e9800998ecf8427e/0") + test("parse md5") { + val expected = Digest("md5-d41d8cd98f00b204e9800998ecf8427e/0") + testEmptyFile("md5", expected) } - test("sha1") { - val d = Digest("sha1-da39a3ee5e6b4b0d3255bfef95601890afd80709/0") + test("parse sha1") { + val expected = Digest("sha1-da39a3ee5e6b4b0d3255bfef95601890afd80709/0") + testEmptyFile("sha1", expected) } test("sha256") { val hashOfNull = Digest.sha256Hash(Array[Byte]()) - val d = Digest("sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0") - assert(hashOfNull == d) + val expected = + Digest("sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0") + assert(hashOfNull == expected) + testEmptyFile("sha256", expected) } - test("sha384") { - val d = Digest( + test("parse sha384") { + val expected = Digest( "sha384-38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b/0" ) + testEmptyFile("sha384", expected) } test("sha512") { - val d = Digest( + val expected = Digest( "sha512-cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e/0" ) + testEmptyFile("sha512", expected) } test("digest composition") { @@ -40,4 +49,12 @@ object DigestTest extends verify.BasicTestSuite: Digest("sha256-66687aadf862bd776c8fc18b8e9f8e20089714856ee233b3902a591d0d5f2925/32") assert(Digest.sha256Hash(dummy1, dummy2) == expected) } + + def testEmptyFile(algo: String, expected: Digest): Unit = + IO.withTemporaryDirectory: tempDir => + val empty = tempDir / "empty.txt" + IO.touch(empty) + val d_sha1 = Digest(algo, empty.toPath()) + assert(d_sha1 == expected) + end DigestTest