Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sql/connect/client/jdbc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@
<classifier>tests</classifier>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-client-jvm_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<classifier>tests</classifier>
<scope>test</scope>
</dependency>
<!-- Use mima to perform the compatibility check -->
<dependency>
<groupId>com.typesafe</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.connect.client.jdbc

import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException}
import java.sql.{Connection, Driver, DriverPropertyInfo, SQLException, SQLFeatureNotSupportedException}
import java.util.Properties
import java.util.logging.Logger

Expand All @@ -29,7 +29,11 @@ class NonRegisteringSparkConnectDriver extends Driver {
override def acceptsURL(url: String): Boolean = url.startsWith("jdbc:sc://")

override def connect(url: String, info: Properties): Connection = {
throw new UnsupportedOperationException("TODO(SPARK-53934)")
if (url == null) {
throw new SQLException("url must not be null")
}

if (this.acceptsURL(url)) new SparkConnectConnection(url, info) else null
}

override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.client.jdbc

import java.sql.{Array => JdbcArray, _}
import java.util
import java.util.Properties
import java.util.concurrent.Executor

import org.apache.spark.sql.connect.SparkSession
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.jdbc.util.JdbcErrorUtils._

class SparkConnectConnection(val url: String, val info: Properties) extends Connection {

private[jdbc] val client = SparkConnectClient
.builder()
.loadFromEnvironment()
.userAgent("Spark Connect JDBC")
.connectionString(url.stripPrefix("jdbc:"))
.build()

private[jdbc] val spark = SparkSession.builder().client(client).create()

@volatile private var closed: Boolean = false

override def isClosed: Boolean = closed

override def close(): Unit = synchronized {
if (!closed) {
spark.close()
closed = true
}
}

private[jdbc] def checkOpen(): Unit = {
if (closed) {
throw new SQLException("JDBC Connection is closed.")
}
if (!client.isSessionValid) {
throw new SQLException(s"Spark Connect Session ${client.sessionId} is invalid.")
}
}

override def isValid(timeout: Int): Boolean = !closed && client.isSessionValid

override def setCatalog(catalog: String): Unit = {
checkOpen()
spark.catalog.setCurrentCatalog(catalog)
}

override def getCatalog: String = {
checkOpen()
spark.catalog.currentCatalog()
}

override def setSchema(schema: String): Unit = {
checkOpen()
spark.catalog.setCurrentDatabase(schema)
}

override def getSchema: String = {
checkOpen()
spark.catalog.currentDatabase
}

override def getMetaData: DatabaseMetaData = {
checkOpen()
new SparkConnectDatabaseMetaData(this)
}

override def createStatement(): Statement = {
checkOpen()
new SparkConnectStatement(this)
}

override def prepareStatement(sql: String): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def prepareCall(sql: String): CallableStatement =
throw new SQLFeatureNotSupportedException

override def createStatement(
resultSetType: Int,
resultSetConcurrency: Int,
resultSetHoldability: Int): Statement =
throw new SQLFeatureNotSupportedException

override def prepareStatement(
sql: String,
resultSetType: Int,
resultSetConcurrency: Int,
resultSetHoldability: Int): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def prepareCall(
sql: String,
resultSetType: Int,
resultSetConcurrency: Int,
resultSetHoldability: Int): CallableStatement =
throw new SQLFeatureNotSupportedException

override def prepareStatement(
sql: String, autoGeneratedKeys: Int): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def prepareStatement(
sql: String, columnIndexes: Array[Int]): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def prepareStatement(
sql: String, columnNames: Array[String]): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def createStatement(
resultSetType: Int, resultSetConcurrency: Int): Statement =
throw new SQLFeatureNotSupportedException

override def prepareStatement(
sql: String,
resultSetType: Int,
resultSetConcurrency: Int): PreparedStatement =
throw new SQLFeatureNotSupportedException

override def prepareCall(
sql: String,
resultSetType: Int,
resultSetConcurrency: Int): CallableStatement =
throw new SQLFeatureNotSupportedException

override def nativeSQL(sql: String): String =
throw new SQLFeatureNotSupportedException

override def setAutoCommit(autoCommit: Boolean): Unit = {
checkOpen()
if (!autoCommit) {
throw new SQLFeatureNotSupportedException("Only auto-commit mode is supported")
}
}

override def getAutoCommit: Boolean = {
checkOpen()
true
}

override def commit(): Unit = {
checkOpen()
throw new SQLException("Connection is in auto-commit mode")
}

override def rollback(): Unit = {
checkOpen()
throw new SQLException("Connection is in auto-commit mode")
}

override def setReadOnly(readOnly: Boolean): Unit = {
checkOpen()
if (readOnly) {
throw new SQLFeatureNotSupportedException("Read-only mode is not supported")
}
}

override def isReadOnly: Boolean = {
checkOpen()
false
}

override def setTransactionIsolation(level: Int): Unit = {
checkOpen()
if (level != Connection.TRANSACTION_NONE) {
throw new SQLFeatureNotSupportedException(
"Requested transaction isolation level " +
s"${stringfiyTransactionIsolationLevel(level)} is not supported")
}
}

override def getTransactionIsolation: Int = {
checkOpen()
Connection.TRANSACTION_NONE
}

override def getWarnings: SQLWarning = null

override def clearWarnings(): Unit = {}

override def getTypeMap: util.Map[String, Class[_]] =
throw new SQLFeatureNotSupportedException

override def setTypeMap(map: util.Map[String, Class[_]]): Unit =
throw new SQLFeatureNotSupportedException

override def setHoldability(holdability: Int): Unit = {
if (holdability != ResultSet.HOLD_CURSORS_OVER_COMMIT) {
throw new SQLFeatureNotSupportedException(
s"Holdability ${stringfiyHoldability(holdability)} is not supported")
}
}

override def getHoldability: Int = ResultSet.HOLD_CURSORS_OVER_COMMIT

override def setSavepoint(): Savepoint =
throw new SQLFeatureNotSupportedException

override def setSavepoint(name: String): Savepoint =
throw new SQLFeatureNotSupportedException

override def rollback(savepoint: Savepoint): Unit =
throw new SQLFeatureNotSupportedException

override def releaseSavepoint(savepoint: Savepoint): Unit =
throw new SQLFeatureNotSupportedException

override def createClob(): Clob =
throw new SQLFeatureNotSupportedException

override def createBlob(): Blob =
throw new SQLFeatureNotSupportedException

override def createNClob(): NClob =
throw new SQLFeatureNotSupportedException

override def createSQLXML(): SQLXML =
throw new SQLFeatureNotSupportedException

override def setClientInfo(name: String, value: String): Unit =
throw new SQLFeatureNotSupportedException

override def setClientInfo(properties: Properties): Unit =
throw new SQLFeatureNotSupportedException

override def getClientInfo(name: String): String =
throw new SQLFeatureNotSupportedException

override def getClientInfo: Properties =
throw new SQLFeatureNotSupportedException

override def createArrayOf(typeName: String, elements: Array[AnyRef]): JdbcArray =
throw new SQLFeatureNotSupportedException

override def createStruct(typeName: String, attributes: Array[AnyRef]): Struct =
throw new SQLFeatureNotSupportedException

override def abort(executor: Executor): Unit = {
if (executor == null) {
throw new SQLException("executor can not be null")
}
if (!closed) {
executor.execute { () => this.close() }
}
}

override def setNetworkTimeout(executor: Executor, milliseconds: Int): Unit =
throw new SQLFeatureNotSupportedException

override def getNetworkTimeout: Int =
throw new SQLFeatureNotSupportedException

override def unwrap[T](iface: Class[T]): T = if (isWrapperFor(iface)) {
iface.asInstanceOf[T]
} else {
throw new SQLException(s"${this.getClass.getName} not unwrappable from ${iface.getName}")
}

override def isWrapperFor(iface: Class[_]): Boolean = iface.isInstance(this)
}
Loading