Skip to content

Commit cc1321b

Browse files
committed
self-review
1 parent 766ce6f commit cc1321b

File tree

3 files changed

+174
-204
lines changed

3 files changed

+174
-204
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@ import java.util.UUID
2020

2121
import scala.collection.MapView
2222
import scala.collection.immutable.HashMap
23-
import scala.collection.mutable.HashSet
2423

2524
import org.apache.hadoop.conf.Configuration
2625
import org.apache.hadoop.fs.Path
2726

2827
import org.apache.spark.sql.catalyst.InternalRow
2928
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
29+
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateStoreColumnFamilySchemaUtils
3030
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
31-
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
31+
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
32+
import org.apache.spark.sql.types.{NullType, StructField, StructType}
3233

3334
case class StatePartitionWriterColumnFamilyInfo(
3435
schema: StateStoreColFamilySchema,
36+
// set this to true if state variable is ListType in TransformWithState
3537
useMultipleValuesPerKey: Boolean = false)
3638
/**
3739
* A writer that can directly write binary data to the streaming state store.
@@ -53,24 +55,34 @@ class StatePartitionAllColumnFamiliesWriter(
5355
storeName: String,
5456
currentBatchId: Long,
5557
columnFamilyToSchemaMap: HashMap[String, StatePartitionWriterColumnFamilyInfo]) {
58+
private val dummySchema: StructType =
59+
StructType(Array(StructField("__dummy__", NullType)))
5660
private val defaultSchema = {
57-
columnFamilyToSchemaMap.getOrElse(
58-
StateStore.DEFAULT_COL_FAMILY_NAME,
59-
columnFamilyToSchemaMap.head._2 // join V3 doesn't have default col family
60-
).schema
61+
columnFamilyToSchemaMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
62+
case Some(info) => info.schema
63+
case None =>
64+
// Return a dummy StateStoreColFamilySchema if not found
65+
StateStoreColFamilySchema(
66+
colFamilyName = "__dummy__",
67+
keySchemaId = 0,
68+
keySchema = dummySchema,
69+
valueSchemaId = 0,
70+
valueSchema = dummySchema,
71+
keyStateEncoderSpec = Option(NoPrefixKeyStateEncoderSpec(dummySchema)))
72+
}
6173
}
6274

6375
private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
6476
columnFamilyToSchemaMap.view.mapValues(_.schema.keySchema.length)
6577
private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
6678
columnFamilyToSchemaMap.view.mapValues(_.schema.valueSchema.length)
67-
private val colFamilyHasWritten: HashSet[String] = HashSet[String]()
6879

6980
protected lazy val provider: StateStoreProvider = {
7081
val stateCheckpointLocation = new Path(targetCpLocation, DIR_NAME_STATE).toString
7182
val stateStoreId = StateStoreId(stateCheckpointLocation,
7283
operatorId, partitionId, storeName)
7384
val stateStoreProviderId = StateStoreProviderId(stateStoreId, UUID.randomUUID())
85+
7486
val useColumnFamilies = columnFamilyToSchemaMap.size > 1
7587
val provider = StateStoreProvider.createAndInit(
7688
stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
@@ -99,7 +111,7 @@ class StatePartitionAllColumnFamiliesWriter(
99111
colFamilyName match {
100112
case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
101113
case _ =>
102-
val isInternal = colFamilyName.startsWith("$")
114+
val isInternal = StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)
103115

104116
require(cfSchema.keyStateEncoderSpec.isDefined,
105117
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
@@ -125,7 +137,6 @@ class StatePartitionAllColumnFamiliesWriter(
125137
try {
126138
rows.foreach(row => writeRow(row))
127139
stateStore.commit()
128-
colFamilyHasWritten.clear()
129140
} finally {
130141
if (!stateStore.hasCommitted) {
131142
stateStore.abort()
@@ -151,12 +162,12 @@ class StatePartitionAllColumnFamiliesWriter(
151162
val valueRow = new UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
152163
valueRow.pointTo(valueBytes, valueBytes.length)
153164

154-
if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey
155-
&& colFamilyHasWritten(colFamilyName)) {
165+
if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
166+
// if a column family useMultipleValuesPerKey (e.g. ListType), we will
167+
// write with 1 put followed by merge
156168
stateStore.merge(keyRow, valueRow, colFamilyName)
157169
} else {
158170
stateStore.put(keyRow, valueRow, colFamilyName)
159-
colFamilyHasWritten.add(colFamilyName)
160171
}
161172
}
162173
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,8 @@ object StateStoreColumnFamilySchemaUtils {
216216
valSchema,
217217
Some(RangeKeyScanStateEncoderSpec(keySchema, Seq(0))))
218218
}
219+
220+
def isInternalColFamily(name: String): Boolean = {
221+
name.startsWith("$")
222+
}
219223
}

0 commit comments

Comments
 (0)