diff --git a/cache/src/main/scala/coursier/Cache.scala b/cache/src/main/scala/coursier/Cache.scala index d733f2486..3e2089474 100644 --- a/cache/src/main/scala/coursier/Cache.scala +++ b/cache/src/main/scala/coursier/Cache.scala @@ -53,7 +53,7 @@ object Cache { } } - private def localFile(url: String, cache: File, user: Option[String]): File = { + def localFile(url: String, cache: File, user: Option[String]): File = { val path = if (url.startsWith("file:///")) url.stripPrefix("file://") @@ -155,7 +155,7 @@ object Cache { } } - private def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = { + def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = { val lockFile = new File(file.getParentFile, s"${file.getName}.lock") var out: FileOutputStream = null diff --git a/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala b/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala index 46ff68ee0..9d4e904cc 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala @@ -1,13 +1,14 @@ package coursier package cli -import java.io.{ FileInputStream, ByteArrayInputStream, ByteArrayOutputStream, File, IOException } +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, IOException} import java.nio.file.Files import java.nio.file.attribute.PosixFilePermission import java.util.Properties -import java.util.zip.{ ZipEntry, ZipOutputStream, ZipInputStream } +import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} import caseapp._ +import coursier.cli.util.Zip case class Bootstrap( @Recurse @@ -61,26 +62,6 @@ case class Bootstrap( sys.exit(1) } - def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] = - new Iterator[(ZipEntry, Array[Byte])] { - var nextEntry = Option.empty[ZipEntry] - def update() = - nextEntry = Option(zipStream.getNextEntry) - - update() - - def hasNext = nextEntry.nonEmpty - def next() = { - val ent = nextEntry.get - val data = Platform.readFullySync(zipStream) - - update() - - (ent, data) - } - } - - val isolatedDeps = options.isolated.isolatedDeps( options.common.defaultArtifactType, options.common.scalaVersion @@ -141,7 +122,7 @@ case class Bootstrap( val bootstrapZip = new ZipInputStream(new ByteArrayInputStream(bootstrapJar)) val outputZip = new ZipOutputStream(buffer) - for ((ent, data) <- zipEntries(bootstrapZip)) { + for ((ent, data) <- Zip.zipEntries(bootstrapZip)) { outputZip.putNextEntry(ent) outputZip.write(data) outputZip.closeEntry() diff --git a/cli/src/main/scala-2.11/coursier/cli/Helper.scala b/cli/src/main/scala-2.11/coursier/cli/Helper.scala index 461b8eef8..9b7c35bd5 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Helper.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Helper.scala @@ -75,6 +75,7 @@ object Util { class Helper( common: CommonOptions, rawDependencies: Seq[String], + extraJars: Seq[File] = Nil, printResultStdout: Boolean = false, ignoreErrors: Boolean = false, isolated: IsolatedLoaderOptions = IsolatedLoaderOptions(), @@ -584,7 +585,7 @@ class Helper( def contextLoader = Thread.currentThread().getContextClassLoader // TODO Would ClassLoader.getSystemClassLoader be better here? - val baseLoader: ClassLoader = + lazy val baseLoader: ClassLoader = Launch.mainClassLoader(contextLoader) .flatMap(cl => Option(cl.getParent)) .getOrElse { @@ -648,7 +649,7 @@ class Helper( } lazy val loader = new URLClassLoader( - filteredFiles.map(_.toURI.toURL).toArray, + (filteredFiles ++ extraJars).map(_.toURI.toURL).toArray, parentLoader ) diff --git a/cli/src/main/scala-2.11/coursier/cli/Launch.scala b/cli/src/main/scala-2.11/coursier/cli/Launch.scala index b7cc971e4..aa0ac1648 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Launch.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Launch.scala @@ -1,6 +1,7 @@ package coursier package cli +import java.io.File import java.net.{ URL, URLClassLoader } import caseapp._ @@ -114,6 +115,7 @@ case class Launch( val helper = new Helper( options.common, remainingArgs ++ options.isolated.rawIsolated.map { case (_, dep) => dep }, + extraJars = options.extraJars.map(new File(_)), isolated = options.isolated ) @@ -123,8 +125,19 @@ case class Launch( else options.mainClass + val extraJars = options.extraJars.filter(_.nonEmpty) + + val loader = + if (extraJars.isEmpty) + helper.loader + else + new URLClassLoader( + extraJars.map(new File(_).toURI.toURL).toArray, + helper.loader + ) + Launch.run( - helper.loader, + loader, mainClass, userArgs, options.common.verbosityLevel diff --git a/cli/src/main/scala-2.11/coursier/cli/Options.scala b/cli/src/main/scala-2.11/coursier/cli/Options.scala index 7359826f3..4d6c803ad 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Options.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Options.scala @@ -188,6 +188,9 @@ case class LaunchOptions( @Short("M") @Short("main") mainClass: String, + @Short("J") + @Help("Extra JARs to be added to the classpath of the launched application. Directories accepted too.") + extraJars: List[String], @Recurse isolated: IsolatedLoaderOptions, @Recurse @@ -226,6 +229,9 @@ case class SparkSubmitOptions( @Short("main") @Help("Main class to be launched (optional if in manifest)") mainClass: String, + @Short("J") + @Help("Extra JARs to be added in the classpath of the job") + extraJars: List[String], @Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)") @Value("file") yarnIdFile: String, diff --git a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala index 3d6a79584..276a7da0b 100644 --- a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala +++ b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala @@ -55,8 +55,23 @@ case class SparkSubmit( options: SparkSubmitOptions ) extends App with ExtraArgsApp { - val helper = new Helper(options.common, remainingArgs) - val jars = helper.fetch(sources = false, javadoc = false) + val rawExtraJars = options.extraJars.map(new File(_)) + + val extraDirs = rawExtraJars.filter(_.isDirectory) + if (extraDirs.nonEmpty) { + Console.err.println(s"Error: directories not allowed in extra job JARs.") + Console.err.println(extraDirs.map(" " + _).mkString("\n")) + sys.exit(1) + } + + val helper: Helper = new Helper( + options.common, + remainingArgs, + extraJars = rawExtraJars + ) + val jars = + helper.fetch(sources = false, javadoc = false) ++ + options.extraJars.map(new File(_)) val (scalaVersion, sparkVersion) = if (options.sparkVersion.isEmpty) @@ -73,31 +88,15 @@ case class SparkSubmit( (options.common.scalaVersion, options.sparkVersion) val assemblyOrError = - if (options.sparkAssembly.isEmpty) { - - // FIXME Also vaguely done in Helper and below - - val (errors, modVers) = Parse.moduleVersionConfigs( - options.assemblyDependencies, - options.common.scalaVersion + if (options.sparkAssembly.isEmpty) + Assembly.spark( + scalaVersion, + sparkVersion, + options.noDefaultAssemblyDependencies, + options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty), + options.common ) - - val deps = modVers.map { - case (module, version, configOpt) => - Dependency( - module, - version, - attributes = Attributes(options.common.defaultArtifactType, ""), - configuration = configOpt.getOrElse(options.common.defaultConfiguration), - exclusions = helper.excludes - ) - } - - if (errors.isEmpty) - Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps) - else - Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}") - } else { + else { val f = new File(options.sparkAssembly) if (f.isFile) Right((f, Nil)) @@ -170,7 +169,7 @@ case class SparkSubmit( scalaVersion, sparkVersion, options.noDefaultSubmitDependencies, - options.submitDependencies, + options.submitDependencies.flatMap(_.split(",")).filter(_.nonEmpty), options.common ) @@ -321,4 +320,4 @@ object OutputHelper { threads.foreach(_.start()) } } -} \ No newline at end of file +} diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala index 18d6cafd9..9dc4962ca 100644 --- a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -1,36 +1,213 @@ package coursier.cli.spark import java.io.{File, FileInputStream, FileOutputStream} -import java.util.zip.{ZipInputStream, ZipOutputStream} +import java.math.BigInteger +import java.nio.file.{Files, StandardCopyOption} +import java.security.MessageDigest +import java.util.jar.{Attributes, JarFile, JarOutputStream, Manifest} +import java.util.regex.Pattern +import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} -import coursier.Dependency +import coursier.Cache +import coursier.cli.{CommonOptions, Helper} +import coursier.cli.util.Zip + +import scala.collection.mutable +import scalaz.\/- object Assembly { sealed abstract class Rule extends Product with Serializable object Rule { - case class Exclude(path: String) extends Rule - case class Append(path: String) extends Rule + sealed abstract class PathRule extends Rule { + def path: String + } + + case class Exclude(path: String) extends PathRule + case class Append(path: String) extends PathRule + + case class ExcludePattern(path: Pattern) extends Rule + + object ExcludePattern { + def apply(s: String): ExcludePattern = + ExcludePattern(Pattern.compile(s)) + } } def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = { - val zos = new ZipOutputStream(new FileOutputStream(output)) + val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap + val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p } - for (jar <- jars) { - new ZipInputStream(new FileInputStream(jar)) + val manifest = new Manifest + manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0") + + var fos: FileOutputStream = null + var zos: ZipOutputStream = null + + try { + fos = new FileOutputStream(output) + zos = new JarOutputStream(fos, manifest) + + val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]] + + var ignore = Set.empty[String] + + for (jar <- jars) { + var fis: FileInputStream = null + var zis: ZipInputStream = null + + try { + fis = new FileInputStream(jar) + zis = new ZipInputStream(fis) + + for ((ent, content) <- Zip.zipEntries(zis)) + rulesMap.get(ent.getName) match { + case Some(Rule.Exclude(_)) => + // ignored + + case Some(Rule.Append(path)) => + concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil)) + + case None => + if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) { + ent.setCompressedSize(-1L) + zos.putNextEntry(ent) + zos.write(content) + zos.closeEntry() + + ignore += ent.getName + } + } + + } finally { + if (zis != null) + zis.close() + if (fis != null) + fis.close() + } + } + + for ((_, entries) <- concatenedEntries) { + val (ent, _) = entries.head + + ent.setCompressedSize(-1L) + + if (entries.tail.nonEmpty) + ent.setSize(entries.map(_._2.length).sum) + + zos.putNextEntry(ent) + // for ((_, b) <- entries.reverse) + // zos.write(b) + zos.write(entries.reverse.toArray.flatMap(_._2)) + zos.closeEntry() + } + } finally { + if (zos != null) + zos.close() + if (fos != null) + fos.close() } - - ??? } + val assemblyRules = Seq[Rule]( + Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"), + Rule.Append("reference.conf"), + Rule.Exclude("log4j.properties"), + Rule.Exclude(JarFile.MANIFEST_NAME), + Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"), + Rule.ExcludePattern("META-INF/.*\\.[dD][sS][aA]"), + Rule.ExcludePattern("META-INF/.*\\.[rR][sS][aA]") + ) + + def sparkAssemblyDependencies( + scalaVersion: String, + sparkVersion: String + ) = Seq( + s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" + ) + def spark( scalaVersion: String, sparkVersion: String, noDefault: Boolean, - extraDependencies: Seq[Dependency] - ): Either[String, (File, Seq[File])] = - throw new Exception("Not implemented: automatic assembly generation") + extraDependencies: Seq[String], + options: CommonOptions + ): Either[String, (File, Seq[File])] = { + + val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion) + val helper = new Helper(options, extraDependencies ++ base) + + val artifacts = helper.artifacts(sources = false, javadoc = false) + val jars = helper.fetch(sources = false, javadoc = false) + + val checksums = artifacts.map { a => + val f = a.checksumUrls.get("SHA-1") match { + case Some(url) => + Cache.localFile(url, helper.cache, a.authentication.map(_.user)) + case None => + throw new Exception(s"SHA-1 file not found for ${a.url}") + } + + val sumOpt = Cache.parseChecksum( + new String(Files.readAllBytes(f.toPath), "UTF-8") + ) + + sumOpt match { + case Some(sum) => + val s = sum.toString(16) + "0" * (40 - s.length) + s + case None => + throw new Exception(s"Cannot read SHA-1 sum from $f") + } + } + + + val md = MessageDigest.getInstance("SHA-1") + + for (c <- checksums.sorted) { + val b = c.getBytes("UTF-8") + md.update(b, 0, b.length) + } + + val digest = md.digest() + val calculatedSum = new BigInteger(1, digest) + val s = calculatedSum.toString(16) + + val sum = "0" * (40 - s.length) + s + + val destPath = Seq( + sys.props("user.home"), + ".coursier", + "spark-assemblies", + s"scala_${scalaVersion}_spark_$sparkVersion", + sum, + "spark-assembly.jar" + ).mkString("/") + + val dest = new File(destPath) + + def success = Right((dest, jars)) + + if (dest.exists()) + success + else + Cache.withLockFor(helper.cache, dest) { + dest.getParentFile.mkdirs() + val tmpDest = new File(dest.getParentFile, s".${dest.getName}.part") + // FIXME Acquire lock on tmpDest + Assembly.make(jars, tmpDest, assemblyRules) + Files.move(tmpDest.toPath, dest.toPath, StandardCopyOption.ATOMIC_MOVE) + \/-((dest, jars)) + }.leftMap(_.describe).toEither + } } diff --git a/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala b/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala new file mode 100644 index 000000000..23ed3263b --- /dev/null +++ b/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala @@ -0,0 +1,28 @@ +package coursier.cli.util + +import java.util.zip.{ZipEntry, ZipInputStream} + +import coursier.Platform + +object Zip { + + def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] = + new Iterator[(ZipEntry, Array[Byte])] { + var nextEntry = Option.empty[ZipEntry] + def update() = + nextEntry = Option(zipStream.getNextEntry) + + update() + + def hasNext = nextEntry.nonEmpty + def next() = { + val ent = nextEntry.get + val data = Platform.readFullySync(zipStream) + + update() + + (ent, data) + } + } + +}