mirror of https://github.com/sbt/sbt.git
WIP - Generate spark assemblies on the fly
This commit is contained in:
parent
b7439cac50
commit
a50cb1bd85
|
|
@ -53,7 +53,7 @@ object Cache {
|
|||
}
|
||||
}
|
||||
|
||||
private def localFile(url: String, cache: File, user: Option[String]): File = {
|
||||
def localFile(url: String, cache: File, user: Option[String]): File = {
|
||||
val path =
|
||||
if (url.startsWith("file:///"))
|
||||
url.stripPrefix("file://")
|
||||
|
|
@ -155,7 +155,7 @@ object Cache {
|
|||
}
|
||||
}
|
||||
|
||||
private def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
|
||||
def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
|
||||
val lockFile = new File(file.getParentFile, s"${file.getName}.lock")
|
||||
|
||||
var out: FileOutputStream = null
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
package coursier
|
||||
package cli
|
||||
|
||||
import java.io.{ FileInputStream, ByteArrayInputStream, ByteArrayOutputStream, File, IOException }
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, IOException}
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.attribute.PosixFilePermission
|
||||
import java.util.Properties
|
||||
import java.util.zip.{ ZipEntry, ZipOutputStream, ZipInputStream }
|
||||
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
|
||||
|
||||
import caseapp._
|
||||
import coursier.cli.util.Zip
|
||||
|
||||
case class Bootstrap(
|
||||
@Recurse
|
||||
|
|
@ -61,26 +62,6 @@ case class Bootstrap(
|
|||
sys.exit(1)
|
||||
}
|
||||
|
||||
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
|
||||
new Iterator[(ZipEntry, Array[Byte])] {
|
||||
var nextEntry = Option.empty[ZipEntry]
|
||||
def update() =
|
||||
nextEntry = Option(zipStream.getNextEntry)
|
||||
|
||||
update()
|
||||
|
||||
def hasNext = nextEntry.nonEmpty
|
||||
def next() = {
|
||||
val ent = nextEntry.get
|
||||
val data = Platform.readFullySync(zipStream)
|
||||
|
||||
update()
|
||||
|
||||
(ent, data)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val isolatedDeps = options.isolated.isolatedDeps(
|
||||
options.common.defaultArtifactType,
|
||||
options.common.scalaVersion
|
||||
|
|
@ -141,7 +122,7 @@ case class Bootstrap(
|
|||
val bootstrapZip = new ZipInputStream(new ByteArrayInputStream(bootstrapJar))
|
||||
val outputZip = new ZipOutputStream(buffer)
|
||||
|
||||
for ((ent, data) <- zipEntries(bootstrapZip)) {
|
||||
for ((ent, data) <- Zip.zipEntries(bootstrapZip)) {
|
||||
outputZip.putNextEntry(ent)
|
||||
outputZip.write(data)
|
||||
outputZip.closeEntry()
|
||||
|
|
|
|||
|
|
@ -94,7 +94,13 @@ case class SparkSubmit(
|
|||
}
|
||||
|
||||
if (errors.isEmpty)
|
||||
Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps)
|
||||
Assembly.spark(
|
||||
scalaVersion,
|
||||
sparkVersion,
|
||||
options.noDefaultAssemblyDependencies,
|
||||
options.assemblyDependencies,
|
||||
options.common
|
||||
)
|
||||
else
|
||||
Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}")
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,36 +1,196 @@
|
|||
package coursier.cli.spark
|
||||
|
||||
import java.io.{File, FileInputStream, FileOutputStream}
|
||||
import java.util.zip.{ZipInputStream, ZipOutputStream}
|
||||
import java.math.BigInteger
|
||||
import java.nio.file.{Files, StandardCopyOption}
|
||||
import java.security.MessageDigest
|
||||
import java.util.regex.Pattern
|
||||
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
|
||||
|
||||
import coursier.Dependency
|
||||
import coursier.Cache
|
||||
import coursier.cli.{CommonOptions, Helper}
|
||||
import coursier.cli.util.Zip
|
||||
|
||||
import scala.collection.mutable
|
||||
import scalaz.\/-
|
||||
|
||||
object Assembly {
|
||||
|
||||
sealed abstract class Rule extends Product with Serializable
|
||||
|
||||
object Rule {
|
||||
case class Exclude(path: String) extends Rule
|
||||
case class Append(path: String) extends Rule
|
||||
sealed abstract class PathRule extends Rule {
|
||||
def path: String
|
||||
}
|
||||
|
||||
case class Exclude(path: String) extends PathRule
|
||||
case class Append(path: String) extends PathRule
|
||||
|
||||
case class ExcludePattern(path: Pattern) extends Rule
|
||||
|
||||
object ExcludePattern {
|
||||
def apply(s: String): ExcludePattern =
|
||||
ExcludePattern(Pattern.compile(s))
|
||||
}
|
||||
}
|
||||
|
||||
def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = {
|
||||
|
||||
val zos = new ZipOutputStream(new FileOutputStream(output))
|
||||
val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap
|
||||
val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p }
|
||||
|
||||
for (jar <- jars) {
|
||||
new ZipInputStream(new FileInputStream(jar))
|
||||
var fos: FileOutputStream = null
|
||||
var zos: ZipOutputStream = null
|
||||
|
||||
try {
|
||||
fos = new FileOutputStream(output)
|
||||
zos = new ZipOutputStream(fos)
|
||||
|
||||
val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]]
|
||||
|
||||
for (jar <- jars) {
|
||||
var fis: FileInputStream = null
|
||||
var zis: ZipInputStream = null
|
||||
|
||||
try {
|
||||
fis = new FileInputStream(jar)
|
||||
zis = new ZipInputStream(fis)
|
||||
|
||||
for ((ent, content) <- Zip.zipEntries(zis))
|
||||
rulesMap.get(ent.getName) match {
|
||||
case Some(Rule.Exclude(_)) =>
|
||||
// ignored
|
||||
|
||||
case Some(Rule.Append(path)) =>
|
||||
concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil))
|
||||
|
||||
case None =>
|
||||
if (!excludePatterns.exists(_.matcher(ent.getName).matches())) {
|
||||
zos.putNextEntry(ent)
|
||||
zos.write(content)
|
||||
zos.closeEntry()
|
||||
}
|
||||
}
|
||||
|
||||
} finally {
|
||||
if (zis != null)
|
||||
zis.close()
|
||||
if (fis != null)
|
||||
fis.close()
|
||||
}
|
||||
}
|
||||
|
||||
for ((path, entries) <- concatenedEntries) {
|
||||
val (ent, _) = entries.head
|
||||
zos.putNextEntry(ent)
|
||||
for ((_, b) <- entries.reverse)
|
||||
zos.write(b)
|
||||
zos.close()
|
||||
}
|
||||
} finally {
|
||||
if (zos != null)
|
||||
zos.close()
|
||||
if (fos != null)
|
||||
fos.close()
|
||||
}
|
||||
|
||||
???
|
||||
}
|
||||
|
||||
val assemblyRules = Seq[Rule](
|
||||
Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"),
|
||||
Rule.Append("reference.conf"),
|
||||
Rule.Exclude("log4j.properties"),
|
||||
Rule.ExcludePattern("META-INF/*.[sS][fF]"),
|
||||
Rule.ExcludePattern("META-INF/*.[dD][sS][aA]"),
|
||||
Rule.ExcludePattern("META-INF/*.[rR][sS][aA]")
|
||||
)
|
||||
|
||||
def sparkAssemblyDependencies(
|
||||
scalaVersion: String,
|
||||
sparkVersion: String
|
||||
) = Seq(
|
||||
"org.apache.spark:spark-core_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion",
|
||||
"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion"
|
||||
)
|
||||
|
||||
def spark(
|
||||
scalaVersion: String,
|
||||
sparkVersion: String,
|
||||
noDefault: Boolean,
|
||||
extraDependencies: Seq[Dependency]
|
||||
): Either[String, (File, Seq[File])] =
|
||||
throw new Exception("Not implemented: automatic assembly generation")
|
||||
extraDependencies: Seq[String],
|
||||
options: CommonOptions
|
||||
): Either[String, (File, Seq[File])] = {
|
||||
|
||||
val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion)
|
||||
val helper = new Helper(options, extraDependencies ++ base)
|
||||
|
||||
val artifacts = helper.artifacts(sources = false, javadoc = false)
|
||||
val jars = helper.fetch(sources = false, javadoc = false)
|
||||
|
||||
val checksums = artifacts.map { a =>
|
||||
val f = a.checksumUrls.get("SHA-1") match {
|
||||
case Some(url) =>
|
||||
Cache.localFile(url, helper.cache, a.authentication.map(_.user))
|
||||
case None =>
|
||||
throw new Exception(s"SHA-1 file not found for ${a.url}")
|
||||
}
|
||||
|
||||
val sumOpt = Cache.parseChecksum(
|
||||
new String(Files.readAllBytes(f.toPath), "UTF-8")
|
||||
)
|
||||
|
||||
sumOpt match {
|
||||
case Some(sum) =>
|
||||
val s = sum.toString(16)
|
||||
"0" * (40 - s.length) + s
|
||||
case None =>
|
||||
throw new Exception(s"Cannot read SHA-1 sum from $f")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val md = MessageDigest.getInstance("SHA-1")
|
||||
|
||||
for (c <- checksums.sorted) {
|
||||
val b = c.getBytes("UTF-8")
|
||||
md.update(b, 0, b.length)
|
||||
}
|
||||
|
||||
val digest = md.digest()
|
||||
val calculatedSum = new BigInteger(1, digest)
|
||||
val s = calculatedSum.toString(16)
|
||||
|
||||
val sum = "0" * (40 - s.length) + s
|
||||
|
||||
val destPath = Seq(
|
||||
sys.props("user.home"),
|
||||
".coursier",
|
||||
"spark-assemblies",
|
||||
s"scala_${scalaVersion}_spark_$sparkVersion",
|
||||
sum,
|
||||
"spark-assembly.jar"
|
||||
).mkString("/")
|
||||
|
||||
val dest = new File(destPath)
|
||||
|
||||
def success = Right((dest, jars))
|
||||
|
||||
if (dest.exists())
|
||||
success
|
||||
else
|
||||
Cache.withLockFor(helper.cache, dest) {
|
||||
dest.getParentFile.mkdirs()
|
||||
val tmpDest = new File(dest.getParentFile, s".${dest.getName}.part")
|
||||
// FIXME Acquire lock on tmpDest
|
||||
Assembly.make(jars, tmpDest, assemblyRules)
|
||||
Files.move(tmpDest.toPath, dest.toPath, StandardCopyOption.ATOMIC_MOVE)
|
||||
\/-((dest, jars))
|
||||
}.leftMap(_.describe).toEither
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
package coursier.cli.util
|
||||
|
||||
import java.util.zip.{ZipEntry, ZipInputStream}
|
||||
|
||||
import coursier.Platform
|
||||
|
||||
object Zip {
|
||||
|
||||
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
|
||||
new Iterator[(ZipEntry, Array[Byte])] {
|
||||
var nextEntry = Option.empty[ZipEntry]
|
||||
def update() =
|
||||
nextEntry = Option(zipStream.getNextEntry)
|
||||
|
||||
update()
|
||||
|
||||
def hasNext = nextEntry.nonEmpty
|
||||
def next() = {
|
||||
val ent = nextEntry.get
|
||||
val data = Platform.readFullySync(zipStream)
|
||||
|
||||
update()
|
||||
|
||||
(ent, data)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue