Skip to content

Commit 67d8600

Browse files
committed
self-review
1 parent 766ce6f commit 67d8600

File tree

3 files changed

+99
-99
lines changed

3 files changed

+99
-99
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@ import java.util.UUID
2020

2121
import scala.collection.MapView
2222
import scala.collection.immutable.HashMap
23-
import scala.collection.mutable.HashSet
2423

2524
import org.apache.hadoop.conf.Configuration
2625
import org.apache.hadoop.fs.Path
2726

2827
import org.apache.spark.sql.catalyst.InternalRow
2928
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
29+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateStoreColumnFamilySchemaUtils
3030
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
31-
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
31+
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
32+
import org.apache.spark.sql.types.{NullType, StructField, StructType}
3233

3334
case class StatePartitionWriterColumnFamilyInfo(
3435
schema: StateStoreColFamilySchema,
36+
// set this to true if state variable is ListType in TransformWithState
3537
useMultipleValuesPerKey: Boolean = false)
3638
/**
3739
* A writer that can directly write binary data to the streaming state store.
@@ -53,24 +55,34 @@ class StatePartitionAllColumnFamiliesWriter(
5355
storeName: String,
5456
currentBatchId: Long,
5557
columnFamilyToSchemaMap: HashMap[String, StatePartitionWriterColumnFamilyInfo]) {
58+
private val dummySchema: StructType =
59+
StructType(Array(StructField("__dummy__", NullType)))
5660
private val defaultSchema = {
57-
columnFamilyToSchemaMap.getOrElse(
58-
StateStore.DEFAULT_COL_FAMILY_NAME,
59-
columnFamilyToSchemaMap.head._2 // join V3 doesn't have default col family
60-
).schema
61+
columnFamilyToSchemaMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
62+
case Some(info) => info.schema
63+
case None =>
64+
// Return a dummy StateStoreColFamilySchema if not found
65+
StateStoreColFamilySchema(
66+
colFamilyName = "__dummy__",
67+
keySchemaId = 0,
68+
keySchema = dummySchema,
69+
valueSchemaId = 0,
70+
valueSchema = dummySchema,
71+
keyStateEncoderSpec = Option(NoPrefixKeyStateEncoderSpec(dummySchema)))
72+
}
6173
}
6274

6375
private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
6476
columnFamilyToSchemaMap.view.mapValues(_.schema.keySchema.length)
6577
private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
6678
columnFamilyToSchemaMap.view.mapValues(_.schema.valueSchema.length)
67-
private val colFamilyHasWritten: HashSet[String] = HashSet[String]()
6879

6980
protected lazy val provider: StateStoreProvider = {
7081
val stateCheckpointLocation = new Path(targetCpLocation, DIR_NAME_STATE).toString
7182
val stateStoreId = StateStoreId(stateCheckpointLocation,
7283
operatorId, partitionId, storeName)
7384
val stateStoreProviderId = StateStoreProviderId(stateStoreId, UUID.randomUUID())
85+
7486
val useColumnFamilies = columnFamilyToSchemaMap.size > 1
7587
val provider = StateStoreProvider.createAndInit(
7688
stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
@@ -99,7 +111,7 @@ class StatePartitionAllColumnFamiliesWriter(
99111
colFamilyName match {
100112
case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
101113
case _ =>
102-
val isInternal = colFamilyName.startsWith("$")
114+
val isInternal = StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)
103115

104116
require(cfSchema.keyStateEncoderSpec.isDefined,
105117
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
@@ -125,7 +137,6 @@ class StatePartitionAllColumnFamiliesWriter(
125137
try {
126138
rows.foreach(row => writeRow(row))
127139
stateStore.commit()
128-
colFamilyHasWritten.clear()
129140
} finally {
130141
if (!stateStore.hasCommitted) {
131142
stateStore.abort()
@@ -151,12 +162,12 @@ class StatePartitionAllColumnFamiliesWriter(
151162
val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
152163
valueRow.pointTo(valueBytes, valueBytes.length)
153164

154-
if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey
155-
&& colFamilyHasWritten(colFamilyName)) {
165+
if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
166+
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
167+
// write with 1 put followed by merge
156168
stateStore.merge(keyRow, valueRow, colFamilyName)
157169
} else {
158170
stateStore.put(keyRow, valueRow, colFamilyName)
159-
colFamilyHasWritten.add(colFamilyName)
160171
}
161172
}
162173
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,8 @@ object StateStoreColumnFamilySchemaUtils {
216216
valSchema,
217217
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))))
218218
}
219+
220+
def isInternalColFamily(name: String): Boolean = {
221+
name.startsWith("$")
222+
}
219223
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala

Lines changed: 72 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,21 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
5959
* @param columnFamilyToSchemaMap Map of column family names to their schemas
6060
* @param storeName Optional store name (for stream-stream join which has multiple stores)
6161
* @param columnFamilyToSelectExprs Map of column family names to custom selectExprs
62-
* @param columnFamilyToReaderOptions Map of column family names to reader options
62+
* @param columnFamilyToStateSourceOptions Map of column family names to state source options
6363
*/
6464
private def performRoundTripTest(
6565
sourceDir: String,
6666
targetDir: String,
6767
columnFamilyToSchemaMap: HashMap[String, StatePartitionWriterColumnFamilyInfo],
6868
storeName: Option[String] = None,
6969
columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
70-
columnFamilyToReaderOptions: Map[String, Map[String, String]] = Map.empty): Unit = {
70+
columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = Map.empty): Unit = {
7171

7272
// Determine column families to validate based on storeName and map size
7373
val columnFamiliesToValidate: Seq[String] = storeName match {
7474
case Some(name) => Seq(name)
7575
case None if columnFamilyToSchemaMap.size > 1 => columnFamilyToSchemaMap.keys.toSeq
76-
case None => Seq.empty // Will use default reader without store name filter
76+
case None => Seq(StateStoreId.DEFAULT_STORE_NAME)
7777
}
7878

7979
// Step 1: Read from source using AllColumnFamiliesReader (raw bytes)
@@ -119,7 +119,8 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
119119

120120
// Use per-column-family converters when there are multiple column families
121121
if (columnFamilyToSchemaMap.size > 1) {
122-
// can remove this once we have extractKeySchema landed
122+
// TODO: Remove the logic of getting colNameToRowConverter once allColumnFamiliesReader is
123+
// returning actual partitionKeySchema instead of the entire key
123124
val colNameToRowConverter = columnFamilyToSchemaMap.view.mapValues { colInfo =>
124125
val cfSchema = SchemaUtil.getScanAllColumnFamiliesSchema(colInfo.schema.keySchema)
125126
CatalystTypeConverters.createToCatalystConverter(cfSchema)
@@ -134,7 +135,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
134135
}
135136
}
136137

137-
// Write raw bytes to target using foreachPartition"
138+
// Write raw bytes to target using foreachPartition
138139
sourceBytesData.foreachPartition(putPartitionFunc)
139140

140141
// Commit to commitLog
@@ -151,50 +152,33 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
151152

152153
def shouldCheckColumnFamilyName: String => Boolean = name => {
153154
(!name.startsWith("$")
154-
|| (columnFamilyToReaderOptions.contains(name) &&
155-
columnFamilyToReaderOptions(name).contains(StateSourceOptions.READ_REGISTERED_TIMERS)))
155+
|| (columnFamilyToStateSourceOptions.contains(name) &&
156+
columnFamilyToStateSourceOptions(name).contains(StateSourceOptions.READ_REGISTERED_TIMERS)))
156157
}
157-
if (columnFamiliesToValidate.nonEmpty) {
158-
// Validate each column family separately (skip internal column families starting with $)
159-
columnFamiliesToValidate
160-
.filter(shouldCheckColumnFamilyName)
161-
.foreach { cfName =>
162-
val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, defaultSelectExprs)
163-
val readerOptions = columnFamilyToReaderOptions.getOrElse(cfName, Map.empty)
164-
def readNormalData(dir: String): Array[Row] = {
165-
var reader = spark.read
166-
.format("statestore")
167-
.option(StateSourceOptions.PATH, dir)
168-
.option(StateSourceOptions.STORE_NAME, storeName.orNull)
169-
readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
170-
reader.load()
171-
.selectExpr(selectExprs: _*)
172-
.collect()
158+
// Validate each column family separately (skip internal column families starting with $)
159+
columnFamiliesToValidate
160+
// TODO: How to validate that internal columns are written correctly?
161+
.filter(shouldCheckColumnFamilyName)
162+
.foreach { cfName =>
163+
val selectExprs = columnFamilyToSelectExprs.getOrElse(cfName, defaultSelectExprs)
164+
val readerOptions = columnFamilyToStateSourceOptions.getOrElse(cfName, Map.empty)
165+
166+
def readNormalData(dir: String): Array[Row] = {
167+
var reader = spark.read
168+
.format("statestore")
169+
.option(StateSourceOptions.PATH, dir)
170+
.option(StateSourceOptions.STORE_NAME, storeName.orNull)
171+
readerOptions.foreach { case (k, v) => reader = reader.option(k, v) }
172+
reader.load()
173+
.selectExpr(selectExprs: _*)
174+
.collect()
173175
}
174176

175177
val sourceNormalData = readNormalData(sourceDir)
176178
val targetNormalData = readNormalData(targetDir)
177179

178180
validateDataMatches(sourceNormalData, targetNormalData)
179181
}
180-
} else {
181-
// Default validation without store name filter
182-
val sourceNormalData = spark.read
183-
.format("statestore")
184-
.option(StateSourceOptions.PATH, sourceDir)
185-
.load()
186-
.selectExpr("key", "value", "partition_id")
187-
.collect()
188-
189-
val targetNormalData = spark.read
190-
.format("statestore")
191-
.option(StateSourceOptions.PATH, targetDir)
192-
.load()
193-
.selectExpr("key", "value", "partition_id")
194-
.collect()
195-
196-
validateDataMatches(sourceNormalData, targetNormalData)
197-
}
198182
}
199183

