diff --git a/build.sbt b/build.sbt index e1dbb3e..dff73ef 100644 --- a/build.sbt +++ b/build.sbt @@ -19,7 +19,14 @@ lazy val `spark-stubs_24` = project .underModules .settings( shared, - libraryDependencies += Deps.sparkSql % "provided" + libraryDependencies += Deps.sparkSql % Provided + ) + +lazy val `spark-stubs_30` = project + .underModules + .settings( + shared, + libraryDependencies += Deps.sparkSql3 % Provided ) lazy val core = project @@ -29,8 +36,8 @@ lazy val core = project name := "ammonite-spark", generatePropertyFile("org/apache/spark/sql/ammonitesparkinternals/ammonite-spark.properties"), libraryDependencies ++= Seq( - Deps.ammoniteReplApi % "provided", - Deps.sparkSql % "provided", + Deps.ammoniteReplApi % Provided, + Deps.sparkSql % Provided, Deps.jettyServer ) ) @@ -90,6 +97,7 @@ lazy val `ammonite-spark` = project .aggregate( core, `spark-stubs_24`, + `spark-stubs_30`, tests ) .settings( diff --git a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala index b4dee11..e656d3e 100644 --- a/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala +++ b/modules/core/src/main/scala/org/apache/spark/sql/ammonitesparkinternals/SparkDependencies.scala @@ -92,10 +92,16 @@ object SparkDependencies { } def stubsDependency = { - val suffix = org.apache.spark.SPARK_VERSION.split('.').take(2) match { + val sv = org.apache.spark.SPARK_VERSION + val suffix = sv.split('.').take(2) match { case Array("2", n) if Try(n.toInt).toOption.exists(_ <= 3) => "20" + case Array("2", n) if Try(n.toInt).toOption.exists(_ >= 4) => + "24" + case Array("3", n) => + "30" case _ => + System.err.println(s"Warning: unrecognized Spark version ($sv), assuming 2.4.x") "24" } Dependency.of( diff --git a/modules/spark-stubs_30/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/modules/spark-stubs_30/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000..5c86cc4 --- /dev/null +++ b/modules/spark-stubs_30/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,253 @@ + +// Like https://github.com/apache/spark/blob/0e318acd0cc3b42e8be9cb2a53cccfdc4a0805f9/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +// but keeping getClassFileInputStreamFromHttpServer from former versions (so that http works without hadoop stuff) + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, IOException, InputStream} +import java.net.{HttpURLConnection, URI, URL, URLEncoder} +import java.nio.channels.Channels + +import scala.util.control.NonFatal +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm7._ +import org.apache.xbean.asm7.Opcodes._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.util.{ParentClassLoader, Utils} + +/** + * A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load + * classes defined by the interpreter when the REPL is used. Allows the user to specify if user + * class path should be first. This class loader delegates getting/finding resources to parent + * loader, which makes sense until REPL never provide resource dynamically. + * + * Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or + * the load failed, that it will call the overridden `findClass` function. To avoid the potential + * issue caused by loading class using inappropriate class loader, we should set the parent of + * ClassLoader to null, so that we can fully control which class loader is used. For detailed + * discussion, see SPARK-18646. + */ +class ExecutorClassLoader( + conf: SparkConf, + env: SparkEnv, + classUri: String, + parent: ClassLoader, + userClassPathFirst: Boolean) extends ClassLoader(null) with Logging { + val uri = new URI(classUri) + val directory = uri.getPath + + val parentLoader = new ParentClassLoader(parent) + + // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes + private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 + + private val fetchFn: (String) => InputStream = uri.getScheme() match { + case "spark" => getClassFileInputStreamFromSparkRPC + case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer + case _ => + val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) + getClassFileInputStreamFromFileSystem(fileSystem) + } + + override def getResource(name: String): URL = { + parentLoader.getResource(name) + } + + override def getResources(name: String): java.util.Enumeration[URL] = { + parentLoader.getResources(name) + } + + override def findClass(name: String): Class[_] = { + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => throw new ClassNotFoundException(name, e) + case Some(a) => a + } + } + } + } + + private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { + val channel = env.rpcEnv.openChannel(s"$classUri/$path") + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { + try { + fn + } catch { + case e: Exception => + throw new ClassNotFoundException(path, e) + } + } + } + } + + private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + // val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + uri.toURL + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)) + } + val connection: HttpURLConnection = url.openConnection().asInstanceOf[HttpURLConnection] + // Set the connection timeouts (for testing purposes) + if (httpUrlConnectionTimeoutMillis != -1) { + connection.setConnectTimeout(httpUrlConnectionTimeoutMillis) + connection.setReadTimeout(httpUrlConnectionTimeoutMillis) + } + connection.connect() + try { + if (connection.getResponseCode != 200) { + // Close the error stream so that the connection is eligible for re-use + try { + connection.getErrorStream.close() + } catch { + case ioe: IOException => + logError("Exception while closing error stream", ioe) + } + throw new ClassNotFoundException(s"Class file not found at URL $url") + } else { + connection.getInputStream + } + } catch { + case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] => + connection.disconnect() + throw e + } + } + + private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)( + pathInDirectory: String): InputStream = { + val path = new Path(directory, pathInDirectory) + try { + fileSystem.open(path) + } catch { + case _: FileNotFoundException => + throw new ClassNotFoundException(s"Class file not found at path $path") + } + } + + def findClassLocally(name: String): Option[Class[_]] = { + val pathInDirectory = name.replace('.', '/') + ".class" + var inputStream: InputStream = null + try { + inputStream = fetchFn(pathInDirectory) + val bytes = readAndTransformClass(name, inputStream) + Some(defineClass(name, bytes, 0, bytes.length)) + } catch { + case e: ClassNotFoundException => + // We did not find the class + logDebug(s"Did not load class $name from REPL class server at $uri", e) + None + case e: Exception => + // Something bad happened while checking if the class exists + logError(s"Failed to check existence of class $name on REPL class server at $uri", e) + None + } finally { + if (inputStream != null) { + try { + inputStream.close() + } catch { + case e: Exception => + logError("Exception while closing inputStream", e) + } + } + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) { + bos.write(bytes, 0, num) + } else { + done = true + } + } + return bos.toByteArray + } + } + + /** + * URL-encode a string, preserving only slashes + */ + def urlEncode(str: String): String = { + str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/") + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassVisitor(ASM6, cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) + mv.visitVarInsn(ALOAD, 0) // load this + // val classType = className.replace('.', '/') + // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/modules/spark-stubs_30/src/main/scala/spark/repl/Main.scala b/modules/spark-stubs_30/src/main/scala/spark/repl/Main.scala new file mode 100644 index 0000000..42e4f81 --- /dev/null +++ b/modules/spark-stubs_30/src/main/scala/spark/repl/Main.scala @@ -0,0 +1,6 @@ +package spark.repl + +object Main { + // May make spark ClosureCleaner a tiny bit happier + def interp = this +} diff --git a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala index 8a49c62..244c00f 100644 --- a/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala +++ b/modules/tests/src/main/scala/ammonite/spark/SparkVersions.scala @@ -5,6 +5,7 @@ object SparkVersions { def latest21 = "2.1.3" def latest22 = "2.2.2" def latest23 = "2.3.2" - def latest24 = "2.4.0" + def latest24 = "2.4.4" + def latest30 = "3.0.0-preview" } diff --git a/modules/tests/src/test/scala/ammonite/spark/Local30Tests.scala b/modules/tests/src/test/scala/ammonite/spark/Local30Tests.scala new file mode 100644 index 0000000..6bb73a4 --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/Local30Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object Local30Tests extends SparkReplTests( + SparkVersions.latest30, + Local.master +) diff --git a/modules/tests/src/test/scala/ammonite/spark/ProgressBar30Tests.scala b/modules/tests/src/test/scala/ammonite/spark/ProgressBar30Tests.scala new file mode 100644 index 0000000..dc1c5dc --- /dev/null +++ b/modules/tests/src/test/scala/ammonite/spark/ProgressBar30Tests.scala @@ -0,0 +1,6 @@ +package ammonite.spark + +object ProgressBar30Tests extends ProgressBarTests( + SparkVersions.latest30, + Local.master +) diff --git a/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn30Tests.scala b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn30Tests.scala new file mode 100644 index 0000000..756340c --- /dev/null +++ b/modules/yarn-tests/src/test/scala/ammonite/spark/Yarn30Tests.scala @@ -0,0 +1,18 @@ +package ammonite.spark + +object Yarn30Tests extends SparkReplTests( + SparkVersions.latest30, + "yarn", + "spark.executor.instances" -> "1", + "spark.executor.memory" -> "2g", + "spark.yarn.executor.memoryOverhead" -> "1g", + "spark.yarn.am.memory" -> "2g" +) { + override def inputUrlOpt = + Some( + sys.env.getOrElse( + "INPUT_TXT_URL", + sys.error("INPUT_TXT_URL not set") + ) + ) +} diff --git a/project/Deps.scala b/project/Deps.scala index 562ecc3..c81e82c 100644 --- a/project/Deps.scala +++ b/project/Deps.scala @@ -13,5 +13,6 @@ object Deps { def utest = "com.lihaoyi" %% "utest" % "0.7.1" def sparkSql = "org.apache.spark" %% "spark-sql" % "2.4.0" + def sparkSql3 = "org.apache.spark" %% "spark-sql" % "3.0.0-preview" } diff --git a/sbt-in-docker-with-yarn-cluster.sh b/sbt-in-docker-with-yarn-cluster.sh index be51674..1b131ea 100755 --- a/sbt-in-docker-with-yarn-cluster.sh +++ b/sbt-in-docker-with-yarn-cluster.sh @@ -118,15 +118,16 @@ set -e if [ "\$SPARK_HOME" = "" ]; then # prefetch stuff - export SPARK_VERSION="2.4.0" + for SPARK_VERSION in "2.4.4" "3.0.0-preview"; do - DEPS=() - DEPS+=("org.apache.spark:spark-sql_$SBV:\$SPARK_VERSION") - DEPS+=("org.apache.spark:spark-yarn_$SBV:\$SPARK_VERSION") + DEPS=() + DEPS+=("org.apache.spark:spark-sql_$SBV:\$SPARK_VERSION") + DEPS+=("org.apache.spark:spark-yarn_$SBV:\$SPARK_VERSION") - for d in "\${DEPS[@]}"; do - echo "Pre-fetching \$d" - coursier fetch $(if [ "$INTERACTIVE" = 1 ]; then echo --progress; fi) "\$d" >/dev/null + for d in "\${DEPS[@]}"; do + echo "Pre-fetching \$d" + coursier fetch "\$d" $(if [ "$INTERACTIVE" = 1 ]; then echo --progress; else echo "/dev/null + done done fi