Skip to content

Commit

Permalink
Merge pull request #222 from alexarchambault/develop
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
alexarchambault committed Jun 10, 2022
2 parents db047a5 + b7e1a3e commit 5187f04
Show file tree
Hide file tree
Showing 23 changed files with 404 additions and 64 deletions.
10 changes: 5 additions & 5 deletions .github/scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ set -e

case "${MASTER:-"local"}" in
local)
./sbt publishLocal test mimaReportBinaryIssues ;;
./sbt +publishLocal +test +mimaReportBinaryIssues ;;
local-distrib)
./with-spark-home.sh ./sbt publishLocal local-spark-distrib-tests/test ;;
./with-spark-home.sh ./sbt +publishLocal +local-spark-distrib-tests/test ;;
standalone)
./with-spark-home.sh ./sbt-with-standalone-cluster.sh publishLocal standalone-tests/test ;;
./with-spark-home.sh ./sbt-with-standalone-cluster.sh +publishLocal +standalone-tests/test ;;
yarn)
./sbt-in-docker-with-yarn-cluster.sh -batch publishLocal yarn-tests/test ;;
./sbt-in-docker-with-yarn-cluster.sh -batch +publishLocal +yarn-tests/test ;;
yarn-distrib)
./with-spark-home.sh ./sbt-in-docker-with-yarn-cluster.sh -batch publishLocal yarn-spark-distrib-tests/test ;;
./with-spark-home.sh ./sbt-in-docker-with-yarn-cluster.sh -batch +publishLocal +yarn-spark-distrib-tests/test ;;
*)
echo "Unrecognized master type $MASTER"
exit 1
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ with for sure.
| `0.10.1` | `2.1.4` | `0.10.1` |
| `0.11.0` | `2.3.8-36-1cce53f3` | `0.11.0` |
| `0.12.0` | `2.3.8-122-9be39deb` | `0.12.0` |
| `0.13.0` | `2.5.4-8-30448e49` | `0.13.0` |

## Missing

Expand Down
17 changes: 15 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lazy val `spark-stubs_24` = project
.underModules
.settings(
shared,
libraryDependencies += Deps.sparkSql % Provided
libraryDependencies += Deps.sparkSql.value % Provided
)

lazy val `spark-stubs_30` = project
Expand All @@ -31,16 +31,26 @@ lazy val `spark-stubs_30` = project
libraryDependencies += Deps.sparkSql3 % Provided
)

lazy val `spark-stubs_32` = project
.disablePlugins(MimaPlugin)
.underModules
.settings(
shared,
crossScalaVersions += Deps.Scala.scala213,
libraryDependencies += Deps.sparkSql32 % Provided
)

lazy val core = project
.in(file("modules/core"))
.settings(
shared,
crossScalaVersions += Deps.Scala.scala213,
name := "ammonite-spark",
Mima.settings,
generatePropertyFile("org/apache/spark/sql/ammonitesparkinternals/ammonite-spark.properties"),
libraryDependencies ++= Seq(
Deps.ammoniteReplApi % Provided,
Deps.sparkSql % Provided,
Deps.sparkSql.value % Provided,
Deps.jettyServer
)
)
Expand All @@ -50,6 +60,7 @@ lazy val tests = project
.underModules
.settings(
shared,
crossScalaVersions += Deps.Scala.scala213,
skip.in(publish) := true,
generatePropertyFile("ammonite/ammonite-spark.properties"),
generateDependenciesFile,
Expand Down Expand Up @@ -87,6 +98,7 @@ lazy val `yarn-tests` = project
.underModules
.settings(
shared,
crossScalaVersions += Deps.Scala.scala213,
skip.in(publish) := true,
testSettings
)
Expand All @@ -108,6 +120,7 @@ lazy val `ammonite-spark` = project
core,
`spark-stubs_24`,
`spark-stubs_30`,
`spark-stubs_32`,
tests
)
.settings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class AmmoniteSparkSessionBuilder
)
case Some(dir) =>
println(s"Adding Hadoop conf dir ${AmmoniteSparkSessionBuilder.prettyDir(dir)} to classpath")
interpApi.load.cp(ammonite.ops.Path(dir))
interpApi.load.cp(os.Path(dir))
}
}

