@@ -59,21 +59,21 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
5959 * @param columnFamilyToSchemaMap Map of column family names to their schemas
6060 * @param storeName Optional store name (for stream-stream join which has multiple stores)
6161 * @param columnFamilyToSelectExprs Map of column family names to custom selectExprs
62- * @param columnFamilyToReaderOptions Map of column family names to reader options
62+ * @param columnFamilyToStateSourceOptions Map of column family names to state source options
6363 */
6464 private def performRoundTripTest (
6565 sourceDir : String ,
6666 targetDir : String ,
6767 columnFamilyToSchemaMap : HashMap [String , StatePartitionWriterColumnFamilyInfo ],
6868 storeName : Option [String ] = None ,
6969 columnFamilyToSelectExprs : Map [String , Seq [String ]] = Map .empty,
70- columnFamilyToReaderOptions : Map [String , Map [String , String ]] = Map .empty): Unit = {
70+ columnFamilyToStateSourceOptions : Map [String , Map [String , String ]] = Map .empty): Unit = {
7171
7272 // Determine column families to validate based on storeName and map size
7373 val columnFamiliesToValidate : Seq [String ] = storeName match {
7474 case Some (name) => Seq (name)
7575 case None if columnFamilyToSchemaMap.size > 1 => columnFamilyToSchemaMap.keys.toSeq
76- case None => Seq .empty // Will use default reader without store name filter
76+ case None => Seq ( StateStoreId . DEFAULT_STORE_NAME )
7777 }
7878
7979 // Step 1: Read from source using AllColumnFamiliesReader (raw bytes)
@@ -119,7 +119,8 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
119119
120120 // Use per-column-family converters when there are multiple column families
121121 if (columnFamilyToSchemaMap.size > 1 ) {
122- // can remove this once we have extractKeySchema landed
122+ // TODO: Remove the logic of getting colNameToRowConverter once allColumnFamiliesReader is
123+ // returning actual partitionKeySchema instead of the entire key
123124 val colNameToRowConverter = columnFamilyToSchemaMap.view.mapValues { colInfo =>
124125 val cfSchema = SchemaUtil .getScanAllColumnFamiliesSchema(colInfo.schema.keySchema)
125126 CatalystTypeConverters .createToCatalystConverter(cfSchema)
@@ -134,7 +135,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
134135 }
135136 }
136137
137- // Write raw bytes to target using foreachPartition"
138+ // Write raw bytes to target using foreachPartition
138139 sourceBytesData.foreachPartition(putPartitionFunc)
139140
140141 // Commit to commitLog
@@ -151,50 +152,33 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
151152
152153 def shouldCheckColumnFamilyName : String => Boolean = name => {
153154 (! name.startsWith(" $" )
154- || (columnFamilyToReaderOptions .contains(name) &&
155- columnFamilyToReaderOptions (name).contains(StateSourceOptions .READ_REGISTERED_TIMERS )))
155+ || (columnFamilyToStateSourceOptions .contains(name) &&
156+ columnFamilyToStateSourceOptions (name).contains(StateSourceOptions .READ_REGISTERED_TIMERS )))
156157 }
157- if (columnFamiliesToValidate.nonEmpty) {
158- // Validate each column family separately (skip internal column families starting with $)
159- columnFamiliesToValidate
160- .filter(shouldCheckColumnFamilyName)
161- .foreach { cfName =>
162- val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, defaultSelectExprs)
163- val readerOptions = columnFamilyToReaderOptions.getOrElse(cfName, Map .empty)
164- def readNormalData (dir : String ): Array [Row ] = {
165- var reader = spark.read
166- .format(" statestore" )
167- .option(StateSourceOptions .PATH , dir)
168- .option(StateSourceOptions .STORE_NAME , storeName.orNull)
169- readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
170- reader.load()
171- .selectExpr(selectExprs : _* )
172- .collect()
158+ // Validate each column family separately (skip internal column families starting with $)
159+ columnFamiliesToValidate
160+ // TODO: How to validate that internal columns are written correctly?
161+ .filter(shouldCheckColumnFamilyName)
162+ .foreach { cfName =>
163+ val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, defaultSelectExprs)
164+ val readerOptions = columnFamilyToStateSourceOptions.getOrElse(cfName, Map .empty)
165+
166+ def readNormalData (dir : String ): Array [Row ] = {
167+ var reader = spark.read
168+ .format(" statestore" )
169+ .option(StateSourceOptions .PATH , dir)
170+ .option(StateSourceOptions .STORE_NAME , storeName.orNull)
171+ readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
172+ reader.load()
173+ .selectExpr(selectExprs : _* )
174+ .collect()
173175 }
174176
175177 val sourceNormalData = readNormalData(sourceDir)
176178 val targetNormalData = readNormalData(targetDir)
177179
178180 validateDataMatches(sourceNormalData, targetNormalData)
179181 }
180- } else {
181- // Default validation without store name filter
182- val sourceNormalData = spark.read
183- .format(" statestore" )
184- .option(StateSourceOptions .PATH , sourceDir)
185- .load()
186- .selectExpr(" key" , " value" , " partition_id" )
187- .collect()
188-
189- val targetNormalData = spark.read
190- .format(" statestore" )
191- .option(StateSourceOptions .PATH , targetDir)
192- .load()
193- .selectExpr(" key" , " value" , " partition_id" )
194- .collect()
195-
196- validateDataMatches(sourceNormalData, targetNormalData)
197- }
198182 }
199183
200184 /**
@@ -207,7 +191,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
207191 s " Row count mismatch: source= ${sourceNormalData.length}, " +
208192 s " target= ${targetNormalData.length}" )
209193
210- // Sort and compare row by row"
194+ // Sort and compare row by row
211195 val sourceSorted = sourceNormalData.sortBy(_.toString)
212196 val targetSorted = targetNormalData.sortBy(_.toString)
213197
@@ -410,6 +394,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
410394 }
411395 }
412396
397+ private val keyToNumValuesColFamilyNames = Seq (" left-keyToNumValues" , " right-keyToNumValues" )
398+ private val keyWithIndexToValueColFamilyNames = Seq (
399+ " left-keyWithIndexToValue" , " right-keyWithIndexToValue" )
400+
401+ private def getJoinV3ColumnSchemaMap (): HashMap [String , StatePartitionWriterColumnFamilyInfo ] = {
402+ val keyToNumValuesKeySchema = StructType (Array (StructField (" key" , IntegerType )))
403+ val keyToNumValuesValueSchema = StructType (Array (StructField (" value" , LongType )))
404+ val keyToNumValuesEncoderSpec = NoPrefixKeyStateEncoderSpec (keyToNumValuesKeySchema)
405+
406+ val keyWithIndexKeySchema = StructType (Array (
407+ StructField (" key" , IntegerType , nullable = false ),
408+ StructField (" index" , LongType )
409+ ))
410+ val keyWithIndexValueSchema = StructType (Array (
411+ StructField (" value" , IntegerType , nullable = false ),
412+ StructField (" time" , TimestampType , nullable = false ),
413+ StructField (" matched" , BooleanType )
414+ ))
415+ val keyWithIndexEncoderSpec = NoPrefixKeyStateEncoderSpec (keyWithIndexKeySchema)
416+
417+ // Build column family to schema map for all 4 join stores
418+ keyToNumValuesColFamilyNames.map { name =>
419+ createSingleColumnFamilySchemaMap(
420+ keyToNumValuesKeySchema, keyToNumValuesValueSchema, keyToNumValuesEncoderSpec, name)
421+ }.reduce(_ ++ _) ++ keyWithIndexToValueColFamilyNames.map { name =>
422+ createSingleColumnFamilySchemaMap(
423+ keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec, name)
424+ }.reduce(_ ++ _)
425+ }
413426 /**
414427 * Helper method to test round-trip for stream-stream join with different versions.
415428 */
@@ -505,48 +518,20 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
505518 )
506519 }
507520
508- runQuery(sourceDir.getAbsolutePath, 2 )
509- runQuery(targetDir.getAbsolutePath, 1 )
510-
511- // Define schemas for V3 join state stores
512- val keyToNumValuesKeySchema = StructType (Array (StructField (" key" , IntegerType )))
513- val keyToNumValuesValueSchema = StructType (Array (StructField (" value" , LongType )))
514- val keyToNumValuesEncoderSpec = NoPrefixKeyStateEncoderSpec (keyToNumValuesKeySchema)
521+ // varying the roundsOfData so that sourceDir and targetDir have different state
522+ runQuery(sourceDir.getAbsolutePath, roundsOfData = 2 )
523+ runQuery(targetDir.getAbsolutePath, roundsOfData = 1 )
515524
516- val keyWithIndexKeySchema = StructType (Array (
517- StructField (" key" , IntegerType , nullable = false ),
518- StructField (" index" , LongType )
519- ))
520- val keyWithIndexValueSchema = StructType (Array (
521- StructField (" value" , IntegerType , nullable = false ),
522- StructField (" time" , TimestampType , nullable = false ),
523- StructField (" matched" , BooleanType )
524- ))
525- val keyWithIndexEncoderSpec = NoPrefixKeyStateEncoderSpec (keyWithIndexKeySchema)
526-
527- val keyToNumValuesColFamilyNames = Seq (" left-keyToNumValues" , " right-keyToNumValues" )
528- val keyWithIndexToValueColFamilyNames = Seq (
529- " left-keyWithIndexToValue" , " right-keyWithIndexToValue" )
530- // Build column family to schema map for all 4 join stores
531- val columnFamilyToSchemaMap =
532- keyToNumValuesColFamilyNames.map { name =>
533- createSingleColumnFamilySchemaMap(
534- keyToNumValuesKeySchema, keyToNumValuesValueSchema, keyToNumValuesEncoderSpec, name)
535- }.reduce(_ ++ _) ++ keyWithIndexToValueColFamilyNames.map { name =>
536- createSingleColumnFamilySchemaMap(
537- keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec, name)
538- }.reduce(_ ++ _)
539- val columnFamilyToReaderOptions =
540- (keyToNumValuesColFamilyNames ++ keyWithIndexToValueColFamilyNames).map {
541- colName =>
542- colName -> Map (StateSourceOptions .STORE_NAME -> colName)
543- }.toMap
544525 // Perform round-trip test using common helper
545526 performRoundTripTest(
546527 sourceDir.getAbsolutePath,
547528 targetDir.getAbsolutePath,
548- columnFamilyToSchemaMap,
549- columnFamilyToReaderOptions = columnFamilyToReaderOptions
529+ getJoinV3ColumnSchemaMap(),
530+ columnFamilyToStateSourceOptions =
531+ (keyToNumValuesColFamilyNames ++ keyWithIndexToValueColFamilyNames).map {
532+ colName =>
533+ colName -> Map (StateSourceOptions .STORE_NAME -> colName)
534+ }.toMap
550535 )
551536 }
552537 }
@@ -690,7 +675,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
690675 )
691676
692677 // Define reader options for column families that need them
693- val columnFamilyToReaderOptions = Map (
678+ val columnFamilyToStateSourceOptions = Map (
694679 " itemsList" -> Map (StateSourceOptions .FLATTEN_COLLECTION_TYPES -> " true" ,
695680 StateSourceOptions .STATE_VAR_NAME -> " itemsList" ),
696681 " itemsMap" -> Map (StateSourceOptions .STATE_VAR_NAME -> " itemsMap" ),
@@ -703,7 +688,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
703688 targetDir.getAbsolutePath,
704689 columnFamilyToSchemaMap,
705690 columnFamilyToSelectExprs = columnFamilyToSelectExprs,
706- columnFamilyToReaderOptions = columnFamilyToReaderOptions
691+ columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
707692 )
708693 }
709694 }
@@ -787,7 +772,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
787772 )
788773
789774 // Timer column families need READ_REGISTERED_TIMERS option
790- val columnFamilyToReaderOptions = Map (
775+ val columnFamilyToStateSourceOptions = Map (
791776 " countState" -> Map (StateSourceOptions .STATE_VAR_NAME -> " countState" ),
792777 " $eventTimers_keyToTimestamp" -> Map (
793778 StateSourceOptions .READ_REGISTERED_TIMERS -> " true" ),
@@ -800,7 +785,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
800785 targetDir.getAbsolutePath,
801786 columnFamilyToSchemaMap,
802787 columnFamilyToSelectExprs = columnFamilyToSelectExprs,
803- columnFamilyToReaderOptions = columnFamilyToReaderOptions
788+ columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
804789 )
805790 }
806791 }
@@ -887,7 +872,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
887872 )
888873
889874 // Timer column families need READ_REGISTERED_TIMERS option
890- val columnFamilyToReaderOptions = Map (
875+ val columnFamilyToStateSourceOptions = Map (
891876 " countState" -> Map (StateSourceOptions .STATE_VAR_NAME -> " countState" ),
892877 " $procTimers_keyToTimestamp" -> Map (
893878 StateSourceOptions .READ_REGISTERED_TIMERS -> " true" ),
@@ -900,7 +885,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
900885 targetDir.getAbsolutePath,
901886 columnFamilyToSchemaMap,
902887 columnFamilyToSelectExprs = columnFamilyToSelectExprs,
903- columnFamilyToReaderOptions = columnFamilyToReaderOptions
888+ columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
904889 )
905890 }
906891 }
@@ -987,7 +972,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
987972 " listState" -> Seq (" key" , " list_element AS value" , " partition_id" )
988973 )
989974
990- val columnFamilyToReaderOptions = Map (
975+ val columnFamilyToStateSourceOptions = Map (
991976 " listState" -> Map (
992977 StateSourceOptions .STATE_VAR_NAME -> " listState" ,
993978 StateSourceOptions .FLATTEN_COLLECTION_TYPES -> " true" )
@@ -998,7 +983,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
998983 targetDir.getAbsolutePath,
999984 columnFamilyToSchemaMap,
1000985 columnFamilyToSelectExprs = columnFamilyToSelectExprs,
1001- columnFamilyToReaderOptions = columnFamilyToReaderOptions
986+ columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
1002987 )
1003988 }
1004989 }
0 commit comments