diff --git a/.travis.yml b/.travis.yml index a4a22b8..f5184df 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,4 @@ language: scala -scala: 2.11.12 jdk: oraclejdk8 script: ./.travis.sh sudo: required @@ -19,7 +18,14 @@ stages: jobs: include: - env: MASTER=local + scala: 2.11.12 + - env: MASTER=local + scala: 2.12.7 - env: MASTER=standalone STANDALONE_CACHE=$HOME/standalone-stuff + scala: 2.11.12 + - env: MASTER=yarn YARN_CACHE=$HOME/yarn-stuff + scala: 2.11.12 - env: MASTER=yarn YARN_CACHE=$HOME/yarn-stuff + scala: 2.12.7 - stage: release script: sbt ci-release diff --git a/build.sbt b/build.sbt index d64f6d4..df48640 100644 --- a/build.sbt +++ b/build.sbt @@ -15,23 +15,43 @@ inThisBuild(List( ) )) -lazy val `spark-stubs` = project +lazy val `spark-stubs_20` = project .underModules .settings( shared, - libraryDependencies += Deps.sparkSql % "provided" + baseDirectory := { + val baseDir = baseDirectory.value + + if (Settings.isAtLeast212.value) + baseDir / "target" / "dummy" + else + baseDir + }, + libraryDependencies ++= { + if (Settings.isAtLeast212.value) + Nil + else + Seq(Deps.sparkSql20 % "provided") + }, + publishArtifact := !Settings.isAtLeast212.value + ) + +lazy val `spark-stubs_24` = project + .underModules + .settings( + shared, + libraryDependencies += Deps.sparkSql24 % "provided" ) lazy val core = project .in(file("modules/core")) - .dependsOn(`spark-stubs`) .settings( shared, name := "ammonite-spark", generatePropertyFile("org/apache/spark/sql/ammonitesparkinternals/ammonite-spark.properties"), libraryDependencies ++= Seq( Deps.ammoniteRepl % "provided", - Deps.sparkSql % "provided", + Deps.sparkSql.value % "provided", Deps.jettyServer ) ) @@ -72,7 +92,8 @@ lazy val `ammonite-spark` = project .in(file(".")) .aggregate( core, - `spark-stubs`, + `spark-stubs_20`, + `spark-stubs_24`, tests ) .settings( diff --git a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala index 9d870f0..298b624 100644 --- a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala +++ b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala @@ -89,16 +89,25 @@ class AmmoniteSparkSessionBuilder replApi: ReplAPI ) extends SparkSession.Builder { - private val options0: scala.collection.Map[String, String] = - try { - val f = classOf[SparkSession.Builder].getDeclaredField("org$apache$spark$sql$SparkSession$Builder$$options") - f.setAccessible(true) - f.get(this).asInstanceOf[scala.collection.mutable.HashMap[String, String]] - } catch { - case t: Throwable => - println(s"Warning: can't read SparkSession Builder options, caught $t") + private val options0: scala.collection.Map[String, String] = { + + def fieldVia(name: String): Option[scala.collection.mutable.HashMap[String, String]] = + try { + val f = classOf[SparkSession.Builder].getDeclaredField(name) + f.setAccessible(true) + Some(f.get(this).asInstanceOf[scala.collection.mutable.HashMap[String, String]]) + } catch { + case _: NoSuchFieldException => + None + } + + fieldVia("org$apache$spark$sql$SparkSession$Builder$$options") + .orElse(fieldVia("options")) + .getOrElse { + println("Warning: can't read SparkSession Builder options (options field not found)") Map.empty[String, String] - } + } + } private def init(): Unit = { @@ -158,17 +167,28 @@ class AmmoniteSparkSessionBuilder private def bindAddress(): String = options0.getOrElse("spark.driver.bindAddress", host()) - override def getOrCreate(): SparkSession = { + private def loadExtraDependencies(): Unit = { - if (isYarn() && !SparkDependencies.sparkYarnFound()) { - println("Loading spark-yarn") - interpApi.load.ivy(SparkDependencies.sparkYarnDependency) - } + var deps = List.empty[(String, coursier.Dependency)] + + if (hiveSupport() && !SparkDependencies.sparkHiveFound()) + deps = ("spark-hive", SparkDependencies.sparkHiveDependency) :: deps + + if (!SparkDependencies.sparkExecutorClassLoaderFound()) + deps = ("spark-stubs", SparkDependencies.stubsDependency) :: deps + + if (isYarn() && !SparkDependencies.sparkYarnFound()) + deps = ("spark-yarn", SparkDependencies.sparkYarnDependency) :: deps - if (hiveSupport() && !SparkDependencies.sparkHiveFound()) { - println("Loading spark-hive") - interpApi.load.ivy(SparkDependencies.sparkHiveDependency) + if (deps.nonEmpty) { + println(s"Loading ${deps.map(_._1).mkString(", ")}") + interpApi.load.ivy(deps.map(_._2): _*) } + } + + override def getOrCreate(): SparkSession = { + + loadExtraDependencies() val sessionJars = replApi diff --git a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala index f032d9e..aba81f6 100644 --- a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala +++ b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala @@ -7,6 +7,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.concurrent.ExecutionContext import scala.util.Properties.{versionNumberString => scalaVersion} +import scala.util.Try object SparkDependencies { @@ -24,6 +25,7 @@ object SparkDependencies { ) private def sparkYarnClass = "org.apache.spark.deploy.yarn.Client" + private def sparkExecutorClassLoaderClass = "org.apache.spark.repl.ExecutorClassLoader" def sparkHiveFound(): Boolean = sparkHiveClasses.exists { className => @@ -45,6 +47,15 @@ object SparkDependencies { false } + def sparkExecutorClassLoaderFound(): Boolean = + try { + Thread.currentThread().getContextClassLoader.loadClass(sparkExecutorClassLoaderClass) + true + } catch { + case _: ClassNotFoundException => + false + } + private def sparkModules(): Seq[String] = { val b = new mutable.ListBuffer[String] @@ -79,10 +90,17 @@ object SparkDependencies { b.result() } - def stubsDependency = + def stubsDependency = { + val suffix = org.apache.spark.SPARK_VERSION.split('.').take(2) match { + case Array("2", n) if Try(n.toInt).toOption.exists(_ <= 3) => + "20" + case _ => + "24" + } coursier.Dependency( - coursier.Module("sh.almond", s"spark-stubs_$sbv"), Properties.version + coursier.Module("sh.almond", s"spark-stubs_${suffix}_$sbv"), Properties.version ) + } def sparkYarnDependency = coursier.Dependency( diff --git a/modules/spark-stubs/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/modules/spark-stubs_20/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala similarity index 100% rename from modules/spark-stubs/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala rename to modules/spark-stubs_20/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala diff --git a/modules/spark-stubs/src/main/scala/spark/repl/Main.scala b/modules/spark-stubs_20/src/main/scala/spark/repl/Main.scala similarity index 100% rename from modules/spark-stubs/src/main/scala/spark/repl/Main.scala rename to modules/spark-stubs_20/src/main/scala/spark/repl/Main.scala diff --git a/modules/spark-stubs_24/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/modules/spark-stubs_24/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000..5748227 --- /dev/null +++ b/modules/spark-stubs_24/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,253 @@ + +// Like https://github.com/apache/spark/blob/0e318acd0cc3b42e8be9cb2a53cccfdc4a0805f9/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +// but keeping getClassFileInputStreamFromHttpServer from former versions (so that http works without hadoop stuff) + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, IOException, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels + +import scala.util.control.NonFatal +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm6._ +import org.apache.xbean.asm6.Opcodes._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ParentClassLoader, Utils} + +/** + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. + * + * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or + * the load failed, that it will call the overridden `findClass` function. To avoid the potential + * issue caused by loading class using inappropriate class loader, we should set the parent of + * ClassLoader to null, so that we can fully control which class loader is used. For detailed + * discussion, see SPARK-18646. + */ +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, + userClassPathFirst: Boolean) extends ClassLoader(null) with Logging { + val uri = new URI(classUri) + val directory = uri.getPath + + val parentLoader = new ParentClassLoader(parent) + + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) + } + + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + + override def findClass(name: String): Class[_] = { + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => throw new ClassNotFoundException(name, e) + case Some(a) => a + } + } + } + } + + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { + try { + fn + } catch { + case e: Exception => + throw new ClassNotFoundException(path, e) + } + } + } + } + + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = url.openConnection().asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + try { + fileSystem.open(path) + } catch { + case _: FileNotFoundException => + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null + try { + inputStream = fetchFn(pathInDirectory) + val bytes = readAndTransformClass(name, inputStream) + Some(defineClass(name, bytes, 0, bytes.length)) + } catch { + case e: ClassNotFoundException => + // We did not find the class + logDebug(s"Did not load class $name from REPL class server at $uri", e) + None + case e: Exception => + // Something bad happened while checking if the class exists + logError(s"Failed to check existence of class $name on REPL class server at $uri", e) + None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) { + bos.write(bytes, 0, num) + } else { + done = true + } + } + return bos.toByteArray + } + } + + /** + * URL-encode a string, preserving only slashes + */ + def urlEncode(str: String): String = { + str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/") + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassVisitor(ASM6, cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) + mv.visitVarInsn(ALOAD, 0) // load this + // val classType = className.replace('.', '/') + // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/modules/spark-stubs_24/src/main/scala/spark/repl/Main.scala b/modules/spark-stubs_24/src/main/scala/spark/repl/Main.scala new file mode 100644 index 0000000..42e4f81 --- /dev/null +++ b/modules/spark-stubs_24/src/main/scala/spark/repl/Main.scala @@ -0,0 +1,6 @@ +package spark.repl + +object Main { + // May make spark ClosureCleaner a tiny bit happier + def interp = this +} diff --git a/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala b/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala index 50f9e1d..cd3354d 100644 --- a/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala +++ b/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala @@ -506,42 +506,44 @@ class SparkReplTests(sparkVersion: String, master: String, conf: (String, String // tests below are custom ones "algebird" - { - sparkSession( - """ - @ import $ivy.`com.twitter::algebird-spark:0.13.0` - - @ AmmoniteSparkSession.sync() - - @ import com.twitter.algebird.Semigroup - import com.twitter.algebird.Semigroup - - @ import com.twitter.algebird.spark._ - import com.twitter.algebird.spark._ - - @ case class Foo(n: Int, weight: Double) - defined class Foo - - @ implicit val fooSemigroup: Semigroup[Foo] = new Semigroup[Foo] { - @ def plus(a: Foo, b: Foo): Foo = - @ Foo(a.n + b.n, a.weight + b.weight) - @ } - - @ val rdd = sc.parallelize((1 to 100).map(n => n.toString.take(1) -> Foo(n, n % 10)), 10) - - @ val res = rdd.algebird.sumByKey[String, Foo].collect().sortBy(_._1) - res: Array[(String, Foo)] = Array( - ("1", Foo(246, 46.0)), - ("2", Foo(247, 47.0)), - ("3", Foo(348, 48.0)), - ("4", Foo(449, 49.0)), - ("5", Foo(550, 50.0)), - ("6", Foo(651, 51.0)), - ("7", Foo(752, 52.0)), - ("8", Foo(853, 53.0)), - ("9", Foo(954, 54.0)) - ) - """ - ) + if (scala.util.Properties.versionNumberString.startsWith("2.11.")) + // no algebird-spark in scala 2.12 yet + sparkSession( + """ + @ import $ivy.`com.twitter::algebird-spark:0.13.0` + + @ AmmoniteSparkSession.sync() + + @ import com.twitter.algebird.Semigroup + import com.twitter.algebird.Semigroup + + @ import com.twitter.algebird.spark._ + import com.twitter.algebird.spark._ + + @ case class Foo(n: Int, weight: Double) + defined class Foo + + @ implicit val fooSemigroup: Semigroup[Foo] = new Semigroup[Foo] { + @ def plus(a: Foo, b: Foo): Foo = + @ Foo(a.n + b.n, a.weight + b.weight) + @ } + + @ val rdd = sc.parallelize((1 to 100).map(n => n.toString.take(1) -> Foo(n, n % 10)), 10) + + @ val res = rdd.algebird.sumByKey[String, Foo].collect().sortBy(_._1) + res: Array[(String, Foo)] = Array( + ("1", Foo(246, 46.0)), + ("2", Foo(247, 47.0)), + ("3", Foo(348, 48.0)), + ("4", Foo(449, 49.0)), + ("5", Foo(550, 50.0)), + ("6", Foo(651, 51.0)), + ("7", Foo(752, 52.0)), + ("8", Foo(853, 53.0)), + ("9", Foo(954, 54.0)) + ) + """ + ) } } diff --git a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala index 3cfd77a..8a49c62 100644 --- a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala +++ b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala @@ -4,6 +4,7 @@ object SparkVersions { def latest21 = "2.1.3" def latest22 = "2.2.2" - def latest23 = "2.3.1" + def latest23 = "2.3.2" + def latest24 = "2.4.0" } diff --git a/modules/tests/src/test/scala/ammonite/spark/Local21Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/Local21Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/Local21Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/Local21Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/Local22Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/Local22Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/Local22Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/Local22Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/Local23Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/Local23Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/Local23Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/Local23Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar21Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar21Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/ProgressBar21Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar21Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar22Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar22Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/ProgressBar22Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar22Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar23Tests.scala b/modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar23Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/ProgressBar23Tests.scala rename to modules/tests/src/test/scala-2.11/ammonite/spark/ProgressBar23Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/Local24Tests.scala b/modules/tests/src/test/scala/ammonite/spark/Local24Tests.scala new file mode 100644 index 0000000..eb17adc --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/Local24Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object Local24Tests extends SparkReplTests( + SparkVersions.latest24, + Local.master +) diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar24Tests.scala b/modules/tests/src/test/scala/ammonite/spark/ProgressBar24Tests.scala new file mode 100644 index 0000000..6595686 --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/ProgressBar24Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object ProgressBar24Tests extends ProgressBarTests( + SparkVersions.latest24, + Local.master +) diff --git a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn23Tests.scala b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn24Tests.scala similarity index 83% rename from modules/yarn-tests/src/test/scala/ammonite/spark/Yarn23Tests.scala rename to modules/yarn-tests/src/test/scala/ammonite/spark/Yarn24Tests.scala index e8eed6d..145eb72 100644 --- a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn23Tests.scala +++ b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn24Tests.scala @@ -1,7 +1,7 @@ package ammonite.spark -object Yarn23Tests extends SparkReplTests( - SparkVersions.latest23, +object Yarn24Tests extends SparkReplTests( + SparkVersions.latest24, "yarn", "spark.executor.instances" -> "1", "spark.executor.memory" -> "2g", diff --git a/project/Deps.scala b/project/Deps.scala index b7fc4b3..1fe7e77 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -1,11 +1,21 @@ import sbt._ +import sbt.Def.setting +import sbt.Keys._ object Deps { - def ammoniteRepl = ("com.lihaoyi" % "ammonite-repl" % "1.1.2-25-5ef5ee1").cross(CrossVersion.full) + def ammoniteRepl = ("com.lihaoyi" % "ammonite-repl" % "1.3.2").cross(CrossVersion.full) def jettyServer = "org.eclipse.jetty" % "jetty-server" % "8.1.14.v20131031" - def sparkSql = "org.apache.spark" %% "spark-sql" % "2.0.2" // no need to bump that version much, to ensure we don't rely on too new stuff def utest = "com.lihaoyi" %% "utest" % "0.6.4" + def sparkSql20 = "org.apache.spark" %% "spark-sql" % "2.0.2" // no need to bump that version much, to ensure we don't rely on too new stuff + def sparkSql24 = "org.apache.spark" %% "spark-sql" % "2.4.0" // that version's required for scala 2.12 + def sparkSql = setting { + if (Settings.isAtLeast212.value) + sparkSql24 + else + sparkSql20 + } + } diff --git a/project/Settings.scala b/project/Settings.scala index b48e0d7..002834e 100644 --- a/project/Settings.scala +++ b/project/Settings.scala @@ -3,6 +3,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Files import sbt._ +import sbt.Def._ import sbt.Keys._ object Settings { @@ -15,9 +16,18 @@ object Settings { } private val scala211 = "2.11.12" + private val scala212 = "2.12.7" + + lazy val isAtLeast212 = setting { + CrossVersion.partialVersion(scalaVersion.value) match { + case Some((2, n)) if n >= 12 => true + case _ => false + } + } lazy val shared = Seq( scalaVersion := scala211, + crossScalaVersions := Seq(scala212, scala211), scalacOptions ++= Seq( "-deprecation", "-feature", diff --git a/sbt-in-docker-with-yarn-cluster.sh b/sbt-in-docker-with-yarn-cluster.sh index 44a34e3..7ce40ed 100755 --- a/sbt-in-docker-with-yarn-cluster.sh +++ b/sbt-in-docker-with-yarn-cluster.sh @@ -98,12 +98,32 @@ if [ ! -d "$CACHE/hadoop-conf" ]; then test "$TRANSIENT_DOCKER_YARN_CLUSTER" = 0 || rm -rf "$CACHE/docker-yarn-cluster" fi +SPARK_VERSION="2.4.0" +SCALA_VERSION="${TRAVIS_SCALA_VERSION:-"2.11.12"}" +case "$SCALA_VERSION" in + 2.11.*) + SBV="2.11" + ;; + 2.12.*) + SBV="2.12" + ;; + *) + echo "Unrecognized scala version: $SCALA_VERSION" + exit 1 + ;; +esac + cat > "$CACHE/run.sh" << EOF #!/usr/bin/env bash set -e # prefetch stuff -for d in org.apache.spark:spark-sql_2.11:2.3.1 org.apache.spark:spark-yarn_2.11:2.3.1; do + +DEPS=() +DEPS+="org.apache.spark:spark-sql_$SBV:$SPARK_VERSION" +DEPS+="org.apache.spark:spark-yarn_$SBV:$SPARK_VERSION" + +for d in "${DEPS[@]}"; do echo "Pre-fetching \$d" coursier fetch $(if [ "$INTERACTIVE" = 1 ]; then echo --progress; fi) "\$d" >/dev/null done diff --git a/sbt-with-standalone-cluster.sh b/sbt-with-standalone-cluster.sh index c068a60..41eab0e 100755 --- a/sbt-with-standalone-cluster.sh +++ b/sbt-with-standalone-cluster.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -e -SPARK_VERSION="2.3.1" +SPARK_VERSION="2.3.2" HOST=localhost cd "$(dirname "${BASH_SOURCE[0]}")"