From 9b1d329d0b556dd7b32cf18134dc6c4bc7b9921d Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Mon, 3 Jul 2017 12:59:26 +0200 Subject: [PATCH] Prevent downloading the same artifact concurrently multiple times --- .../coursier/core/ResolutionProcess.scala | 54 ++++++++++++++----- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/core/shared/src/main/scala/coursier/core/ResolutionProcess.scala b/core/shared/src/main/scala/coursier/core/ResolutionProcess.scala index 195b237ec..2e2657abc 100644 --- a/core/shared/src/main/scala/coursier/core/ResolutionProcess.scala +++ b/core/shared/src/main/scala/coursier/core/ResolutionProcess.scala @@ -1,10 +1,12 @@ package coursier package core -import scalaz._ import scala.annotation.tailrec import scala.language.higherKinds +import scalaz.{Monad, -\/, \/-} +import scalaz.Scalaz.{ToFunctorOps, ToTraverseOps, vectorInstance} + sealed abstract class ResolutionProcess { def run[F[_]]( @@ -12,8 +14,7 @@ sealed abstract class ResolutionProcess { maxIterations: Int = 50 )(implicit F: Monad[F] - ): F[Resolution] = { - + ): F[Resolution] = if (maxIterations == 0) F.point(current) else { val maxIterations0 = @@ -23,7 +24,7 @@ sealed abstract class ResolutionProcess { case Done(res) => F.point(res) case missing0 @ Missing(missing, _, _) => - F.bind(fetch(missing))(result => + F.bind(ResolutionProcess.fetchAll(missing, fetch))(result => missing0.next(result).run(fetch, maxIterations0) ) case cont @ Continue(_, _) => @@ -32,7 +33,6 @@ sealed abstract class ResolutionProcess { .run(fetch, maxIterations0) } } - } @tailrec final def next[F[_]]( @@ -40,20 +40,18 @@ sealed abstract class ResolutionProcess { fastForward: Boolean = true )(implicit F: Monad[F] - ): F[ResolutionProcess] = { - + ): F[ResolutionProcess] = this match { - case Done(res) => + case Done(_) => F.point(this) case missing0 @ Missing(missing, _, _) => - F.map(fetch(missing))(result => missing0.next(result)) + F.map(ResolutionProcess.fetchAll(missing, fetch))(result => missing0.next(result)) case cont @ Continue(_, _) => if (fastForward) cont.nextNoCont.next(fetch, fastForward = fastForward) else F.point(cont.next) } - } def current: Resolution } @@ -110,10 +108,10 @@ final case class Missing( val orderedSuccesses = order(depMgmtMissing0.map { case (k, v) => k -> v.intersect(modVer) }.toMap, Nil) val res0 = orderedSuccesses.foldLeft(res) { - case (acc, (modVer, (source, proj))) => + case (acc, (modVer0, (source, proj))) => acc.copyWithCache( projectCache = acc.projectCache + ( - modVer -> (source, acc.withDependencyManagement(proj)) + modVer0 -> (source, acc.withDependencyManagement(proj)) ) ) } @@ -139,7 +137,7 @@ final case class Continue( def next: ResolutionProcess = cont(current) - @tailrec final def nextNoCont: ResolutionProcess = + @tailrec def nextNoCont: ResolutionProcess = next match { case nextCont: Continue => nextCont.nextNoCont case other => other @@ -160,5 +158,35 @@ object ResolutionProcess { else Missing(resolution0.missingFromCache.toSeq, resolution0, apply) } + + private def fetchAll[F[_]](modVers: Seq[(Module, String)], fetch: Fetch.Metadata[F])(implicit F: Monad[F]) = { + + def uniqueModules(modVers: Seq[(Module, String)]): Stream[Seq[(Module, String)]] = { + + val res = modVers.groupBy(_._1).toSeq.map(_._2).map { + case Seq(v) => (v, Nil) + case Seq() => sys.error("Cannot happen") + case v => + // there might be version intervals in there, but that shouldn't matter... + val res = v.maxBy { case (_, v0) => Version(v0) } + (res, v.filter(_ != res)) + } + + val other = res.flatMap(_._2) + + if (other.isEmpty) + Stream(modVers) + else { + val missing0 = res.map(_._1) + missing0 #:: uniqueModules(other) + } + } + + uniqueModules(modVers) + .toVector + .traverse(fetch) + .map(_.flatten) + } + }