diff --git a/cli/src/main/scala-2.11/coursier/cli/Helper.scala b/cli/src/main/scala-2.11/coursier/cli/Helper.scala index 6f898c3e7..461b8eef8 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Helper.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Helper.scala @@ -581,35 +581,36 @@ class Helper( files0 } - lazy val (parentLoader, filteredFiles) = { + def contextLoader = Thread.currentThread().getContextClassLoader - val contextLoader = Thread.currentThread().getContextClassLoader + // TODO Would ClassLoader.getSystemClassLoader be better here? + val baseLoader: ClassLoader = + Launch.mainClassLoader(contextLoader) + .flatMap(cl => Option(cl.getParent)) + .getOrElse { + // proguarded -> no risk of conflicts, no absolute need to find a specific ClassLoader + val isProguarded = Try(contextLoader.loadClass("coursier.cli.Launch")).isFailure + if (warnBaseLoaderNotFound && !isProguarded && common.verbosityLevel >= 0) + Console.err.println( + "Warning: cannot find the main ClassLoader that launched coursier.\n" + + "Was coursier launched by its main launcher? " + + "The ClassLoader of the application that is about to be launched will be intertwined " + + "with the one of coursier, which may be a problem if their dependencies conflict." + ) + contextLoader + } + + lazy val (parentLoader, filteredFiles) = { val files0 = fetch(sources = false, javadoc = false) - val parentLoader0: ClassLoader = - Launch.mainClassLoader(contextLoader) - .flatMap(cl => Option(cl.getParent)) - .getOrElse { - // proguarded -> no risk of conflicts, no absolute need to find a specific ClassLoader - val isProguarded = Try(contextLoader.loadClass("coursier.cli.Launch")).isFailure - if (warnBaseLoaderNotFound && !isProguarded && common.verbosityLevel >= 0) - Console.err.println( - "Warning: cannot find the main ClassLoader that launched coursier.\n" + - "Was coursier launched by its main launcher? " + - "The ClassLoader of the application that is about to be launched will be intertwined " + - "with the one of coursier, which may be a problem if their dependencies conflict." - ) - contextLoader - } - if (isolated.isolated.isEmpty) - (parentLoader0, files0) + (baseLoader, files0) else { val isolatedDeps = isolated.isolatedDeps(common.defaultArtifactType, common.scalaVersion) - val (isolatedLoader, filteredFiles0) = isolated.targets.foldLeft((parentLoader0, files0)) { + val (isolatedLoader, filteredFiles0) = isolated.targets.foldLeft((baseLoader, files0)) { case ((parent, files0), target) => // FIXME These were already fetched above diff --git a/cli/src/main/scala-2.11/coursier/cli/Launch.scala b/cli/src/main/scala-2.11/coursier/cli/Launch.scala index b4731215a..b7cc971e4 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Launch.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Launch.scala @@ -33,6 +33,43 @@ object Launch { mainClassLoader(cl.getParent) } + def run( + loader: ClassLoader, + mainClass: String, + args: Seq[String], + verbosity: Int, + beforeMain: => Unit = () + ): Unit = { + + val cls = + try loader.loadClass(mainClass) + catch { case e: ClassNotFoundException => + Helper.errPrintln(s"Error: class $mainClass not found") + sys.exit(255) + } + val method = + try cls.getMethod("main", classOf[Array[String]]) + catch { case e: NoSuchMethodException => + Helper.errPrintln(s"Error: method main not found in $mainClass") + sys.exit(255) + } + method.setAccessible(true) + + if (verbosity >= 2) + Helper.errPrintln(s"Launching $mainClass ${args.mkString(" ")}") + else if (verbosity == 1) + Helper.errPrintln(s"Launching") + + beforeMain + + Thread.currentThread().setContextClassLoader(loader) + try method.invoke(null, args.toArray) + catch { + case e: java.lang.reflect.InvocationTargetException => + throw Option(e.getCause).getOrElse(e) + } + } + } class IsolatedClassLoader( @@ -86,29 +123,10 @@ case class Launch( else options.mainClass - val cls = - try helper.loader.loadClass(mainClass) - catch { case e: ClassNotFoundException => - Helper.errPrintln(s"Error: class $mainClass not found") - sys.exit(255) - } - val method = - try cls.getMethod("main", classOf[Array[String]]) - catch { case e: NoSuchMethodException => - Helper.errPrintln(s"Error: method main not found in $mainClass") - sys.exit(255) - } - method.setAccessible(true) - - if (options.common.verbosityLevel >= 2) - Helper.errPrintln(s"Launching $mainClass ${userArgs.mkString(" ")}") - else if (options.common.verbosityLevel == 1) - Helper.errPrintln(s"Launching") - - Thread.currentThread().setContextClassLoader(helper.loader) - try method.invoke(null, userArgs.toArray) - catch { - case e: java.lang.reflect.InvocationTargetException => - throw Option(e.getCause).getOrElse(e) - } + Launch.run( + helper.loader, + mainClass, + userArgs, + options.common.verbosityLevel + ) } \ No newline at end of file diff --git a/cli/src/main/scala-2.11/coursier/cli/Options.scala b/cli/src/main/scala-2.11/coursier/cli/Options.scala index 36a539856..7359826f3 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Options.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Options.scala @@ -194,22 +194,6 @@ case class LaunchOptions( common: CommonOptions ) -case class SparkSubmitOptions( - @Short("M") - @Short("main") - mainClass: String, - @Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)") - @Value("file") - yarnIdFile: String, - @Help("Spark home (default: SPARK_HOME from the environment)") - sparkHome: String, - @Help("Maximum idle time of spark-submit (time with no output). Exit early if no output from spark-submit for more than this duration. Set to 0 for unlimited. (Default: 0)") - @Value("seconds") - maxIdleTime: Int, - @Recurse - common: CommonOptions -) - case class BootstrapOptions( @Short("M") @Short("main") @@ -236,3 +220,25 @@ case class BootstrapOptions( @Recurse common: CommonOptions ) + +case class SparkSubmitOptions( + @Short("M") + @Short("main") + @Help("Main class to be launched (optional if in manifest)") + mainClass: String, + @Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)") + @Value("file") + yarnIdFile: String, + @Help("Spark assembly. If empty, automatically generate (default: empty)") + sparkAssembly: String, + noDefaultAssemblyDependencies: Boolean, + assemblyDependencies: List[String], + noDefaultSubmitDependencies: Boolean, + submitDependencies: List[String], + sparkVersion: String, + @Help("Maximum idle time of spark-submit (time with no output). Exit early if no output from spark-submit for more than this duration. Set to 0 for unlimited. (Default: 0)") + @Value("seconds") + maxIdleTime: Int, + @Recurse + common: CommonOptions +) \ No newline at end of file diff --git a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala index 0e02eb8c0..3d6a79584 100644 --- a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala +++ b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala @@ -1,12 +1,54 @@ package coursier.cli -import java.io.{ BufferedReader, File, InputStream, InputStreamReader, OutputStream } -import java.nio.file.{ Files, Paths } +import java.io.{PrintStream, BufferedReader, File, PipedInputStream, PipedOutputStream, InputStream, InputStreamReader} +import java.net.URLClassLoader +import java.nio.file.Files import caseapp._ +import coursier.{ Attributes, Dependency } +import coursier.cli.spark.{ Assembly, Submit } +import coursier.util.Parse + import scala.util.control.NonFatal +object SparkSubmit { + + def scalaSparkVersions(dependencies: Iterable[Dependency]): Either[String, (String, String)] = { + + val sparkCoreMods = dependencies.collect { + case dep if dep.module.organization == "org.apache.spark" && + (dep.module.name == "spark-core_2.10" || dep.module.name == "spark-core_2.11") => + (dep.module, dep.version) + } + + if (sparkCoreMods.isEmpty) + Left("Cannot find spark among dependencies") + else if (sparkCoreMods.size == 1) { + val scalaVersion = sparkCoreMods.head._1.name match { + case "spark-core_2.10" => "2.10" + case "spark-core_2.11" => "2.11" + case _ => throw new Exception("Cannot happen") + } + + val sparkVersion = sparkCoreMods.head._2 + + Right((scalaVersion, sparkVersion)) + } else + Left(s"Found several spark code modules among dependencies (${sparkCoreMods.mkString(", ")})") + + } + +} + +/** + * Submits spark applications. + * + * Can be run with no spark distributions around. + * + * @author Alexandre Archambault + * @author Han Ju + */ @CommandName("spark-submit") case class SparkSubmit( @Recurse @@ -14,56 +56,73 @@ case class SparkSubmit( ) extends App with ExtraArgsApp { val helper = new Helper(options.common, remainingArgs) - val jars = helper.fetch(sources = false, javadoc = false) - - val sparkHome = - if (options.sparkHome.isEmpty) - sys.env.getOrElse( - "SPARK_HOME", { - Console.err.println("Error: SPARK_HOME not set and the --spark-home option not given a value.") + val (scalaVersion, sparkVersion) = + if (options.sparkVersion.isEmpty) + SparkSubmit.scalaSparkVersions(helper.res.dependencies) match { + case Left(err) => + Console.err.println( + s"Cannot get spark / scala versions from dependencies: $err\n" + + "Set them via --scala-version or --spark-version" + ) sys.exit(1) - } - ) + case Right(versions) => versions + } else - options.sparkHome + (options.common.scalaVersion, options.sparkVersion) - val sparkAssembly = { - // TODO Make this more reliable (assemblies can be found in other directories I think, this - // must be fine with spark 2.10 too, ...) - val dir = new File(sparkHome + "/assembly/target/scala-2.11") - Option(dir.listFiles()).getOrElse(Array.empty).filter { f => - f.isFile && f.getName.endsWith(".jar") - } match { - case Array(assemblyFile) => - assemblyFile.getAbsolutePath - case Array() => - throw new Exception(s"No spark assembly found under $dir") - case jars => - throw new Exception(s"Found several JARs under $dir") + val assemblyOrError = + if (options.sparkAssembly.isEmpty) { + + // FIXME Also vaguely done in Helper and below + + val (errors, modVers) = Parse.moduleVersionConfigs( + options.assemblyDependencies, + options.common.scalaVersion + ) + + val deps = modVers.map { + case (module, version, configOpt) => + Dependency( + module, + version, + attributes = Attributes(options.common.defaultArtifactType, ""), + configuration = configOpt.getOrElse(options.common.defaultConfiguration), + exclusions = helper.excludes + ) + } + + if (errors.isEmpty) + Assembly.spark(scalaVersion, sparkVersion, options.noDefaultAssemblyDependencies, deps) + else + Left(s"Cannot parse assembly dependencies:\n${errors.map(" " + _).mkString("\n")}") + } else { + val f = new File(options.sparkAssembly) + if (f.isFile) + Right((f, Nil)) + else if (f.exists()) + Left(s"${options.sparkAssembly} is not a file") + else + Left(s"${options.sparkAssembly} not found") } + + val (assembly, assemblyJars) = assemblyOrError match { + case Left(err) => + Console.err.println(s"Cannot get spark assembly: $err") + sys.exit(1) + case Right(res) => res } - val libManaged = { - val dir = new File(sparkHome + "/lib_managed/jars") - if (dir.isDirectory) { - dir.listFiles().toSeq.map(_.getAbsolutePath) - } else - Nil + + val idx = { + val idx0 = extraArgs.indexOf("--") + if (idx0 < 0) + extraArgs.length + else + idx0 } - val yarnConfOpt = sys.env.get("YARN_CONF_DIR").filter(_.nonEmpty) - - for (yarnConf <- yarnConfOpt if !new File(yarnConf).isDirectory) - throw new Exception(s"Error: YARN conf path ($yarnConf) is not a directory or doesn't exist.") - - val cp = Seq( - sparkHome + "/conf", - sparkAssembly - ) ++ libManaged ++ yarnConfOpt.toSeq - - val idx = extraArgs.indexOf("--") assert(idx >= 0) val sparkOpts = extraArgs.take(idx) @@ -83,13 +142,19 @@ case class SparkSubmit( .getLocation .getPath // TODO Safety check: protocol must be file - val (check, extraJars) = jars.partition(_.getAbsolutePath == mainJar) + val (check, extraJars0) = jars.partition(_.getAbsolutePath == mainJar) + + val extraJars = extraJars0.filterNot(assemblyJars.toSet) if (check.isEmpty) Console.err.println( s"Warning: cannot find back $mainJar among the dependencies JARs (likely a coursier bug)" ) + val extraSparkOpts = Seq( + "--conf", "spark.yarn.jar=" + assembly.getAbsolutePath + ) + val extraJarsOptions = if (extraJars.isEmpty) Nil @@ -98,87 +163,52 @@ case class SparkSubmit( val mainClassOptions = Seq("--class", mainClass) - val sparkSubmitOptions = sparkOpts ++ extraJarsOptions ++ mainClassOptions ++ + val sparkSubmitOptions = sparkOpts ++ extraSparkOpts ++ extraJarsOptions ++ mainClassOptions ++ Seq(mainJar) ++ jobArgs - val cmd = Seq( - "java", - "-cp", - cp.mkString(File.pathSeparator), - "org.apache.spark.deploy.SparkSubmit" - ) ++ sparkSubmitOptions + val submitCp = Submit.cp( + scalaVersion, + sparkVersion, + options.noDefaultSubmitDependencies, + options.submitDependencies, + options.common + ) - object YarnAppId { - val Pattern = ".*Application report for ([^ ]+) .*".r + val submitLoader = new URLClassLoader( + submitCp.map(_.toURI.toURL).toArray, + helper.baseLoader + ) - val fileOpt = Some(options.yarnIdFile).filter(_.nonEmpty) + Launch.run( + submitLoader, + Submit.mainClassName, + sparkSubmitOptions, + options.common.verbosityLevel, + { + if (options.common.verbosityLevel >= 1) + Console.err.println( + s"Launching spark-submit with arguments:\n" + + sparkSubmitOptions.map(" " + _).mkString("\n") + ) - @volatile var written = false - val lock = new AnyRef - def handleMessage(s: String): Unit = - if (!written) - s match { - case Pattern(id) => - lock.synchronized { - if (!written) { - println(s"Detected YARN app ID $id") - for (writeAppIdTo <- fileOpt) { - val path = Paths.get(writeAppIdTo) - Option(path.getParent).foreach(_.toFile.mkdirs()) - Files.write(path, id.getBytes("UTF-8")) - } - written = true - } - } - case _ => - } - } - - object IdleChecker { - - @volatile var lastMessageTs = -1L - - def updateLastMessageTs() = { - lastMessageTs = System.currentTimeMillis() + OutputHelper.handleOutput( + Some(options.yarnIdFile).filter(_.nonEmpty).map(new File(_)), + Some(options.maxIdleTime).filter(_ > 0) + ) } + ) +} - val checkThreadOpt = - if (options.maxIdleTime > 0) { - val checkThread = new Thread { - override def run() = - try { - while (true) { - lastMessageTs = -1L - Thread.sleep(options.maxIdleTime * 1000L) - if (lastMessageTs < 0) { - Console.err.println(s"No output from spark-submit for more than ${options.maxIdleTime} s, exiting") - sys.exit(1) - } - } - } catch { - case t: Throwable => - Console.err.println(s"Caught $t in check spark-submit output thread!") - throw t - } - } +object OutputHelper { - checkThread.setName("check-spark-submit-output") - checkThread.setDaemon(true) + def outputInspectThread( + name: String, + from: InputStream, + to: PrintStream, + handlers: Seq[String => Unit] + ) = { - Some(checkThread) - } else - None - } - - Console.err.println(s"Running command:\n${cmd.map(" "+_).mkString("\n")}\n") - - val process = new ProcessBuilder() - .command(cmd: _*) - .redirectErrorStream(true) // merges error stream into output stream - .start() - - def pipeThread(from: InputStream, to: OutputStream) = { val t = new Thread { override def run() = { val in = new BufferedReader(new InputStreamReader(from)) @@ -187,35 +217,108 @@ case class SparkSubmit( line = in.readLine() line != null }) { - if (options.maxIdleTime > 0) - IdleChecker.updateLastMessageTs() - - to.write((line + "\n").getBytes("UTF-8")) - - if (YarnAppId.fileOpt.nonEmpty) - try YarnAppId.handleMessage(line) - catch { - case NonFatal(_) => - } + to.println(line) + handlers.foreach(_(line)) } } } - t.setName("pipe-output") + t.setName(name) t.setDaemon(true) t } - val is = process.getInputStream - val isPipeThread = pipeThread(is, System.out) + def handleOutput(yarnAppFileOpt: Option[File], maxIdleTimeOpt: Option[Int]): Unit = { - IdleChecker.checkThreadOpt.foreach(_.start()) - isPipeThread.start() + var handlers = Seq.empty[String => Unit] + var threads = Seq.empty[Thread] - val exitValue = process.waitFor() + for (yarnAppFile <- yarnAppFileOpt) { - sys.exit(exitValue) + val Pattern = ".*Application report for ([^ ]+) .*".r -} + @volatile var written = false + val lock = new AnyRef + def handleMessage(s: String): Unit = + if (!written) + s match { + case Pattern(id) => + lock.synchronized { + if (!written) { + println(s"Detected YARN app ID $id") + val path = yarnAppFile.toPath + Option(path.getParent).foreach(_.toFile.mkdirs()) + Files.write(path, id.getBytes("UTF-8")) + written = true + } + } + case _ => + } + + val f = { line: String => + try handleMessage(line) + catch { + case NonFatal(_) => + } + } + + handlers = handlers :+ f + } + + for (maxIdleTime <- maxIdleTimeOpt if maxIdleTime > 0) { + + @volatile var lastMessageTs = -1L + + def updateLastMessageTs() = { + lastMessageTs = System.currentTimeMillis() + } + + val checkThread = new Thread { + override def run() = + try { + while (true) { + lastMessageTs = -1L + Thread.sleep(maxIdleTime * 1000L) + if (lastMessageTs < 0) { + Console.err.println(s"No output from spark-submit for more than $maxIdleTime s, exiting") + sys.exit(1) + } + } + } catch { + case t: Throwable => + Console.err.println(s"Caught $t in check spark-submit output thread!") + throw t + } + } + + checkThread.setName("check-spark-submit-output") + checkThread.setDaemon(true) + + threads = threads :+ checkThread + + val f = { line: String => + updateLastMessageTs() + } + + handlers = handlers :+ f + } + + def createThread(name: String, replaces: PrintStream, install: PrintStream => Unit): Thread = { + val in = new PipedInputStream + val out = new PipedOutputStream(in) + install(new PrintStream(out)) + outputInspectThread(name, in, replaces, handlers) + } + + if (handlers.nonEmpty) { + threads = threads ++ Seq( + createThread("inspect-out", System.out, System.setOut), + createThread("inspect-err", System.err, System.setErr) + ) + + threads.foreach(_.start()) + } + } +} \ No newline at end of file diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala new file mode 100644 index 000000000..18d6cafd9 --- /dev/null +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -0,0 +1,36 @@ +package coursier.cli.spark + +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.zip.{ZipInputStream, ZipOutputStream} + +import coursier.Dependency + +object Assembly { + + sealed abstract class Rule extends Product with Serializable + + object Rule { + case class Exclude(path: String) extends Rule + case class Append(path: String) extends Rule + } + + def make(jars: Seq[File], output: File, rules: Seq[Rule]): Unit = { + + val zos = new ZipOutputStream(new FileOutputStream(output)) + + for (jar <- jars) { + new ZipInputStream(new FileInputStream(jar)) + } + + ??? + } + + def spark( + scalaVersion: String, + sparkVersion: String, + noDefault: Boolean, + extraDependencies: Seq[Dependency] + ): Either[String, (File, Seq[File])] = + throw new Exception("Not implemented: automatic assembly generation") + +} diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Submit.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Submit.scala new file mode 100644 index 000000000..887dec1af --- /dev/null +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Submit.scala @@ -0,0 +1,51 @@ +package coursier.cli.spark + +import java.io.File + +import coursier.cli.{ CommonOptions, Helper } + +object Submit { + + def cp( + scalaVersion: String, + sparkVersion: String, + noDefault: Boolean, + extraDependencies: Seq[String], + common: CommonOptions + ): Seq[File] = { + + var extraCp = Seq.empty[File] + + for (yarnConf <- sys.env.get("YARN_CONF_DIR") if yarnConf.nonEmpty) { + val f = new File(yarnConf) + + if (!f.isDirectory) { + Console.err.println(s"Error: YARN conf path ($yarnConf) is not a directory or doesn't exist.") + sys.exit(1) + } + + extraCp = extraCp :+ f + } + + def defaultDependencies = Seq( + // FIXME We whould be able to pass these as (parsed) Dependency instances to Helper + s"org.apache.spark::spark-core:$sparkVersion", + s"org.apache.spark::spark-yarn:$sparkVersion" + ) + + val helper = new Helper( + common.copy( + intransitive = Nil, + classifier = Nil, + scalaVersion = scalaVersion + ), + // FIXME We whould be able to pass these as (parsed) Dependency instances to Helper + (if (noDefault) Nil else defaultDependencies) ++ extraDependencies + ) + + helper.fetch(sources = false, javadoc = false) ++ extraCp + } + + def mainClassName = "org.apache.spark.deploy.SparkSubmit" + +} diff --git a/project/plugins.sbt b/project/plugins.sbt index e22b0abe5..b5d5bb769 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,7 +3,7 @@ addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.11") addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0") addSbtPlugin("org.tpolecat" % "tut-plugin" % "0.4.2") -addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M13") +addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M14") addSbtPlugin("com.typesafe.sbt" % "sbt-proguard" % "0.2.2") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9") libraryDependencies += "org.scala-sbt" % "scripted-plugin" % sbtVersion.value diff --git a/project/project/plugins.sbt b/project/project/plugins.sbt new file mode 100644 index 000000000..2617da214 --- /dev/null +++ b/project/project/plugins.sbt @@ -0,0 +1 @@ +addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-M14")