Expand All @@ -311,7 +311,7 @@ class AmmoniteSparkSessionBuilder
println("Warning: hive-site.xml not found in the classpath, and no Hive conf found via HIVE_CONF_DIR")
case Some(dir) =>
println(s"Adding Hive conf dir ${AmmoniteSparkSessionBuilder.prettyDir(dir)} to classpath")
interpApi.load.cp(ammonite.ops.Path(dir))
interpApi.load.cp(os.Path(dir))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ object SparkDependencies {
"20"
case Array("2", n) if Try(n.toInt).toOption.exists(_ >= 4) =>
"24"
case Array("3", n) =>
case Array("3", n) if Try(n.toInt).toOption.exists(_ <= 1) =>
"30"
case Array("3", n) =>
"32"
case _ =>
System.err.println(s"Warning: unrecognized Spark version ($sv), assuming 2.4.x")
"24"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@

// Like https://github.com/apache/spark/blob/0e318acd0cc3b42e8be9cb2a53cccfdc4a0805f9/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
// but keeping getClassFileInputStreamFromHttpServer from former versions (so that http works without hadoop stuff)

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.repl

import java.io.{ByteArrayOutputStream, FileNotFoundException, FilterInputStream, IOException, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}
import java.nio.channels.Channels

import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.xbean.asm9._
import org.apache.xbean.asm9.Opcodes._

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.util.{ParentClassLoader, Utils}

/**
* A ClassLoader that reads classes from a Hadoop FileSystem or Spark RPC endpoint, used to load
* classes defined by the interpreter when the REPL is used. Allows the user to specify if user
* class path should be first. This class loader delegates getting/finding resources to parent
* loader, which makes sense until REPL never provide resource dynamically.
*
* Note: [[ClassLoader]] will preferentially load class from parent. Only when parent is null or
* the load failed, that it will call the overridden `findClass` function. To avoid the potential
* issue caused by loading class using inappropriate class loader, we should set the parent of
* ClassLoader to null, so that we can fully control which class loader is used. For detailed
* discussion, see SPARK-18646.
*/
class ExecutorClassLoader(
conf: SparkConf,
env: SparkEnv,
classUri: String,
parent: ClassLoader,
userClassPathFirst: Boolean) extends ClassLoader(null) with Logging {
val uri = new URI(classUri)
val directory = uri.getPath

val parentLoader = new ParentClassLoader(parent)

// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1

private val fetchFn: (String) => InputStream = uri.getScheme() match {
case "spark" => getClassFileInputStreamFromSparkRPC
case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer
case _ =>
val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf))
getClassFileInputStreamFromFileSystem(fileSystem)
}

override def getResource(name: String): URL = {
parentLoader.getResource(name)
}

override def getResources(name: String): java.util.Enumeration[URL] = {
parentLoader.getResources(name)
}

override def findClass(name: String): Class[_] = {
if (userClassPathFirst) {
findClassLocally(name).getOrElse(parentLoader.loadClass(name))
} else {
try {
parentLoader.loadClass(name)
} catch {
case e: ClassNotFoundException =>
val classOption = findClassLocally(name)
classOption match {
case None => throw new ClassNotFoundException(name, e)
case Some(a) => a
}
}
}
}

private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
val channel = env.rpcEnv.openChannel(s"$classUri/$path")
new FilterInputStream(Channels.newInputStream(channel)) {

override def read(): Int = toClassNotFound(super.read())

override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b))

override def read(b: Array[Byte], offset: Int, len: Int) =
toClassNotFound(super.read(b, offset, len))