200184
/**
@@ -207,7 +191,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
207191
s"Row count mismatch: source=${sourceNormalData.length}, " +
208192
s"target=${targetNormalData.length}")
209193

210-
// Sort and compare row by row"
194+
// Sort and compare row by row
211195
val sourceSorted = sourceNormalData.sortBy(_.toString)
212196
val targetSorted = targetNormalData.sortBy(_.toString)
213197

@@ -410,6 +394,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
410394
}
411395
}
412396

397+
private val keyToNumValuesColFamilyNames = Seq("left-keyToNumValues", "right-keyToNumValues")
398+
private val keyWithIndexToValueColFamilyNames = Seq(
399+
"left-keyWithIndexToValue", "right-keyWithIndexToValue")
400+
401+
private def getJoinV3ColumnSchemaMap(): HashMap[String, StatePartitionWriterColumnFamilyInfo] = {
402+
val keyToNumValuesKeySchema = StructType(Array(StructField("key", IntegerType)))
403+
val keyToNumValuesValueSchema = StructType(Array(StructField("value", LongType)))
404+
val keyToNumValuesEncoderSpec = NoPrefixKeyStateEncoderSpec(keyToNumValuesKeySchema)
405+
406+
val keyWithIndexKeySchema = StructType(Array(
407+
StructField("key", IntegerType, nullable = false),
408+
StructField("index", LongType)
409+
))
410+
val keyWithIndexValueSchema = StructType(Array(
411+
StructField("value", IntegerType, nullable = false),
412+
StructField("time", TimestampType, nullable = false),
413+
StructField("matched", BooleanType)
414+
))
415+
val keyWithIndexEncoderSpec = NoPrefixKeyStateEncoderSpec(keyWithIndexKeySchema)
416+
417+
// Build column family to schema map for all 4 join stores
418+
keyToNumValuesColFamilyNames.map { name =>
419+
createSingleColumnFamilySchemaMap(
420+
keyToNumValuesKeySchema, keyToNumValuesValueSchema, keyToNumValuesEncoderSpec, name)
421+
}.reduce(_ ++ _) ++ keyWithIndexToValueColFamilyNames.map { name =>
422+
createSingleColumnFamilySchemaMap(
423+
keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec, name)
424+
}.reduce(_ ++ _)
425+
}
413426
/**
414427
* Helper method to test round-trip for stream-stream join with different versions.
415428
*/
@@ -505,48 +518,20 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
505518
)
506519
}
507520

