diff --git a/cli/src/main/scala/coursier/cli/Coursier.scala b/cli/src/main/scala/coursier/cli/Coursier.scala index 18a9b8199..e6db55133 100644 --- a/cli/src/main/scala/coursier/cli/Coursier.scala +++ b/cli/src/main/scala/coursier/cli/Coursier.scala @@ -256,7 +256,7 @@ case class Coursier( files0 } - val tasks = artifacts.map(artifact => files.file(artifact, cachePolicy).run.map(artifact.->)) + val tasks = artifacts.map(artifact => files.file(artifact).run.map(artifact.->)) def printTask = Task{ if (verbose0 >= 0 && artifacts.nonEmpty) println(s"Found ${artifacts.length} artifacts") diff --git a/core/src/main/scala/coursier/core/Repository.scala b/core/src/main/scala/coursier/core/Repository.scala index a1263f75c..eb8ac7350 100644 --- a/core/src/main/scala/coursier/core/Repository.scala +++ b/core/src/main/scala/coursier/core/Repository.scala @@ -76,8 +76,8 @@ object Repository { implicit class ArtifactExtensions(val underlying: Artifact) extends AnyVal { def withDefaultChecksums: Artifact = underlying.copy(checksumUrls = underlying.checksumUrls ++ Seq( - "md5" -> (underlying.url + ".md5"), - "sha1" -> (underlying.url + ".sha1") + "MD5" -> (underlying.url + ".md5"), + "SHA-1" -> (underlying.url + ".sha1") )) def withDefaultSignature: Artifact = underlying.copy(extra = underlying.extra ++ Seq( diff --git a/files/src/main/scala/coursier/Files.scala b/files/src/main/scala/coursier/Files.scala index 8a6613c72..7c09f4533 100644 --- a/files/src/main/scala/coursier/Files.scala +++ b/files/src/main/scala/coursier/Files.scala @@ -1,10 +1,11 @@ package coursier -import java.net.{ URI, URL } +import java.net.URL +import java.security.MessageDigest import java.util.concurrent.{ Executors, ExecutorService } import scala.annotation.tailrec -import scalaz.{ -\/, \/-, \/, EitherT } +import scalaz._ import scalaz.concurrent.{ Task, Strategy } import java.io._ @@ -19,85 +20,182 @@ case class Files( lazy val defaultPool = Executors.newFixedThreadPool(concurrentDownloadCount, Strategy.DefaultDaemonThreadFactory) - def file( + def withLocal(artifact: Artifact): Artifact = { + val isLocal = + artifact.url.startsWith("file://") && + artifact.checksumUrls.values.forall(_.startsWith("file://")) + + def local(url: String) = + if (url.startsWith("file://")) + url.stripPrefix("file://") + else + cache.find{case (base, _) => url.startsWith(base)} match { + case None => ??? + case Some((base, cacheDir)) => + cacheDir + "/" + url.stripPrefix(base) + } + + if (artifact.extra.contains("local") || isLocal) + artifact + else + artifact.copy(extra = artifact.extra + ("local" -> + artifact.copy( + url = local(artifact.url), + checksumUrls = artifact.checksumUrls + .mapValues(local) + .toVector + .toMap, + extra = Map.empty + ) + )) + } + + def download( artifact: Artifact, - cachePolicy: CachePolicy + withChecksums: Boolean = true + )(implicit + cachePolicy: CachePolicy, + pool: ExecutorService = defaultPool + ): Task[Seq[((File, String), FileError \/ Unit)]] = { + val artifact0 = withLocal(artifact) + .extra + .getOrElse("local", artifact) + + val pairs = + Seq(artifact0.url -> artifact.url) ++ { + if (withChecksums) + (artifact0.checksumUrls.keySet intersect artifact.checksumUrls.keySet) + .toList + .map(sumType => artifact0.checksumUrls(sumType) -> artifact.checksumUrls(sumType)) + else + Nil + } + + + def locally(file: File) = + Task { + if (file.exists()) { + logger.foreach(_.foundLocally(file)) + \/-(file) + } else + -\/(FileError.NotFound(file.toString): FileError) + } + + // FIXME Things can go wrong here and are not properly handled, + // e.g. what if the connection gets closed during the transfer? + // (partial file on disk?) + def remote(file: File, url: String) = + Task { + try { + file.getParentFile.mkdirs() + + logger.foreach(_.downloadingArtifact(url)) + + val url0 = new URL(url) + val b = Array.fill[Byte](Files.bufferSize)(0) + val in = new BufferedInputStream(url0.openStream(), Files.bufferSize) + + try { + val out = new FileOutputStream(file) + try { + @tailrec + def helper(): Unit = { + val read = in.read(b) + if (read >= 0) { + out.write(b, 0, read) + helper() + } + } + + helper() + } finally out.close() + } finally in.close() + + logger.foreach(_.downloadedArtifact(url, success = true)) + \/-(file) + } + catch { case e: Exception => + logger.foreach(_.downloadedArtifact(url, success = false)) + -\/(FileError.DownloadError(e.getMessage)) + } + } + + + val tasks = + for ((f, url) <- pairs if url != ("file://" + f)) yield { + val file = new File(f) + cachePolicy(locally(file))(remote(file, url)) + .map(e => (file, url) -> e.map(_ => ())) + } + + Nondeterminism[Task].gather(tasks) + } + + def validateChecksum( + artifact: Artifact, + sumType: String )(implicit pool: ExecutorService = defaultPool - ): EitherT[Task, String, File] = { + ): Task[FileError \/ Unit] = { + val artifact0 = withLocal(artifact) + .extra + .getOrElse("local", artifact) - if (artifact.url.startsWith("file:///")) { - val f = new File(new URI(artifact.url) .getPath) - EitherT(Task.now( - if (f.exists()) { - logger.foreach(_.foundLocally(f)) - \/-(f) - } else -\/("Not found") - )) - } else { - cache.find{case (base, _) => artifact.url.startsWith(base)} match { - case None => ??? - case Some((base, cacheDir)) => - val file = new File(cacheDir, artifact.url.stripPrefix(base)) - def locally = { - Task { - if (file.exists()) { - logger.foreach(_.foundLocally(file)) - \/-(file) - } - else -\/("Not found in cache") - } - } + artifact0.checksumUrls.get(sumType) match { + case Some(sumFile) => + Task { + val sum = scala.io.Source.fromFile(sumFile) + .getLines() + .toStream + .headOption + .mkString + .takeWhile(!_.isSpaceChar) - def remote = { - // FIXME A lot of things can go wrong here and are not properly handled: - // - checksums should be validated - // - what if the connection gets closed during the transfer (partial file on disk)? - // - what if someone is trying to write this file at the same time? (no locking of any kind yet) - // - ... + val md = MessageDigest.getInstance(sumType) + val is = new FileInputStream(new File(artifact0.url)) + try Files.withContent(is, md.update(_, 0, _)) + finally is.close() - Task { - try { - file.getParentFile.mkdirs() + val digest = md.digest() + val calculatedSum = f"${BigInt(1, digest)}%040x" - logger.foreach(_.downloadingArtifact(artifact.url)) + if (sum == calculatedSum) + \/-(()) + else + -\/(FileError.WrongChecksum(sumType, calculatedSum, sum, artifact0.url, sumFile)) + } - val url = new URL(artifact.url) - val b = Array.fill[Byte](Files.bufferSize)(0) - val in = new BufferedInputStream(url.openStream(), Files.bufferSize) - - try { - val out = new FileOutputStream(file) - try { - @tailrec - def helper(): Unit = { - val read = in.read(b) - if (read >= 0) { - out.write(b, 0, read) - helper() - } - } - - helper() - } finally out.close() - } finally in.close() - - logger.foreach(_.downloadedArtifact(artifact.url, success = true)) - \/-(file) - } - catch { case e: Exception => - logger.foreach(_.downloadedArtifact(artifact.url, success = false)) - -\/(e.getMessage) - } - } - } - - EitherT(cachePolicy(locally)(remote)) - } + case None => + Task.now(-\/(FileError.ChecksumNoFound(sumType, artifact0.url))) } } + def file( + artifact: Artifact, + checksum: Option[String] = Some("SHA-1") + )(implicit + cachePolicy: CachePolicy, + pool: ExecutorService = defaultPool + ): EitherT[Task, FileError, File] = + EitherT{ + val res = + download(artifact) + .map(results => + results.head._2.map(_ => results.head._1._1) + ) + + checksum.fold(res) { sumType => + res + .flatMap{ + case err @ -\/(_) => Task.now(err) + case \/-(f) => + validateChecksum(artifact, sumType) + .map(_.map(_ => f)) + } + } + } + } object Files { @@ -139,4 +237,26 @@ object Files { } .leftMap(_.getMessage) } + def withContent(is: InputStream, f: (Array[Byte], Int) => Unit): Unit = { + val data = Array.ofDim[Byte](16384) + + var nRead = is.read(data, 0, data.length) + while (nRead != -1) { + f(data, nRead) + nRead = is.read(data, 0, data.length) + } + } + +} + +sealed trait FileError + +object FileError { + + case class DownloadError(message: String) extends FileError + case class NotFound(file: String) extends FileError + case class Locked(file: String) extends FileError + case class ChecksumNoFound(sumType: String, file: String) extends FileError + case class WrongChecksum(sumType: String, got: String, expected: String, file: String, sumFile: String) extends FileError + }