From 0ee3585b6d27aa5898cf3a25028d09773580a1c6 Mon Sep 17 00:00:00 2001 From: jvican Date: Tue, 2 May 2017 00:25:59 +0200 Subject: [PATCH] Add batch mode execution to scripted For that, we: * Change the existing infrastructure to recycle as much code as possible. * Use `BatchScriptRunner` since `ScriptRunner` is too restrictive to programmatically control the underlying sbt servers. * Unify `TestRunner` to the more general way for both batch and non-batch modes. --- .../scala/sbt/test/BatchScriptRunner.scala | 58 +++++ .../main/scala/sbt/test/ScriptedTests.scala | 219 +++++++++++++----- 2 files changed, 225 insertions(+), 52 deletions(-) create mode 100644 scripted/sbt/src/main/scala/sbt/test/BatchScriptRunner.scala diff --git a/scripted/sbt/src/main/scala/sbt/test/BatchScriptRunner.scala b/scripted/sbt/src/main/scala/sbt/test/BatchScriptRunner.scala new file mode 100644 index 000000000..d181a2ac1 --- /dev/null +++ b/scripted/sbt/src/main/scala/sbt/test/BatchScriptRunner.scala @@ -0,0 +1,58 @@ +package sbt +package test + +import sbt.internal.scripted._ +import sbt.test.BatchScriptRunner.States + +/** Defines an alternative script runner that allows batch execution. */ +private[sbt] class BatchScriptRunner extends ScriptRunner { + + /** Defines a method to run batched execution. + * + * @param statements The list of handlers and statements. + * @param states The states of the runner. In case it's empty, inherited apply is called. + */ + def apply(statements: List[(StatementHandler, Statement)], states: States): Unit = { + if (states.isEmpty) super.apply(statements) + else statements.foreach(st => processStatement(st._1, st._2, states)) + } + + def initStates(states: States, handlers: Seq[StatementHandler]): Unit = + handlers.foreach(handler => states(handler) = handler.initialState) + + def cleanUpHandlers(handlers: Seq[StatementHandler], states: States): Unit = { + for (handler <- handlers; state <- states.get(handler)) { + try handler.finish(state.asInstanceOf[handler.State]) + catch { case _: Exception => () } + } + } + + def processStatement(handler: StatementHandler, statement: Statement, states: States): Unit = { + val state = states(handler).asInstanceOf[handler.State] + val nextState = + try { Right(handler(statement.command, statement.arguments, state)) } catch { + case e: Exception => Left(e) + } + nextState match { + case Left(err) => + if (statement.successExpected) { + err match { + case t: TestFailed => + throw new TestException(statement, "Command failed: " + t.getMessage, null) + case _ => throw new TestException(statement, "Command failed", err) + } + } else + () + case Right(s) => + if (statement.successExpected) + states(handler) = s + else + throw new TestException(statement, "Command succeeded but failure was expected", null) + } + } +} + +private[sbt] object BatchScriptRunner { + import scala.collection.mutable + type States = mutable.HashMap[StatementHandler, Any] +} diff --git a/scripted/sbt/src/main/scala/sbt/test/ScriptedTests.scala b/scripted/sbt/src/main/scala/sbt/test/ScriptedTests.scala index f6b5e7f43..e32258347 100644 --- a/scripted/sbt/src/main/scala/sbt/test/ScriptedTests.scala +++ b/scripted/sbt/src/main/scala/sbt/test/ScriptedTests.scala @@ -8,67 +8,184 @@ package test import java.io.File import scala.util.control.NonFatal -import sbt.internal.scripted.{ - CommentHandler, - FileCommands, - ScriptRunner, - TestException, - TestScriptParser -} -import sbt.io.{ DirectoryFilter, HiddenFileFilter } +import sbt.internal.scripted._ +import sbt.io.{ DirectoryFilter, HiddenFileFilter, IO } import sbt.io.IO.wrapNull +import sbt.io.FileFilter._ import sbt.internal.io.Resources import sbt.internal.util.{ BufferedLogger, ConsoleLogger, FullLogger } import sbt.util.{ AbstractLogger, Logger } +import scala.collection.mutable import scala.collection.parallel.mutable.ParSeq final class ScriptedTests(resourceBaseDirectory: File, bufferLog: Boolean, launcher: File, launchOpts: Seq[String]) { + import sbt.io.syntax._ import ScriptedTests._ private val testResources = new Resources(resourceBaseDirectory) val ScriptFilename = "test" val PendingScriptFilename = "pending" - def scriptedTest(group: String, name: String, log: xsbti.Logger): Seq[() => Option[String]] = + def scriptedTest(group: String, name: String, log: xsbti.Logger): Seq[TestRunner] = scriptedTest(group, name, Logger.xlog2Log(log)) - def scriptedTest(group: String, name: String, log: Logger): Seq[() => Option[String]] = - scriptedTest(group, name, emptyCallback, log) + def scriptedTest(group: String, name: String, log: Logger): Seq[TestRunner] = + singleScriptedTest(group, name, emptyCallback, log) /** Returns a sequence of test runners that have to be applied in the call site. */ - def scriptedTest(group: String, - name: String, - prescripted: File => Unit, - log: Logger): Seq[TestRunner] = { - import sbt.io.syntax._ + def singleScriptedTest(group: String, + name: String, + prescripted: File => Unit, + log: Logger): Seq[TestRunner] = { + + // Test group and names may be file filters (like '*') for (groupDir <- (resourceBaseDirectory * group).get; nme <- (groupDir * name).get) yield { val g = groupDir.getName val n = nme.getName - val testLabel = s"$g / $n" + val label = s"$g / $n" () => { - println("Running " + testLabel) - testResources.readWriteResourceDirectory(g, n) { testDirectory => - val disabled = new File(testDirectory, "disabled").isFile - if (disabled) { - log.info("D " + testLabel + " [DISABLED]") - None - } else scriptedTest(testLabel, testDirectory, prescripted, log) + println(s"Running $label") + val result = testResources.readWriteResourceDirectory(g, n) { testDirectory => + val buffer = new BufferedLogger(new FullLogger(log)) + val singleTestRunner = () => { + val handlers = createScriptedHandlers(testDirectory, buffer) + val runner = new BatchScriptRunner + val states = new mutable.HashMap[StatementHandler, Any]() + commonRunTest(label, testDirectory, prescripted, handlers, runner, states, buffer) + } + runOrHandleDisabled(label, testDirectory, singleTestRunner, buffer) } + Seq(result) } } } + private def createScriptedHandlers(testDir: File, + buffered: Logger): Map[Char, StatementHandler] = { + val fileHandler = new FileCommands(testDir) + val sbtHandler = new SbtHandler(testDir, launcher, buffered, launchOpts) + Map('$' -> fileHandler, '>' -> sbtHandler, '#' -> CommentHandler) + } + + /** Returns a sequence of test runners that have to be applied in the call site. */ + def batchScriptedRunner( + testGroupAndNames: Seq[(String, String)], + prescripted: File => Unit, + log: Logger + ): Seq[TestRunner] = { + // Test group and names may be file filters (like '*') + val groupAndNameDirs = { + for { + (group, name) <- testGroupAndNames + groupDir <- resourceBaseDirectory.*(group).get + testDir <- groupDir.*(name).get + } yield (groupDir, testDir) + } + + val labelsAndDirs = groupAndNameDirs.map { + case (groupDir, nameDir) => + val groupName = groupDir.getName + val testName = nameDir.getName + val testLabel = s"$groupName / $testName" + val testDirectory = testResources.readOnlyResourceDirectory(groupName, testName) + testLabel -> testDirectory + } + + val batchSeed = labelsAndDirs.size / 4 + val batchSize = if (batchSeed == 0) labelsAndDirs.size else batchSeed + Seq(labelsAndDirs).map { batch => () => + IO.withTemporaryDirectory(runBatchedTests(batch, _, prescripted, log)) + }.toList + } + + /** Defines the batch execution of scripted tests. + * + * Scripted tests are run one after the other one recycling the handlers, under + * the assumption that handlers do not produce side effects that can change scripted + * tests' behaviours. + * + * In batch mode, the test runner performs these operations between executions: + * + * 1. Delete previous test files in the common test directory. + * 2. Copy over next test files to the common test directory. + * 3. Reload the sbt handler. + * + * @param groupedTests The labels and directories of the tests to run. + * @param tempTestDir The common test directory. + * @param preHook The hook to run before scripted execution. + * @param log The logger. + */ + private def runBatchedTests( + groupedTests: Seq[(String, File)], + tempTestDir: File, + preHook: File => Unit, + log: Logger + ): Seq[Option[String]] = { + + val runner = new BatchScriptRunner + val buffer = new BufferedLogger(new FullLogger(log)) + val handlers = createScriptedHandlers(tempTestDir, buffer) + val states = new BatchScriptRunner.States + val seqHandlers = handlers.values.toList + runner.initStates(states, seqHandlers) + + def runBatchTests = { + groupedTests.map { + case (label, originalDir) => + println(s"Running $label") + // Copy test's contents and reload the sbt instance to pick them up + IO.copyDirectory(originalDir, tempTestDir) + + // Reload and initialize (to reload contents of .sbtrc files) + val sbtHandler = handlers.getOrElse('>', sys.error("Missing sbt handler.")) + val statement = + Statement(";reload;initialize", Nil, successExpected = true, line = -1) + runner.processStatement(sbtHandler.asInstanceOf[SbtHandler], statement, states) + + val runTest = + () => commonRunTest(label, tempTestDir, preHook, handlers, runner, states, buffer) + val result = runOrHandleDisabled(label, tempTestDir, runTest, buffer) + + // Delete test's files and clear buffer if successful + IO.delete(tempTestDir.*("*" -- "global").get) + result + } + } + + try runBatchTests + finally runner.cleanUpHandlers(seqHandlers, states) + } + + private def runOrHandleDisabled( + label: String, + testDirectory: File, + runTest: () => Option[String], + log: Logger + ): Option[String] = { + val existsDisabled = new File(testDirectory, "disabled").isFile + if (!existsDisabled) runTest() + else { + log.info(s"D $label [DISABLED]") + None + } + } + private val PendingLabel = "[PENDING]" - private def scriptedTest(label: String, - testDirectory: File, - preScriptedHook: File => Unit, - log: Logger): Option[String] = { - val buffered = new BufferedLogger(new FullLogger(log)) - if (bufferLog) buffered.record() + + private def commonRunTest( + label: String, + testDirectory: File, + preScriptedHook: File => Unit, + createHandlers: Map[Char, StatementHandler], + runner: BatchScriptRunner, + states: BatchScriptRunner.States, + log: BufferedLogger + ): Option[String] = { + if (bufferLog) log.record() val (file, pending) = { val normal = new File(testDirectory, ScriptFilename) @@ -78,13 +195,13 @@ final class ScriptedTests(resourceBaseDirectory: File, val pendingMark = if (pending) PendingLabel else "" def testFailed(t: Throwable): Option[String] = { - if (pending) buffered.clear() else buffered.stop() - buffered.error(s"x $label $pendingMark") + if (pending) log.clear() else log.stop() + log.error(s"x $label $pendingMark") if (!NonFatal(t)) throw t // We make sure fatal errors are rethrown if (t.isInstanceOf[TestException]) { t.getCause match { case null | _: java.net.SocketException => - buffered.error(" Cause of test exception: " + t.getMessage) + log.error(" Cause of test exception: " + t.getMessage) case _ => t.printStackTrace() } } @@ -92,20 +209,18 @@ final class ScriptedTests(resourceBaseDirectory: File, } import scala.util.control.Exception.catching - catching(classOf[TestException]).withApply(testFailed).andFinally(buffered.clear).apply { + catching(classOf[TestException]).withApply(testFailed).andFinally(log.clear).apply { preScriptedHook(testDirectory) - val run = new ScriptRunner - val fileHandler = new FileCommands(testDirectory) - val sbtHandler = new SbtHandler(testDirectory, launcher, buffered, launchOpts) - val handlers = Map('$' -> fileHandler, '>' -> sbtHandler, '#' -> CommentHandler) + val handlers = createHandlers val parser = new TestScriptParser(handlers) - run(parser.parse(file)) + val handlersAndStatements = parser.parse(file) + runner.apply(handlersAndStatements, states) // Handle successful tests - buffered.info(s"+ $label $pendingMark") + log.info(s"+ $label $pendingMark") if (pending) { - buffered.clear() - buffered.error(" Pending test passed. Mark as passing to remove this failure.") + log.clear() + log.error(" Pending test passed. Mark as passing to remove this failure.") Some(label) } else None } @@ -114,8 +229,8 @@ final class ScriptedTests(resourceBaseDirectory: File, object ScriptedTests extends ScriptedRunner { - /** Represents the function that runs the scripted tests. */ - type TestRunner = () => Option[String] + /** Represents the function that runs the scripted tests, both in single or batch mode. */ + type TestRunner = () => Seq[Option[String]] val emptyCallback: File => Unit = _ => () def main(args: Array[String]): Unit = { @@ -155,7 +270,7 @@ class ScriptedRunner { val runner = new ScriptedTests(resourceBaseDirectory, bufferLog, bootProperties, launchOpts) val allTests = get(tests, resourceBaseDirectory, logger) flatMap { case ScriptedTest(group, name) => - runner.scriptedTest(group, name, prescripted, logger) + runner.singleScriptedTest(group, name, prescripted, logger) } runAll(allTests) } @@ -187,21 +302,21 @@ class ScriptedRunner { prescripted: File => Unit ): Unit = { val runner = new ScriptedTests(resourceBaseDirectory, bufferLog, bootProperties, launchOpts) - val scriptedTests = get(tests, resourceBaseDirectory, logger) - val scriptedTestRunners = scriptedTests - .flatMap(t => runner.scriptedTest(t.group, t.name, prescripted, logger)) - runAllInParallel(scriptedTestRunners.toParArray) + // The scripted tests mapped to the inputs that the user wrote after `scripted`. + val scriptedTests = get(tests, resourceBaseDirectory, logger).map(st => (st.group, st.name)) + val scriptedRunners = runner.batchScriptedRunner(scriptedTests, prescripted, logger) + runAll(scriptedRunners) } private def reportErrors(errors: Seq[String]): Unit = if (errors.nonEmpty) sys.error(errors.mkString("Failed tests:\n\t", "\n\t", "\n")) else () - def runAll(tests: Seq[ScriptedTests.TestRunner]): Unit = - reportErrors(tests.flatMap(test => test.apply().toSeq)) + def runAll(toRun: Seq[ScriptedTests.TestRunner]): Unit = + reportErrors(toRun.flatMap(test => test.apply().flatten.toSeq)) // We cannot reuse `runAll` because parallel collections != collections def runAllInParallel(tests: ParSeq[ScriptedTests.TestRunner]): Unit = - reportErrors(tests.flatMap(test => test.apply().toSeq).toList) + reportErrors(tests.flatMap(test => test.apply().flatten.toSeq).toList) def get(tests: Seq[String], baseDirectory: File, log: Logger): Seq[ScriptedTest] = if (tests.isEmpty) listTests(baseDirectory, log) else parseTests(tests)