From 06674f598108eded92387ed382cf11e53fe22cd4 Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Wed, 21 Sep 2016 21:23:13 +0200 Subject: [PATCH] Fixes / enhancements --- .../scala-2.11/coursier/cli/SparkSubmit.scala | 53 +++++++------------ .../coursier/cli/spark/Assembly.scala | 51 ++++++++++++------ 2 files changed, 54 insertions(+), 50 deletions(-) diff --git a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala index d34a36297..276a7da0b 100644 --- a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala +++ b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala @@ -55,10 +55,19 @@ case class SparkSubmit( options: SparkSubmitOptions ) extends App with ExtraArgsApp { + val rawExtraJars = options.extraJars.map(new File(_)) + + val extraDirs = rawExtraJars.filter(_.isDirectory) + if (extraDirs.nonEmpty) { + Console.err.println(s"Error: directories not allowed in extra job JARs.") + Console.err.println(extraDirs.map(" " + _).mkString("\n")) + sys.exit(1) + } + val helper: Helper = new Helper( options.common, remainingArgs, - extraJars = options.extraJars.map(new File(_)) + extraJars = rawExtraJars ) val jars = helper.fetch(sources = false, javadoc = false) ++ @@ -79,37 +88,15 @@ case class SparkSubmit( (options.common.scalaVersion, options.sparkVersion) val assemblyOrError = - if (options.sparkAssembly.isEmpty) { - - // FIXME Also vaguely done in Helper and below - - val (errors, modVers) = Parse.moduleVersionConfigs( - options.assemblyDependencies, - options.common.scalaVersion + if (options.sparkAssembly.isEmpty) + Assembly.spark( + scalaVersion, + sparkVersion, + options.noDefaultAssemblyDependencies, + options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty), + options.common ) - - val deps = modVers.map { - case (module, version, configOpt) => - Dependency( - module, - version, - attributes = Attributes(options.common.defaultArtifactType, ""), - configuration = configOpt.getOrElse(options.common.defaultConfiguration), - exclusions = helper.excludes - ) - } - - if (errors.isEmpty) - Assembly.spark( - scalaVersion, - sparkVersion, - options.noDefaultAssemblyDependencies, - options.assemblyDependencies, - options.common - ) - else - Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}") - } else { + else { val f = new File(options.sparkAssembly) if (f.isFile) Right((f, Nil)) @@ -182,7 +169,7 @@ case class SparkSubmit( scalaVersion, sparkVersion, options.noDefaultSubmitDependencies, - options.submitDependencies, + options.submitDependencies.flatMap(_.split(",")).filter(_.nonEmpty), options.common ) @@ -333,4 +320,4 @@ object OutputHelper { threads.foreach(_.start()) } } -} \ No newline at end of file +} diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala index 1d545d8f5..9dc4962ca 100644 --- a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -4,6 +4,7 @@ import java.io.{File, FileInputStream, FileOutputStream} import java.math.BigInteger import java.nio.file.{Files, StandardCopyOption} import java.security.MessageDigest +import java.util.jar.{Attributes, JarFile, JarOutputStream, Manifest} import java.util.regex.Pattern import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} @@ -39,15 +40,20 @@ object Assembly { val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p } + val manifest = new Manifest + manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0") + var fos: FileOutputStream = null var zos: ZipOutputStream = null try { fos = new FileOutputStream(output) - zos = new ZipOutputStream(fos) + zos = new JarOutputStream(fos, manifest) val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]] + var ignore = Set.empty[String] + for (jar <- jars) { var fis: FileInputStream = null var zis: ZipInputStream = null @@ -65,10 +71,13 @@ object Assembly { concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil)) case None => - if (!excludePatterns.exists(_.matcher(ent.getName).matches())) { + if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) { + ent.setCompressedSize(-1L) zos.putNextEntry(ent) zos.write(content) zos.closeEntry() + + ignore += ent.getName } } @@ -80,12 +89,19 @@ object Assembly { } } - for ((path, entries) <- concatenedEntries) { + for ((_, entries) <- concatenedEntries) { val (ent, _) = entries.head + + ent.setCompressedSize(-1L) + + if (entries.tail.nonEmpty) + ent.setSize(entries.map(_._2.length).sum) + zos.putNextEntry(ent) - for ((_, b) <- entries.reverse) - zos.write(b) - zos.close() + // for ((_, b) <- entries.reverse) + // zos.write(b) + zos.write(entries.reverse.toArray.flatMap(_._2)) + zos.closeEntry() } } finally { if (zos != null) @@ -99,23 +115,24 @@ object Assembly { Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"), Rule.Append("reference.conf"), Rule.Exclude("log4j.properties"), - Rule.ExcludePattern("META-INF/*.[sS][fF]"), - Rule.ExcludePattern("META-INF/*.[dD][sS][aA]"), - Rule.ExcludePattern("META-INF/*.[rR][sS][aA]") + Rule.Exclude(JarFile.MANIFEST_NAME), + Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"), + Rule.ExcludePattern("META-INF/.*\\.[dD][sS][aA]"), + Rule.ExcludePattern("META-INF/.*\\.[rR][sS][aA]") ) def sparkAssemblyDependencies( scalaVersion: String, sparkVersion: String ) = Seq( - "org.apache.spark:spark-core_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", - "org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" + s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" ) def spark(