Concatenate services files in generated spark-assemblies

This commit is contained in:
Alexandre Archambault 2016-11-04 18:00:43 +01:00
parent 23bbfcc552
commit 9d0526d1bf
No known key found for this signature in database
GPG Key ID: 14640A6839C263A9
1 changed files with 34 additions and 12 deletions

View File

@ -25,20 +25,29 @@ object Assembly {
}
case class Exclude(path: String) extends PathRule
case class Append(path: String) extends PathRule
case class ExcludePattern(path: Pattern) extends Rule
object ExcludePattern {
def apply(s: String): ExcludePattern =
ExcludePattern(Pattern.compile(s))
}
// TODO Accept a separator: Array[Byte] argument in these
// (to separate content with a line return in particular)
case class Append(path: String) extends PathRule
case class AppendPattern(path: Pattern) extends Rule
object AppendPattern {
def apply(s: String): AppendPattern =
AppendPattern(Pattern.compile(s))
}
}
def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = {
val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap
val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p }
val appendPatterns = rules.collect { case Rule.AppendPattern(p) => p }
val manifest = new Manifest
manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
@ -62,24 +71,33 @@ object Assembly {
fis = new FileInputStream(jar)
zis = new ZipInputStream(fis)
for ((ent, content) <- Zip.zipEntries(zis))
for ((ent, content) <- Zip.zipEntries(zis)) {
def append() =
concatenedEntries += ent.getName -> ::((ent, content), concatenedEntries.getOrElse(ent.getName, Nil))
rulesMap.get(ent.getName) match {
case Some(Rule.Exclude(_)) =>
// ignored
case Some(Rule.Append(path)) =>
concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil))
case Some(Rule.Append(_)) =>
append()
case None =>
if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) {
ent.setCompressedSize(-1L)
zos.putNextEntry(ent)
zos.write(content)
zos.closeEntry()
if (!excludePatterns.exists(_.matcher(ent.getName).matches())) {
if (appendPatterns.exists(_.matcher(ent.getName).matches()))
append()
else if (!ignore(ent.getName)) {
ent.setCompressedSize(-1L)
zos.putNextEntry(ent)
zos.write(content)
zos.closeEntry()
ignore += ent.getName
ignore += ent.getName
}
}
}
}
} finally {
if (zis != null)
@ -114,6 +132,7 @@ object Assembly {
val assemblyRules = Seq[Rule](
Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"),
Rule.Append("reference.conf"),
Rule.AppendPattern("META-INF/services/.*"),
Rule.Exclude("log4j.properties"),
Rule.Exclude(JarFile.MANIFEST_NAME),
Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"),
@ -141,7 +160,8 @@ object Assembly {
noDefault: Boolean,
extraDependencies: Seq[String],
options: CommonOptions,
artifactTypes: Set[String] = Set("jar")
artifactTypes: Set[String] = Set("jar"),
checksumSeed: Array[Byte] = "v1".getBytes("UTF-8")
): Either[String, (File, Seq[File])] = {
val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion)
@ -174,6 +194,8 @@ object Assembly {
val md = MessageDigest.getInstance("SHA-1")
md.update(checksumSeed)
for (c <- checksums.sorted) {
val b = c.getBytes("UTF-8")
md.update(b, 0, b.length)