diff --git a/cli/src/main/scala/coursier/cli/Helper.scala b/cli/src/main/scala/coursier/cli/Helper.scala index 8ec1f43ef..23d5cdf12 100644 --- a/cli/src/main/scala/coursier/cli/Helper.scala +++ b/cli/src/main/scala/coursier/cli/Helper.scala @@ -2,11 +2,12 @@ package coursier package cli import java.io.{ OutputStreamWriter, File } +import java.util.concurrent.Executors import coursier.ivy.IvyRepository import scalaz.{ \/-, -\/ } -import scalaz.concurrent.Task +import scalaz.concurrent.{ Task, Strategy } object Helper { def fileRepr(f: File) = f.toString @@ -57,16 +58,14 @@ class Helper( sys.exit(255) } - val files = - Files( - Seq( - "http://" -> new File(new File(cacheOptions.cache), "http"), - "https://" -> new File(new File(cacheOptions.cache), "https") - ), - () => ???, - concurrentDownloadCount = parallel + val caches = + Seq( + "http://" -> new File(new File(cacheOptions.cache), "http"), + "https://" -> new File(new File(cacheOptions.cache), "https") ) + val pool = Executors.newFixedThreadPool(parallel, Strategy.DefaultDaemonThreadFactory) + val central = MavenRepository("https://repo1.maven.org/maven2/") val ivy2Local = MavenRepository( new File(sys.props("user.home") + "/.ivy2/local/").toURI.toString, @@ -218,7 +217,7 @@ class Helper( logger.foreach(_.init()) val fetchs = cachePolicies.map(p => - files.fetch(logger = logger)(cachePolicy = p) + Files.fetch(caches, p, logger = logger, pool = pool) ) val fetchQuiet = coursier.Fetch( repositories, @@ -345,8 +344,8 @@ class Helper( None logger.foreach(_.init()) val tasks = artifacts.map(artifact => - (files.file(artifact, logger = logger)(cachePolicy = cachePolicies.head) /: cachePolicies.tail)( - _ orElse files.file(artifact, logger = logger)(_) + (Files.file(artifact, caches, cachePolicies.head, logger = logger, pool = pool) /: cachePolicies.tail)( + _ orElse Files.file(artifact, caches, _, logger = logger, pool = pool) ).run.map(artifact.->) ) def printTask = Task { diff --git a/files/src/main/scala/coursier/Files.scala b/files/src/main/scala/coursier/Files.scala index d6e5c9eb8..1d9f2f467 100644 --- a/files/src/main/scala/coursier/Files.scala +++ b/files/src/main/scala/coursier/Files.scala @@ -11,34 +11,28 @@ import scalaz.concurrent.{ Task, Strategy } import java.io.{ Serializable => _, _ } -case class Files( - cache: Seq[(String, File)], - tmp: () => File, - concurrentDownloadCount: Int = Files.defaultConcurrentDownloadCount -) { +object Files { - import Files.urlLocks - - lazy val defaultPool = - Executors.newFixedThreadPool(concurrentDownloadCount, Strategy.DefaultDaemonThreadFactory) - - def withLocal(artifact: Artifact): Artifact = { + def withLocal(artifact: Artifact, cache: Seq[(String, File)]): Artifact = { def local(url: String) = if (url.startsWith("file:///")) url.stripPrefix("file://") else if (url.startsWith("file:/")) url.stripPrefix("file:") - else - cache.find { case (base, _) => url.startsWith(base) } match { - case None => - // FIXME Means we were handed an artifact from repositories other than the known ones - println(cache.mkString("\n")) - println(url) - ??? - case Some((base, cacheDir)) => + else { + val localPathOpt = cache.collectFirst { + case (base, cacheDir) if url.startsWith(base) => cacheDir + "/" + url.stripPrefix(base) } + localPathOpt.getOrElse { + // FIXME Means we were handed an artifact from repositories other than the known ones + println(cache.mkString("\n")) + println(url) + ??? + } + } + if (artifact.extra.contains("local")) artifact else @@ -56,13 +50,16 @@ case class Files( def download( artifact: Artifact, + cache: Seq[(String, File)], checksums: Set[String], - logger: Option[Files.Logger] = None - )(implicit cachePolicy: CachePolicy, - pool: ExecutorService = defaultPool + pool: ExecutorService, + logger: Option[Files.Logger] = None ): Task[Seq[((File, String), FileError \/ Unit)]] = { - val artifact0 = withLocal(artifact) + + implicit val pool0 = pool + + val artifact0 = withLocal(artifact, cache) .extra .getOrElse("local", artifact) @@ -271,11 +268,14 @@ case class Files( def validateChecksum( artifact: Artifact, - sumType: String - )(implicit - pool: ExecutorService = defaultPool + sumType: String, + cache: Seq[(String, File)], + pool: ExecutorService ): EitherT[Task, FileError, Unit] = { - val artifact0 = withLocal(artifact) + + implicit val pool0 = pool + + val artifact0 = withLocal(artifact, cache) .extra .getOrElse("local", artifact) @@ -330,18 +330,24 @@ case class Files( def file( artifact: Artifact, - checksums: Seq[Option[String]] = Seq(Some("SHA-1")), - logger: Option[Files.Logger] = None - )(implicit + cache: Seq[(String, File)], cachePolicy: CachePolicy, - pool: ExecutorService = defaultPool + checksums: Seq[Option[String]] = Seq(Some("SHA-1")), + logger: Option[Files.Logger] = None, + pool: ExecutorService = Files.defaultPool ): EitherT[Task, FileError, File] = { + + implicit val pool0 = pool + val checksums0 = if (checksums.isEmpty) Seq(None) else checksums val res = EitherT { download( artifact, + cache, checksums = checksums0.collect { case Some(c) => c }.toSet, + cachePolicy, + pool, logger = logger ).map { results => val checksum = checksums0.find { @@ -370,28 +376,31 @@ case class Files( res.flatMap { case (f, None) => EitherT(Task.now[FileError \/ File](\/-(f))) case (f, Some(c)) => - validateChecksum(artifact, c).map(_ => f) + validateChecksum(artifact, c, cache, pool).map(_ => f) } } def fetch( - checksums: Seq[Option[String]] = Seq(Some("SHA-1")), - logger: Option[Files.Logger] = None - )(implicit + cache: Seq[(String, File)], cachePolicy: CachePolicy, - pool: ExecutorService = defaultPool + checksums: Seq[Option[String]] = Seq(Some("SHA-1")), + logger: Option[Files.Logger] = None, + pool: ExecutorService = Files.defaultPool ): Fetch.Content[Task] = { artifact => - file(artifact, checksums = checksums, logger = logger)(cachePolicy).leftMap(_.message).map { f => + file( + artifact, + cache, + cachePolicy, + checksums = checksums, + logger = logger, + pool = pool + ).leftMap(_.message).map { f => // FIXME Catch error here? scala.io.Source.fromFile(f)("UTF-8").mkString } } -} - -object Files { - lazy val ivy2Local = MavenRepository( new File(sys.props("user.home") + "/.ivy2/local/").toURI.toString, ivyLike = true @@ -399,6 +408,10 @@ object Files { val defaultConcurrentDownloadCount = 6 + lazy val defaultPool = + Executors.newFixedThreadPool(defaultConcurrentDownloadCount, Strategy.DefaultDaemonThreadFactory) + + private val urlLocks = new ConcurrentHashMap[String, Object] trait Logger { diff --git a/plugin/src/main/scala/coursier/Tasks.scala b/plugin/src/main/scala/coursier/Tasks.scala index 223909724..3854919b1 100644 --- a/plugin/src/main/scala/coursier/Tasks.scala +++ b/plugin/src/main/scala/coursier/Tasks.scala @@ -1,6 +1,7 @@ package coursier import java.io.{OutputStreamWriter, File} +import java.util.concurrent.Executors import coursier.cli.TermDisplay import coursier.ivy.IvyRepository @@ -10,7 +11,7 @@ import Keys._ import sbt.Keys._ import scalaz.{\/-, -\/} -import scalaz.concurrent.Task +import scalaz.concurrent.{ Task, Strategy } object Tasks { @@ -150,18 +151,19 @@ object Tasks { val interProjectRepo = InterProjectRepository(projects) val repositories = Seq(globalPluginsRepo, interProjectRepo) ++ resolvers.flatMap(FromSbt.repository(_, ivyProperties)) - val files = Files( - Seq("http://" -> new File(cacheDir, "http"), "https://" -> new File(cacheDir, "https")), - () => ???, - concurrentDownloadCount = parallelDownloads + val caches = Seq( + "http://" -> new File(cacheDir, "http"), + "https://" -> new File(cacheDir, "https") ) + val pool = Executors.newFixedThreadPool(parallelDownloads, Strategy.DefaultDaemonThreadFactory) + val logger = createLogger() logger.foreach(_.init()) val fetch = coursier.Fetch( repositories, - files.fetch(checksums = checksums, logger = logger)(cachePolicy = CachePolicy.LocalOnly), - files.fetch(checksums = checksums, logger = logger)(cachePolicy = cachePolicy) + Files.fetch(caches, CachePolicy.LocalOnly, checksums = checksums, logger = logger, pool = pool), + Files.fetch(caches, cachePolicy, checksums = checksums, logger = logger, pool = pool) ) def depsRepr = currentProject.dependencies.map { case (config, dep) => @@ -225,6 +227,7 @@ object Tasks { for ((dep, errs) <- errors) { println(s" ${dep.module}:${dep.version}:\n${errs.map(" " + _.replace("\n", " \n")).mkString("\n")}") } + throw new Exception(s"Encountered ${errors.length} error(s)") } val classifiers = @@ -245,7 +248,7 @@ object Tasks { } val artifactFileOrErrorTasks = allArtifacts.toVector.map { a => - files.file(a, checksums = checksums, logger = logger)(cachePolicy = cachePolicy).run.map((a, _)) + Files.file(a, caches, cachePolicy, checksums = checksums, logger = logger, pool = pool).run.map((a, _)) } if (verbosity >= 0)