private def toClassNotFound(fn: => Int): Int = {
try {
fn
} catch {
case e: Exception =>
throw new ClassNotFoundException(path, e)
}
}
}
}

private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
// val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
uri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
val connection: HttpURLConnection = url.openConnection().asInstanceOf[HttpURLConnection]
// Set the connection timeouts (for testing purposes)
if (httpUrlConnectionTimeoutMillis != -1) {
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
}
connection.connect()
try {
if (connection.getResponseCode != 200) {
// Close the error stream so that the connection is eligible for re-use
try {
connection.getErrorStream.close()
} catch {
case ioe: IOException =>
logError("Exception while closing error stream", ioe)
}
throw new ClassNotFoundException(s"Class file not found at URL $url")
} else {
connection.getInputStream
}
} catch {
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
connection.disconnect()
throw e
}
}

private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)(
pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
try {
fileSystem.open(path)
} catch {
case _: FileNotFoundException =>
throw new ClassNotFoundException(s"Class file not found at path $path")
}
}

def findClassLocally(name: String): Option[Class[_]] = {
val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null
try {
inputStream = fetchFn(pathInDirectory)
val bytes = readAndTransformClass(name, inputStream)
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
case e: Exception =>
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
} finally {
if (inputStream != null) {
try {
inputStream.close()
} catch {
case e: Exception =>
logError("Exception while closing inputStream", e)
}
}
}
}

def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
if (name.startsWith("line") && name.endsWith("$iw$")) {
// Class seems to be an interpreter "wrapper" object storing a val or var.
// Replace its constructor with a dummy one that does not run the
// initialization code placed there by the REPL. The val or var will
// be initialized later through reflection when it is used in a task.
val cr = new ClassReader(in)
val cw = new ClassWriter(
ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
val cleaner = new ConstructorCleaner(name, cw)
cr.accept(cleaner, 0)
return cw.toByteArray
} else {
// Pass the class through unmodified
val bos = new ByteArrayOutputStream
val bytes = new Array[Byte](4096)
var done = false
while (!done) {
val num = in.read(bytes)
if (num >= 0) {
bos.write(bytes, 0, num)
} else {
done = true
}
}
return bos.toByteArray
}
}

/**
* URL-encode a string, preserving only slashes
*/
def urlEncode(str: String): String = {
str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/")
}
}

class ConstructorCleaner(className: String, cv: ClassVisitor)
extends ClassVisitor(ASM6, cv) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
val mv = cv.visitMethod(access, name, desc, sig, exceptions)
if (name == "<init>" && (access & ACC_STATIC) == 0) {
// This is the constructor, time to clean it; just output some new
// instructions to mv that create the object and set the static MODULE$
// field in the class to point to it, but do nothing otherwise.
mv.visitCode()
mv.visitVarInsn(ALOAD, 0) // load this
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false)
mv.visitVarInsn(ALOAD, 0) // load this
// val classType = className.replace('.', '/')
// mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
mv.visitInsn(RETURN)
mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
mv.visitEnd()
return null
} else {
return mv
}
}
}
6 changes: 6 additions & 0 deletions modules/spark-stubs_32/src/main/scala/spark/repl/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package spark.repl

object Main {
// May make spark ClosureCleaner a tiny bit happier
def interp = this
}
2 changes: 1 addition & 1 deletion modules/tests/src/main/scala/ammonite/spark/Init.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object Init {
@ .toVector
@ .filter(f => !f.getFileName.toString.startsWith("scala-compiler") && !f.getFileName.toString.startsWith("scala-reflect") && !f.getFileName.toString.startsWith("scala-library") && !f.getFileName.toString.startsWith("spark-repl_"))
@ .sortBy(_.getFileName.toString)
@ .map(ammonite.ops.Path(_))
@ .map(os.Path(_))
@ }
""" ++ init(master, sparkVersion, conf, loadSparkSql = false)

Expand Down
Loading

0 comments on commit 5187f04

Please sign in to comment.