From 5f8784b8008f0487bb4fc570181f227c30853cbf Mon Sep 17 00:00:00 2001 From: Alexandre Archambault Date: Fri, 18 Nov 2016 11:17:24 +0100 Subject: [PATCH] Add support for Spark 2 in spark-submit command --- .../scala-2.11/coursier/cli/Options.scala | 12 ++- .../scala-2.11/coursier/cli/SparkSubmit.scala | 68 +++++++++++------ .../coursier/cli/spark/Assembly.scala | 75 +++++++++++++++---- 3 files changed, 113 insertions(+), 42 deletions(-) diff --git a/cli/src/main/scala-2.11/coursier/cli/Options.scala b/cli/src/main/scala-2.11/coursier/cli/Options.scala index 9482d9f8f..72b4100bb 100644 --- a/cli/src/main/scala-2.11/coursier/cli/Options.scala +++ b/cli/src/main/scala-2.11/coursier/cli/Options.scala @@ -264,13 +264,17 @@ case class SparkSubmitOptions( @Help("If master is yarn-cluster, write YARN app ID to a file. (The ID is deduced from the spark-submit output.)") @Value("file") yarnIdFile: String = "", - @Help("Spark assembly. If empty, automatically generate (default: empty)") - sparkAssembly: String = "", - noDefaultAssemblyDependencies: Boolean = false, + @Help("Generate Spark Yarn assembly (Spark 1.x) or fetch Spark Yarn jars (Spark 2.x), and supply those to Spark via conf. (Default: true)") + autoAssembly: Boolean = true, + @Help("Include default dependencies in Spark Yarn assembly or jars (see --auto-assembly). If --auto-assembly is false, the corresponding dependencies will still be shunted from the job classpath if this option is true. (Default: same as --auto-assembly)") + defaultAssemblyDependencies: Option[Boolean] = None, assemblyDependencies: List[String] = Nil, noDefaultSubmitDependencies: Boolean = false, submitDependencies: List[String] = Nil, - sparkVersion: String = "", + @Help("Spark version - if empty, deduced from the job classpath. (Default: empty)") + sparkVersion: String = "", + @Help("YARN version - only used with Spark 2. (Default: 2.7.3)") + yarnVersion: String = "2.7.3", @Help("Maximum idle time of spark-submit (time with no output). Exit early if no output from spark-submit for more than this duration. Set to 0 for unlimited. (Default: 0)") @Value("seconds") maxIdleTime: Int = 0, diff --git a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala index 48cb3376b..9bb408c21 100644 --- a/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala +++ b/cli/src/main/scala-2.11/coursier/cli/SparkSubmit.scala @@ -90,31 +90,51 @@ case class SparkSubmit( else (options.common.scalaVersion, options.sparkVersion) - val assemblyOrError = - if (options.sparkAssembly.isEmpty) - Assembly.spark( + val (sparkYarnExtraConf, sparkBaseJars) = + if (!options.autoAssembly || sparkVersion.startsWith("2.")) { + + val assemblyJars = Assembly.sparkJars( scalaVersion, sparkVersion, - options.noDefaultAssemblyDependencies, + options.yarnVersion, + options.defaultAssemblyDependencies.getOrElse(options.autoAssembly), options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty), options.common ) - else { - val f = new File(options.sparkAssembly) - if (f.isFile) - Right((f, Nil)) - else if (f.exists()) - Left(s"${options.sparkAssembly} is not a file") - else - Left(s"${options.sparkAssembly} not found") - } - val (assembly, assemblyJars) = assemblyOrError match { - case Left(err) => - Console.err.println(s"Cannot get spark assembly: $err") - sys.exit(1) - case Right(res) => res - } + val extraConf = + if (options.autoAssembly && sparkVersion.startsWith("2.")) + Seq( + "spark.yarn.jars" -> assemblyJars.map(_.getAbsolutePath).mkString(",") + ) + else + Nil + + (extraConf, assemblyJars) + } else { + + val assemblyAndJarsOrError = Assembly.spark( + scalaVersion, + sparkVersion, + options.yarnVersion, + options.defaultAssemblyDependencies.getOrElse(true), + options.assemblyDependencies.flatMap(_.split(",")).filter(_.nonEmpty), + options.common + ) + + val (assembly, assemblyJars) = assemblyAndJarsOrError match { + case Left(err) => + Console.err.println(s"Cannot get spark assembly: $err") + sys.exit(1) + case Right(res) => res + } + + val extraConf = Seq( + "spark.yarn.jar" -> assembly.getAbsolutePath + ) + + (extraConf, assemblyJars) + } val idx = { @@ -146,16 +166,18 @@ case class SparkSubmit( val (check, extraJars0) = jars.partition(_.getAbsolutePath == mainJar) - val extraJars = extraJars0.filterNot(assemblyJars.toSet) + val extraJars = extraJars0.filterNot(sparkBaseJars.toSet) if (check.isEmpty) Console.err.println( s"Warning: cannot find back $mainJar among the dependencies JARs (likely a coursier bug)" ) - val extraSparkOpts = Seq( - "--conf", "spark.yarn.jar=" + assembly.getAbsolutePath - ) + val extraSparkOpts = sparkYarnExtraConf.flatMap { + case (k, v) => Seq( + "--conf", s"$k=$v" + ) + } val extraJarsOptions = if (extraJars.isEmpty) diff --git a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala index 4e130ec31..9f6131d07 100644 --- a/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala +++ b/cli/src/main/scala-2.11/coursier/cli/spark/Assembly.scala @@ -140,32 +140,77 @@ object Assembly { Rule.ExcludePattern("META-INF/.*\\.[rR][sS][aA]") ) - def sparkAssemblyDependencies( + def sparkBaseDependencies( scalaVersion: String, - sparkVersion: String - ) = Seq( - s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", - s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" - ) + sparkVersion: String, + yarnVersion: String + ) = + if (sparkVersion.startsWith("2.")) + Seq( + s"org.apache.spark::spark-hive-thriftserver:$sparkVersion", + s"org.apache.spark::spark-repl:$sparkVersion", + s"org.apache.spark::spark-hive:$sparkVersion", + s"org.apache.spark::spark-graphx:$sparkVersion", + s"org.apache.spark::spark-mllib:$sparkVersion", + s"org.apache.spark::spark-streaming:$sparkVersion", + s"org.apache.spark::spark-yarn:$sparkVersion", + s"org.apache.spark::spark-sql:$sparkVersion", + s"org.apache.hadoop:hadoop-client:$yarnVersion", + s"org.apache.hadoop:hadoop-yarn-server-web-proxy:$yarnVersion", + s"org.apache.hadoop:hadoop-yarn-server-nodemanager:$yarnVersion" + ) + else + Seq( + s"org.apache.spark:spark-core_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-bagel_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-mllib_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-streaming_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-graphx_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-sql_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-repl_$scalaVersion:$sparkVersion", + s"org.apache.spark:spark-yarn_$scalaVersion:$sparkVersion" + ) + + def sparkJarsHelper( + scalaVersion: String, + sparkVersion: String, + yarnVersion: String, + default: Boolean, + extraDependencies: Seq[String], + options: CommonOptions + ): Helper = { + + val base = if (default) sparkBaseDependencies(scalaVersion, sparkVersion, yarnVersion) else Seq() + new Helper(options, extraDependencies ++ base) + } + + def sparkJars( + scalaVersion: String, + sparkVersion: String, + yarnVersion: String, + default: Boolean, + extraDependencies: Seq[String], + options: CommonOptions, + artifactTypes: Set[String] = Set("jar") + ): Seq[File] = { + + val helper = sparkJarsHelper(scalaVersion, sparkVersion, yarnVersion, default, extraDependencies, options) + + helper.fetch(sources = false, javadoc = false, artifactTypes = artifactTypes) + } def spark( scalaVersion: String, sparkVersion: String, - noDefault: Boolean, + yarnVersion: String, + default: Boolean, extraDependencies: Seq[String], options: CommonOptions, artifactTypes: Set[String] = Set("jar"), checksumSeed: Array[Byte] = "v1".getBytes("UTF-8") ): Either[String, (File, Seq[File])] = { - val base = if (noDefault) Seq() else sparkAssemblyDependencies(scalaVersion, sparkVersion) - val helper = new Helper(options, extraDependencies ++ base) + val helper = sparkJarsHelper(scalaVersion, sparkVersion, yarnVersion, default, extraDependencies, options) val artifacts = helper.artifacts(sources = false, javadoc = false, artifactTypes = artifactTypes) val jars = helper.fetch(sources = false, javadoc = false, artifactTypes = artifactTypes)