From a50cb1bd85ef5ec36a83984c167e13011b4c2498 Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Wed, 21 Sep 2016 10:37:13 +0200 Subject: [PATCH] WIP - Generate spark assemblies on the fly --- cache/src/main/scala/coursier/Cache.scala | 4 +- .../scala-2.11/coursier/cli/Bootstrap.scala | 27 +-- .../scala-2.11/coursier/cli/SparkSubmit.scala | 8 +- .../coursier/cli/spark/Assembly.scala | 184 ++++++++++++++++-- .../scala-2.11/coursier/cli/util/Zip.scala | 28 +++ 5 files changed, 213 insertions(+), 38 deletions(-) create mode 100644 cli/src/main/scala-2.11/coursier/cli/util/Zip.scala diff --git a/cache/src/main/scala/coursier/Cache.scala b/cache/src/main/scala/coursier/Cache.scala index d733f2486..3e2089474 100644 --- a/cache/src/main/scala/coursier/Cache.scala +++ b/cache/src/main/scala/coursier/Cache.scala @@ -53,7 +53,7 @@ object Cache { } } - private def localFile(url: String, cache: File, user: Option[String]): File = { + def localFile(url: String, cache: File, user: Option[String]): File = { val path = if (url.startsWith("file:///")) url.stripPrefix("file://") @@ -155,7 +155,7 @@ object Cache { } } - private def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = { + def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = { val lockFile = new File(file.getParentFile, s"${file.getName}.lock") var out: FileOutputStream = null diff --git a/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala b/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala index 46ff68ee0..9d4e904cc 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Bootstrap.scala @@ -1,13 +1,14 @@ package coursier package cli -import java.io.{ FileInputStream, ByteArrayInputStream, ByteArrayOutputStream, File, IOException } +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, IOException} import java.nio.file.Files import java.nio.file.attribute.PosixFilePermission import java.util.Properties -import java.util.zip.{ ZipEntry, ZipOutputStream, ZipInputStream } +import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} import caseapp._ +import coursier.cli.util.Zip case class Bootstrap( @Recurse @@ -61,26 +62,6 @@ case class Bootstrap( sys.exit(1) } - def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] = - new Iterator[(ZipEntry, Array[Byte])] { - var nextEntry = Option.empty[ZipEntry] - def update() = - nextEntry = Option(zipStream.getNextEntry) - - update() - - def hasNext = nextEntry.nonEmpty - def next() = { - val ent = nextEntry.get - val data = Platform.readFullySync(zipStream) - - update() - - (ent, data) - } - } - - val isolatedDeps = options.isolated.isolatedDeps( options.common.defaultArtifactType, options.common.scalaVersion @@ -141,7 +122,7 @@ case class Bootstrap( val bootstrapZip = new ZipInputStream(new ByteArrayInputStream(bootstrapJar)) val outputZip = new ZipOutputStream(buffer) - for ((ent, data) <- zipEntries(bootstrapZip)) { + for ((ent, data) <- Zip.zipEntries(bootstrapZip)) { outputZip.putNextEntry(ent) outputZip.write(data) outputZip.closeEntry() diff --git a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala index 3d6a79584..78573408d 100644 --- a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala +++ b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala @@ -94,7 +94,13 @@ case class SparkSubmit( } if (errors.isEmpty) - Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps) + Assembly.spark( + scalaVersion, + sparkVersion, + options.noDefaultAssemblyDependencies, + options.assemblyDependencies, + options.common + ) else Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}") } else { diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala index 18d6cafd9..1d545d8f5 100644 --- a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -1,36 +1,196 @@ package coursier.cli.spark import java.io.{File, FileInputStream, FileOutputStream} -import java.util.zip.{ZipInputStream, ZipOutputStream} +import java.math.BigInteger +import java.nio.file.{Files, StandardCopyOption} +import java.security.MessageDigest +import java.util.regex.Pattern +import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream} -import coursier.Dependency +import coursier.Cache +import coursier.cli.{CommonOptions, Helper} +import coursier.cli.util.Zip + +import scala.collection.mutable +import scalaz.\/- object Assembly { sealed abstract class Rule extends Product with Serializable object Rule { - case class Exclude(path: String) extends Rule - case class Append(path: String) extends Rule + sealed abstract class PathRule extends Rule { + def path: String + } + + case class Exclude(path: String) extends PathRule + case class Append(path: String) extends PathRule + + case class ExcludePattern(path: Pattern) extends Rule + + object ExcludePattern { + def apply(s: String): ExcludePattern = + ExcludePattern(Pattern.compile(s)) + } } def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = { - val zos = new ZipOutputStream(new FileOutputStream(output)) + val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap + val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p } - for (jar <- jars) { - new ZipInputStream(new FileInputStream(jar)) + var fos: FileOutputStream = null + var zos: ZipOutputStream = null + + try { + fos = new FileOutputStream(output) + zos = new ZipOutputStream(fos) + + val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]] + + for (jar <- jars) { + var fis: FileInputStream = null + var zis: ZipInputStream = null + + try { + fis = new FileInputStream(jar) + zis = new ZipInputStream(fis) + + for ((ent, content) <- Zip.zipEntries(zis)) + rulesMap.get(ent.getName) match { + case Some(Rule.Exclude(_)) => + // ignored + + case Some(Rule.Append(path)) => + concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil)) + + case None => + if (!excludePatterns.exists(_.matcher(ent.getName).matches())) { + zos.putNextEntry(ent) + zos.write(content) + zos.closeEntry() + } + } + + } finally { + if (zis != null) + zis.close() + if (fis != null) + fis.close() + } + } + + for ((path, entries) <- concatenedEntries) { + val (ent, _) = entries.head + zos.putNextEntry(ent) + for ((_, b) <- entries.reverse) + zos.write(b) + zos.close() + } + } finally { + if (zos != null) + zos.close() + if (fos != null) + fos.close() } - - ??? } + val assemblyRules = Seq[Rule]( + Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"), + Rule.Append("reference.conf"), + Rule.Exclude("log4j.properties"), + Rule.ExcludePattern("META-INF/*.[sS][fF]"), + Rule.ExcludePattern("META-INF/*.[dD][sS][aA]"), + Rule.ExcludePattern("META-INF/*.[rR][sS][aA]") + ) + + def sparkAssemblyDependencies( + scalaVersion: String, + sparkVersion: String + ) = Seq( + "org.apache.spark:spark-core_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", + "org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" + ) + def spark( scalaVersion: String, sparkVersion: String, noDefault: Boolean, - extraDependencies: Seq[Dependency] - ): Either[String, (File, Seq[File])] = - throw new Exception("Not implemented: automatic assembly generation") + extraDependencies: Seq[String], + options: CommonOptions + ): Either[String, (File, Seq[File])] = { + + val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion) + val helper = new Helper(options, extraDependencies ++ base) + + val artifacts = helper.artifacts(sources = false, javadoc = false) + val jars = helper.fetch(sources = false, javadoc = false) + + val checksums = artifacts.map { a => + val f = a.checksumUrls.get("SHA-1") match { + case Some(url) => + Cache.localFile(url, helper.cache, a.authentication.map(_.user)) + case None => + throw new Exception(s"SHA-1 file not found for ${a.url}") + } + + val sumOpt = Cache.parseChecksum( + new String(Files.readAllBytes(f.toPath), "UTF-8") + ) + + sumOpt match { + case Some(sum) => + val s = sum.toString(16) + "0" * (40 - s.length) + s + case None => + throw new Exception(s"Cannot read SHA-1 sum from $f") + } + } + + + val md = MessageDigest.getInstance("SHA-1") + + for (c <- checksums.sorted) { + val b = c.getBytes("UTF-8") + md.update(b, 0, b.length) + } + + val digest = md.digest() + val calculatedSum = new BigInteger(1, digest) + val s = calculatedSum.toString(16) + + val sum = "0" * (40 - s.length) + s + + val destPath = Seq( + sys.props("user.home"), + ".coursier", + "spark-assemblies", + s"scala_${scalaVersion}_spark_$sparkVersion", + sum, + "spark-assembly.jar" + ).mkString("/") + + val dest = new File(destPath) + + def success = Right((dest, jars)) + + if (dest.exists()) + success + else + Cache.withLockFor(helper.cache, dest) { + dest.getParentFile.mkdirs() + val tmpDest = new File(dest.getParentFile, s".${dest.getName}.part") + // FIXME Acquire lock on tmpDest + Assembly.make(jars, tmpDest, assemblyRules) + Files.move(tmpDest.toPath, dest.toPath, StandardCopyOption.ATOMIC_MOVE) + \/-((dest, jars)) + }.leftMap(_.describe).toEither + } } diff --git a/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala b/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala new file mode 100644 index 000000000..23ed3263b --- /dev/null +++ b/cli/src/main/scala-2.11/coursier/cli/util/Zip.scala @@ -0,0 +1,28 @@ +package coursier.cli.util + +import java.util.zip.{ZipEntry, ZipInputStream} + +import coursier.Platform + +object Zip { + + def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] = + new Iterator[(ZipEntry, Array[Byte])] { + var nextEntry = Option.empty[ZipEntry] + def update() = + nextEntry = Option(zipStream.getNextEntry) + + update() + + def hasNext = nextEntry.nonEmpty + def next() = { + val ent = nextEntry.get + val data = Platform.readFullySync(zipStream) + + update() + + (ent, data) + } + } + +}