Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-577, round two #1250

Open
wants to merge 17 commits into
base: b3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package com.datastax.spark.connector.cql

import java.io.IOException

import com.datastax.spark.connector.SparkCassandraITWordSpecBase
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.types._
import com.datastax.spark.connector.util.schemaFromCassandra
import org.apache.commons.lang3.SerializationUtils
import org.scalatest.Inspectors._
import org.scalatest.OptionValues._

class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {

override lazy val conn = CassandraConnector(defaultConf)

val altKeyspaceName = "another_keyspace"

conn.withSessionDo { session =>
createKeyspace(session)
createKeyspace(session, altKeyspaceName)

session.execute(
s"""CREATE TYPE $ks.address (street varchar, city varchar, zip int)""")
Expand Down Expand Up @@ -45,6 +53,10 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
s"""CREATE INDEX test_d9_map_idx ON $ks.test (keys(d9_map))""")
session.execute(
s"""CREATE INDEX test_d7_int_idx ON $ks.test (d7_int)""")
session.execute(
s"""CREATE TABLE $ks.another_test(k1 int, PRIMARY KEY (k1))""")
session.execute(
s"""CREATE TABLE $ks.yet_another_test(k1 int, PRIMARY KEY (k1))""")
}

val schema = schemaFromCassandra(conn)
Expand All @@ -53,22 +65,63 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
"allow to get a list of keyspaces" in {
schema.keyspaces.map(_.keyspaceName) should contain(ks)
}

"allow to look up a keyspace by name" in {
val keyspace = schema.keyspaceByName(ks)
keyspace.keyspaceName shouldBe ks
}

"find the correct table using Schema.tableFromCassandra" in {
conn.withSessionDo(s => {
Schema.tableFromCassandra(s, ks, "test").tableName shouldBe "test"
Schema.tableFromCassandra(s, ks, "another_test").tableName shouldBe "another_test"
Schema.tableFromCassandra(s, ks, "yet_another_test").tableName shouldBe "yet_another_test"
})
}

"enforce constraints in fromCassandra" in {
conn.withSessionDo(s => {
val selectedTableName = "yet_another_test"
Schema.fromCassandra(s, None, None).keyspaceByName(ks).keyspaceName shouldBe ks
Schema.fromCassandra(s, None, None).keyspaceByName(altKeyspaceName).keyspaceName shouldBe altKeyspaceName
val schema1 = Schema.fromCassandra(s, Some(altKeyspaceName), None)
schema1.keyspaces.size shouldBe 1
schema1.keyspaces.head.keyspaceName shouldBe altKeyspaceName
val schema2 = Schema.fromCassandra(s, Some(ks), Some(selectedTableName))
schema2.keyspaces.size shouldBe 1
schema2.keyspaces.head.keyspaceName shouldBe ks
schema2.keyspaceByName(ks).tableByName.size shouldBe 1
schema2.keyspaceByName(ks).tableByName(selectedTableName).tableName shouldBe selectedTableName
})
}

"throw IOException for tableFromCassandra call with unknown table" in {
assertThrows[IOException] {
conn.withSessionDo(s => {
Schema.tableFromCassandra(s, ks, "unknown_table")
})
}
}
}

