diff --git a/src/main/scala/net/snowflake/spark/snowflake/catalog/SfCatalog.scala b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfCatalog.scala new file mode 100644 index 00000000..3b736351 --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfCatalog.scala @@ -0,0 +1,219 @@ +package net.snowflake.spark.snowflake.catalog + +import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations +import net.snowflake.spark.snowflake.Parameters.{ + MergedParameters, + PARAM_SF_DATABASE, + PARAM_SF_DBTABLE, + PARAM_SF_SCHEMA +} +import net.snowflake.spark.snowflake.{DefaultJDBCWrapper, Parameters} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.analysis.{ + NoSuchNamespaceException, + NoSuchTableException +} +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.sql.SQLException +import scala.collection.convert.ImplicitConversions.`map AsScala` +import scala.collection.mutable.ArrayBuilder + +class SfCatalog extends TableCatalog with Logging with SupportsNamespaces { + var catalogName: String = null + var params: MergedParameters = _ + val jdbcWrapper = DefaultJDBCWrapper + + override def name(): String = { + require(catalogName != null, "The SfCatalog is not initialed") + catalogName + } + + override def initialize( + name: String, + options: CaseInsensitiveStringMap + ): Unit = { + val map = options.asCaseSensitiveMap().toMap + // to pass the check + params = Parameters.mergeParameters( + map + + (PARAM_SF_DATABASE -> "__invalid_database") + + (PARAM_SF_SCHEMA -> "__invalid_schema") + + (PARAM_SF_DBTABLE -> "__invalid_dbtable") + ) + catalogName = name + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + checkNamespace(namespace) + val catalog = if (namespace.length == 2) namespace(0) else null + val schemaPattern = if (namespace.length == 2) namespace(1) else null + val rs = DefaultJDBCWrapper + .getConnector(params) + .getMetaData() + .getTables(catalog, schemaPattern, "%", Array("TABLE")) + new Iterator[Identifier] { + def hasNext = rs.next() + def next() = Identifier.of(namespace, rs.getString("TABLE_NAME")) + }.toArray + + } + + override def tableExists(ident: Identifier): Boolean = { + checkNamespace(ident.namespace()) + DefaultJDBCWrapper.tableExists(params, getFullTableName(ident)) + } + + override def dropTable(ident: Identifier): Boolean = { + checkNamespace(ident.namespace()) + val conn = DefaultJDBCWrapper.getConnector(params) + conn.dropTable(getFullTableName(ident)) + } + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + checkNamespace(oldIdent.namespace()) + val conn = DefaultJDBCWrapper.getConnector(params) + conn.renameTable(getFullTableName(newIdent), getFullTableName(newIdent)) + } + + override def loadTable(ident: Identifier): Table = { + checkNamespace(ident.namespace()) + val map = params.parameters + params = Parameters.mergeParameters( + map + + (PARAM_SF_DBTABLE -> getTableName(ident)) + + (PARAM_SF_DATABASE -> getDatabase(ident)) + + (PARAM_SF_SCHEMA -> getSchema(ident)) + ) + try { + SfTable(ident, jdbcWrapper, params) + } catch { + case _: SQLException => + throw new NoSuchTableException(ident) + + } + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + throw new UnsupportedOperationException( + "SfCatalog does not support altering table operation" + ) + } + + override def namespaceExists(namespace: Array[String]): Boolean = + namespace match { + case Array(catalog, schema) => + val rs = DefaultJDBCWrapper + .getConnector(params) + .getMetaData() + .getSchemas(catalog, schema) + + while (rs.next()) { + val tableSchema = rs.getString("TABLE_SCHEM") + if (tableSchema == schema) return true + } + false + case _ => false + } + + override def listNamespaces(): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + val rs = DefaultJDBCWrapper.getConnector(params).getMetaData().getSchemas() + while (rs.next()) { + schemaBuilder += Array(rs.getString("TABLE_SCHEM")) + } + schemaBuilder.result + } + + override def listNamespaces( + namespace: Array[String] + ): Array[Array[String]] = { + namespace match { + case Array() => + listNamespaces() + case Array(_, _) if namespaceExists(namespace) => + Array() + case _ => + throw new NoSuchNamespaceException(namespace) + } + } + + override def loadNamespaceMetadata( + namespace: Array[String] + ): java.util.Map[String, String] = { + namespace match { + case Array(catalog, schema) => + if (!namespaceExists(namespace)) { + throw new NoSuchNamespaceException( + Array(catalog, schema) + ) + } + new java.util.HashMap[String, String]() + case _ => + throw new NoSuchNamespaceException(namespace) + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: java.util.Map[String, String] + ): Table = { + throw new UnsupportedOperationException( + "SfCatalog does not support creating table operation" + ) + } + + override def alterNamespace( + namespace: Array[String], + changes: NamespaceChange* + ): Unit = { + throw new UnsupportedOperationException( + "SfCatalog does not support altering namespace operation" + ) + + } + + override def dropNamespace( + namespace: Array[String], + cascade: Boolean + ): Boolean = { + throw new UnsupportedOperationException( + "SfCatalog does not support dropping namespace operation" + ) + } + + private def checkNamespace(namespace: Array[String]): Unit = { + // a database and schema comprise a namespace in Snowflake + if (namespace.length != 2) { + throw new NoSuchNamespaceException(namespace) + } + } + + override def createNamespace( + namespace: Array[String], + metadata: java.util.Map[String, String] + ): Unit = { + throw new UnsupportedOperationException( + "SfCatalog does not support creating namespace operation" + ) + } + + private def getTableName(ident: Identifier): String = { + (ident.name()) + } + private def getDatabase(ident: Identifier): String = { + (ident.namespace())(0) + } + private def getSchema(ident: Identifier): String = { + (ident.namespace())(1) + } + private def getFullTableName(ident: Identifier): String = { + (ident.namespace() :+ ident.name()).mkString(".") + + } +} diff --git a/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScan.scala b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScan.scala new file mode 100644 index 00000000..ad054a86 --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScan.scala @@ -0,0 +1,38 @@ +package net.snowflake.spark.snowflake.catalog + +import net.snowflake.spark.snowflake.SnowflakeRelation +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.connector.read.V1Scan +import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} + +case class SfScan( + relation: SnowflakeRelation, + prunedSchema: StructType, + pushedFilters: Array[Filter] +) extends V1Scan { + + override def readSchema(): StructType = prunedSchema + + override def toV1TableScan[T <: BaseRelation with TableScan]( + context: SQLContext + ): T = { + new BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = prunedSchema + override def needConversion: Boolean = relation.needConversion + override def buildScan(): RDD[Row] = { + val columnList = prunedSchema.map(_.name).toArray + relation.buildScan(columnList, pushedFilters) + } + }.asInstanceOf[T] + } + + override def description(): String = { + super.description() + ", prunedSchema: " + seqToString(prunedSchema) + + ", PushedFilters: " + seqToString(pushedFilters) + } + + private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") +} diff --git a/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScanBuilder.scala b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScanBuilder.scala new file mode 100644 index 00000000..e87ab7ea --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfScanBuilder.scala @@ -0,0 +1,72 @@ +package net.snowflake.spark.snowflake.catalog + +import net.snowflake.spark.snowflake.Parameters.MergedParameters +import net.snowflake.spark.snowflake.{ + FilterPushdown, + JDBCWrapper, + SnowflakeRelation +} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{ + Scan, + ScanBuilder, + SupportsPushDownFilters, + SupportsPushDownRequiredColumns +} +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +case class SfScanBuilder( + session: SparkSession, + schema: StructType, + params: MergedParameters, + jdbcWrapper: JDBCWrapper +) extends ScanBuilder + with SupportsPushDownFilters + with SupportsPushDownRequiredColumns + with Logging { + private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis + + private var pushedFilter = Array.empty[Filter] + + private var finalSchema = schema + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (pushed, unSupported) = filters.partition(filter => + FilterPushdown + .buildFilterStatement( + schema, + filter, + true + ) + .isDefined + ) + this.pushedFilter = pushed + unSupported + } + + override def pushedFilters(): Array[Filter] = pushedFilter + + override def pruneColumns(requiredSchema: StructType): Unit = { + val requiredCols = requiredSchema.fields + .map(PartitioningUtils.getColName(_, isCaseSensitive)) + .toSet + val fields = schema.fields.filter { field => + val colName = PartitioningUtils.getColName(field, isCaseSensitive) + requiredCols.contains(colName) + } + finalSchema = StructType(fields) + } + + override def build(): Scan = { + SfScan( + SnowflakeRelation(jdbcWrapper, params, Option(schema))( + session.sqlContext + ), + finalSchema, + pushedFilters + ) + } +} diff --git a/src/main/scala/net/snowflake/spark/snowflake/catalog/SfTable.scala b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfTable.scala new file mode 100644 index 00000000..e2d52c5e --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfTable.scala @@ -0,0 +1,56 @@ +package net.snowflake.spark.snowflake.catalog + +import net.snowflake.spark.snowflake.DefaultJDBCWrapper.DataBaseOperations +import net.snowflake.spark.snowflake.Parameters.MergedParameters +import net.snowflake.spark.snowflake.{DefaultJDBCWrapper, JDBCWrapper} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.sql.Connection +import java.util +import scala.collection.JavaConverters._ + +case class SfTable( + ident: Identifier, + jdbcWrapper: JDBCWrapper, + params: MergedParameters +) extends Table + with SupportsRead + with SupportsWrite + with Logging { + + override def name(): String = + (ident.namespace() :+ ident.name()).mkString(".") + + override def schema(): StructType = { + val conn: Connection = DefaultJDBCWrapper.getConnector(params) + try { + conn.tableSchema(name, params) + } finally { + conn.close() + } + } + + override def capabilities(): util.Set[TableCapability] = { + Set( + TableCapability.BATCH_READ, + TableCapability.V1_BATCH_WRITE, + TableCapability.TRUNCATE + ).asJava + } + + override def newScanBuilder( + options: CaseInsensitiveStringMap + ): ScanBuilder = { + SfScanBuilder(SparkSession.active, schema, params, jdbcWrapper) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + SfWriterBuilder(jdbcWrapper, params) + } +} diff --git a/src/main/scala/net/snowflake/spark/snowflake/catalog/SfWriterBuilder.scala b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfWriterBuilder.scala new file mode 100644 index 00000000..68c689d4 --- /dev/null +++ b/src/main/scala/net/snowflake/spark/snowflake/catalog/SfWriterBuilder.scala @@ -0,0 +1,32 @@ +package net.snowflake.spark.snowflake.catalog + +import net.snowflake.spark.snowflake.JDBCWrapper +import net.snowflake.spark.snowflake.Parameters.MergedParameters +import org.apache.spark.sql._ +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.sources.InsertableRelation + +case class SfWriterBuilder(jdbcWrapper: JDBCWrapper, params: MergedParameters) + extends WriteBuilder + with SupportsTruncate { + private var isTruncate = false + + override def truncate(): WriteBuilder = { + isTruncate = true + this + } + + override def build(): V1Write = new V1Write { + override def toInsertableRelation: InsertableRelation = + (data: DataFrame, _: Boolean) => { + val saveMode = if (isTruncate) { + SaveMode.Overwrite + } else { + SaveMode.Append + } + val writer = + new net.snowflake.spark.snowflake.SnowflakeWriter(jdbcWrapper) + writer.save(data.sqlContext, data, saveMode, params) + } + } +} diff --git a/src/test/scala/net/snowflake/spark/snowflake/BaseTest.scala b/src/test/scala/net/snowflake/spark/snowflake/BaseTest.scala index d2087e1b..0be3819d 100644 --- a/src/test/scala/net/snowflake/spark/snowflake/BaseTest.scala +++ b/src/test/scala/net/snowflake/spark/snowflake/BaseTest.scala @@ -32,23 +32,22 @@ import org.mockito.Mockito import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -/** - * Created by mzukowski on 8/9/16. +/** Created by mzukowski on 8/9/16. */ private class TestContext extends SparkContext("local", "SnowflakeBaseTest") { - /** - * A text file containing fake unloaded Snowflake data of all supported types + /** A text file containing fake unloaded Snowflake data of all supported types */ val testData: String = new File( - "src/test/resources/snowflake_unload_data.txt").toURI.toString + "src/test/resources/snowflake_unload_data.txt" + ).toURI.toString override def newAPIHadoopFile[K, V, F <: InputFormat[K, V]]( - path: String, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V], - conf: Configuration = hadoopConfiguration + path: String, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V], + conf: Configuration = hadoopConfiguration ): RDD[(K, V)] = { super.newAPIHadoopFile[K, V, F](testData, fClass, kClass, vClass, conf) } @@ -60,8 +59,7 @@ class BaseTest with BeforeAndAfterAll with BeforeAndAfterEach { - /** - * Spark Context with Hadoop file overridden to point at our local test data file for this suite, + /** Spark Context with Hadoop file overridden to point at our local test data file for this suite, * no matter what temp directory was generated and requested. */ protected var sc: SparkContext = _ @@ -106,6 +104,7 @@ class BaseTest sc.hadoopConfiguration.set("fs.s3.awsSecretAccessKey", "test2") sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", "test1") sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", "test2") + // Configure a mock S3 client so that we don't hit errors when trying to access AWS in tests. mockS3Client = Mockito.mock(classOf[AmazonS3Client], Mockito.RETURNS_SMART_NULLS) diff --git a/src/test/scala/net/snowflake/spark/snowflake/SfCatalogSuite.scala b/src/test/scala/net/snowflake/spark/snowflake/SfCatalogSuite.scala new file mode 100644 index 00000000..ab9d4044 --- /dev/null +++ b/src/test/scala/net/snowflake/spark/snowflake/SfCatalogSuite.scala @@ -0,0 +1,48 @@ +package net.snowflake.spark.snowflake + +import org.apache.spark.{SparkConf, SparkContext} +import org.scalatest.FunSuite + +class SfCatalogSuite extends FunSuite { + + protected def defaultSfCatalogParams: Map[String, String] = Map( + "spark.sql.catalog.snowflake" -> "net.snowflake.spark.snowflake.catalog.SfCatalog", + "spark.sql.catalog.snowflake.sfURL" -> "account.snowflakecomputing.com:443", + "spark.sql.catalog.snowflake.sfUser" -> "username", + "spark.sql.catalog.snowflake.sfPassword" -> "password" + ) + + test("SfCatalog Params") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("SfCatalogSuite") + conf.setAll(defaultSfCatalogParams) + + val sc = SparkContext.getOrCreate(conf) + + assert( + sc.getConf + .get("spark.sql.catalog.snowflake") + .equals("net.snowflake.spark.snowflake.catalog.SfCatalog") + ) + + assert( + sc.getConf + .get("spark.sql.catalog.snowflake.sfURL") + .equals("account.snowflakecomputing.com:443") + ) + + assert( + sc.getConf + .get("spark.sql.catalog.snowflake.sfUser") + .equals("username") + ) + + assert( + sc.getConf + .get("spark.sql.catalog.snowflake.sfPassword") + .equals("password") + ) + } + +}