Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jacek-lewandowski committed May 8, 2024
1 parent 41ca19f commit 43073c7
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
| d14_varchar varchar,
| d15_varint varint,
| d16_address frozen<address>,
| d17_vector frozen<vector<int,3>>,
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
|)
""".stripMargin)
Expand Down Expand Up @@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {

"allow to read regular column definitions" in {
val columns = table.regularColumns
columns.size shouldBe 16
columns.size shouldBe 17
columns.map(_.columnName).toSet shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
"d15_varint", "d16_address")
"d15_varint", "d16_address", "d17_vector")
}

"allow to read proper types of columns" in {
Expand All @@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)
}

"allow to list fields of a user defined type" in {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.datastax.spark.connector.rdd.typeTests

import com.datastax.oss.driver.api.core.CqlSession
import com.datastax.spark.connector._
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import org.apache.spark.sql.cassandra.DataFrameWriterWrapper

class VectorTypeTest extends SparkCassandraITFlatSpecBase with DefaultCluster

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.12.19, 3.11.17)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.12.19, 4.0.12)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.12.19, 4.1.4)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.12.19, dse-6.8.44)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.13.13, 3.11.17)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.13.13, 4.0.12)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.13.13, 4.1.4)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)

Check failure on line 9 in connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala

View workflow job for this annotation

GitHub Actions / build (2.13.13, dse-6.8.44)

VectorTypeTest.(It is not a test it is a sbt.testing.SuiteSelector)

com.datastax.oss.driver.api.core.servererrors.SyntaxError: line 2:29 mismatched input '<' expecting ')' (...int PRIMARY KEY, v vector[<]...)
{
override lazy val conn = CassandraConnector(sparkConf)

val VectorTable = "vectors"

def createVectorTable(session: CqlSession): Unit = {
session.execute(
s"""CREATE TABLE IF NOT EXISTS $ks.$VectorTable
|(id int PRIMARY KEY, v vector<int, 5>);""".stripMargin)
}

override def beforeClass {
conn.withSessionDo { session =>
session.execute(
s"""CREATE KEYSPACE IF NOT EXISTS $ks
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
.stripMargin)
createVectorTable(session)
}
}

"SparkSql" should "write tuples with BLOB elements" in {
spark.createDataFrame(Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10))))
.write
.cassandraFormat(VectorTable, ks)
.mode("append")
.save()

val tupleRows = spark.sparkContext
.cassandraTable[(Int, Seq[Int])](ks, VectorTable)
.collect()
.toList

tupleRows should contain theSameElementsAs Seq((1, Seq(1,2,3,4,5)), (2, Seq(6,7,8,9,10)))

val rows = spark.sparkContext
.cassandraTable[VectorItem](ks, VectorTable)
.collect()
.toList

rows should contain theSameElementsAs Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10)))
}

}

case class VectorItem(id: Int, v: Seq[Int])
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource

import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
Expand Down Expand Up @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
primitiveCatalystDataType(cassandraType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)

case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]])

case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import java.util.{Date, UUID}
import com.datastax.dse.driver.api.core.`type`.DseDataTypes
import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4
import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType}
import com.datastax.spark.connector.util._


Expand Down Expand Up @@ -77,6 +77,7 @@ object ColumnType {
case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen)
case userType: DriverUserDefinedType => UserDefinedType(userType)
case tupleType: DriverTupleType => TupleType(tupleType)
case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions)
case dataType => primitiveTypeMap(dataType)
}

Expand Down Expand Up @@ -153,6 +154,7 @@ object ColumnType {
val converter: TypeConverter[_] =
dataType match {
case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType))
case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]])
case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType))
case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType))
case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID}

import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon}
import com.datastax.dse.driver.api.core.data.time.DateRange
import com.datastax.oss.driver.api.core.data.CqlDuration
import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector}
import com.datastax.spark.connector.TupleValue
import com.datastax.spark.connector.UDTValue.UDTValueConverter
import com.datastax.spark.connector.util.ByteBufferUtil
Expand Down Expand Up @@ -700,6 +700,7 @@ object TypeConverter {
case x: java.util.List[_] => newCollection(x.asScala)
case x: java.util.Set[_] => newCollection(x.asScala)
case x: java.util.Map[_, _] => newCollection(x.asScala)
case x: CqlVector[_] => newCollection(x.asScala)
case x: Iterable[_] => newCollection(x)
}
}
Expand Down Expand Up @@ -768,6 +769,29 @@ object TypeConverter {
}
}

class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] {
val elemConverter = implicitly[TypeConverter[T]]

implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag

@transient
lazy val targetTypeTag = {
implicitly[TypeTag[CqlVector[T]]]
}

def newCollection(items: Iterable[Any]): java.util.List[T] = {
val buf = new java.util.ArrayList[T](dimension)
for (item <- items) buf.add(elemConverter.convert(item))
buf
}

def convertPF = {
case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements?
case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala))
case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x))
}
}

class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] {
@transient
lazy val targetTypeTag = {
Expand Down Expand Up @@ -869,6 +893,9 @@ object TypeConverter {
implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] =
new JavaArrayListConverter[T]

implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] =
new CqlVectorConverter[T](dimension)

implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] =
new JavaSetConverter[T]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.datastax.spark.connector.types

import scala.language.existentials
import scala.reflect.runtime.universe._

case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] {

override def isCollection: Boolean = false

@transient
lazy val scalaTypeTag = {
implicit val elemTypeTag = elemType.scalaTypeTag
implicitly[TypeTag[Seq[T]]]
}

def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>"

override def converterToCassandra: TypeConverter[_ <: AnyRef] =
new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra))
}

0 comments on commit 43073c7

Please sign in to comment.