508-
runQuery(sourceDir.getAbsolutePath, 2)
509-
runQuery(targetDir.getAbsolutePath, 1)
510-
511-
// Define schemas for V3 join state stores
512-
val keyToNumValuesKeySchema = StructType(Array(StructField("key", IntegerType)))
513-
val keyToNumValuesValueSchema = StructType(Array(StructField("value", LongType)))
514-
val keyToNumValuesEncoderSpec = NoPrefixKeyStateEncoderSpec(keyToNumValuesKeySchema)
521+
// varying the roundsOfData so that sourceDir and targetDir have different state
522+
runQuery(sourceDir.getAbsolutePath, roundsOfData = 2)
523+
runQuery(targetDir.getAbsolutePath, roundsOfData = 1)
515524

516-
val keyWithIndexKeySchema = StructType(Array(
517-
StructField("key", IntegerType, nullable = false),
518-
StructField("index", LongType)
519-
))
520-
val keyWithIndexValueSchema = StructType(Array(
521-
StructField("value", IntegerType, nullable = false),
522-
StructField("time", TimestampType, nullable = false),
523-
StructField("matched", BooleanType)
524-
))
525-
val keyWithIndexEncoderSpec = NoPrefixKeyStateEncoderSpec(keyWithIndexKeySchema)
526-
527-
val keyToNumValuesColFamilyNames = Seq("left-keyToNumValues", "right-keyToNumValues")
528-
val keyWithIndexToValueColFamilyNames = Seq(
529-
"left-keyWithIndexToValue", "right-keyWithIndexToValue")
530-
// Build column family to schema map for all 4 join stores
531-
val columnFamilyToSchemaMap =
532-
keyToNumValuesColFamilyNames.map { name =>
533-
createSingleColumnFamilySchemaMap(
534-
keyToNumValuesKeySchema, keyToNumValuesValueSchema, keyToNumValuesEncoderSpec, name)
535-
}.reduce(_ ++ _) ++ keyWithIndexToValueColFamilyNames.map { name =>
536-
createSingleColumnFamilySchemaMap(
537-
keyWithIndexKeySchema, keyWithIndexValueSchema, keyWithIndexEncoderSpec, name)
538-
}.reduce(_ ++ _)
539-
val columnFamilyToReaderOptions =
540-
(keyToNumValuesColFamilyNames ++ keyWithIndexToValueColFamilyNames).map {
541-
colName =>
542-
colName -> Map(StateSourceOptions.STORE_NAME -> colName)
543-
}.toMap
544525
// Perform round-trip test using common helper
545526
performRoundTripTest(
546527
sourceDir.getAbsolutePath,
547528
targetDir.getAbsolutePath,
548-
columnFamilyToSchemaMap,
549-
columnFamilyToReaderOptions = columnFamilyToReaderOptions
529+
getJoinV3ColumnSchemaMap(),
530+
columnFamilyToStateSourceOptions =
531+
(keyToNumValuesColFamilyNames ++ keyWithIndexToValueColFamilyNames).map {
532+
colName =>
533+
colName -> Map(StateSourceOptions.STORE_NAME -> colName)
534+
}.toMap
550535
)
551536
}
552537
}
@@ -690,7 +675,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
690675
)
691676

