Fixes / enhancements

This commit is contained in:
Alexandre Archambault 2016-09-21 21:23:13 +02:00
parent 7c921257ae
commit 06674f5981
No known key found for this signature in database
GPG Key ID: 14640A6839C263A9
2 changed files with 54 additions and 50 deletions

View File

@ -55,10 +55,19 @@ case class SparkSubmit(
options: SparkSubmitOptions options: SparkSubmitOptions
) extends App with ExtraArgsApp { ) extends App with ExtraArgsApp {
val rawExtraJars = options.extraJars.map(new File(_))
val extraDirs = rawExtraJars.filter(_.isDirectory)
if (extraDirs.nonEmpty) {
Console.err.println(s"Error: directories not allowed in extra job JARs.")
Console.err.println(extraDirs.map(" " + _).mkString("\n"))
sys.exit(1)
}
val helper: Helper = new Helper( val helper: Helper = new Helper(
options.common, options.common,
remainingArgs, remainingArgs,
extraJars = options.extraJars.map(new File(_)) extraJars = rawExtraJars
) )
val jars = val jars =
helper.fetch(sources = false, javadoc = false) ++ helper.fetch(sources = false, javadoc = false) ++
@ -79,37 +88,15 @@ case class SparkSubmit(
(options.common.scalaVersion, options.sparkVersion) (options.common.scalaVersion, options.sparkVersion)
val assemblyOrError = val assemblyOrError =
if (options.sparkAssembly.isEmpty) { if (options.sparkAssembly.isEmpty)
Assembly.spark(
// FIXME Also vaguely done in Helper and below scalaVersion,
sparkVersion,
val (errors, modVers) = Parse.moduleVersionConfigs( options.noDefaultAssemblyDependencies,
options.assemblyDependencies, options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty),
options.common.scalaVersion options.common
) )
else {
val deps = modVers.map {
case (module, version, configOpt) =>
Dependency(
module,
version,
attributes = Attributes(options.common.defaultArtifactType, ""),
configuration = configOpt.getOrElse(options.common.defaultConfiguration),
exclusions = helper.excludes
)
}
if (errors.isEmpty)
Assembly.spark(
scalaVersion,
sparkVersion,
options.noDefaultAssemblyDependencies,
options.assemblyDependencies,
options.common
)
else
Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}")
} else {
val f = new File(options.sparkAssembly) val f = new File(options.sparkAssembly)
if (f.isFile) if (f.isFile)
Right((f, Nil)) Right((f, Nil))
@ -182,7 +169,7 @@ case class SparkSubmit(
scalaVersion, scalaVersion,
sparkVersion, sparkVersion,
options.noDefaultSubmitDependencies, options.noDefaultSubmitDependencies,
options.submitDependencies, options.submitDependencies.flatMap(_.split(",")).filter(_.nonEmpty),
options.common options.common
) )
@ -333,4 +320,4 @@ object OutputHelper {
threads.foreach(_.start()) threads.foreach(_.start())
} }
} }
} }

View File

@ -4,6 +4,7 @@ import java.io.{File, FileInputStream, FileOutputStream}
import java.math.BigInteger import java.math.BigInteger
import java.nio.file.{Files, StandardCopyOption} import java.nio.file.{Files, StandardCopyOption}
import java.security.MessageDigest import java.security.MessageDigest
import java.util.jar.{Attributes, JarFile, JarOutputStream, Manifest}
import java.util.regex.Pattern import java.util.regex.Pattern
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
@ -39,15 +40,20 @@ object Assembly {
val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap
val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p } val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p }
val manifest = new Manifest
manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
var fos: FileOutputStream = null var fos: FileOutputStream = null
var zos: ZipOutputStream = null var zos: ZipOutputStream = null
try { try {
fos = new FileOutputStream(output) fos = new FileOutputStream(output)
zos = new ZipOutputStream(fos) zos = new JarOutputStream(fos, manifest)
val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]] val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]]
var ignore = Set.empty[String]
for (jar <- jars) { for (jar <- jars) {
var fis: FileInputStream = null var fis: FileInputStream = null
var zis: ZipInputStream = null var zis: ZipInputStream = null
@ -65,10 +71,13 @@ object Assembly {
concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil)) concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil))
case None => case None =>
if (!excludePatterns.exists(_.matcher(ent.getName).matches())) { if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) {
ent.setCompressedSize(-1L)
zos.putNextEntry(ent) zos.putNextEntry(ent)
zos.write(content) zos.write(content)
zos.closeEntry() zos.closeEntry()
ignore += ent.getName
} }
} }
@ -80,12 +89,19 @@ object Assembly {
} }
} }
for ((path, entries) <- concatenedEntries) { for ((_, entries) <- concatenedEntries) {
val (ent, _) = entries.head val (ent, _) = entries.head
ent.setCompressedSize(-1L)
if (entries.tail.nonEmpty)
ent.setSize(entries.map(_._2.length).sum)
zos.putNextEntry(ent) zos.putNextEntry(ent)
for ((_, b) <- entries.reverse) // for ((_, b) <- entries.reverse)
zos.write(b) // zos.write(b)
zos.close() zos.write(entries.reverse.toArray.flatMap(_._2))
zos.closeEntry()
} }
} finally { } finally {
if (zos != null) if (zos != null)
@ -99,23 +115,24 @@ object Assembly {
Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"), Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"),
Rule.Append("reference.conf"), Rule.Append("reference.conf"),
Rule.Exclude("log4j.properties"), Rule.Exclude("log4j.properties"),
Rule.ExcludePattern("META-INF/*.[sS][fF]"), Rule.Exclude(JarFile.MANIFEST_NAME),
Rule.ExcludePattern("META-INF/*.[dD][sS][aA]"), Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"),
Rule.ExcludePattern("META-INF/*.[rR][sS][aA]") Rule.ExcludePattern("META-INF/.*\\.[dD][sS][aA]"),
Rule.ExcludePattern("META-INF/.*\\.[rR][sS][aA]")
) )
def sparkAssemblyDependencies( def sparkAssemblyDependencies(
scalaVersion: String, scalaVersion: String,
sparkVersion: String sparkVersion: String
) = Seq( ) = Seq(
"org.apache.spark:spark-core_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion"
) )
def spark( def spark(