From 965b2dcbdc3ae56fd5c24ee7b6240fc8f6b17238 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski <6516951+jacek-lewandowski@users.noreply.github.com> Date: Fri, 21 Jun 2024 12:08:13 +0200 Subject: [PATCH] SPARKC-706: Add basic support for Cassandra vectors (#1366) --- CHANGES.txt | 3 + README.md | 21 +- .../SparkCassandraITFlatSpecBase.scala | 32 ++- .../spark/connector/cql/SchemaSpec.scala | 11 +- .../connector/rdd/CassandraRDDSpec.scala | 4 +- .../spark/connector/rdd/RDDSpec.scala | 4 +- .../rdd/typeTests/VectorTypeTest.scala | 230 ++++++++++++++++++ .../datasource/CassandraSourceUtil.scala | 3 +- .../sql/cassandra/DataTypeConverter.scala | 1 + doc/14_data_frames.md | 15 ++ doc/2_loading.md | 47 +++- doc/4_mapper.md | 3 + doc/5_saving.md | 33 +++ doc/6_advanced_mapper.md | 1 + .../GettableDataToMappedTypeConverter.scala | 4 + .../MappedToGettableDataConverter.scala | 4 + .../spark/connector/types/ColumnType.scala | 4 +- .../spark/connector/types/TypeConverter.scala | 29 ++- .../spark/connector/types/VectorType.scala | 20 ++ project/Versions.scala | 2 +- .../spark/connector/ccm/CcmConfig.scala | 2 +- 21 files changed, 445 insertions(+), 28 deletions(-) create mode 100644 connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala create mode 100644 driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala diff --git a/CHANGES.txt b/CHANGES.txt index da7e5c6c2..39e3dd0e5 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,6 @@ +3.5.1 + * Support for Vector type available in Cassandra 5.0 (SPARKC-706) + * Upgrade Cassandra Java Driver to 4.18.1, support Cassandra 5.0 in test framework (SPARKC-710) 3.5.0 * Support for Apache Spark 3.5 (SPARKC-704) diff --git a/README.md b/README.md index 2ec1a3491..172139a33 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,15 @@ ## Quick Links -| What | Where | -| ---------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Community | Chat with us at [Apache Cassandra](https://cassandra.apache.org/_/community.html#discussions) | -| Scala Docs | Most Recent Release (3.5.0): [Connector API docs](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.0/connector/com/datastax/spark/connector/index.html), [Connector Driver docs](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.0/driver/com/datastax/spark/connector/index.html) | -| Latest Production Release | [3.5.0](https://search.maven.org/artifact/com.datastax.spark/spark-cassandra-connector_2.12/3.5.0/jar) | +| What | Where | +| ---------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Community | Chat with us at [Apache Cassandra](https://cassandra.apache.org/_/community.html#discussions) | +| Scala Docs | Most Recent Release (3.5.1): [Connector API docs](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.1/connector/com/datastax/spark/connector/index.html), [Connector Driver docs](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.1/driver/com/datastax/spark/connector/index.html) | +| Latest Production Release | [3.5.1](https://search.maven.org/artifact/com.datastax.spark/spark-cassandra-connector_2.12/3.5.1/jar) | + +## News +### 3.5.1 + - The latest release of the Spark-Cassandra-Connector introduces support for vector types, greatly enhancing its capabilities. This new feature allows developers to seamlessly integrate and work with Cassandra 5.0 and Astra vectors within the Spark ecosystem. By supporting vector types, the connector now provides insights into AI and Retrieval-Augmented Generation (RAG) data, enabling more advanced and efficient data processing and analysis. ## Features @@ -55,7 +59,7 @@ Currently, the following branches are actively supported: | Connector | Spark | Cassandra | Cassandra Java Driver | Minimum Java Version | Supported Scala Versions | |-----------|---------------|----------------------------|-----------------------|----------------------|--------------------------| -| 3.5.1 | 3.5 | 2.1.5*, 2.2, 3.x, 4.x, 5.0 | 4.18 | 8 | 2.12, 2.13 | +| 3.5.1 | 3.5 | 2.1.5*, 2.2, 3.x, 4.x, 5.0 | 4.18.1 | 8 | 2.12, 2.13 | | 3.5 | 3.5 | 2.1.5*, 2.2, 3.x, 4.x | 4.13 | 8 | 2.12, 2.13 | | 3.4 | 3.4 | 2.1.5*, 2.2, 3.x, 4.x | 4.13 | 8 | 2.12, 2.13 | | 3.3 | 3.3 | 2.1.5*, 2.2, 3.x, 4.x | 4.13 | 8 | 2.12 | @@ -80,6 +84,9 @@ Currently, the following branches are actively supported: ## Hosted API Docs API documentation for the Scala and Java interfaces are available online: +### 3.5.1 +* [Spark-Cassandra-Connector](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.1/connector/com/datastax/spark/connector/index.html) + ### 3.5.0 * [Spark-Cassandra-Connector](https://datastax.github.io/spark-cassandra-connector/ApiDocs/3.5.0/connector/com/datastax/spark/connector/index.html) @@ -111,7 +118,7 @@ This project is available on the Maven Central Repository. For SBT to download the connector binaries, sources and javadoc, put this in your project SBT config: - libraryDependencies += "com.datastax.spark" %% "spark-cassandra-connector" % "3.5.0" + libraryDependencies += "com.datastax.spark" %% "spark-cassandra-connector" % "3.5.1" * The default Scala version for Spark 3.0+ is 2.12 please choose the appropriate build. See the [FAQ](doc/FAQ.md) for more information. diff --git a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala index b1c77a0ba..41985b56b 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala @@ -98,7 +98,7 @@ trait SparkCassandraITSpecBase } override def withFixture(test: NoArgTest): Outcome = wrapUnserializableExceptions { - super.withFixture(test) + super.withFixture(test) } def getKsName = { @@ -145,18 +145,32 @@ trait SparkCassandraITSpecBase else report(s"Skipped Because ProtocolVersion $pv < $protocolVersion") } - /** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version - * (if this is a DSE cluster) */ - def from(cassandra: Version, dse: Version)(f: => Unit): Unit = { + /** Runs the given test only if the cluster type and version matches. + * + * @param cassandra run the test if the cluster is Cassandra >= the given version; + * if `None`, never run the test for Cassandra clusters + * @param dse run the test if the cluster is DSE >= the given version; + * if `None`, never run the test for DSE clusters + * @param f the test to run + */ + def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Option(cassandra), Option(dse))(f) + + def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = { if (isDse(conn)) { - from(dse)(f) + dse match { + case Some(dseVersion) => from(dseVersion)(f) + case None => report(s"Skipped because not DSE") + } } else { - from(cassandra)(f) + cassandra match { + case Some(cassandraVersion) => from(cassandraVersion)(f) + case None => report(s"Skipped because not Cassandra") + } } } - /** Skips the given test if the Cluster Version is lower or equal to the given version */ - def from(version: Version)(f: => Unit): Unit = { + /** Skips the given test if the Cluster Version is lower than the given version */ + private def from(version: Version)(f: => Unit): Unit = { skip(cluster.getCassandraVersion, version) { f } } @@ -172,7 +186,7 @@ trait SparkCassandraITSpecBase else f } - /** Skips the given test if the Cluster Version is lower or equal to the given version or the cluster is not DSE */ + /** Skips the given test if the Cluster Version is lower than the given version or the cluster is not DSE */ def dseFrom(version: Version)(f: => Any): Unit = { dseOnly { skip(cluster.getDseVersion.get, version) { f } diff --git a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala index f30f09d0b..b784ce6bd 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala @@ -1,6 +1,7 @@ package com.datastax.spark.connector.cql import com.datastax.spark.connector.SparkCassandraITWordSpecBase +import com.datastax.spark.connector.ccm.CcmConfig import com.datastax.spark.connector.cluster.DefaultCluster import com.datastax.spark.connector.types._ import com.datastax.spark.connector.util.schemaFromCassandra @@ -49,6 +50,9 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { s"""CREATE INDEX test_d9_m23423ap_idx ON $ks.test (full(d10_set))""") session.execute( s"""CREATE INDEX test_d7_int_idx ON $ks.test (d7_int)""") + from(Some(CcmConfig.V5_0_0), None) { + session.execute(s"ALTER TABLE $ks.test ADD d17_vector frozen>") + } for (i <- 0 to 9) { session.execute(s"insert into $ks.test (k1,k2,k3,c1,c2,c3,d10_set) " + @@ -111,8 +115,8 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { "allow to read regular column definitions" in { val columns = table.regularColumns - columns.size shouldBe 16 - columns.map(_.columnName).toSet shouldBe Set( + columns.size should be >= 16 + columns.map(_.columnName).toSet should contain allElementsOf Set( "d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float", "d6_inet", "d7_int", "d8_list", "d9_map", "d10_set", "d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar", @@ -136,6 +140,9 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { table.columnByName("d14_varchar").columnType shouldBe VarCharType table.columnByName("d15_varint").columnType shouldBe VarIntType table.columnByName("d16_address").columnType shouldBe a [UserDefinedType] + from(Some(CcmConfig.V5_0_0), None) { + table.columnByName("d17_vector").columnType shouldBe VectorType(IntType, 3) + } } "allow to list fields of a user defined type" in { diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala index 3a9ac7e90..7bf60fe28 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala @@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption import com.datastax.oss.driver.api.core.cql.SimpleStatement import com.datastax.oss.driver.api.core.cql.SimpleStatement._ import com.datastax.spark.connector._ -import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0} +import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0} import com.datastax.spark.connector.cluster.DefaultCluster import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf} import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean} @@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster results should contain ((KeyGroup(3, 300), (3, 300, "0003"))) } - it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) { + it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) { val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect result.length should be (1) } diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala index d882cbbd6..5175173d3 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala @@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._ import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement} import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion} import com.datastax.spark.connector._ -import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0 +import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0} import com.datastax.spark.connector.cluster.DefaultCluster import com.datastax.spark.connector.cql.CassandraConnector import com.datastax.spark.connector.embedded.SparkTemplate._ @@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster { } - it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){ + it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){ val source = sc.parallelize(keys).map(x => (x, x * 100)) val someCass = source .joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group")) diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala new file mode 100644 index 000000000..d06335033 --- /dev/null +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala @@ -0,0 +1,230 @@ +package com.datastax.spark.connector.rdd.typeTests + +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.core.{CqlSession, Version} +import com.datastax.spark.connector._ +import com.datastax.spark.connector.cluster.DefaultCluster +import com.datastax.spark.connector.cql.CassandraConnector +import com.datastax.spark.connector.mapper.ColumnMapper +import com.datastax.spark.connector.rdd.ValidRDDType +import com.datastax.spark.connector.rdd.reader.RowReaderFactory +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper} + +import scala.collection.convert.ImplicitConversionsToScala._ +import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ + + +abstract class VectorTypeTest[ + ScalaType: ClassTag : TypeTag, + DriverType <: Number : ClassTag, + CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster +{ + override lazy val conn = CassandraConnector(sparkConf) + + val VectorTable = "vectors" + + def createVectorTable(session: CqlSession, table: String): Unit = { + session.execute( + s"""CREATE TABLE IF NOT EXISTS $ks.$table ( + | id INT PRIMARY KEY, + | v VECTOR<$typeName, 3> + |)""".stripMargin) + } + + def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1")) + + def minDSEVersion: Option[Version] = None + + def vectorFromInts(ints: Int*): Seq[ScalaType] + + def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType + + override def beforeClass() { + conn.withSessionDo { session => + session.execute( + s"""CREATE KEYSPACE IF NOT EXISTS $ks + |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }""" + .stripMargin) + } + } + + private def assertVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = { + val returnedVectors = for (i <- expectedVectors.indices) yield { + rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq + } + + returnedVectors should contain theSameElementsInOrderAs expectedVectors + } + + "SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_caseclass_to_df" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + + spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9))) + + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Overwrite) + .option("confirm.truncate", value = true) + .save() + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12))) + } + } + + it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_tuple_to_df" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .toDF("id", "v") + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_tuple_to_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + assertVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_caseclass_from_df" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) + } + + it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_tuple_from_df" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + + it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) + } + + it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_tuple_from_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + + it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_rows_from_sql" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.conf.set("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog") + spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + +} + +class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") { + override def vectorFromInts(ints: Int*): Seq[Int] = ints + + override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v) +} + +case class IntVectorItem(id: Int, v: Seq[Int]) + +class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") { + override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong) + + override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v) +} + +case class LongVectorItem(id: Int, v: Seq[Long]) + +class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") { + override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f) + + override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v) +} + +case class FloatVectorItem(id: Int, v: Seq[Float]) + +class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") { + override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d) + + override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v) +} + +case class DoubleVectorItem(id: Int, v: Seq[Double]) + diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala index 19764a0a6..657fab249 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource import java.util.Locale import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType} import com.datastax.oss.driver.api.core.`type`.DataTypes._ import com.datastax.dse.driver.api.core.`type`.DseDataTypes._ import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata} @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging { case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable) case udt: UserDefinedType => fromUdt(udt) case t: TupleType => fromTuple(t) + case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable) case VARINT => logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.") primitiveCatalystDataType(cassandraType) diff --git a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala index 1287bf23d..0f279215d 100644 --- a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala +++ b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging { cassandraType match { case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) + case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable) case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField)) case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple)) diff --git a/doc/14_data_frames.md b/doc/14_data_frames.md index 291dd74f1..aa97ccecd 100644 --- a/doc/14_data_frames.md +++ b/doc/14_data_frames.md @@ -122,6 +122,9 @@ CREATE TABLE casscatalog.ksname.testTable (key_1 Int, key_2 Int, key_3 Int, cc1 Any statements that involve creating a Table are also supported like `CREATE TABLE AS SELECT` +Note that creating columns of Cassandra vector type is not supported yet. Such columns have to +be created manually with CQL. + #### Altering Tables All table properties can be changed and normal columns can be added and removed @@ -157,6 +160,12 @@ Reading with Scala spark.read.table("casscatalog.ksname.testTable") ``` +Reading vectors, specifying a predicate on vector column +```sql +SELECT name, features FROM casscatalog.test.things WHERE features = array(float(1), float(1.5), float(4)) +``` + + #### Writing Examples Writing with Sql @@ -164,11 +173,17 @@ Writing with Sql INSERT INTO casscatalog.ksname.testTable SELECT * from casscatalog.ksname.testTable2 ``` + Writing with Scala ```scala df.writeTo("casscatalog.ksname.testTable") ``` +Writing vectors +```sql +INSERT INTO casscatalog.test.things (id, name, features) VALUES (9, 'x', array(2, 3, 4)) +``` + #### Predicate Pushdown and Column Pruning The connector will automatically pushdown all valid predicates to Cassandra. The diff --git a/doc/2_loading.md b/doc/2_loading.md index 19e32ee3f..a0ada6018 100644 --- a/doc/2_loading.md +++ b/doc/2_loading.md @@ -30,7 +30,7 @@ CREATE TABLE test.words (word text PRIMARY KEY, count int); Load data into the table: -```scala +```sql INSERT INTO test.words (word, count) VALUES ('foo', 20); INSERT INTO test.words (word, count) VALUES ('bar', 20); ``` @@ -184,6 +184,50 @@ val street = address.getString("street") val number = address.getInt("number") ``` +### Reading vectors + +You can read vector columns in a Cassandra table similarly +to reading lists using `getList` or generic `get` methods of the +`CassandraRow` object. + +Assuming you set up the test keyspace earlier, follow these steps +to access a Cassandra collection. + +In the test keyspace, set up a collection set using cqlsh: + +```sql +CREATE TABLE test.things (id int PRIMARY KEY, name text, features vector); +INSERT INTO test.things (id, name, features) VALUES (1, 'a', [1.0, 2.0, 3.0]); +INSERT INTO test.things (id, name, features) VALUES (2, 'b', [2.2, 2.1, 2.0]); +INSERT INTO test.things (id, name, features) VALUES (3, 'c', [1.0, 1.5, 4.0]); +``` + +Then in your application, retrieve the first row: + +```scala +val row = sc.cassandraTable("test", "things").first +// row: com.datastax.spark.connector.CassandraRow = CassandraRow{id: 2, features: [2.2, 2.1, 2.0], name: b} +``` + +Query the vectors in Cassandra from Spark: + +```scala +row.getList[Float]("features") // Vector[Float] = Vector(2.2, 2.1, 2.0) +row.get[List[Float]]("features") // List[Float] = List(2.2, 2.1, 2.0) +row.get[Seq[Double]]("features") // Seq[Double] = List(2.200000047683716, 2.0999999046325684, 2.0) +row.get[IndexedSeq[Int]]("features") // IndexedSeq[Int] = Vector(2, 2, 2) +row.get[Set[Long]]("features") // Set[Long] = Set(2) +``` + +It is also possible to convert a vector to CQL `String` representation: + +```scala +scala> row.get[String]("features") // String = [2.2, 2.1, 2.0] +``` + +A `null` vector is equivalent to an empty list. You can also use +`get[Option[List[...]]]` to get `None` in case of `null`. + ### Data type conversions The following table shows recommended Scala types corresponding to Cassandra column types. @@ -213,6 +257,7 @@ The following table shows recommended Scala types corresponding to Cassandra col | `uuid` | `java.util.UUID` | `varchar` | `String` | `varint` | `BigInt`, `java.math.BigInteger` +| `vector` | `Vector`, `List`, `Iterable`, `Seq`, `IndexedSeq`, `java.util.List` | `frozen>` | `TupleValue`, `scala.Product`, `org.apache.commons.lang3.tuple.Pair`, `org.apache.commons.lang3.tuple.Triple` | user defined | `UDTValue` diff --git a/doc/4_mapper.md b/doc/4_mapper.md index d66a10d75..63b536ddd 100644 --- a/doc/4_mapper.md +++ b/doc/4_mapper.md @@ -14,6 +14,9 @@ sc.cassandraTable[(String, Int)]("test", "words").select("word", "count").toArra sc.cassandraTable[(Int, String)]("test", "words").select("count", "word").toArray // Array((20,bar), (10,foo)) + +scala> sc.cassandraTable[(String, List[Float])]("test", "things").select("name", "features").collect +// Array[(String, List[Float])] = Array((c,List(1.0, 1.5, 4.0)), (d,List()), (b,List(2.2, 2.1, 2.0)), (a,List(1.0, 2.0, 3.0))) ``` ### Mapping rows to (case) objects diff --git a/doc/5_saving.md b/doc/5_saving.md index 8a9f91717..f2f368776 100644 --- a/doc/5_saving.md +++ b/doc/5_saving.md @@ -170,6 +170,36 @@ cqlsh> Select * from ks.collections_mod where key = 1 (1 rows) ``` +## Saving Cassandra vectors + +```sql +CREATE TABLE test.things ( + id int PRIMARY KEY, + name text, + features vector +); +``` + +```scala +val newData = sc.parallelize(Seq((5, "e", List(5, 6, 7)), (6, "f", List(6, 7, 8)))) +// newData: org.apache.spark.rdd.RDD[(Int, String, List[Int])] = ParallelCollectionRDD[...] + +newData.saveToCassandra("test", "things", SomeColumns("id", "name", "features")) +``` + +```sql +cqlsh> select * from test.things ; + +id | features | name +---+---------------+------ + 5 | [5, 6, 7] | e + 6 | [6, 7, 8] | f + +(2 rows) +``` +Note that Cassandra vectors are fixed size and are not capable of adding or removing +elements from them. + ## Saving objects of Cassandra User Defined Types To save structures consisting of many fields, use a [Case Class](4_mapper.md#Mapping-User-Defined-Types) or a `com.datastax.spark.connector.UDTValue` class. An instance of this class @@ -481,6 +511,9 @@ val rddOut = rdd.map(s => outData(s._1, s._2(0), s._2(1), s._3)) rddOut.saveAsCassandraTableEx(table, SomeColumns("col1", "col2", "col3", "col4")) ``` +Note that creating columns of Cassandra vector type is not supported yet and each +time you want to save vectors you need to create the table manually with CQL. + ## Deleting Rows and Columns `RDD.deleteFromCassandra(keyspaceName, tableName)` deletes specific rows from the specified Cassandra table. The values in the RDD are diff --git a/doc/6_advanced_mapper.md b/doc/6_advanced_mapper.md index 5879d27da..d6b882651 100644 --- a/doc/6_advanced_mapper.md +++ b/doc/6_advanced_mapper.md @@ -122,6 +122,7 @@ Cassandra column type | Object type to convert from / to `uuid` | `java.util.UUID` `varchar` | `java.lang.String` `varint` | `java.math.BigInteger` + `vector` | `com.datastax.oss.driver.api.core.data.CqlVector` user defined | `com.datastax.spark.connector.UDTValue` Custom converters for collections are not supported. diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala index b4afa5da6..3c68b5f52 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper]( val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) + case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.forType[U](Seq(argConverter)) + case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala index 3b69267ae..05f682f47 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging { val valueConverter = converter(valueColumnType, valueScalaType) TypeConverter.javaHashMapConverter(keyConverter, valueConverter) + case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]]) + case (tt @ TupleType(argColumnType1, argColumnType2), TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) => val c1 = converter(argColumnType1.columnType, argScalaType1) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala index 99a3f9fb3..b5aaf57fb 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala @@ -7,7 +7,7 @@ import java.util.{Date, UUID} import com.datastax.dse.driver.api.core.`type`.DseDataTypes import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4 import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType} import com.datastax.spark.connector.util._ @@ -77,6 +77,7 @@ object ColumnType { case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen) case userType: DriverUserDefinedType => UserDefinedType(userType) case tupleType: DriverTupleType => TupleType(tupleType) + case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions) case dataType => primitiveTypeMap(dataType) } @@ -153,6 +154,7 @@ object ColumnType { val converter: TypeConverter[_] = dataType match { case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType)) + case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]]) case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType)) case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType)) case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala index 58615965b..1e123cf3f 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala @@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID} import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon} import com.datastax.dse.driver.api.core.data.time.DateRange -import com.datastax.oss.driver.api.core.data.CqlDuration +import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector} import com.datastax.spark.connector.TupleValue import com.datastax.spark.connector.UDTValue.UDTValueConverter import com.datastax.spark.connector.util.ByteBufferUtil @@ -700,6 +700,7 @@ object TypeConverter { case x: java.util.List[_] => newCollection(x.asScala) case x: java.util.Set[_] => newCollection(x.asScala) case x: java.util.Map[_, _] => newCollection(x.asScala) + case x: CqlVector[_] => newCollection(x.asScala) case x: Iterable[_] => newCollection(x) } } @@ -768,6 +769,29 @@ object TypeConverter { } } + class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] { + val elemConverter = implicitly[TypeConverter[T]] + + implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag + + @transient + lazy val targetTypeTag = { + implicitly[TypeTag[CqlVector[T]]] + } + + private def newCollection(items: Iterable[Any]): java.util.List[T] = { + val buf = new java.util.ArrayList[T](dimension) + for (item <- items) buf.add(elemConverter.convert(item)) + buf + } + + def convertPF = { + case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements? + case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala)) + case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x)) + } + } + class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] { @transient lazy val targetTypeTag = { @@ -869,6 +893,9 @@ object TypeConverter { implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] = new JavaArrayListConverter[T] + implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] = + new CqlVectorConverter[T](dimension) + implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] = new JavaSetConverter[T] diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala new file mode 100644 index 000000000..8060fe225 --- /dev/null +++ b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala @@ -0,0 +1,20 @@ +package com.datastax.spark.connector.types + +import scala.language.existentials +import scala.reflect.runtime.universe._ + +case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] { + + override def isCollection: Boolean = false + + @transient + lazy val scalaTypeTag = { + implicit val elemTypeTag = elemType.scalaTypeTag + implicitly[TypeTag[Seq[T]]] + } + + def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>" + + override def converterToCassandra: TypeConverter[_ <: AnyRef] = + new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra)) +} diff --git a/project/Versions.scala b/project/Versions.scala index 1afbd6ff0..203450f8e 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -5,7 +5,7 @@ object Versions { val CommonsLang3 = "3.10" val Paranamer = "2.8" - val CassandraJavaDriver = "4.18.0" + val CassandraJavaDriver = "4.18.1" val EsriGeometry = "2.2.4" val ScalaCheck = "1.14.0" diff --git a/test-support/src/main/scala/com/datastax/spark/connector/ccm/CcmConfig.scala b/test-support/src/main/scala/com/datastax/spark/connector/ccm/CcmConfig.scala index ef208d29f..e6238b297 100644 --- a/test-support/src/main/scala/com/datastax/spark/connector/ccm/CcmConfig.scala +++ b/test-support/src/main/scala/com/datastax/spark/connector/ccm/CcmConfig.scala @@ -19,7 +19,7 @@ case class CcmConfig( createOptions: List[String] = List(), dseWorkloads: List[String] = List(), jmxPortOffset: Int = 0, - version: Version = Version.parse(System.getProperty("ccm.version", "4.1.4")), + version: Version = Version.parse(System.getProperty("ccm.version", "5.0-beta1")), installDirectory: Option[String] = Option(System.getProperty("ccm.directory")), installBranch: Option[String] = Option(System.getProperty("ccm.branch")), dseEnabled: Boolean = Option(System.getProperty("ccm.dse")).exists(_.toLowerCase == "true"),