diff --git a/.github/scripts/test.sh b/.github/scripts/test.sh index ccd5fa3..11760a1 100755 --- a/.github/scripts/test.sh +++ b/.github/scripts/test.sh @@ -3,15 +3,15 @@ set -e case "${MASTER:-"local"}" in local) - ./sbt publishLocal test mimaReportBinaryIssues ;; + ./sbt +publishLocal +test +mimaReportBinaryIssues ;; local-distrib) - ./with-spark-home.sh ./sbt publishLocal local-spark-distrib-tests/test ;; + ./with-spark-home.sh ./sbt +publishLocal +local-spark-distrib-tests/test ;; standalone) - ./with-spark-home.sh ./sbt-with-standalone-cluster.sh publishLocal standalone-tests/test ;; + ./with-spark-home.sh ./sbt-with-standalone-cluster.sh +publishLocal +standalone-tests/test ;; yarn) - ./sbt-in-docker-with-yarn-cluster.sh -batch publishLocal yarn-tests/test ;; + ./sbt-in-docker-with-yarn-cluster.sh -batch +publishLocal +yarn-tests/test ;; yarn-distrib) - ./with-spark-home.sh ./sbt-in-docker-with-yarn-cluster.sh -batch publishLocal yarn-spark-distrib-tests/test ;; + ./with-spark-home.sh ./sbt-in-docker-with-yarn-cluster.sh -batch +publishLocal +yarn-spark-distrib-tests/test ;; *) echo "Unrecognized master type $MASTER" exit 1 diff --git a/README.md b/README.md index aefeff0..a809fbd 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ with for sure. | `0.10.1` | `2.1.4` | `0.10.1` | | `0.11.0` | `2.3.8-36-1cce53f3` | `0.11.0` | | `0.12.0` | `2.3.8-122-9be39deb` | `0.12.0` | +| `0.13.0` | `2.5.4-8-30448e49` | `0.13.0` | ## Missing diff --git a/build.sbt b/build.sbt index 39d1dfc..93361e2 100644 --- a/build.sbt +++ b/build.sbt @@ -20,7 +20,7 @@ lazy val `spark-stubs_24` = project .underModules .settings( shared, - libraryDependencies += Deps.sparkSql % Provided + libraryDependencies += Deps.sparkSql.value % Provided ) lazy val `spark-stubs_30` = project @@ -31,16 +31,26 @@ lazy val `spark-stubs_30` = project libraryDependencies += Deps.sparkSql3 % Provided ) +lazy val `spark-stubs_32` = project + .disablePlugins(MimaPlugin) + .underModules + .settings( + shared, + crossScalaVersions += Deps.Scala.scala213, + libraryDependencies += Deps.sparkSql32 % Provided + ) + lazy val core = project .in(file("modules/core")) .settings( shared, + crossScalaVersions += Deps.Scala.scala213, name := "ammonite-spark", Mima.settings, generatePropertyFile("org/apache/spark/sql/ammonitesparkinternals/ammonite-spark.properties"), libraryDependencies ++= Seq( Deps.ammoniteReplApi % Provided, - Deps.sparkSql % Provided, + Deps.sparkSql.value % Provided, Deps.jettyServer ) ) @@ -50,6 +60,7 @@ lazy val tests = project .underModules .settings( shared, + crossScalaVersions += Deps.Scala.scala213, skip.in(publish) := true, generatePropertyFile("ammonite/ammonite-spark.properties"), generateDependenciesFile, @@ -87,6 +98,7 @@ lazy val `yarn-tests` = project .underModules .settings( shared, + crossScalaVersions += Deps.Scala.scala213, skip.in(publish) := true, testSettings ) @@ -108,6 +120,7 @@ lazy val `ammonite-spark` = project core, `spark-stubs_24`, `spark-stubs_30`, + `spark-stubs_32`, tests ) .settings( diff --git a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala index d28da01..ca5b73d 100644 --- a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala +++ b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/AmmoniteSparkSessionBuilder.scala @@ -298,7 +298,7 @@ class AmmoniteSparkSessionBuilder ) case Some(dir) => println(s"Adding Hadoop conf dir ${AmmoniteSparkSessionBuilder.prettyDir(dir)} to classpath") - interpApi.load.cp(ammonite.ops.Path(dir)) + interpApi.load.cp(os.Path(dir)) } } @@ -311,7 +311,7 @@ class AmmoniteSparkSessionBuilder println("Warning: hive-site.xml not found in the classpath, and no Hive conf found via HIVE_CONF_DIR") case Some(dir) => println(s"Adding Hive conf dir ${AmmoniteSparkSessionBuilder.prettyDir(dir)} to classpath") - interpApi.load.cp(ammonite.ops.Path(dir)) + interpApi.load.cp(os.Path(dir)) } } diff --git a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala index e656d3e..16b58d5 100644 --- a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala +++ b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala @@ -98,8 +98,10 @@ object SparkDependencies { "20" case Array("2", n) if Try(n.toInt).toOption.exists(_ >= 4) => "24" - case Array("3", n) => + case Array("3", n) if Try(n.toInt).toOption.exists(_ <= 1) => "30" + case Array("3", n) => + "32" case _ => System.err.println(s"Warning: unrecognized Spark version ($sv), assuming 2.4.x") "24" diff --git a/modules/spark-stubs_32/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/modules/spark-stubs_32/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000..8ca16f0 --- /dev/null +++ b/modules/spark-stubs_32/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,253 @@ + +// Like https://github.com/apache/spark/blob/0e318acd0cc3b42e8be9cb2a53cccfdc4a0805f9/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +// but keeping getClassFileInputStreamFromHttpServer from former versions (so that http works without hadoop stuff) + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, IOException, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels + +import scala.util.control.NonFatal +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm9._ +import org.apache.xbean.asm9.Opcodes._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ParentClassLoader, Utils} + +/** + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. + * + * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or + * the load failed, that it will call the overridden `findClass` function. To avoid the potential + * issue caused by loading class using inappropriate class loader, we should set the parent of + * ClassLoader to null, so that we can fully control which class loader is used. For detailed + * discussion, see SPARK-18646. + */ +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, + userClassPathFirst: Boolean) extends ClassLoader(null) with Logging { + val uri = new URI(classUri) + val directory = uri.getPath + + val parentLoader = new ParentClassLoader(parent) + + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) + } + + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + + override def findClass(name: String): Class[_] = { + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => throw new ClassNotFoundException(name, e) + case Some(a) => a + } + } + } + } + + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { + try { + fn + } catch { + case e: Exception => + throw new ClassNotFoundException(path, e) + } + } + } + } + + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + // val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + uri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = url.openConnection().asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + try { + fileSystem.open(path) + } catch { + case _: FileNotFoundException => + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null + try { + inputStream = fetchFn(pathInDirectory) + val bytes = readAndTransformClass(name, inputStream) + Some(defineClass(name, bytes, 0, bytes.length)) + } catch { + case e: ClassNotFoundException => + // We did not find the class + logDebug(s"Did not load class $name from REPL class server at $uri", e) + None + case e: Exception => + // Something bad happened while checking if the class exists + logError(s"Failed to check existence of class $name on REPL class server at $uri", e) + None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) { + bos.write(bytes, 0, num) + } else { + done = true + } + } + return bos.toByteArray + } + } + + /** + * URL-encode a string, preserving only slashes + */ + def urlEncode(str: String): String = { + str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/") + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassVisitor(ASM6, cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) + mv.visitVarInsn(ALOAD, 0) // load this + // val classType = className.replace('.', '/') + // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/modules/spark-stubs_32/src/main/scala/spark/repl/Main.scala b/modules/spark-stubs_32/src/main/scala/spark/repl/Main.scala new file mode 100644 index 0000000..42e4f81 --- /dev/null +++ b/modules/spark-stubs_32/src/main/scala/spark/repl/Main.scala @@ -0,0 +1,6 @@ +package spark.repl + +object Main { + // May make spark ClosureCleaner a tiny bit happier + def interp = this +} diff --git a/modules/tests/src/main/scala/ammonite/spark/Init.scala b/modules/tests/src/main/scala/ammonite/spark/Init.scala index d89c06d..d8a8803 100644 --- a/modules/tests/src/main/scala/ammonite/spark/Init.scala +++ b/modules/tests/src/main/scala/ammonite/spark/Init.scala @@ -47,7 +47,7 @@ object Init { @ .toVector @ .filter(f => !f.getFileName.toString.startsWith("scala-compiler") && !f.getFileName.toString.startsWith("scala-reflect") && !f.getFileName.toString.startsWith("scala-library") && !f.getFileName.toString.startsWith("spark-repl_")) @ .sortBy(_.getFileName.toString) - @ .map(ammonite.ops.Path(_)) + @ .map(os.Path(_)) @ } """ ++ init(master, sparkVersion, conf, loadSparkSql = false) diff --git a/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala b/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala index 5f02208..0f40bb3 100644 --- a/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala +++ b/modules/tests/src/main/scala/ammonite/spark/SparkReplTests.scala @@ -3,12 +3,16 @@ package ammonite.spark import ammonite.spark.fromammonite.TestRepl import utest._ +import scala.util.Properties.versionNumberString + class SparkReplTests( val sparkVersion: String, val master: String, val conf: (String, String)* ) extends TestSuite { + private def is212 = versionNumberString.startsWith("2.12.") + // Most of the tests here were adapted from https://github.com/apache/spark/blob/ab18b02e66fd04bc8f1a4fb7b6a7f2773902a494/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala Init.setupLog4j() @@ -181,13 +185,15 @@ class SparkReplTests( } "SPARK-1199 two instances of same class don't type check" - { + val expFieldNamePart = if (is212) "" else "exp = " + val exp2FieldNamePart = if (is212) "" else "exp2 = " sparkSession( - """ + s""" @ case class Sum(exp: String, exp2: String) defined class Sum @ val a = Sum("A", "B") - a: Sum = Sum("A", "B") + a: Sum = Sum(${expFieldNamePart}"A", ${exp2FieldNamePart}"B") @ def b(a: Sum): String = a match { case Sum(_, _) => "OK" } defined function b @@ -212,9 +218,10 @@ class SparkReplTests( } "SPARK-2576 importing implicits" - { + val fieldNamePart = if (is212) "" else "value = " // FIXME The addOuterScope should be automatically added. (Tweak CodeClassWrapper for that?) sparkSession( - """ + s""" @ import spark.implicits._ import spark.implicits._ @@ -225,7 +232,7 @@ class SparkReplTests( res: Array[Row] = Array([1], [2], [3], [4], [5], [6], [7], [8], [9], [10]) @ val foo = Seq(TestCaseClass(1)).toDS().collect() - foo: Array[TestCaseClass] = Array(TestCaseClass(1)) + foo: Array[TestCaseClass] = Array(TestCaseClass(${fieldNamePart}1)) """ ) } @@ -267,8 +274,9 @@ class SparkReplTests( } "SPARK-2632 importing a method from non serializable class and not using it" - { + val fieldNamePart = if (is212) "" else "value = " sparkSession( - """ + s""" @ class TestClass() { def testMethod = 3; override def toString = "TestClass" } defined class TestClass @@ -283,55 +291,57 @@ class SparkReplTests( @ val res = sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() res: Array[TestCaseClass] = Array( - TestCaseClass(1), - TestCaseClass(2), - TestCaseClass(3), - TestCaseClass(4), - TestCaseClass(5), - TestCaseClass(6), - TestCaseClass(7), - TestCaseClass(8), - TestCaseClass(9), - TestCaseClass(10) + TestCaseClass(${fieldNamePart}1), + TestCaseClass(${fieldNamePart}2), + TestCaseClass(${fieldNamePart}3), + TestCaseClass(${fieldNamePart}4), + TestCaseClass(${fieldNamePart}5), + TestCaseClass(${fieldNamePart}6), + TestCaseClass(${fieldNamePart}7), + TestCaseClass(${fieldNamePart}8), + TestCaseClass(${fieldNamePart}9), + TestCaseClass(${fieldNamePart}10) ) """ ) } "collecting objects of class defined in repl" - { + val fieldNamePart = if (is212) "" else "i = " sparkSession( - """ + s""" @ case class Foo(i: Int) defined class Foo @ val res = sc.parallelize((1 to 100).map(Foo), 10).collect() res: Array[Foo] = Array( - Foo(1), - Foo(2), - Foo(3), - Foo(4), - Foo(5), - Foo(6), - Foo(7), - Foo(8), - Foo(9), - Foo(10), + Foo(${fieldNamePart}1), + Foo(${fieldNamePart}2), + Foo(${fieldNamePart}3), + Foo(${fieldNamePart}4), + Foo(${fieldNamePart}5), + Foo(${fieldNamePart}6), + Foo(${fieldNamePart}7), + Foo(${fieldNamePart}8), + Foo(${fieldNamePart}9), + Foo(${fieldNamePart}10), ... """ ) } "collecting objects of class defined in repl - shuffling" - { + val fieldNamePart = if (is212) "" else "i = " sparkSession( - """ + s""" @ case class Foo(i: Int) defined class Foo @ val list = List((1, Foo(1)), (1, Foo(2))) - list: List[(Int, Foo)] = List((1, Foo(1)), (1, Foo(2))) + list: List[(Int, Foo)] = List((1, Foo(${fieldNamePart}1)), (1, Foo(${fieldNamePart}2))) @ val res = sc.parallelize(list).groupByKey().collect().map { case (k, v) => k -> v.toList } - res: Array[(Int, List[Foo])] = Array((1, List(Foo(1), Foo(2)))) + res: Array[(Int, List[Foo])] = Array((1, List(Foo(${fieldNamePart}1), Foo(${fieldNamePart}2)))) """ ) } @@ -420,8 +430,9 @@ class SparkReplTests( // Adapted from https://github.com/apache/spark/blob/3d5c61e5fd24f07302e39b5d61294da79aa0c2f9/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala#L193-L208 "line wrapper only initialized once when used as encoder outer scope" - { + val fieldNamePart = if (is212) "" else "value = " sparkSession( - """ + s""" @ val fileName = "repl-test-" + java.util.UUID.randomUUID() @ val tmpDir = System.getProperty("java.io.tmpdir") @@ -444,16 +455,16 @@ class SparkReplTests( @ val res = sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect() res: Array[TestCaseClass] = Array( - TestCaseClass(1), - TestCaseClass(2), - TestCaseClass(3), - TestCaseClass(4), - TestCaseClass(5), - TestCaseClass(6), - TestCaseClass(7), - TestCaseClass(8), - TestCaseClass(9), - TestCaseClass(10) + TestCaseClass(${fieldNamePart}1), + TestCaseClass(${fieldNamePart}2), + TestCaseClass(${fieldNamePart}3), + TestCaseClass(${fieldNamePart}4), + TestCaseClass(${fieldNamePart}5), + TestCaseClass(${fieldNamePart}6), + TestCaseClass(${fieldNamePart}7), + TestCaseClass(${fieldNamePart}8), + TestCaseClass(${fieldNamePart}9), + TestCaseClass(${fieldNamePart}10) ) @ val exists2 = file.exists() diff --git a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala index 1786ffc..e5d6721 100644 --- a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala +++ b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala @@ -7,5 +7,6 @@ object SparkVersions { def latest23 = "2.3.2" def latest24 = "2.4.4" def latest30 = "3.0.0" + def latest32 = "3.2.0" } diff --git a/modules/tests/src/main/scala/ammonite/spark/fromammonite/TestRepl.scala b/modules/tests/src/main/scala/ammonite/spark/fromammonite/TestRepl.scala index ce09a7c..482cc81 100644 --- a/modules/tests/src/main/scala/ammonite/spark/fromammonite/TestRepl.scala +++ b/modules/tests/src/main/scala/ammonite/spark/fromammonite/TestRepl.scala @@ -4,7 +4,6 @@ import ammonite.compiler.CodeClassWrapper import ammonite.compiler.iface.CodeWrapper import ammonite.interp.Interpreter import ammonite.main.Defaults -import ammonite.ops.{Path, read} import ammonite.repl._ import ammonite.repl.api.{FrontEnd, History, ReplLoad} import ammonite.runtime.{Frame, ImportHook, Storage} @@ -20,10 +19,10 @@ import scala.collection.mutable */ class TestRepl { var allOutput = "" - def predef: (String, Option[ammonite.ops.Path]) = ("", None) + def predef: (String, Option[os.Path]) = ("", None) def codeWrapper: CodeWrapper = CodeClassWrapper - val tempDir = ammonite.ops.Path( + val tempDir = os.Path( java.nio.file.Files.createTempDirectory("ammonite-tester") ) @@ -61,7 +60,7 @@ class TestRepl { parser = ammonite.compiler.Parsers, printer = printer0, storage = storage, - wd = ammonite.ops.pwd, + wd = os.pwd, colors = Ref(Colors.BlackWhite), verboseOutput = true, getFrame = () => frames().head, @@ -112,9 +111,9 @@ class TestRepl { } } - def exec(file: Path): Unit = { + def exec(file: os.Path): Unit = { interp.watch(file) - apply(normalizeNewlines(read(file))) + apply(normalizeNewlines(os.read(file))) } } } diff --git a/modules/tests/src/test/scala/ammonite/spark/Local24Tests.scala b/modules/tests/src/test/scala-2.12/ammonite/spark/Local24Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/Local24Tests.scala rename to modules/tests/src/test/scala-2.12/ammonite/spark/Local24Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/Local30Tests.scala b/modules/tests/src/test/scala-2.12/ammonite/spark/Local30Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/Local30Tests.scala rename to modules/tests/src/test/scala-2.12/ammonite/spark/Local30Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar24Tests.scala b/modules/tests/src/test/scala-2.12/ammonite/spark/ProgressBar24Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/ProgressBar24Tests.scala rename to modules/tests/src/test/scala-2.12/ammonite/spark/ProgressBar24Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar30Tests.scala b/modules/tests/src/test/scala-2.12/ammonite/spark/ProgressBar30Tests.scala similarity index 100% rename from modules/tests/src/test/scala/ammonite/spark/ProgressBar30Tests.scala rename to modules/tests/src/test/scala-2.12/ammonite/spark/ProgressBar30Tests.scala diff --git a/modules/tests/src/test/scala/ammonite/spark/Local32Tests.scala b/modules/tests/src/test/scala/ammonite/spark/Local32Tests.scala new file mode 100644 index 0000000..37dfc26 --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/Local32Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object Local32Tests extends SparkReplTests( + SparkVersions.latest32, + Local.master +) diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar32Tests.scala b/modules/tests/src/test/scala/ammonite/spark/ProgressBar32Tests.scala new file mode 100644 index 0000000..800d108 --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/ProgressBar32Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object ProgressBar32Tests extends ProgressBarTests( + SparkVersions.latest32, + Local.master +) diff --git a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn24Tests.scala b/modules/yarn-tests/src/test/scala-2.12/ammonite/spark/Yarn24Tests.scala similarity index 100% rename from modules/yarn-tests/src/test/scala/ammonite/spark/Yarn24Tests.scala rename to modules/yarn-tests/src/test/scala-2.12/ammonite/spark/Yarn24Tests.scala diff --git a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn30Tests.scala b/modules/yarn-tests/src/test/scala-2.12/ammonite/spark/Yarn30Tests.scala similarity index 100% rename from modules/yarn-tests/src/test/scala/ammonite/spark/Yarn30Tests.scala rename to modules/yarn-tests/src/test/scala-2.12/ammonite/spark/Yarn30Tests.scala diff --git a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn32Tests.scala b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn32Tests.scala new file mode 100644 index 0000000..48f7f60 --- /dev/null +++ b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn32Tests.scala @@ -0,0 +1,22 @@ +package ammonite.spark + +// Temporarily disabled, until we can (either or both) +// - enable the hadoop-2.7 profile (and excluding the remaining hadoop 3 dependencies?) when fetching Spark dependencies +// - update the docker-based hadoop setup in the tests to hadoop 3 + +// object Yarn32Tests extends SparkReplTests( +// SparkVersions.latest32, +// "yarn", +// "spark.executor.instances" -> "1", +// "spark.executor.memory" -> "2g", +// "spark.yarn.executor.memoryOverhead" -> "1g", +// "spark.yarn.am.memory" -> "2g" +// ) { +// override def inputUrlOpt = +// Some( +// sys.env.getOrElse( +// "INPUT_TXT_URL", +// sys.error("INPUT_TXT_URL not set") +// ) +// ) +// } diff --git a/project/Deps.scala b/project/Deps.scala index 79c9ddd..fd44bd0 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -5,7 +5,12 @@ import sbt.Keys._ object Deps { - private def ammoniteVersion = "2.3.8-125-f6bb1cf9" + object Scala { + def scala212 = "2.12.11" + def scala213 = "2.13.8" + } + + private def ammoniteVersion = "2.5.4-8-30448e49" def ammoniteCompiler = ("com.lihaoyi" % "ammonite-compiler" % ammoniteVersion).cross(CrossVersion.full) def ammoniteReplApi = ("com.lihaoyi" % "ammonite-repl-api" % ammoniteVersion).cross(CrossVersion.full) def ammoniteRepl = ("com.lihaoyi" % "ammonite-repl" % ammoniteVersion).cross(CrossVersion.full) @@ -13,7 +18,14 @@ object Deps { def jettyServer = "org.eclipse.jetty" % "jetty-server" % "9.4.46.v20220331" def utest = "com.lihaoyi" %% "utest" % "0.7.11" - def sparkSql = "org.apache.spark" %% "spark-sql" % "2.4.0" + def sparkSql = setting { + val sv = scalaVersion.value + val ver = + if (sv.startsWith("2.12.")) "2.4.0" + else "3.2.0" + "org.apache.spark" %% "spark-sql" % ver + } def sparkSql3 = "org.apache.spark" %% "spark-sql" % "3.0.0" + def sparkSql32 = "org.apache.spark" %% "spark-sql" % "3.2.0" } diff --git a/project/Mima.scala b/project/Mima.scala index 38713ba..da59070 100644 --- a/project/Mima.scala +++ b/project/Mima.scala @@ -19,7 +19,17 @@ object Mima { def settings = Def.settings( MimaPlugin.autoImport.mimaPreviousArtifacts := { - binaryCompatibilityVersions + val sv = scalaVersion.value + val binaryCompatibilityVersions0 = + if (sv.startsWith("2.12.")) binaryCompatibilityVersions + else + binaryCompatibilityVersions.filter { v => + !v.startsWith("0.9.") && + !v.startsWith("0.10.") && + !v.startsWith("0.11.") && + !v.startsWith("0.12.") + } + binaryCompatibilityVersions0 .map { ver => (organization.value % moduleName.value % ver) .cross(crossVersion.value) diff --git a/project/Settings.scala b/project/Settings.scala index ea7efc8..cfe6729 100644 --- a/project/Settings.scala +++ b/project/Settings.scala @@ -15,11 +15,9 @@ object Settings { } } - private val scala212 = "2.12.11" - lazy val shared = Seq( - scalaVersion := scala212, - crossScalaVersions := Seq(scala212), + scalaVersion := Deps.Scala.scala212, + crossScalaVersions := Seq(Deps.Scala.scala212), scalacOptions ++= Seq( "-deprecation", "-feature",