Hermetic incremental test

**Problem**
Current implementation of testQuick depends on the concept of timestamp,
which probably won't work well with the new consistent analysis store or
the idea of remote caching.

**Solution**
This is a step towards cached testing by making the incrementality hermetic
(no longer depends on the timestamp). Instead this calculates the aggregated
SHA-256 of the class files involved in the test.
This commit is contained in:
Eugene Yokota 2024-09-04 23:28:48 -04:00
parent 4dd59a0b52
commit 721f202ae5
7 changed files with 247 additions and 141 deletions

View File

@ -101,7 +101,6 @@ import sbt.SlashSyntax0._
import sbt.internal.inc.{
Analysis,
AnalyzingCompiler,
FileAnalysisStore,
ManagedLoggedReporter,
MixedAnalyzingCompiler,
ScalaInstance
@ -140,6 +139,7 @@ import xsbti.compile.{
TastyFiles,
TransactionalManagerType
}
import sbt.internal.IncrementalTest
object Defaults extends BuildCommon {
final val CacheDirectoryName = "cache"
@ -153,18 +153,6 @@ object Defaults extends BuildCommon {
def lock(app: xsbti.AppConfiguration): xsbti.GlobalLock = LibraryManagement.lock(app)
private[sbt] def extractAnalysis(
metadata: StringAttributeMap,
converter: FileConverter
): Option[CompileAnalysis] =
def asBinary(file: File) = FileAnalysisStore.binary(file).get.asScala
def asText(file: File) = FileAnalysisStore.text(file).get.asScala
for
ref <- metadata.get(Keys.analysis)
file = converter.toPath(VirtualFileRef.of(ref)).toFile
content <- asBinary(file).orElse(asText(file))
yield content.getAnalysis
private[sbt] def globalDefaults(ss: Seq[Setting[_]]): Seq[Setting[_]] =
Def.defaultSettings(inScope(GlobalScope)(ss))
@ -1322,7 +1310,7 @@ object Defaults extends BuildCommon {
testListeners :== Nil,
testOptions :== Nil,
testResultLogger :== TestResultLogger.Default,
testOnly / testFilter :== (selectedFilter _)
testOnly / testFilter :== (IncrementalTest.selectedFilter _)
)
)
lazy val testTasks: Seq[Setting[_]] =
@ -1341,7 +1329,7 @@ object Defaults extends BuildCommon {
.storeAs(definedTestNames)
.triggeredBy(compile)
.value,
testQuick / testFilter := testQuickFilter.value,
testQuick / testFilter := IncrementalTest.filterTask.value,
executeTests := {
import sbt.TupleSyntax.*
(
@ -1421,7 +1409,11 @@ object Defaults extends BuildCommon {
),
Keys.logLevel.?.value.getOrElse(stateLogLevel),
) +:
new TestStatusReporter(succeededFile((test / streams).value.cacheDirectory)) +:
TestStatusReporter(
IncrementalTest.succeededFile((test / streams).value.cacheDirectory),
(Keys.test / fullClasspath).value,
fileConverter.value,
) +:
(TaskZero / testListeners).value
},
testOptions := Tests.Listeners(testListeners.value) +: (TaskZero / testOptions).value,
@ -1490,46 +1482,6 @@ object Defaults extends BuildCommon {
)
}
def testQuickFilter: Initialize[Task[Seq[String] => Seq[String => Boolean]]] =
Def.task {
val cp = (test / fullClasspath).value
val s = (test / streams).value
val converter = fileConverter.value
val analyses = cp
.flatMap(a => extractAnalysis(a.metadata, converter))
.collect { case analysis: Analysis => analysis }
val succeeded = TestStatus.read(succeededFile(s.cacheDirectory))
val stamps = collection.mutable.Map.empty[String, Long]
def stamp(dep: String): Option[Long] =
analyses.flatMap(internalStamp(dep, _, Set.empty)).maxOption
def internalStamp(c: String, analysis: Analysis, alreadySeen: Set[String]): Option[Long] = {
if (alreadySeen.contains(c)) None
else
def computeAndStoreStamp: Option[Long] = {
import analysis.{ apis, relations }
val internalDeps = relations
.internalClassDeps(c)
.flatMap(internalStamp(_, analysis, alreadySeen + c))
val externalDeps = relations.externalDeps(c).flatMap(stamp)
val classStamps = relations.productClassName.reverse(c).flatMap { pc =>
apis.internal.get(pc).map(_.compilationTimestamp)
}
val maxStamp = (internalDeps ++ externalDeps ++ classStamps).maxOption
maxStamp.foreach(maxStamp => stamps(c) = maxStamp)
maxStamp
}
stamps.get(c).orElse(computeAndStoreStamp)
}
def noSuccessYet(test: String) = succeeded.get(test) match {
case None => true
case Some(ts) => stamps.synchronized(stamp(test)).exists(_ > ts)
}
args =>
for (filter <- selectedFilter(args))
yield (test: String) => filter(test) && noSuccessYet(test)
}
def succeededFile(dir: File) = dir / "succeeded_tests"
@nowarn
def inputTests(key: InputKey[_]): Initialize[InputTask[Unit]] =
inputTests0.mapReferenced(Def.mapScope(_ in key.key))
@ -1746,21 +1698,6 @@ object Defaults extends BuildCommon {
result
}
def selectedFilter(args: Seq[String]): Seq[String => Boolean] = {
def matches(nfs: Seq[NameFilter], s: String) = nfs.exists(_.accept(s))
val (excludeArgs, includeArgs) = args.partition(_.startsWith("-"))
val includeFilters = includeArgs map GlobFilter.apply
val excludeFilters = excludeArgs.map(_.substring(1)).map(GlobFilter.apply)
(includeFilters, excludeArgs) match {
case (Nil, Nil) => Seq(const(true))
case (Nil, _) => Seq((s: String) => !matches(excludeFilters, s))
case _ =>
includeFilters.map(f => (s: String) => (f.accept(s) && !matches(excludeFilters, s)))
}
}
def detectTests: Initialize[Task[Seq[TestDefinition]]] =
Def.task {
Tests.discover(loadedTestFrameworks.value.values.toList, compile.value, streams.value.log)._1
@ -2624,7 +2561,7 @@ object Defaults extends BuildCommon {
val cachedAnalysisMap: Map[VirtualFile, CompileAnalysis] = (
for
attributed <- cp
analysis <- extractAnalysis(attributed.metadata, converter)
analysis <- BuildDef.extractAnalysis(attributed.metadata, converter)
yield (converter.toVirtualFile(attributed.data), analysis)
).toMap
val cachedPerEntryDefinesClassLookup: VirtualFile => DefinesClass =

View File

@ -392,7 +392,7 @@ object RemoteCache {
configuration / packageCache,
(configuration / classDirectory).value,
(configuration / compileAnalysisFile).value,
Defaults.succeededFile((configuration / test / streams).value.cacheDirectory)
IncrementalTest.succeededFile((configuration / test / streams).value.cacheDirectory)
)
}

View File

@ -14,9 +14,10 @@ import Keys.{ organization, thisProject, autoGeneratedProject }
import Def.Setting
// import sbt.ProjectExtra.apply
import sbt.io.Hash
import sbt.internal.util.Attributed
import sbt.internal.inc.ReflectUtilities
import xsbti.FileConverter
import sbt.internal.util.{ Attributed, StringAttributeMap }
import sbt.internal.inc.{ FileAnalysisStore, ReflectUtilities }
import xsbti.{ FileConverter, VirtualFileRef }
import xsbti.compile.CompileAnalysis
trait BuildDef {
def projectDefinitions(@deprecated("unused", "") baseDirectory: File): Seq[Project] = projects
@ -33,7 +34,7 @@ trait BuildDef {
def rootProject: Option[Project] = None
}
private[sbt] object BuildDef {
private[sbt] object BuildDef:
val defaultEmpty: BuildDef = new BuildDef { override def projects = Nil }
val default: BuildDef = new BuildDef {
@ -78,5 +79,19 @@ private[sbt] object BuildDef {
in: Seq[Attributed[_]],
converter: FileConverter
): Seq[xsbti.compile.CompileAnalysis] =
in.flatMap(a => Defaults.extractAnalysis(a.metadata, converter))
}
in.flatMap(a => extractAnalysis(a.metadata, converter))
private[sbt] def extractAnalysis(
metadata: StringAttributeMap,
converter: FileConverter
): Option[CompileAnalysis] =
import sbt.OptionSyntax.*
def asBinary(file: File) = FileAnalysisStore.binary(file).get.asScala
def asText(file: File) = FileAnalysisStore.text(file).get.asScala
for
ref <- metadata.get(Keys.analysis)
file = converter.toPath(VirtualFileRef.of(ref)).toFile
content <- asBinary(file).orElse(asText(file))
yield content.getAnalysis
end BuildDef

View File

@ -0,0 +1,183 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
package internal
import java.io.File
import java.util.concurrent.ConcurrentHashMap
import Keys.{ test, compileInputs, fileConverter, fullClasspath, streams }
import sbt.Def.Initialize
import sbt.internal.inc.Analysis
import sbt.internal.util.Attributed
import sbt.internal.util.Types.const
import sbt.io.syntax.*
import sbt.io.{ GlobFilter, IO, NameFilter }
import sbt.protocol.testing.TestResult
import sbt.SlashSyntax0.*
import sbt.util.Digest
import sbt.util.CacheImplicits.given
import scala.collection.concurrent
import scala.collection.mutable
import scala.collection.SortedSet
import xsbti.{ FileConverter, HashedVirtualFileRef, VirtualFileRef }
object IncrementalTest:
def filterTask: Initialize[Task[Seq[String] => Seq[String => Boolean]]] =
Def.task {
val cp = (Keys.test / fullClasspath).value
val s = (Keys.test / streams).value
val converter = fileConverter.value
val stamper = ClassStamper(cp, converter)
val succeeded = TestStatus.read(succeededFile(s.cacheDirectory))
def hasSucceeded(className: String): Boolean = succeeded.get(className) match
case None => false
case Some(ts) => ts == stamper.transitiveStamp(className)
args =>
for filter <- selectedFilter(args)
yield (test: String) => filter(test) && !hasSucceeded(test)
}
def succeededFile(dir: File): File = dir / "succeeded_tests.txt"
def selectedFilter(args: Seq[String]): Seq[String => Boolean] =
def matches(nfs: Seq[NameFilter], s: String) = nfs.exists(_.accept(s))
val (excludeArgs, includeArgs) = args.partition(_.startsWith("-"))
val includeFilters = includeArgs.map(GlobFilter.apply)
val excludeFilters = excludeArgs.map(_.substring(1)).map(GlobFilter.apply)
(includeFilters, excludeArgs) match
case (Nil, Nil) => Seq(const(true))
case (Nil, _) => Seq((s: String) => !matches(excludeFilters, s))
case _ =>
includeFilters.map(f => (s: String) => (f.accept(s) && !matches(excludeFilters, s)))
end IncrementalTest
// Assumes exclusive ownership of the file.
private[sbt] class TestStatusReporter(
f: File,
digests: Map[String, Digest],
) extends TestsListener:
private lazy val succeeded: concurrent.Map[String, Digest] =
TestStatus.read(f)
def doInit(): Unit = ()
def startGroup(name: String): Unit =
succeeded.remove(name)
()
def testEvent(event: TestEvent): Unit = ()
def endGroup(name: String, t: Throwable): Unit = ()
/**
* If the test has succeeded, record the fact that it has
* using its unique digest, so we can skip the test later.
*/
def endGroup(name: String, result: TestResult): Unit =
if result == TestResult.Passed then
digests.get(name) match
case Some(ts) => succeeded(name) = ts
case None => succeeded(name) = Digest.zero
else ()
def doComplete(finalResult: TestResult): Unit =
TestStatus.write(succeeded, "Successful Tests", f)
end TestStatusReporter
private[sbt] object TestStatus:
import java.util.Properties
def read(f: File): concurrent.Map[String, Digest] =
import scala.jdk.CollectionConverters.*
val props = Properties()
IO.load(props, f)
val result = ConcurrentHashMap[String, Digest]()
props.asScala.iterator.foreach { case (k, v) => result.put(k, Digest(v)) }
result.asScala
def write(map: collection.Map[String, Digest], label: String, f: File): Unit =
IO.writeLines(
f,
s"# $label" ::
map.toList.sortBy(_._1).map { case (k, v) =>
s"$k=$v"
}
)
end TestStatus
/**
* ClassStamper provides `transitiveStamp` method to calculate a unique
* fingerprint, which will be used for runtime invalidation.
*/
class ClassStamper(
classpath: Seq[Attributed[HashedVirtualFileRef]],
converter: FileConverter,
):
private val stamps = mutable.Map.empty[String, SortedSet[Digest]]
private val vfStamps = mutable.Map.empty[VirtualFileRef, Digest]
private lazy val analyses = classpath
.flatMap(a => BuildDef.extractAnalysis(a.metadata, converter))
.collect { case analysis: Analysis => analysis }
/**
* Given a classpath and a class name, this tries to create a SHA-256 digest.
*/
def transitiveStamp(className: String): Digest =
val digests = SortedSet(analyses.flatMap(internalStamp(className, _, Set.empty)): _*)
Digest.sha256Hash(digests.toSeq: _*)
private def internalStamp(
className: String,
analysis: Analysis,
alreadySeen: Set[String],
): SortedSet[Digest] =
if alreadySeen.contains(className) then SortedSet.empty
else
stamps.get(className) match
case Some(xs) => xs
case _ =>
import analysis.relations
val internalDeps = relations
.internalClassDeps(className)
.flatMap: otherCN =>
internalStamp(otherCN, analysis, alreadySeen + className)
val internalJarDeps = relations
.externalDeps(className)
.map: libClassName =>
transitiveStamp(libClassName)
val externalDeps = relations
.externalDeps(className)
.flatMap: libClassName =>
relations.libraryClassName
.reverse(libClassName)
.map(stampVf)
val classDigests = relations.productClassName
.reverse(className)
.flatMap: prodClassName =>
relations
.definesClass(prodClassName)
.flatMap: sourceFile =>
relations
.products(sourceFile)
.map(stampVf)
// TODO: substitue the above with
// val classDigests = relations.productClassName
// .reverse(className)
// .flatMap: prodClassName =>
// analysis.apis.internal
// .get(prodClassName)
// .map: analyzed =>
// 0L // analyzed.??? we need a hash here
val xs = SortedSet(
(internalDeps union internalJarDeps union externalDeps union classDigests).toSeq: _*
)
if xs.nonEmpty then stamps(className) = xs
else ()
xs
def stampVf(vf: VirtualFileRef): Digest =
vf match
case h: HashedVirtualFileRef => Digest(h)
case _ =>
vfStamps.getOrElseUpdate(vf, Digest.sha256Hash(converter.toPath(vf)))
end ClassStamper

View File

@ -8,13 +8,17 @@
package sbt
import sbt.internal.IncrementalTest
object DefaultsTest extends verify.BasicTestSuite {
test("`selectedFilter` should return all tests for an empty list") {
val expected = Map("Test1" -> true, "Test2" -> true)
val filter = List.empty[String]
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -22,7 +26,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> false, "Foo" -> false)
val filter = List("Test1", "foo")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -30,7 +36,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> true, "Foo" -> false)
val filter = List("Test*")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -38,7 +46,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> false, "Foo" -> false)
val filter = List("Test*", "-Test2")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -46,7 +56,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> false, "Foo" -> true)
val filter = List("-Test2")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -54,7 +66,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> true, "Foo" -> false)
val filter = List("Test*", "-F*")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -62,7 +76,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> true, "Foo" -> false)
val filter = List("T*1", "T*2", "-F*")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -70,7 +86,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> true, "Test2" -> true, "AAA" -> false, "Foo" -> false)
val filter = List("-A*", "-F*")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
@ -78,7 +96,9 @@ object DefaultsTest extends verify.BasicTestSuite {
val expected = Map("Test1" -> false, "Test2" -> false, "Test3" -> true)
val filter = List("T*", "-T*1", "-T*2")
assert(
expected.map(t => (t._1, Defaults.selectedFilter(filter).exists(fn => fn(t._1)))) == expected
expected.map(t =>
(t._1, IncrementalTest.selectedFilter(filter).exists(fn => fn(t._1)))
) == expected
)
}
}

