diff --git a/util/io/FileUtilities.scala b/util/io/FileUtilities.scala index db879d5ad..0f4064be3 100644 --- a/util/io/FileUtilities.scala +++ b/util/io/FileUtilities.scala @@ -407,4 +407,36 @@ object FileUtilities /** Splits a String around path separator characters. */ def pathSplit(s: String) = PathSeparatorPattern.split(s) + + /** Move the provided files to a temporary location. + * If 'f' returns normally, delete the files. + * If 'f' throws an Exception, return the files to their original location.*/ + def stash[T](files: Set[File])(f: => T): T = + withTemporaryDirectory { dir => + val stashed = stashLocations(dir, files.toArray) + move(stashed) + + try { f } catch { case e: Exception => + try { move(stashed.map(_.swap)); throw e } + catch { case _: Exception => throw e } + } + } + + private def stashLocations(dir: File, files: Array[File]) = + for( (file, index) <- files.zipWithIndex) yield + (file, new File(dir, index.toHexString)) + + def move(files: Iterable[(File, File)]): Unit = + files.foreach(Function.tupled(move)) + + def move(a: File, b: File): Unit = + { + if(b.exists) + delete(b) + if(!a.renameTo(b)) + { + copyFile(a, b) + delete(a) + } + } } diff --git a/util/io/src/test/scala/StashSpec.scala b/util/io/src/test/scala/StashSpec.scala new file mode 100644 index 000000000..df76165d2 --- /dev/null +++ b/util/io/src/test/scala/StashSpec.scala @@ -0,0 +1,81 @@ +/* sbt -- Simple Build Tool + * Copyright 2010 Mark Harrah */ + +package xsbt + +import org.specs._ + +import FileUtilities._ +import java.io.File +import Function.tupled + +object CheckStash extends Specification +{ + "stash" should { + "handle empty files" in { + stash(Set()) { } + } + + "move files during execution" in { + WithFiles(TestFiles : _*) ( checkMove ) + } + + "restore files on exceptions but not errors" in { + WithFiles(TestFiles : _*) ( checkRestore ) + } + } + + def checkRestore(seq: Seq[File]) + { + allCorrect(seq) + + stash0(seq, throw new TestRuntimeException) must beFalse + allCorrect(seq) + + stash0(seq, throw new TestException) must beFalse + allCorrect(seq) + + stash0(seq, throw new TestError) must beFalse + noneExist(seq) + } + def checkMove(seq: Seq[File]) + { + allCorrect(seq) + stash0(seq, ()) must beTrue + noneExist(seq) + } + def stash0(seq: Seq[File], post: => Unit): Boolean = + try + { + stash(Set() ++ seq) { + noneExist(seq) + post + } + true + } + catch { + case _: TestError | _: TestException | _: TestRuntimeException => false + } + + def allCorrect(s: Seq[File]) = (s.toList zip TestFiles.toList).forall(tupled(correct)) + def correct(check: File, ref: (File, String)) = + { + check.exists must beTrue + read(check) must beEqual(ref._2) + } + def noneExist(s: Seq[File]) = s.forall(!_.exists) must beTrue + + lazy val TestFiles = + Seq( + "a/b/c" -> "content1", + "a/b/e" -> "content1", + "c" -> "", + "e/g" -> "asdf", + "a/g/c" -> "other" + ) map { + case (f, c) => (new File(f), c) + } +} +class TestError extends Error +class TestRuntimeException extends RuntimeException +class TestException extends Exception \ No newline at end of file