692677
// Define reader options for column families that need them
693-
val columnFamilyToReaderOptions = Map(
678+
val columnFamilyToStateSourceOptions = Map(
694679
"itemsList" -> Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> "true",
695680
StateSourceOptions.STATE_VAR_NAME -> "itemsList"),
696681
"itemsMap" -> Map(StateSourceOptions.STATE_VAR_NAME -> "itemsMap"),
@@ -703,7 +688,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
703688
targetDir.getAbsolutePath,
704689
columnFamilyToSchemaMap,
705690
columnFamilyToSelectExprs = columnFamilyToSelectExprs,
706-
columnFamilyToReaderOptions = columnFamilyToReaderOptions
691+
columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
707692
)
708693
}
709694
}
@@ -787,7 +772,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
787772
)
788773

789774
// Timer column families need READ_REGISTERED_TIMERS option
790-
val columnFamilyToReaderOptions = Map(
775+
val columnFamilyToStateSourceOptions = Map(
791776
"countState" -> Map(StateSourceOptions.STATE_VAR_NAME -> "countState"),
792777
"$eventTimers_keyToTimestamp" -> Map(
793778
StateSourceOptions.READ_REGISTERED_TIMERS -> "true"),
@@ -800,7 +785,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
800785
targetDir.getAbsolutePath,
801786
columnFamilyToSchemaMap,
802787
columnFamilyToSelectExprs = columnFamilyToSelectExprs,
803-
columnFamilyToReaderOptions = columnFamilyToReaderOptions
788+
columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
804789
)
805790
}
806791
}
@@ -887,7 +872,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
887872
)
888873

889874
// Timer column families need READ_REGISTERED_TIMERS option
890-
val columnFamilyToReaderOptions = Map(
875+
val columnFamilyToStateSourceOptions = Map(
891876
"countState" -> Map(StateSourceOptions.STATE_VAR_NAME -> "countState"),
892877
"$procTimers_keyToTimestamp" -> Map(
893878
StateSourceOptions.READ_REGISTERED_TIMERS -> "true"),
@@ -900,7 +885,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
900885
targetDir.getAbsolutePath,
901886
columnFamilyToSchemaMap,
902887
columnFamilyToSelectExprs = columnFamilyToSelectExprs,
903-
columnFamilyToReaderOptions = columnFamilyToReaderOptions
888+
columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
904889
)
905890
}
906891
}
@@ -987,7 +972,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
987972
"listState" -> Seq("key", "list_element AS value", "partition_id")
988973
)
989974

990-
val columnFamilyToReaderOptions = Map(
975+
val columnFamilyToStateSourceOptions = Map(
991976
"listState" -> Map(
992977
StateSourceOptions.STATE_VAR_NAME -> "listState",
993978
StateSourceOptions.FLATTEN_COLLECTION_TYPES -> "true")
@@ -998,7 +983,7 @@ class StatePartitionAllColumnFamiliesWriterSuite extends StateDataSourceTestBase
998983
targetDir.getAbsolutePath,
999984
columnFamilyToSchemaMap,
1000985
columnFamilyToSelectExprs = columnFamilyToSelectExprs,
1001-
columnFamilyToReaderOptions = columnFamilyToReaderOptions
986+
columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
1002987
)
1003988
}
1004989
}

0 commit comments

Comments
 (0)