WIP - Generate spark assemblies on the fly

This commit is contained in:
Alexandre Archambault 2016-09-21 10:37:13 +02:00
parent b7439cac50
commit a50cb1bd85
No known key found for this signature in database
GPG Key ID: 14640A6839C263A9
5 changed files with 213 additions and 38 deletions

View File

@ -53,7 +53,7 @@ object Cache {
}
}
private def localFile(url: String, cache: File, user: Option[String]): File = {
def localFile(url: String, cache: File, user: Option[String]): File = {
val path =
if (url.startsWith("file:///"))
url.stripPrefix("file://")
@ -155,7 +155,7 @@ object Cache {
}
}
private def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
val lockFile = new File(file.getParentFile, s"${file.getName}.lock")
var out: FileOutputStream = null

View File

@ -1,13 +1,14 @@
package coursier
package cli
import java.io.{ FileInputStream, ByteArrayInputStream, ByteArrayOutputStream, File, IOException }
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, IOException}
import java.nio.file.Files
import java.nio.file.attribute.PosixFilePermission
import java.util.Properties
import java.util.zip.{ ZipEntry, ZipOutputStream, ZipInputStream }
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
import caseapp._
import coursier.cli.util.Zip
case class Bootstrap(
@Recurse
@ -61,26 +62,6 @@ case class Bootstrap(
sys.exit(1)
}
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
new Iterator[(ZipEntry, Array[Byte])] {
var nextEntry = Option.empty[ZipEntry]
def update() =
nextEntry = Option(zipStream.getNextEntry)
update()
def hasNext = nextEntry.nonEmpty
def next() = {
val ent = nextEntry.get
val data = Platform.readFullySync(zipStream)
update()
(ent, data)
}
}
val isolatedDeps = options.isolated.isolatedDeps(
options.common.defaultArtifactType,
options.common.scalaVersion
@ -141,7 +122,7 @@ case class Bootstrap(
val bootstrapZip = new ZipInputStream(new ByteArrayInputStream(bootstrapJar))
val outputZip = new ZipOutputStream(buffer)
for ((ent, data) <- zipEntries(bootstrapZip)) {
for ((ent, data) <- Zip.zipEntries(bootstrapZip)) {
outputZip.putNextEntry(ent)
outputZip.write(data)
outputZip.closeEntry()

View File

@ -94,7 +94,13 @@ case class SparkSubmit(
}
if (errors.isEmpty)
Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps)
Assembly.spark(
scalaVersion,
sparkVersion,
options.noDefaultAssemblyDependencies,
options.assemblyDependencies,
options.common
)
else
Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}")
} else {

View File

@ -1,36 +1,196 @@
package coursier.cli.spark
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.zip.{ZipInputStream, ZipOutputStream}
import java.math.BigInteger
import java.nio.file.{Files, StandardCopyOption}
import java.security.MessageDigest
import java.util.regex.Pattern
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
import coursier.Dependency
import coursier.Cache
import coursier.cli.{CommonOptions, Helper}
import coursier.cli.util.Zip
import scala.collection.mutable
import scalaz.\/-
object Assembly {
sealed abstract class Rule extends Product with Serializable
object Rule {
case class Exclude(path: String) extends Rule
case class Append(path: String) extends Rule
sealed abstract class PathRule extends Rule {
def path: String
}
case class Exclude(path: String) extends PathRule
case class Append(path: String) extends PathRule
case class ExcludePattern(path: Pattern) extends Rule
object ExcludePattern {
def apply(s: String): ExcludePattern =
ExcludePattern(Pattern.compile(s))
}
}
def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = {
val zos = new ZipOutputStream(new FileOutputStream(output))
val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap
val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p }
for (jar <- jars) {
new ZipInputStream(new FileInputStream(jar))
var fos: FileOutputStream = null
var zos: ZipOutputStream = null
try {
fos = new FileOutputStream(output)
zos = new ZipOutputStream(fos)
val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]]
for (jar <- jars) {
var fis: FileInputStream = null
var zis: ZipInputStream = null
try {
fis = new FileInputStream(jar)
zis = new ZipInputStream(fis)
for ((ent, content) <- Zip.zipEntries(zis))
rulesMap.get(ent.getName) match {
case Some(Rule.Exclude(_)) =>
// ignored
case Some(Rule.Append(path)) =>
concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil))
case None =>
if (!excludePatterns.exists(_.matcher(ent.getName).matches())) {
zos.putNextEntry(ent)
zos.write(content)
zos.closeEntry()
}
}
} finally {
if (zis != null)
zis.close()
if (fis != null)
fis.close()
}
}
for ((path, entries) <- concatenedEntries) {
val (ent, _) = entries.head
zos.putNextEntry(ent)
for ((_, b) <- entries.reverse)
zos.write(b)
zos.close()
}
} finally {
if (zos != null)
zos.close()
if (fos != null)
fos.close()
}
???
}
val assemblyRules = Seq[Rule](
Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"),
Rule.Append("reference.conf"),
Rule.Exclude("log4j.properties"),
Rule.ExcludePattern("META-INF/*.[sS][fF]"),
Rule.ExcludePattern("META-INF/*.[dD][sS][aA]"),
Rule.ExcludePattern("META-INF/*.[rR][sS][aA]")
)
def sparkAssemblyDependencies(
scalaVersion: String,
sparkVersion: String
) = Seq(
"org.apache.spark:spark-core_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion",
"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion"
)
def spark(
scalaVersion: String,
sparkVersion: String,
noDefault: Boolean,
extraDependencies: Seq[Dependency]
): Either[String, (File, Seq[File])] =
throw new Exception("Not implemented: automatic assembly generation")
extraDependencies: Seq[String],
options: CommonOptions
): Either[String, (File, Seq[File])] = {
val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion)
val helper = new Helper(options, extraDependencies ++ base)
val artifacts = helper.artifacts(sources = false, javadoc = false)
val jars = helper.fetch(sources = false, javadoc = false)
val checksums = artifacts.map { a =>
val f = a.checksumUrls.get("SHA-1") match {
case Some(url) =>
Cache.localFile(url, helper.cache, a.authentication.map(_.user))
case None =>
throw new Exception(s"SHA-1 file not found for ${a.url}")
}
val sumOpt = Cache.parseChecksum(
new String(Files.readAllBytes(f.toPath), "UTF-8")
)
sumOpt match {
case Some(sum) =>
val s = sum.toString(16)
"0" * (40 - s.length) + s
case None =>
throw new Exception(s"Cannot read SHA-1 sum from $f")
}
}
val md = MessageDigest.getInstance("SHA-1")
for (c <- checksums.sorted) {
val b = c.getBytes("UTF-8")
md.update(b, 0, b.length)
}
val digest = md.digest()
val calculatedSum = new BigInteger(1, digest)
val s = calculatedSum.toString(16)
val sum = "0" * (40 - s.length) + s
val destPath = Seq(
sys.props("user.home"),
".coursier",
"spark-assemblies",
s"scala_${scalaVersion}_spark_$sparkVersion",
sum,
"spark-assembly.jar"
).mkString("/")
val dest = new File(destPath)
def success = Right((dest, jars))
if (dest.exists())
success
else
Cache.withLockFor(helper.cache, dest) {
dest.getParentFile.mkdirs()
val tmpDest = new File(dest.getParentFile, s".${dest.getName}.part")
// FIXME Acquire lock on tmpDest
Assembly.make(jars, tmpDest, assemblyRules)
Files.move(tmpDest.toPath, dest.toPath, StandardCopyOption.ATOMIC_MOVE)
\/-((dest, jars))
}.leftMap(_.describe).toEither
}
}

View File

@ -0,0 +1,28 @@
package coursier.cli.util
import java.util.zip.{ZipEntry, ZipInputStream}
import coursier.Platform
object Zip {
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
new Iterator[(ZipEntry, Array[Byte])] {
var nextEntry = Option.empty[ZipEntry]
def update() =
nextEntry = Option(zipStream.getNextEntry)
update()
def hasNext = nextEntry.nonEmpty
def next() = {
val ent = nextEntry.get
val data = Platform.readFullySync(zipStream)
update()
(ent, data)
}
}
}