"A KeyspaceDef" should {

"be serializable" in {
SerializationUtils.roundtrip(schema.keyspaceByName(ks))
}

"allow to get a list of tables in the given keyspace" in {
val keyspace = schema.keyspaceByName(ks)
keyspace.tables.map(_.tableName) shouldBe Set("test")
keyspace.tableByName.values.map(_.tableName).toSet shouldBe Set("another_test", "yet_another_test", "test")
}

"allow to look up a table by name" in {
val keyspace = schema.keyspaceByName(ks)
val table = keyspace.tableByName("test")
table.tableName shouldBe "test"
keyspace.tableByName("test").tableName shouldBe "test"
keyspace.tableByName("another_test").tableName shouldBe "another_test"
keyspace.tableByName("yet_another_test").tableName shouldBe "yet_another_test"
}

"allow to look up user type by name" in {
val keyspace = schema.keyspaceByName(ks)
val userType = keyspace.userTypeByName("address")
Expand All @@ -80,17 +133,36 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
val keyspace = schema.keyspaceByName(ks)
val table = keyspace.tableByName("test")

"be serializable" in {
SerializationUtils.roundtrip(table)
}

"list all columns" in {
val colNames = table.columns.map(_.columnName)
colNames.size shouldBe 22

// Spot checks of a few column values only here
colNames should contain("k2")
colNames should contain("c3")
colNames should contain("d12_uuid")
}

"allow to read column definitions by name" in {
table.columnByName("k1").columnName shouldBe "k1"
}

"allow to read column definitions by index" in {
table.columnByIndex(0).columnName shouldBe "k1"
table.columnByIndex(4).columnName shouldBe "c2"
}

"allow to read primary key column definitions" in {
table.primaryKey.size shouldBe 6
table.primaryKey.map(_.columnName) shouldBe Seq(
"k1", "k2", "k3", "c1", "c2", "c3")
table.primaryKey.map(_.columnType) shouldBe Seq(
IntType, VarCharType, TimestampType, BigIntType, VarCharType, UUIDType)
forAll(table.primaryKey) { c => c.isPrimaryKeyColumn shouldBe true }
table.primaryKey.map(_.columnType) shouldBe
Seq(IntType, VarCharType, TimestampType, BigIntType, VarCharType, UUIDType)
table.primaryKey.forall(_.isPrimaryKeyColumn)
}

"allow to read partitioning key column definitions" in {
Expand All @@ -101,9 +173,9 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
}

"allow to read regular column definitions" in {
val columns = table.regularColumns
columns.size shouldBe 16
columns.map(_.columnName).toSet shouldBe Set(
val regularColumns = table.regularColumns
regularColumns.size shouldBe 16
regularColumns.map(_.columnName).toSet shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
Expand Down Expand Up @@ -143,6 +215,48 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
"should hold all indices retrieved from cassandra" in {
table.indexes.size shouldBe 2
}

"have a sane check for missing columns" in {
val missing1 = table.missingColumns(Seq("k1", "c2", "d12_uuid"))
missing1.size shouldBe 0
val missing2 = table.missingColumns(Seq("k1", "c2", "d12_uuid", "made_up_column_name"))
missing2.size shouldBe 1
missing2.head.columnName shouldBe "made_up_column_name"
val missing3 = table.missingColumns(Seq("k1", "another_made_up_column_name", "c2", "d12_uuid", "made_up_column_name"))
missing3.size shouldBe 2
missing3.head.columnName shouldBe "another_made_up_column_name"
missing3.tail.head.columnName shouldBe "made_up_column_name"
}

"support generating a DefaultTableDef" in {
val defaultDef = DefaultTableDef.fromDriverDef(table.asInstanceOf[DriverTableDef])
defaultDef.keyspaceName shouldBe table.keyspaceName
defaultDef.tableName shouldBe table.tableName

defaultDef.partitionKey.map(_.columnName) shouldBe table.partitionKey.map(_.columnName)
defaultDef.clusteringColumns.map(_.columnName) shouldBe table.clusteringColumns.map(_.columnName)
defaultDef.regularColumns.map(_.columnName) shouldBe table.regularColumns.map(_.columnName)
defaultDef.primaryKey.map(_.columnName) shouldBe table.primaryKey.map(_.columnName)
defaultDef.indexes.map(_.indexName) shouldBe table.indexes.map(_.indexName)
}
}

"A ColumnDef" should {

val keyspace = schema.keyspaceByName(ks)
val table = keyspace.tableByName("test")
val column = table.columnByName("c2")

"be serializable" in {
SerializationUtils.roundtrip(column)
}

"correctly find it's index if it's a clustering column" in {
column.componentIndex.value shouldBe 1
}

"return None if it's not a clustering column" in {
table.columnByName("k1").componentIndex shouldBe None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.concurrent.Future
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.{SomeColumns, _}
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.cql.{KeyValueWithConversion => _, _}
import com.datastax.spark.connector.mapper.DefaultColumnMapper
import com.datastax.spark.connector.types._
import org.apache.spark.SparkException
Expand Down Expand Up @@ -131,10 +131,10 @@ class TableWriterSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {
}

it should "write RDD of tuples to a new table" in {
val pkey = ColumnDef("key", PartitionKeyColumn, IntType)
val group = ColumnDef("group", ClusteringColumn(0), BigIntType)
val value = ColumnDef("value", RegularColumn, TextType)
val table = TableDef(ks, "new_kv_table", Seq(pkey), Seq(group), Seq(value))
val pkey = DefaultColumnDef("key", PartitionKeyColumn, IntType)
val group = DefaultColumnDef("group", ClusteringColumn(0), BigIntType)
val value = DefaultColumnDef("value", RegularColumn, TextType)
val table = DefaultTableDef(ks, "new_kv_table", Seq(pkey), Seq(group), Seq(value))
val rows = Seq((1, 1L, "value1"), (2, 2L, "value2"), (3, 3L, "value3"))
sc.parallelize(rows).saveAsCassandraTableEx(table, SomeColumns("key", "group", "value"))
verifyKeyValueTable("new_kv_table")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.datastax.spark.connector

import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.mapper.DataFrameColumnMapper
import org.apache.spark.SparkContext
Expand Down Expand Up @@ -94,7 +93,7 @@ class DatasetFunctions[K: Encoder](dataset: Dataset[K]) extends Serializable {
,
ifNotExists = ifNotExists
,
tableOptions = tableOptions
options = tableOptions
)

connector.withSessionDo(session => session.execute(table.cql))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa
* from items of the [[org.apache.spark.rdd.RDD RDD]]
*/
def saveAsCassandraTableEx(
table: TableDef,
table: DefaultTableDef,
columns: ColumnSelector = AllColumns,
writeConf: WriteConf = WriteConf.fromSparkConf(sparkContext.getConf))(
implicit
Expand Down Expand Up @@ -92,7 +92,7 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa

val protocolVersion = connector.withSessionDo(_.getContext.getProtocolVersion)

val table = TableDef.fromType[T](keyspaceName, tableName, protocolVersion)
val table = DefaultTableDef.fromType[T](keyspaceName, tableName, protocolVersion)
saveAsCassandraTableEx(table, columns, writeConf)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ class DataFrameColumnMapper[T](structType: StructType) extends ColumnMapper[T] {
override def newTable(
keyspaceName: String,
tableName: String,
protocolVersion: ProtocolVersion = ProtocolVersion.DEFAULT): TableDef = {
protocolVersion: ProtocolVersion = ProtocolVersion.DEFAULT): DefaultTableDef = {

val columns = structType.zipWithIndex.map { case (field, i) => {
val columnRole = if (i == 0) PartitionKeyColumn else RegularColumn
ColumnDef(field.name, columnRole, ColumnType.fromDriverType(CassandraSourceUtil.sparkSqlToJavaDriverType(field.dataType, protocolVersion)))
DefaultColumnDef(
field.name,
columnRole,
ColumnType.fromDriverType(CassandraSourceUtil.sparkSqlToJavaDriverType(field.dataType, protocolVersion)))
}}

TableDef(keyspaceName, tableName, Seq(columns.head), Seq.empty, columns.tail)
DefaultTableDef(keyspaceName, tableName, Seq(columns.head), Seq.empty, columns.tail)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector

import com.datastax.dse.driver.api.core.auth.ProxyAuthentication
import com.datastax.oss.driver.api.core.cql.Statement
import com.datastax.spark.connector.cql.{CassandraConnector, Schema, TableDef}
import com.datastax.spark.connector.cql.{CassandraConnector, DriverTableDef, Schema, TableDef}

/** Useful stuff that didn't fit elsewhere. */
package object util {
Expand All @@ -26,7 +26,7 @@ package object util {
def tableFromCassandra(
connector: CassandraConnector,
keyspaceName: String,
tableName: String): TableDef = {
tableName: String): DriverTableDef = {
connector.withSessionDo(Schema.tableFromCassandra(_, keyspaceName, tableName))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TableWriter[T] private (
val keyspaceName = tableDef.keyspaceName
val tableName = tableDef.tableName
val columnNames = rowWriter.columnNames diff writeConf.optionPlaceholders
val columns = columnNames.map(tableDef.columnByName)
val columns = columnNames.map(tableDef.columnByName.asInstanceOf[Map[String,ColumnDef]])

private[connector] lazy val queryTemplateUsingInsert: String = {
val quotedColumnNames: Seq[String] = columnNames.map(quote)
Expand Down Expand Up @@ -318,9 +318,9 @@ object TableWriter {
private def onlyPartitionKeyAndStatic(table: TableDef, columnNames: Seq[String]): Boolean = {
val nonPartitionKeyColumnNames = columnNames.toSet -- table.partitionKey.map(_.columnName)
val nonPartitionKeyColumnRefs = table
.allColumns
.columns
.filter(columnDef => nonPartitionKeyColumnNames.contains(columnDef.columnName))
nonPartitionKeyColumnRefs.forall( columnDef => columnDef.columnRole == StaticColumn)
nonPartitionKeyColumnRefs.forall( columnDef => columnDef.isStatic)
}

/**
Expand Down Expand Up @@ -413,9 +413,9 @@ object TableWriter {
writeConf: WriteConf,
checkPartitionKey: Boolean = false): TableWriter[T] = {

val tableDef = tableFromCassandra(connector, keyspaceName, tableName)
val tableDef = DefaultTableDef.fromDriverDef(tableFromCassandra(connector, keyspaceName, tableName))
val optionColumns = writeConf.optionsAsColumns(keyspaceName, tableName)
val tablDefWithMeta = tableDef.copy(regularColumns = tableDef.regularColumns ++ optionColumns)
val tablDefWithMeta = tableDef.copy(regularColumns = (tableDef.regularColumns ++ optionColumns))

val selectedColumns = columnNames
.selectFrom(tablDefWithMeta)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector.writer

import com.datastax.oss.driver.api.core.{ConsistencyLevel, DefaultConsistencyLevel}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes}
import com.datastax.spark.connector.cql.{ColumnDef, RegularColumn}
import com.datastax.spark.connector.cql.{ColumnDef, DefaultColumnDef, RegularColumn}
import com.datastax.spark.connector.types.ColumnType
import com.datastax.spark.connector.util.ConfigCheck.ConnectorConfigurationException
import com.datastax.spark.connector.util.{ConfigCheck, ConfigParameter, DeprecatedConfigParameter}
Expand Down Expand Up @@ -42,10 +42,10 @@ case class WriteConf(
case WriteOption(PerRowWriteOptionValue(placeholder)) => placeholder
}

private[writer] val optionsAsColumns: (String, String) => Seq[ColumnDef] = { (keyspace, table) =>
private[writer] val optionsAsColumns: (String, String) => Seq[DefaultColumnDef] = { (keyspace, table) =>
def toRegularColDef(opt: WriteOption[_], dataType: DataType) = opt match {
case WriteOption(PerRowWriteOptionValue(placeholder)) =>
Some(ColumnDef(placeholder, RegularColumn, ColumnType.fromDriverType(dataType)))
Some(DefaultColumnDef(placeholder, RegularColumn, ColumnType.fromDriverType(dataType)))
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import com.datastax.spark.connector.types.{TimestampType, VarCharType, IntType}

class ColumnSelectorSpec extends WordSpec with Matchers {
"A ColumnSelector#selectFrom method" should {
val column1 = ColumnDef("c1", PartitionKeyColumn, IntType)
val column2 = ColumnDef("c2", PartitionKeyColumn, VarCharType)
val column3 = ColumnDef("c3", ClusteringColumn(0), VarCharType)
val column4 = ColumnDef("c4", ClusteringColumn(1), VarCharType)
val column5 = ColumnDef("c5", RegularColumn, VarCharType)
val column6 = ColumnDef("c6", RegularColumn, TimestampType)

val tableDef = TableDef("keyspace", "table", Seq(column1, column2), Seq(column3, column4), Seq(column5, column6))
val column1 = DefaultColumnDef("c1", PartitionKeyColumn, IntType)
val column2 = DefaultColumnDef("c2", PartitionKeyColumn, VarCharType)
val column3 = DefaultColumnDef("c3", ClusteringColumn(0), VarCharType)
val column4 = DefaultColumnDef("c4", ClusteringColumn(1), VarCharType)
val column5 = DefaultColumnDef("c5", RegularColumn, VarCharType)
val column6 = DefaultColumnDef("c6", RegularColumn, TimestampType)

val tableDef = DefaultTableDef("keyspace", "table", Seq(column1, column2), Seq(column3, column4), Seq(column5, column6))

"return all columns" in {
val columns = AllColumns.selectFrom(tableDef)
Expand Down
Loading