Merge pull request #346 from alexarchambault/topic/standalone-spark-submit

Rework spark submit command
This commit is contained in:
Alexandre Archambault 2016-09-09 12:28:24 +02:00 committed by GitHub
commit 12f0cd0b82
8 changed files with 410 additions and 194 deletions

View File

@ -581,35 +581,36 @@ class Helper(
files0
}
lazy val (parentLoader, filteredFiles) = {
def contextLoader = Thread.currentThread().getContextClassLoader
val contextLoader = Thread.currentThread().getContextClassLoader
// TODO Would ClassLoader.getSystemClassLoader be better here?
val baseLoader: ClassLoader =
Launch.mainClassLoader(contextLoader)
.flatMap(cl => Option(cl.getParent))
.getOrElse {
// proguarded -> no risk of conflicts, no absolute need to find a specific ClassLoader
val isProguarded = Try(contextLoader.loadClass("coursier.cli.Launch")).isFailure
if (warnBaseLoaderNotFound && !isProguarded && common.verbosityLevel >= 0)
Console.err.println(
"Warning: cannot find the main ClassLoader that launched coursier.\n" +
"Was coursier launched by its main launcher? " +
"The ClassLoader of the application that is about to be launched will be intertwined " +
"with the one of coursier, which may be a problem if their dependencies conflict."
)
contextLoader
}
lazy val (parentLoader, filteredFiles) = {
val files0 = fetch(sources = false, javadoc = false)
val parentLoader0: ClassLoader =
Launch.mainClassLoader(contextLoader)
.flatMap(cl => Option(cl.getParent))
.getOrElse {
// proguarded -> no risk of conflicts, no absolute need to find a specific ClassLoader
val isProguarded = Try(contextLoader.loadClass("coursier.cli.Launch")).isFailure
if (warnBaseLoaderNotFound && !isProguarded && common.verbosityLevel >= 0)
Console.err.println(
"Warning: cannot find the main ClassLoader that launched coursier.\n" +
"Was coursier launched by its main launcher? " +
"The ClassLoader of the application that is about to be launched will be intertwined " +
"with the one of coursier, which may be a problem if their dependencies conflict."
)
contextLoader
}
if (isolated.isolated.isEmpty)
(parentLoader0, files0)
(baseLoader, files0)
else {
val isolatedDeps = isolated.isolatedDeps(common.defaultArtifactType, common.scalaVersion)
val (isolatedLoader, filteredFiles0) = isolated.targets.foldLeft((parentLoader0, files0)) {
val (isolatedLoader, filteredFiles0) = isolated.targets.foldLeft((baseLoader, files0)) {
case ((parent, files0), target) =>
// FIXME These were already fetched above

View File

@ -33,6 +33,43 @@ object Launch {
mainClassLoader(cl.getParent)
}
def run(
loader: ClassLoader,
mainClass: String,
args: Seq[String],
verbosity: Int,
beforeMain: => Unit = ()
): Unit = {
val cls =
try loader.loadClass(mainClass)
catch { case e: ClassNotFoundException =>
Helper.errPrintln(s"Error: class $mainClass not found")
sys.exit(255)
}
val method =
try cls.getMethod("main", classOf[Array[String]])
catch { case e: NoSuchMethodException =>
Helper.errPrintln(s"Error: method main not found in $mainClass")
sys.exit(255)
}
method.setAccessible(true)
if (verbosity >= 2)
Helper.errPrintln(s"Launching $mainClass ${args.mkString(" ")}")
else if (verbosity == 1)
Helper.errPrintln(s"Launching")
beforeMain
Thread.currentThread().setContextClassLoader(loader)
try method.invoke(null, args.toArray)
catch {
case e: java.lang.reflect.InvocationTargetException =>
throw Option(e.getCause).getOrElse(e)
}
}
}
class IsolatedClassLoader(
@ -86,29 +123,10 @@ case class Launch(
else
options.mainClass
val cls =
try helper.loader.loadClass(mainClass)
catch { case e: ClassNotFoundException =>
Helper.errPrintln(s"Error: class $mainClass not found")
sys.exit(255)
}
val method =
try cls.getMethod("main", classOf[Array[String]])
catch { case e: NoSuchMethodException =>
Helper.errPrintln(s"Error: method main not found in $mainClass")
sys.exit(255)
}
method.setAccessible(true)
if (options.common.verbosityLevel >= 2)
Helper.errPrintln(s"Launching $mainClass ${userArgs.mkString(" ")}")
else if (options.common.verbosityLevel == 1)
Helper.errPrintln(s"Launching")
Thread.currentThread().setContextClassLoader(helper.loader)
try method.invoke(null, userArgs.toArray)
catch {
case e: java.lang.reflect.InvocationTargetException =>
throw Option(e.getCause).getOrElse(e)
}
Launch.run(
helper.loader,
mainClass,
userArgs,
options.common.verbosityLevel
)
}

View File

@ -194,22 +194,6 @@ case class LaunchOptions(
common: CommonOptions
)
case class SparkSubmitOptions(
@Short("M")
@Short("main")
mainClass: String,
@Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)")
@Value("file")
yarnIdFile: String,
@Help("Spark home (default: SPARK_HOME from the environment)")
sparkHome: String,
@Help("Maximum idle time of spark-submit (time with no output). Exit early if no output from spark-submit for more than this duration. Set to 0 for unlimited. (Default: 0)")
@Value("seconds")
maxIdleTime: Int,
@Recurse
common: CommonOptions
)
case class BootstrapOptions(
@Short("M")
@Short("main")
@ -236,3 +220,25 @@ case class BootstrapOptions(
@Recurse
common: CommonOptions
)
case class SparkSubmitOptions(
@Short("M")
@Short("main")
@Help("Main class to be launched (optional if in manifest)")
mainClass: String,
@Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)")
@Value("file")
yarnIdFile: String,
@Help("Spark assembly. If empty, automatically generate (default: empty)")
sparkAssembly: String,
noDefaultAssemblyDependencies: Boolean,
assemblyDependencies: List[String],
noDefaultSubmitDependencies: Boolean,
submitDependencies: List[String],
sparkVersion: String,
@Help("Maximum idle time of spark-submit (time with no output). Exit early if no output from spark-submit for more than this duration. Set to 0 for unlimited. (Default: 0)")
@Value("seconds")
maxIdleTime: Int,
@Recurse
common: CommonOptions
)

View File

@ -1,12 +1,54 @@
package coursier.cli
import java.io.{ BufferedReader, File, InputStream, InputStreamReader, OutputStream }
import java.nio.file.{ Files, Paths }
import java.io.{PrintStream, BufferedReader, File, PipedInputStream, PipedOutputStream, InputStream, InputStreamReader}
import java.net.URLClassLoader
import java.nio.file.Files
import caseapp._
import coursier.{ Attributes, Dependency }
import coursier.cli.spark.{ Assembly, Submit }
import coursier.util.Parse
import scala.util.control.NonFatal
object SparkSubmit {
def scalaSparkVersions(dependencies: Iterable[Dependency]): Either[String, (String, String)] = {
val sparkCoreMods = dependencies.collect {
case dep if dep.module.organization == "org.apache.spark" &&
(dep.module.name == "spark-core_2.10" || dep.module.name == "spark-core_2.11") =>
(dep.module, dep.version)
}
if (sparkCoreMods.isEmpty)
Left("Cannot find spark among dependencies")
else if (sparkCoreMods.size == 1) {
val scalaVersion = sparkCoreMods.head._1.name match {
case "spark-core_2.10" => "2.10"
case "spark-core_2.11" => "2.11"
case _ => throw new Exception("Cannot happen")
}
val sparkVersion = sparkCoreMods.head._2
Right((scalaVersion, sparkVersion))
} else
Left(s"Found several spark code modules among dependencies (${sparkCoreMods.mkString(", ")})")
}
}
/**
* Submits spark applications.
*
* Can be run with no spark distributions around.
*
* @author Alexandre Archambault
* @author Han Ju
*/
@CommandName("spark-submit")
case class SparkSubmit(
@Recurse
@ -14,56 +56,73 @@ case class SparkSubmit(
) extends App with ExtraArgsApp {
val helper = new Helper(options.common, remainingArgs)
val jars = helper.fetch(sources = false, javadoc = false)
val sparkHome =
if (options.sparkHome.isEmpty)
sys.env.getOrElse(
"SPARK_HOME", {
Console.err.println("Error: SPARK_HOME not set and the --spark-home option not given a value.")
val (scalaVersion, sparkVersion) =
if (options.sparkVersion.isEmpty)
SparkSubmit.scalaSparkVersions(helper.res.dependencies) match {
case Left(err) =>
Console.err.println(
s"Cannot get spark / scala versions from dependencies: $err\n" +
"Set them via --scala-version or --spark-version"
)
sys.exit(1)
}
)
case Right(versions) => versions
}
else
options.sparkHome
(options.common.scalaVersion, options.sparkVersion)
val sparkAssembly = {
// TODO Make this more reliable (assemblies can be found in other directories I think, this
// must be fine with spark 2.10 too, ...)
val dir = new File(sparkHome + "/assembly/target/scala-2.11")
Option(dir.listFiles()).getOrElse(Array.empty).filter { f =>
f.isFile && f.getName.endsWith(".jar")
} match {
case Array(assemblyFile) =>
assemblyFile.getAbsolutePath
case Array() =>
throw new Exception(s"No spark assembly found under $dir")
case jars =>
throw new Exception(s"Found several JARs under $dir")
val assemblyOrError =
if (options.sparkAssembly.isEmpty) {
// FIXME Also vaguely done in Helper and below
val (errors, modVers) = Parse.moduleVersionConfigs(
options.assemblyDependencies,
options.common.scalaVersion
)
val deps = modVers.map {
case (module, version, configOpt) =>
Dependency(
module,
version,
attributes = Attributes(options.common.defaultArtifactType, ""),
configuration = configOpt.getOrElse(options.common.defaultConfiguration),
exclusions = helper.excludes
)
}
if (errors.isEmpty)
Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps)
else
Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}")
} else {
val f = new File(options.sparkAssembly)
if (f.isFile)
Right((f, Nil))
else if (f.exists())
Left(s"${options.sparkAssembly} is not a file")
else
Left(s"${options.sparkAssembly} not found")
}
val (assembly, assemblyJars) = assemblyOrError match {
case Left(err) =>
Console.err.println(s"Cannot get spark assembly: $err")
sys.exit(1)
case Right(res) => res
}
val libManaged = {
val dir = new File(sparkHome + "/lib_managed/jars")
if (dir.isDirectory) {
dir.listFiles().toSeq.map(_.getAbsolutePath)
} else
Nil
val idx = {
val idx0 = extraArgs.indexOf("--")
if (idx0 < 0)
extraArgs.length
else
idx0
}
val yarnConfOpt = sys.env.get("YARN_CONF_DIR").filter(_.nonEmpty)
for (yarnConf <- yarnConfOpt if !new File(yarnConf).isDirectory)
throw new Exception(s"Error: YARN conf path ($yarnConf) is not a directory or doesn't exist.")
val cp = Seq(
sparkHome + "/conf",
sparkAssembly
) ++ libManaged ++ yarnConfOpt.toSeq
val idx = extraArgs.indexOf("--")
assert(idx >= 0)
val sparkOpts = extraArgs.take(idx)
@ -83,13 +142,19 @@ case class SparkSubmit(
.getLocation
.getPath // TODO Safety check: protocol must be file
val (check, extraJars) = jars.partition(_.getAbsolutePath == mainJar)
val (check, extraJars0) = jars.partition(_.getAbsolutePath == mainJar)
val extraJars = extraJars0.filterNot(assemblyJars.toSet)
if (check.isEmpty)
Console.err.println(
s"Warning: cannot find back $mainJar among the dependencies JARs (likely a coursier bug)"
)
val extraSparkOpts = Seq(
"--conf", "spark.yarn.jar=" + assembly.getAbsolutePath
)
val extraJarsOptions =
if (extraJars.isEmpty)
Nil
@ -98,87 +163,52 @@ case class SparkSubmit(
val mainClassOptions = Seq("--class", mainClass)
val sparkSubmitOptions = sparkOpts ++ extraJarsOptions ++ mainClassOptions ++
val sparkSubmitOptions = sparkOpts ++ extraSparkOpts ++ extraJarsOptions ++ mainClassOptions ++
Seq(mainJar) ++ jobArgs
val cmd = Seq(
"java",
"-cp",
cp.mkString(File.pathSeparator),
"org.apache.spark.deploy.SparkSubmit"
) ++ sparkSubmitOptions
val submitCp = Submit.cp(
scalaVersion,
sparkVersion,
options.noDefaultSubmitDependencies,
options.submitDependencies,
options.common
)
object YarnAppId {
val Pattern = ".*Application report for ([^ ]+) .*".r
val submitLoader = new URLClassLoader(
submitCp.map(_.toURI.toURL).toArray,
helper.baseLoader
)
val fileOpt = Some(options.yarnIdFile).filter(_.nonEmpty)
Launch.run(
submitLoader,
Submit.mainClassName,
sparkSubmitOptions,
options.common.verbosityLevel,
{
if (options.common.verbosityLevel >= 1)
Console.err.println(
s"Launching spark-submit with arguments:\n" +
sparkSubmitOptions.map(" " + _).mkString("\n")
)
@volatile var written = false
val lock = new AnyRef
def handleMessage(s: String): Unit =
if (!written)
s match {
case Pattern(id) =>
lock.synchronized {
if (!written) {
println(s"Detected YARN app ID $id")
for (writeAppIdTo <- fileOpt) {
val path = Paths.get(writeAppIdTo)
Option(path.getParent).foreach(_.toFile.mkdirs())
Files.write(path, id.getBytes("UTF-8"))
}
written = true
}
}
case _ =>
}
}
object IdleChecker {
@volatile var lastMessageTs = -1L
def updateLastMessageTs() = {
lastMessageTs = System.currentTimeMillis()
OutputHelper.handleOutput(
Some(options.yarnIdFile).filter(_.nonEmpty).map(new File(_)),
Some(options.maxIdleTime).filter(_ > 0)
)
}
)
}
val checkThreadOpt =
if (options.maxIdleTime > 0) {
val checkThread = new Thread {
override def run() =
try {
while (true) {
lastMessageTs = -1L
Thread.sleep(options.maxIdleTime * 1000L)
if (lastMessageTs < 0) {
Console.err.println(s"No output from spark-submit for more than ${options.maxIdleTime} s, exiting")
sys.exit(1)
}
}
} catch {
case t: Throwable =>
Console.err.println(s"Caught $t in check spark-submit output thread!")
throw t
}
}
object OutputHelper {
checkThread.setName("check-spark-submit-output")
checkThread.setDaemon(true)
def outputInspectThread(
name: String,
from: InputStream,
to: PrintStream,
handlers: Seq[String => Unit]
) = {
Some(checkThread)
} else
None
}
Console.err.println(s"Running command:\n${cmd.map(" "+_).mkString("\n")}\n")
val process = new ProcessBuilder()
.command(cmd: _*)
.redirectErrorStream(true) // merges error stream into output stream
.start()
def pipeThread(from: InputStream, to: OutputStream) = {
val t = new Thread {
override def run() = {
val in = new BufferedReader(new InputStreamReader(from))
@ -187,35 +217,108 @@ case class SparkSubmit(
line = in.readLine()
line != null
}) {
if (options.maxIdleTime > 0)
IdleChecker.updateLastMessageTs()
to.write((line + "\n").getBytes("UTF-8"))
if (YarnAppId.fileOpt.nonEmpty)
try YarnAppId.handleMessage(line)
catch {
case NonFatal(_) =>
}
to.println(line)
handlers.foreach(_(line))
}
}
}
t.setName("pipe-output")
t.setName(name)
t.setDaemon(true)
t
}
val is = process.getInputStream
val isPipeThread = pipeThread(is, System.out)
def handleOutput(yarnAppFileOpt: Option[File], maxIdleTimeOpt: Option[Int]): Unit = {
IdleChecker.checkThreadOpt.foreach(_.start())
isPipeThread.start()
var handlers = Seq.empty[String => Unit]
var threads = Seq.empty[Thread]
val exitValue = process.waitFor()
for (yarnAppFile <- yarnAppFileOpt) {
sys.exit(exitValue)
val Pattern = ".*Application report for ([^ ]+) .*".r
}
@volatile var written = false
val lock = new AnyRef
def handleMessage(s: String): Unit =
if (!written)
s match {
case Pattern(id) =>
lock.synchronized {
if (!written) {
println(s"Detected YARN app ID $id")
val path = yarnAppFile.toPath
Option(path.getParent).foreach(_.toFile.mkdirs())
Files.write(path, id.getBytes("UTF-8"))
written = true
}
}
case _ =>
}
val f = { line: String =>
try handleMessage(line)
catch {
case NonFatal(_) =>
}
}
handlers = handlers :+ f
}
for (maxIdleTime <- maxIdleTimeOpt if maxIdleTime > 0) {
@volatile var lastMessageTs = -1L
def updateLastMessageTs() = {
lastMessageTs = System.currentTimeMillis()
}
val checkThread = new Thread {
override def run() =
try {
while (true) {
lastMessageTs = -1L
Thread.sleep(maxIdleTime * 1000L)
if (lastMessageTs < 0) {
Console.err.println(s"No output from spark-submit for more than $maxIdleTime s, exiting")
sys.exit(1)
}
}
} catch {
case t: Throwable =>
Console.err.println(s"Caught $t in check spark-submit output thread!")
throw t
}
}
checkThread.setName("check-spark-submit-output")
checkThread.setDaemon(true)
threads = threads :+ checkThread
val f = { line: String =>
updateLastMessageTs()
}
handlers = handlers :+ f
}
def createThread(name: String, replaces: PrintStream, install: PrintStream => Unit): Thread = {
val in = new PipedInputStream
val out = new PipedOutputStream(in)
install(new PrintStream(out))
outputInspectThread(name, in, replaces, handlers)
}
if (handlers.nonEmpty) {
threads = threads ++ Seq(
createThread("inspect-out", System.out, System.setOut),
createThread("inspect-err", System.err, System.setErr)
)
threads.foreach(_.start())
}
}
}

View File

@ -0,0 +1,36 @@
package coursier.cli.spark
import java.io.{File, FileInputStream, FileOutputStream}
import java.util.zip.{ZipInputStream, ZipOutputStream}
import coursier.Dependency
object Assembly {
sealed abstract class Rule extends Product with Serializable
object Rule {
case class Exclude(path: String) extends Rule
case class Append(path: String) extends Rule
}
def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = {
val zos = new ZipOutputStream(new FileOutputStream(output))
for (jar <- jars) {
new ZipInputStream(new FileInputStream(jar))
}
???
}
def spark(
scalaVersion: String,
sparkVersion: String,
noDefault: Boolean,
extraDependencies: Seq[Dependency]
): Either[String, (File, Seq[File])] =
throw new Exception("Not implemented: automatic assembly generation")
}

View File

@ -0,0 +1,51 @@
package coursier.cli.spark
import java.io.File
import coursier.cli.{ CommonOptions, Helper }
object Submit {
def cp(
scalaVersion: String,
sparkVersion: String,
noDefault: Boolean,
extraDependencies: Seq[String],
common: CommonOptions
): Seq[File] = {
var extraCp = Seq.empty[File]
for (yarnConf <- sys.env.get("YARN_CONF_DIR") if yarnConf.nonEmpty) {
val f = new File(yarnConf)
if (!f.isDirectory) {
Console.err.println(s"Error: YARN conf path ($yarnConf) is not a directory or doesn't exist.")
sys.exit(1)
}
extraCp = extraCp :+ f
}
def defaultDependencies = Seq(
// FIXME We whould be able to pass these as (parsed) Dependency instances to Helper
s"org.apache.spark::spark-core:$sparkVersion",
s"org.apache.spark::spark-yarn:$sparkVersion"
)
val helper = new Helper(
common.copy(
intransitive = Nil,
classifier = Nil,
scalaVersion = scalaVersion
),
// FIXME We whould be able to pass these as (parsed) Dependency instances to Helper
(if (noDefault) Nil else defaultDependencies) ++ extraDependencies
)
helper.fetch(sources = false, javadoc = false) ++ extraCp
}
def mainClassName = "org.apache.spark.deploy.SparkSubmit"
}

View File

@ -3,7 +3,7 @@ addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.11")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0")
addSbtPlugin("org.tpolecat" % "tut-plugin" % "0.4.2")
addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M13")
addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M14")
addSbtPlugin("com.typesafe.sbt" % "sbt-proguard" % "0.2.2")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9")
libraryDependencies += "org.scala-sbt" % "scripted-plugin" % sbtVersion.value

View File

@ -0,0 +1 @@
addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M14")