Merge pull request #350 from alexarchambault/topic/spark-submit

Tweak spark submit
This commit is contained in:
Alexandre Archambault 2016-09-22 15:09:55 +02:00 committed by GitHub
commit 51f4fc93ab
8 changed files with 273 additions and 68 deletions

View File

@ -53,7 +53,7 @@ object Cache {
}
}
private def localFile(url: String, cache: File, user: Option[String]): File = {
def localFile(url: String, cache: File, user: Option[String]): File = {
val path =
if (url.startsWith("file:///"))
url.stripPrefix("file://")
@ -155,7 +155,7 @@ object Cache {
}
}
private def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
def withLockFor[T](cache: File, file: File)(f: => FileError \/ T): FileError \/ T = {
val lockFile = new File(file.getParentFile, s"${file.getName}.lock")
var out: FileOutputStream = null

View File

@ -1,13 +1,14 @@
package coursier
package cli
import java.io.{ FileInputStream, ByteArrayInputStream, ByteArrayOutputStream, File, IOException }
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream, IOException}
import java.nio.file.Files
import java.nio.file.attribute.PosixFilePermission
import java.util.Properties
import java.util.zip.{ ZipEntry, ZipOutputStream, ZipInputStream }
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
import caseapp._
import coursier.cli.util.Zip
case class Bootstrap(
@Recurse
@ -61,26 +62,6 @@ case class Bootstrap(
sys.exit(1)
}
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
new Iterator[(ZipEntry, Array[Byte])] {
var nextEntry = Option.empty[ZipEntry]
def update() =
nextEntry = Option(zipStream.getNextEntry)
update()
def hasNext = nextEntry.nonEmpty
def next() = {
val ent = nextEntry.get
val data = Platform.readFullySync(zipStream)
update()
(ent, data)
}
}
val isolatedDeps = options.isolated.isolatedDeps(
options.common.defaultArtifactType,
options.common.scalaVersion
@ -141,7 +122,7 @@ case class Bootstrap(
val bootstrapZip = new ZipInputStream(new ByteArrayInputStream(bootstrapJar))
val outputZip = new ZipOutputStream(buffer)
for ((ent, data) <- zipEntries(bootstrapZip)) {
for ((ent, data) <- Zip.zipEntries(bootstrapZip)) {
outputZip.putNextEntry(ent)
outputZip.write(data)
outputZip.closeEntry()

View File

@ -75,6 +75,7 @@ object Util {
class Helper(
common: CommonOptions,
rawDependencies: Seq[String],
extraJars: Seq[File] = Nil,
printResultStdout: Boolean = false,
ignoreErrors: Boolean = false,
isolated: IsolatedLoaderOptions = IsolatedLoaderOptions(),
@ -584,7 +585,7 @@ class Helper(
def contextLoader = Thread.currentThread().getContextClassLoader
// TODO Would ClassLoader.getSystemClassLoader be better here?
val baseLoader: ClassLoader =
lazy val baseLoader: ClassLoader =
Launch.mainClassLoader(contextLoader)
.flatMap(cl => Option(cl.getParent))
.getOrElse {
@ -648,7 +649,7 @@ class Helper(
}
lazy val loader = new URLClassLoader(
filteredFiles.map(_.toURI.toURL).toArray,
(filteredFiles ++ extraJars).map(_.toURI.toURL).toArray,
parentLoader
)

View File

@ -1,6 +1,7 @@
package coursier
package cli
import java.io.File
import java.net.{ URL, URLClassLoader }
import caseapp._
@ -114,6 +115,7 @@ case class Launch(
val helper = new Helper(
options.common,
remainingArgs ++ options.isolated.rawIsolated.map { case (_, dep) => dep },
extraJars = options.extraJars.map(new File(_)),
isolated = options.isolated
)
@ -123,8 +125,19 @@ case class Launch(
else
options.mainClass
val extraJars = options.extraJars.filter(_.nonEmpty)
val loader =
if (extraJars.isEmpty)
helper.loader
else
new URLClassLoader(
extraJars.map(new File(_).toURI.toURL).toArray,
helper.loader
)
Launch.run(
helper.loader,
loader,
mainClass,
userArgs,
options.common.verbosityLevel

View File

@ -188,6 +188,9 @@ case class LaunchOptions(
@Short("M")
@Short("main")
mainClass: String,
@Short("J")
@Help("Extra JARs to be added to the classpath of the launched application. Directories accepted too.")
extraJars: List[String],
@Recurse
isolated: IsolatedLoaderOptions,
@Recurse
@ -226,6 +229,9 @@ case class SparkSubmitOptions(
@Short("main")
@Help("Main class to be launched (optional if in manifest)")
mainClass: String,
@Short("J")
@Help("Extra JARs to be added in the classpath of the job")
extraJars: List[String],
@Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)")
@Value("file")
yarnIdFile: String,

View File

@ -55,8 +55,23 @@ case class SparkSubmit(
options: SparkSubmitOptions
) extends App with ExtraArgsApp {
val helper = new Helper(options.common, remainingArgs)
val jars = helper.fetch(sources = false, javadoc = false)
val rawExtraJars = options.extraJars.map(new File(_))
val extraDirs = rawExtraJars.filter(_.isDirectory)
if (extraDirs.nonEmpty) {
Console.err.println(s"Error: directories not allowed in extra job JARs.")
Console.err.println(extraDirs.map(" " + _).mkString("\n"))
sys.exit(1)
}
val helper: Helper = new Helper(
options.common,
remainingArgs,
extraJars = rawExtraJars
)
val jars =
helper.fetch(sources = false, javadoc = false) ++
options.extraJars.map(new File(_))
val (scalaVersion, sparkVersion) =
if (options.sparkVersion.isEmpty)
@ -73,31 +88,15 @@ case class SparkSubmit(
(options.common.scalaVersion, options.sparkVersion)
val assemblyOrError =
if (options.sparkAssembly.isEmpty) {
// FIXME Also vaguely done in Helper and below
val (errors, modVers) = Parse.moduleVersionConfigs(
options.assemblyDependencies,
options.common.scalaVersion
if (options.sparkAssembly.isEmpty)
Assembly.spark(
scalaVersion,
sparkVersion,
options.noDefaultAssemblyDependencies,
options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty),
options.common
)
val deps = modVers.map {
case (module, version, configOpt) =>
Dependency(
module,
version,
attributes = Attributes(options.common.defaultArtifactType, ""),
configuration = configOpt.getOrElse(options.common.defaultConfiguration),
exclusions = helper.excludes
)
}
if (errors.isEmpty)
Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps)
else
Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}")
} else {
else {
val f = new File(options.sparkAssembly)
if (f.isFile)
Right((f, Nil))
@ -170,7 +169,7 @@ case class SparkSubmit(
scalaVersion,
sparkVersion,
options.noDefaultSubmitDependencies,
options.submitDependencies,
options.submitDependencies.flatMap(_.split(",")).filter(_.nonEmpty),
options.common
)
@ -321,4 +320,4 @@ object OutputHelper {
threads.foreach(_.start())
}
}
}
}

View File

@ -1,36 +1,213 @@
package coursier.cli.spark
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.zip.{ZipInputStream, ZipOutputStream}
import java.math.BigInteger
import java.nio.file.{Files, StandardCopyOption}
import java.security.MessageDigest
import java.util.jar.{Attributes, JarFile, JarOutputStream, Manifest}
import java.util.regex.Pattern
import java.util.zip.{ZipEntry, ZipInputStream, ZipOutputStream}
import coursier.Dependency
import coursier.Cache
import coursier.cli.{CommonOptions, Helper}
import coursier.cli.util.Zip
import scala.collection.mutable
import scalaz.\/-
object Assembly {
sealed abstract class Rule extends Product with Serializable
object Rule {
case class Exclude(path: String) extends Rule
case class Append(path: String) extends Rule
sealed abstract class PathRule extends Rule {
def path: String
}
case class Exclude(path: String) extends PathRule
case class Append(path: String) extends PathRule
case class ExcludePattern(path: Pattern) extends Rule
object ExcludePattern {
def apply(s: String): ExcludePattern =
ExcludePattern(Pattern.compile(s))
}
}
def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = {
val zos = new ZipOutputStream(new FileOutputStream(output))
val rulesMap = rules.collect { case r: Rule.PathRule => r.path -> r }.toMap
val excludePatterns = rules.collect { case Rule.ExcludePattern(p) => p }
for (jar <- jars) {
new ZipInputStream(new FileInputStream(jar))
val manifest = new Manifest
manifest.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
var fos: FileOutputStream = null
var zos: ZipOutputStream = null
try {
fos = new FileOutputStream(output)
zos = new JarOutputStream(fos, manifest)
val concatenedEntries = new mutable.HashMap[String, ::[(ZipEntry, Array[Byte])]]
var ignore = Set.empty[String]
for (jar <- jars) {
var fis: FileInputStream = null
var zis: ZipInputStream = null
try {
fis = new FileInputStream(jar)
zis = new ZipInputStream(fis)
for ((ent, content) <- Zip.zipEntries(zis))
rulesMap.get(ent.getName) match {
case Some(Rule.Exclude(_)) =>
// ignored
case Some(Rule.Append(path)) =>
concatenedEntries += path -> ::((ent, content), concatenedEntries.getOrElse(path, Nil))
case None =>
if (!excludePatterns.exists(_.matcher(ent.getName).matches()) && !ignore(ent.getName)) {
ent.setCompressedSize(-1L)
zos.putNextEntry(ent)
zos.write(content)
zos.closeEntry()
ignore += ent.getName
}
}
} finally {
if (zis != null)
zis.close()
if (fis != null)
fis.close()
}
}
for ((_, entries) <- concatenedEntries) {
val (ent, _) = entries.head
ent.setCompressedSize(-1L)
if (entries.tail.nonEmpty)
ent.setSize(entries.map(_._2.length).sum)
zos.putNextEntry(ent)
// for ((_, b) <- entries.reverse)
// zos.write(b)
zos.write(entries.reverse.toArray.flatMap(_._2))
zos.closeEntry()
}
} finally {
if (zos != null)
zos.close()
if (fos != null)
fos.close()
}
???
}
val assemblyRules = Seq[Rule](
Rule.Append("META-INF/services/org.apache.hadoop.fs.FileSystem"),
Rule.Append("reference.conf"),
Rule.Exclude("log4j.properties"),
Rule.Exclude(JarFile.MANIFEST_NAME),
Rule.ExcludePattern("META-INF/.*\\.[sS][fF]"),
Rule.ExcludePattern("META-INF/.*\\.[dD][sS][aA]"),
Rule.ExcludePattern("META-INF/.*\\.[rR][sS][aA]")
)
def sparkAssemblyDependencies(
scalaVersion: String,
sparkVersion: String
) = Seq(
s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion",
s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion"
)
def spark(
scalaVersion: String,
sparkVersion: String,
noDefault: Boolean,
extraDependencies: Seq[Dependency]
): Either[String, (File, Seq[File])] =
throw new Exception("Not implemented: automatic assembly generation")
extraDependencies: Seq[String],
options: CommonOptions
): Either[String, (File, Seq[File])] = {
val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion)
val helper = new Helper(options, extraDependencies ++ base)
val artifacts = helper.artifacts(sources = false, javadoc = false)
val jars = helper.fetch(sources = false, javadoc = false)
val checksums = artifacts.map { a =>
val f = a.checksumUrls.get("SHA-1") match {
case Some(url) =>
Cache.localFile(url, helper.cache, a.authentication.map(_.user))
case None =>
throw new Exception(s"SHA-1 file not found for ${a.url}")
}
val sumOpt = Cache.parseChecksum(
new String(Files.readAllBytes(f.toPath), "UTF-8")
)
sumOpt match {
case Some(sum) =>
val s = sum.toString(16)
"0" * (40 - s.length) + s
case None =>
throw new Exception(s"Cannot read SHA-1 sum from $f")
}
}
val md = MessageDigest.getInstance("SHA-1")
for (c <- checksums.sorted) {
val b = c.getBytes("UTF-8")
md.update(b, 0, b.length)
}
val digest = md.digest()
val calculatedSum = new BigInteger(1, digest)
val s = calculatedSum.toString(16)
val sum = "0" * (40 - s.length) + s
val destPath = Seq(
sys.props("user.home"),
".coursier",
"spark-assemblies",
s"scala_${scalaVersion}_spark_$sparkVersion",
sum,
"spark-assembly.jar"
).mkString("/")
val dest = new File(destPath)
def success = Right((dest, jars))
if (dest.exists())
success
else
Cache.withLockFor(helper.cache, dest) {
dest.getParentFile.mkdirs()
val tmpDest = new File(dest.getParentFile, s".${dest.getName}.part")
// FIXME Acquire lock on tmpDest
Assembly.make(jars, tmpDest, assemblyRules)
Files.move(tmpDest.toPath, dest.toPath, StandardCopyOption.ATOMIC_MOVE)
\/-((dest, jars))
}.leftMap(_.describe).toEither
}
}

View File

@ -0,0 +1,28 @@
package coursier.cli.util
import java.util.zip.{ZipEntry, ZipInputStream}
import coursier.Platform
object Zip {
def zipEntries(zipStream: ZipInputStream): Iterator[(ZipEntry, Array[Byte])] =
new Iterator[(ZipEntry, Array[Byte])] {
var nextEntry = Option.empty[ZipEntry]
def update() =
nextEntry = Option(zipStream.getNextEntry)
update()
def hasNext = nextEntry.nonEmpty
def next() = {
val ent = nextEntry.get
val data = Platform.readFullySync(zipStream)
update()
(ent, data)
}
}
}