View File

@ -1,53 +0,0 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
import java.io.File
import sbt.io.IO
import sbt.protocol.testing.TestResult
import java.util.concurrent.ConcurrentHashMap
import scala.collection.concurrent
// Assumes exclusive ownership of the file.
private[sbt] class TestStatusReporter(f: File) extends TestsListener {
private lazy val succeeded: concurrent.Map[String, Long] = TestStatus.read(f)
def doInit(): Unit = ()
def startGroup(name: String): Unit = { succeeded remove name; () }
def testEvent(event: TestEvent): Unit = ()
def endGroup(name: String, t: Throwable): Unit = ()
def endGroup(name: String, result: TestResult): Unit = {
if (result == TestResult.Passed)
succeeded(name) = System.currentTimeMillis
}
def doComplete(finalResult: TestResult): Unit = {
TestStatus.write(succeeded, "Successful Tests", f)
}
}
private[sbt] object TestStatus {
import java.util.Properties
def read(f: File): concurrent.Map[String, Long] = {
import scala.jdk.CollectionConverters.*
val properties = new Properties
IO.load(properties, f)
val result = new ConcurrentHashMap[String, Long]()
properties.asScala.iterator.foreach { case (k, v) => result.put(k, v.toLong) }
result.asScala
}
def write(map: collection.Map[String, Long], label: String, f: File): Unit = {
val properties = new Properties
for ((test, lastSuccessTime) <- map)
properties.setProperty(test, lastSuccessTime.toString)
IO.write(properties, label, f)
}
}

View File

@ -26,6 +26,10 @@ object Digest:
def toBytes: Array[Byte] = parse(d)._4
def sizeBytes: Long = parse(d)._3
given digestOrd(using ord: Ordering[String]): Ordering[Digest] with
def compare(x: Digest, y: Digest) =
ord.compare(x, y)
def apply(s: String): Digest =
validateString(s)
s