@@ -20,18 +20,20 @@ import java.util.UUID
2020
2121import scala .collection .MapView
2222import scala .collection .immutable .HashMap
23- import scala .collection .mutable .HashSet
2423
2524import org .apache .hadoop .conf .Configuration
2625import org .apache .hadoop .fs .Path
2726
2827import org .apache .spark .sql .catalyst .InternalRow
2928import org .apache .spark .sql .catalyst .expressions .UnsafeRow
29+ import org .apache .spark .sql .execution .streaming .operators .stateful .transformwithstate .StateStoreColumnFamilySchemaUtils
3030import org .apache .spark .sql .execution .streaming .runtime .StreamingCheckpointConstants .DIR_NAME_STATE
31- import org .apache .spark .sql .execution .streaming .state .{StateStore , StateStoreColFamilySchema , StateStoreConf , StateStoreId , StateStoreProvider , StateStoreProviderId }
31+ import org .apache .spark .sql .execution .streaming .state .{NoPrefixKeyStateEncoderSpec , StateStore , StateStoreColFamilySchema , StateStoreConf , StateStoreId , StateStoreProvider , StateStoreProviderId }
32+ import org .apache .spark .sql .types .{NullType , StructField , StructType }
3233
3334case class StatePartitionWriterColumnFamilyInfo (
3435 schema : StateStoreColFamilySchema ,
36+ // set this to true if state variable is ListType in TransformWithState
3537 useMultipleValuesPerKey : Boolean = false )
3638/**
3739 * A writer that can directly write binary data to the streaming state store.
@@ -53,24 +55,34 @@ class StatePartitionAllColumnFamiliesWriter(
5355 storeName : String ,
5456 currentBatchId : Long ,
5557 columnFamilyToSchemaMap : HashMap [String , StatePartitionWriterColumnFamilyInfo ]) {
58+ private val dummySchema : StructType =
59+ StructType (Array (StructField (" __dummy__" , NullType )))
5660 private val defaultSchema = {
57- columnFamilyToSchemaMap.getOrElse(
58- StateStore .DEFAULT_COL_FAMILY_NAME ,
59- columnFamilyToSchemaMap.head._2 // join V3 doesn't have default col family
60- ).schema
61+ columnFamilyToSchemaMap.get(StateStore .DEFAULT_COL_FAMILY_NAME ) match {
62+ case Some (info) => info.schema
63+ case None =>
64+ // Return a dummy StateStoreColFamilySchema if not found
65+ StateStoreColFamilySchema (
66+ colFamilyName = " __dummy__" ,
67+ keySchemaId = 0 ,
68+ keySchema = dummySchema,
69+ valueSchemaId = 0 ,
70+ valueSchema = dummySchema,
71+ keyStateEncoderSpec = Option (NoPrefixKeyStateEncoderSpec (dummySchema)))
72+ }
6173 }
6274
6375 private val columnFamilyToKeySchemaLenMap : MapView [String , Int ] =
6476 columnFamilyToSchemaMap.view.mapValues(_.schema.keySchema.length)
6577 private val columnFamilyToValueSchemaLenMap : MapView [String , Int ] =
6678 columnFamilyToSchemaMap.view.mapValues(_.schema.valueSchema.length)
67- private val colFamilyHasWritten : HashSet [String ] = HashSet [String ]()
6879
6980 protected lazy val provider : StateStoreProvider = {
7081 val stateCheckpointLocation = new Path (targetCpLocation, DIR_NAME_STATE ).toString
7182 val stateStoreId = StateStoreId (stateCheckpointLocation,
7283 operatorId, partitionId, storeName)
7384 val stateStoreProviderId = StateStoreProviderId (stateStoreId, UUID .randomUUID())
85+
7486 val useColumnFamilies = columnFamilyToSchemaMap.size > 1
7587 val provider = StateStoreProvider .createAndInit(
7688 stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
@@ -99,7 +111,7 @@ class StatePartitionAllColumnFamiliesWriter(
99111 colFamilyName match {
100112 case StateStore .DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
101113 case _ =>
102- val isInternal = colFamilyName.startsWith( " $ " )
114+ val isInternal = StateStoreColumnFamilySchemaUtils .isInternalColFamily(colFamilyName )
103115
104116 require(cfSchema.keyStateEncoderSpec.isDefined,
105117 s " keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}" )
@@ -125,7 +137,6 @@ class StatePartitionAllColumnFamiliesWriter(
125137 try {
126138 rows.foreach(row => writeRow(row))
127139 stateStore.commit()
128- colFamilyHasWritten.clear()
129140 } finally {
130141 if (! stateStore.hasCommitted) {
131142 stateStore.abort()
@@ -151,12 +162,12 @@ class StatePartitionAllColumnFamiliesWriter(
151162 val valueRow = new UnsafeRow (columnFamilyToValueSchemaLenMap(colFamilyName))
152163 valueRow.pointTo(valueBytes, valueBytes.length)
153164
154- if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey
155- && colFamilyHasWritten(colFamilyName)) {
165+ if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
166+ // if a column family useMultipleValuesPerKey (e.g. ListType), we will
167+ // write with 1 put followed by merge
156168 stateStore.merge(keyRow, valueRow, colFamilyName)
157169 } else {
158170 stateStore.put(keyRow, valueRow, colFamilyName)
159- colFamilyHasWritten.add(colFamilyName)
160171 }
161172 }
162173}
0 commit comments