diff --git a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala index f30f09d0b..a3b30aef9 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala @@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { | d14_varchar varchar, | d15_varint varint, | d16_address frozen
, + | d17_vector frozen>, | PRIMARY KEY ((k1, k2, k3), c1, c2, c3) |) """.stripMargin) @@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { "allow to read regular column definitions" in { val columns = table.regularColumns - columns.size shouldBe 16 + columns.size shouldBe 17 columns.map(_.columnName).toSet shouldBe Set( "d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float", "d6_inet", "d7_int", "d8_list", "d9_map", "d10_set", "d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar", - "d15_varint", "d16_address") + "d15_varint", "d16_address", "d17_vector") } "allow to read proper types of columns" in { @@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { table.columnByName("d14_varchar").columnType shouldBe VarCharType table.columnByName("d15_varint").columnType shouldBe VarIntType table.columnByName("d16_address").columnType shouldBe a [UserDefinedType] + table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3) } "allow to list fields of a user defined type" in { diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala new file mode 100644 index 000000000..9d3452746 --- /dev/null +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala @@ -0,0 +1,55 @@ +package com.datastax.spark.connector.rdd.typeTests + +import com.datastax.oss.driver.api.core.CqlSession +import com.datastax.spark.connector._ +import com.datastax.spark.connector.cluster.DefaultCluster +import com.datastax.spark.connector.cql.CassandraConnector +import org.apache.spark.sql.cassandra.DataFrameWriterWrapper + +class VectorTypeTest extends SparkCassandraITFlatSpecBase with DefaultCluster +{ + override lazy val conn = CassandraConnector(sparkConf) + + val VectorTable = "vectors" + + def createVectorTable(session: CqlSession): Unit = { + session.execute( + s"""CREATE TABLE IF NOT EXISTS $ks.$VectorTable + |(id int PRIMARY KEY, v vector);""".stripMargin) + } + + override def beforeClass { + conn.withSessionDo { session => + session.execute( + s"""CREATE KEYSPACE IF NOT EXISTS $ks + |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }""" + .stripMargin) + createVectorTable(session) + } + } + + "SparkSql" should "write tuples with BLOB elements" in { + spark.createDataFrame(Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10)))) + .write + .cassandraFormat(VectorTable, ks) + .mode("append") + .save() + + val tupleRows = spark.sparkContext + .cassandraTable[(Int, Seq[Int])](ks, VectorTable) + .collect() + .toList + + tupleRows should contain theSameElementsAs Seq((1, Seq(1,2,3,4,5)), (2, Seq(6,7,8,9,10))) + + val rows = spark.sparkContext + .cassandraTable[VectorItem](ks, VectorTable) + .collect() + .toList + + rows should contain theSameElementsAs Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10))) + } + +} + +case class VectorItem(id: Int, v: Seq[Int]) diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala index 19764a0a6..657fab249 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource import java.util.Locale import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType} import com.datastax.oss.driver.api.core.`type`.DataTypes._ import com.datastax.dse.driver.api.core.`type`.DseDataTypes._ import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata} @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging { case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable) case udt: UserDefinedType => fromUdt(udt) case t: TupleType => fromTuple(t) + case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable) case VARINT => logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.") primitiveCatalystDataType(cassandraType) diff --git a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala index 1287bf23d..0f279215d 100644 --- a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala +++ b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging { cassandraType match { case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) + case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable) case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField)) case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple)) diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala index b4afa5da6..3c68b5f52 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper]( val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) + case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.forType[U](Seq(argConverter)) + case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala index 3b69267ae..05f682f47 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging { val valueConverter = converter(valueColumnType, valueScalaType) TypeConverter.javaHashMapConverter(keyConverter, valueConverter) + case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]]) + case (tt @ TupleType(argColumnType1, argColumnType2), TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) => val c1 = converter(argColumnType1.columnType, argScalaType1) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala index 99a3f9fb3..b5aaf57fb 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala @@ -7,7 +7,7 @@ import java.util.{Date, UUID} import com.datastax.dse.driver.api.core.`type`.DseDataTypes import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4 import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType} import com.datastax.spark.connector.util._ @@ -77,6 +77,7 @@ object ColumnType { case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen) case userType: DriverUserDefinedType => UserDefinedType(userType) case tupleType: DriverTupleType => TupleType(tupleType) + case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions) case dataType => primitiveTypeMap(dataType) } @@ -153,6 +154,7 @@ object ColumnType { val converter: TypeConverter[_] = dataType match { case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType)) + case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]]) case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType)) case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType)) case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala index 58615965b..ea13093e9 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala @@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID} import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon} import com.datastax.dse.driver.api.core.data.time.DateRange -import com.datastax.oss.driver.api.core.data.CqlDuration +import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector} import com.datastax.spark.connector.TupleValue import com.datastax.spark.connector.UDTValue.UDTValueConverter import com.datastax.spark.connector.util.ByteBufferUtil @@ -700,6 +700,7 @@ object TypeConverter { case x: java.util.List[_] => newCollection(x.asScala) case x: java.util.Set[_] => newCollection(x.asScala) case x: java.util.Map[_, _] => newCollection(x.asScala) + case x: CqlVector[_] => newCollection(x.asScala) case x: Iterable[_] => newCollection(x) } } @@ -768,6 +769,29 @@ object TypeConverter { } } + class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] { + val elemConverter = implicitly[TypeConverter[T]] + + implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag + + @transient + lazy val targetTypeTag = { + implicitly[TypeTag[CqlVector[T]]] + } + + def newCollection(items: Iterable[Any]): java.util.List[T] = { + val buf = new java.util.ArrayList[T](dimension) + for (item <- items) buf.add(elemConverter.convert(item)) + buf + } + + def convertPF = { + case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements? + case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala)) + case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x)) + } + } + class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] { @transient lazy val targetTypeTag = { @@ -869,6 +893,9 @@ object TypeConverter { implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] = new JavaArrayListConverter[T] + implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] = + new CqlVectorConverter[T](dimension) + implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] = new JavaSetConverter[T] diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala new file mode 100644 index 000000000..8060fe225 --- /dev/null +++ b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala @@ -0,0 +1,20 @@ +package com.datastax.spark.connector.types + +import scala.language.existentials +import scala.reflect.runtime.universe._ + +case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] { + + override def isCollection: Boolean = false + + @transient + lazy val scalaTypeTag = { + implicit val elemTypeTag = elemType.scalaTypeTag + implicitly[TypeTag[Seq[T]]] + } + + def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>" + + override def converterToCassandra: TypeConverter[_ <: AnyRef] = + new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra)) +}