diff --git a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
index f30f09d0b..a3b30aef9 100644
--- a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
+++ b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala
@@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
| d14_varchar varchar,
| d15_varint varint,
| d16_address frozen
,
+ | d17_vector frozen>,
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
|)
""".stripMargin)
@@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
"allow to read regular column definitions" in {
val columns = table.regularColumns
- columns.size shouldBe 16
+ columns.size shouldBe 17
columns.map(_.columnName).toSet shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
- "d15_varint", "d16_address")
+ "d15_varint", "d16_address", "d17_vector")
}
"allow to read proper types of columns" in {
@@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
+ table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)
}
"allow to list fields of a user defined type" in {
diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala
new file mode 100644
index 000000000..9d3452746
--- /dev/null
+++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala
@@ -0,0 +1,55 @@
+package com.datastax.spark.connector.rdd.typeTests
+
+import com.datastax.oss.driver.api.core.CqlSession
+import com.datastax.spark.connector._
+import com.datastax.spark.connector.cluster.DefaultCluster
+import com.datastax.spark.connector.cql.CassandraConnector
+import org.apache.spark.sql.cassandra.DataFrameWriterWrapper
+
+class VectorTypeTest extends SparkCassandraITFlatSpecBase with DefaultCluster
+{
+ override lazy val conn = CassandraConnector(sparkConf)
+
+ val VectorTable = "vectors"
+
+ def createVectorTable(session: CqlSession): Unit = {
+ session.execute(
+ s"""CREATE TABLE IF NOT EXISTS $ks.$VectorTable
+ |(id int PRIMARY KEY, v vector);""".stripMargin)
+ }
+
+ override def beforeClass {
+ conn.withSessionDo { session =>
+ session.execute(
+ s"""CREATE KEYSPACE IF NOT EXISTS $ks
+ |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
+ .stripMargin)
+ createVectorTable(session)
+ }
+ }
+
+ "SparkSql" should "write tuples with BLOB elements" in {
+ spark.createDataFrame(Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10))))
+ .write
+ .cassandraFormat(VectorTable, ks)
+ .mode("append")
+ .save()
+
+ val tupleRows = spark.sparkContext
+ .cassandraTable[(Int, Seq[Int])](ks, VectorTable)
+ .collect()
+ .toList
+
+ tupleRows should contain theSameElementsAs Seq((1, Seq(1,2,3,4,5)), (2, Seq(6,7,8,9,10)))
+
+ val rows = spark.sparkContext
+ .cassandraTable[VectorItem](ks, VectorTable)
+ .collect()
+ .toList
+
+ rows should contain theSameElementsAs Seq(VectorItem(1, Seq(1,2,3,4,5)), VectorItem(2, Seq(6,7,8,9,10)))
+ }
+
+}
+
+case class VectorItem(id: Int, v: Seq[Int])
diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
index 19764a0a6..657fab249 100644
--- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
+++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala
@@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource
import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
-import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
+import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
@@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
+ case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
primitiveCatalystDataType(cassandraType)
diff --git a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
index 1287bf23d..0f279215d 100644
--- a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
+++ b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala
@@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
+ case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))
diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
index b4afa5da6..3c68b5f52 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala
@@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
+ case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
+ val argConverter = converter(argColumnType, argScalaType)
+ TypeConverter.forType[U](Seq(argConverter))
+
case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
index 3b69267ae..05f682f47 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala
@@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)
+ case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
+ val argConverter = converter(argColumnType, argScalaType)
+ TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]])
+
case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
index 99a3f9fb3..b5aaf57fb 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala
@@ -7,7 +7,7 @@ import java.util.{Date, UUID}
import com.datastax.dse.driver.api.core.`type`.DseDataTypes
import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4
import com.datastax.oss.driver.api.core.ProtocolVersion
-import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType}
+import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType}
import com.datastax.spark.connector.util._
@@ -77,6 +77,7 @@ object ColumnType {
case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen)
case userType: DriverUserDefinedType => UserDefinedType(userType)
case tupleType: DriverTupleType => TupleType(tupleType)
+ case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions)
case dataType => primitiveTypeMap(dataType)
}
@@ -153,6 +154,7 @@ object ColumnType {
val converter: TypeConverter[_] =
dataType match {
case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType))
+ case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]])
case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType))
case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType))
case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt)
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
index 58615965b..ea13093e9 100644
--- a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala
@@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID}
import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon}
import com.datastax.dse.driver.api.core.data.time.DateRange
-import com.datastax.oss.driver.api.core.data.CqlDuration
+import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector}
import com.datastax.spark.connector.TupleValue
import com.datastax.spark.connector.UDTValue.UDTValueConverter
import com.datastax.spark.connector.util.ByteBufferUtil
@@ -700,6 +700,7 @@ object TypeConverter {
case x: java.util.List[_] => newCollection(x.asScala)
case x: java.util.Set[_] => newCollection(x.asScala)
case x: java.util.Map[_, _] => newCollection(x.asScala)
+ case x: CqlVector[_] => newCollection(x.asScala)
case x: Iterable[_] => newCollection(x)
}
}
@@ -768,6 +769,29 @@ object TypeConverter {
}
}
+ class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] {
+ val elemConverter = implicitly[TypeConverter[T]]
+
+ implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag
+
+ @transient
+ lazy val targetTypeTag = {
+ implicitly[TypeTag[CqlVector[T]]]
+ }
+
+ def newCollection(items: Iterable[Any]): java.util.List[T] = {
+ val buf = new java.util.ArrayList[T](dimension)
+ for (item <- items) buf.add(elemConverter.convert(item))
+ buf
+ }
+
+ def convertPF = {
+ case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements?
+ case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala))
+ case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x))
+ }
+ }
+
class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] {
@transient
lazy val targetTypeTag = {
@@ -869,6 +893,9 @@ object TypeConverter {
implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] =
new JavaArrayListConverter[T]
+ implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] =
+ new CqlVectorConverter[T](dimension)
+
implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] =
new JavaSetConverter[T]
diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala
new file mode 100644
index 000000000..8060fe225
--- /dev/null
+++ b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala
@@ -0,0 +1,20 @@
+package com.datastax.spark.connector.types
+
+import scala.language.existentials
+import scala.reflect.runtime.universe._
+
+case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] {
+
+ override def isCollection: Boolean = false
+
+ @transient
+ lazy val scalaTypeTag = {
+ implicit val elemTypeTag = elemType.scalaTypeTag
+ implicitly[TypeTag[Seq[T]]]
+ }
+
+ def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>"
+
+ override def converterToCassandra: TypeConverter[_ <: AnyRef] =
+ new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra))
+}