From 9d0526d1bf8bcadd5e6d2c16046bc873fee0077c Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Fri, 4 Nov 2016 18:00:43 +0100 Subject: [PATCH] Concatenate services files in generated spark-assemblies --- .../coursier/cli/spark/Assembly.scala | 46 ++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala index d7d53d087..4e130ec31 100644 --- a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -25,20 +25,29 @@ object Assembly { } case class Exclude(path: String) extends PathRule - case class Append(path: String) extends PathRule - case class ExcludePattern(path: Pattern) extends Rule object ExcludePattern { def apply(s: String): ExcludePattern = ExcludePattern(Pattern.compile(s)) } + + // TODO Accept a separator: Array[Byte] argument in these + // (to separate content with a line return in particular) + case class Append(path: String) extends PathRule + case class AppendPattern(path: Pattern) extends Rule + + object AppendPattern { + def apply(s: String): AppendPattern = + AppendPattern(Pattern.compile(s)) + } } def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = { val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p } + val appendPatterns = rules.collect { case Rule.AppendPattern(p) => p } val manifest = new Manifest manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0") @@ -62,24 +71,33 @@ object Assembly { fis = new FileInputStream(jar) zis = new ZipInputStream(fis) - for ((ent, content) <- Zip.zipEntries(zis)) + for ((ent, content) <- Zip.zipEntries(zis)) { + + def append() = + concatenedEntries += ent.getName -> ::((ent, content), concatenedEntries.getOrElse(ent.getName, Nil)) + rulesMap.get(ent.getName) match { case Some(Rule.Exclude(_)) => // ignored - case Some(Rule.Append(path)) => - concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil)) + case Some(Rule.Append(_)) => + append() case None => - if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) { - ent.setCompressedSize(-1L) - zos.putNextEntry(ent) - zos.write(content) - zos.closeEntry() + if (!excludePatterns.exists(_.matcher(ent.getName).matches())) { + if (appendPatterns.exists(_.matcher(ent.getName).matches())) + append() + else if (!ignore(ent.getName)) { + ent.setCompressedSize(-1L) + zos.putNextEntry(ent) + zos.write(content) + zos.closeEntry() - ignore += ent.getName + ignore += ent.getName + } } } + } } finally { if (zis != null) @@ -114,6 +132,7 @@ object Assembly { val assemblyRules = Seq[Rule]( Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"), Rule.Append("reference.conf"), + Rule.AppendPattern("META-INF/services/.*"), Rule.Exclude("log4j.properties"), Rule.Exclude(JarFile.MANIFEST_NAME), Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"), @@ -141,7 +160,8 @@ object Assembly { noDefault: Boolean, extraDependencies: Seq[String], options: CommonOptions, - artifactTypes: Set[String] = Set("jar") + artifactTypes: Set[String] = Set("jar"), + checksumSeed: Array[Byte] = "v1".getBytes("UTF-8") ): Either[String, (File, Seq[File])] = { val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion) @@ -174,6 +194,8 @@ object Assembly { val md = MessageDigest.getInstance("SHA-1") + md.update(checksumSeed) + for (c <- checksums.sorted) { val b = c.getBytes("UTF-8") md.update(b, 0, b.length)