Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
import org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
import org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
Expand Down Expand Up @@ -90,7 +90,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
stateStoreReaderInfo.stateSchemaProviderOpt,
stateStoreReaderInfo.joinColFamilyOpt)
stateStoreReaderInfo.joinColFamilyOpt,
Option(stateStoreReaderInfo.allColumnFamiliesReaderInfo))
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down Expand Up @@ -131,6 +132,21 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging

override def supportsExternalMetadata(): Boolean = false

/**
* Returns true if this is a read-all-column-families request for a stream-stream join
* that uses virtual column families (state format version 3).
*/
private def isReadAllColFamiliesOnJoinV3(
sourceOptions: StateSourceOptions,
storeMetadata: Array[StateMetadataTableEntry]): Boolean = {
sourceOptions.internalOnlyReadAllColumnFamilies &&
storeMetadata.head.operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf,
sourceOptions.stateCheckpointLocation.toString,
sourceOptions.operatorId)
}

private def buildStateStoreConf(checkpointLocation: String, batchId: Long): StateStoreConf = {
val offsetLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).offsetLog
offsetLog.get(batchId) match {
Expand Down Expand Up @@ -177,7 +193,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging

val stateVars = twsOperatorProperties.stateVariables
val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == stateVarName)
if (stateVarInfo.size != 1) {
if (stateVarInfo.size != 1 &&
!StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
s"State variable $stateVarName is not defined for the transformWithState operator.")
}
Expand Down Expand Up @@ -242,14 +259,20 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
StateStoreReaderInfo = {
val storeMetadata = StateDataSource.getStateStoreMetadata(sourceOptions, hadoopConf)
runStateVarChecks(sourceOptions, storeMetadata)
if (!sourceOptions.internalOnlyReadAllColumnFamilies) {
// skipping runStateVarChecks for StatePartitionAllColumnFamiliesReader because
// we won't specify any stateVars when querying a TWS operator
runStateVarChecks(sourceOptions, storeMetadata)
}

var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None
var stateSchemaProvider: Option[StateSchemaProvider] = None
var joinColFamilyOpt: Option[String] = None
var timeMode: String = TimeMode.None.toString
var stateStoreColFamilySchemas: List[StateStoreColFamilySchema] = List.empty
var stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty

if (sourceOptions.joinSide == JoinSideValues.none) {
var stateVarName = sourceOptions.stateVarName
Expand All @@ -268,13 +291,23 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
if (sourceOptions.readRegisteredTimers) {
stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
}

val stateVarInfoList = operatorProperties.stateVariables
.filter(stateVar => stateVar.stateName == stateVarName)
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
val stateVarInfo = stateVarInfoList.head
transformWithStateVariableInfoOpt = Some(stateVarInfo)
if (sourceOptions.internalOnlyReadAllColumnFamilies) {
stateVariableInfos = operatorProperties.stateVariables
} else {
var stateVarInfoList = operatorProperties.stateVariables
.filter(stateVar => stateVar.stateName == stateVarName)
if (stateVarInfoList.isEmpty &&
StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {
// pass this dummy TWSStateVariableInfo for TWS internal column family during testing,
stateVarInfoList = List(TransformWithStateVariableInfo(
stateVarName, StateVariableType.ValueState, false
))
}
require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " +
s"for state variable $stateVarName in operator ${sourceOptions.operatorId}")
val stateVarInfo = stateVarInfoList.head
transformWithStateVariableInfoOpt = Some(stateVarInfo)
}
val schemaFilePaths = storeMetadataEntry.stateSchemaFilePaths
val stateSchemaMetadata = StateSchemaMetadata.createStateSchemaMetadata(
sourceOptions.stateCheckpointLocation.toString,
Expand Down Expand Up @@ -305,9 +338,22 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
oldSchemaFilePaths = oldSchemaFilePaths)
val stateSchema = manager.readSchemaFile()

if (sourceOptions.internalOnlyReadAllColumnFamilies) {
// Store all column family schemas for multi-CF reading
stateStoreColFamilySchemas = stateSchema
}
// When reading all column families for Join V3, no specific state variable is targeted,
// so stateVarName defaults to DEFAULT_COL_FAMILY_NAME.
// However, Join V3 does not have a "default" column family. Therefore, we pick the first
// schema as resultSchema which will be used as placeholder schema for default schema
// in StatePartitionAllColumnFamiliesReader
val resultSchema = if (isReadAllColFamiliesOnJoinV3(sourceOptions, storeMetadata)) {
stateSchema.head
} else {
stateSchema.filter(_.colFamilyName == stateVarName).head
}
// Based on the version and read schema, populate the keyStateEncoderSpec used for
// reading the column families
val resultSchema = stateSchema.filter(_.colFamilyName == stateVarName).head
keyStateEncoderSpecOpt = Some(getKeyStateEncoderSpec(resultSchema, storeMetadata))
stateStoreColFamilySchemaOpt = Some(resultSchema)
} catch {
Expand All @@ -321,7 +367,8 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
stateStoreColFamilySchemaOpt,
transformWithStateVariableInfoOpt,
stateSchemaProvider,
joinColFamilyOpt
joinColFamilyOpt,
AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, stateVariableInfos)
)
}

Expand Down Expand Up @@ -708,7 +755,9 @@ case class StateStoreReaderInfo(
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String] // Only used for join op with state format v3
joinColFamilyOpt: Option[String], // Only used for join op with state format v3
// List of all column family schemas - used when internalOnlyReadAllColumnFamilies=true
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo
)

object StateDataSource {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils, StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
import org.apache.spark.sql.types.{NullType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, SerializableConfiguration}

case class AllColumnFamiliesReaderInfo(
colFamilySchemas: List[StateStoreColFamilySchema] = List.empty,
stateVariableInfos: List[TransformWithStateVariableInfo] = List.empty)

/**
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
* general read from a state store instance, rather than specific to the operator.
Expand All @@ -44,14 +48,17 @@ class StatePartitionReaderFactory(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
if (stateStoreInputPartition.sourceOptions.internalOnlyReadAllColumnFamilies) {
require(allColumnFamiliesReaderInfo.isDefined)
new StatePartitionAllColumnFamiliesReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt)
stateStoreInputPartition, schema, keyStateEncoderSpec, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt, allColumnFamiliesReaderInfo.get)
} else if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
Expand Down Expand Up @@ -136,14 +143,15 @@ abstract class StatePartitionReaderBase(
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
useMultipleValuesPerKey = useMultipleValuesPerKey, stateSchemaProviderOpt)

val isInternal = partition.sourceOptions.readRegisteredTimers

if (useColFamilies) {
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getEndStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
val isInternal = partition.sourceOptions.readRegisteredTimers ||
StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(
stateStoreColFamilySchema.colFamilyName)
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
store.createColFamilyIfAbsent(
stateStoreColFamilySchema.colFamilyName,
Expand Down Expand Up @@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
partition: StateStoreInputPartition,
schema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
extends StatePartitionReaderBase(
storeConf,
hadoopConf, partition, schema,
keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
keyStateEncoderSpec, None,
defaultStateStoreColFamilySchemaOpt,
stateSchemaProviderOpt, None) {

private lazy val store: ReadStateStore = {
private val stateStoreColFamilySchemas = allColumnFamiliesReaderInfo.colFamilySchemas
private val stateVariableInfos = allColumnFamiliesReaderInfo.stateVariableInfos

private def isListType(colFamilyName: String): Boolean = {
SchemaUtil.checkVariableType(
stateVariableInfos.find(info => info.stateName == colFamilyName),
StateVariableType.ListState)
}

override protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
val useColumnFamilies = stateStoreColFamilySchemas.length > 1
StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies, storeConf, hadoopConf.value,
useMultipleValuesPerKey = false, stateSchemaProviderOpt)
}


private def checkAllColFamiliesExist(
colFamilyNames: List[String], stateStore: StateStore
): Unit = {
// Filter out DEFAULT column family from validation for two reasons:
// 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT in their schema
// because the underlying RocksDB creates "default" column family automatically
// 2. The default column family schema is handled separately via
// defaultStateStoreColFamilySchemaOpt, so no need to verify it here
val actualCFs = colFamilyNames.toSet.filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
val expectedCFs = stateStore.allColumnFamilyNames
.filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)

// Validation: All column families found in the checkpoint must be declared in the schema.
// It's acceptable if some schema CFs are not in expectedCFs - this just means those
// column families have no data yet in the checkpoint
// (they'll be created during registration).
// However, if the checkpoint contains CFs not in the schema, it indicates a mismatch.
require(expectedCFs.subsetOf(actualCFs),
s"Checkpoint contains unexpected column families. " +
s"Column families in checkpoint but not in schema: ${expectedCFs.diff(actualCFs)}")
}

// Use a single store instance for both registering column families and iteration.
// We cannot abort and then get a read store because abort() invalidates the loaded version,
// causing getReadStore() to reload from checkpoint and clear the column family registrations.
private lazy val store: StateStore = {
assert(getStartStoreUniqueId == getEndStoreUniqueId,
"Start and end store unique IDs must be the same when reading all column families")
provider.getReadStore(
val stateStore = provider.getStore(
partition.sourceOptions.batchId + 1,
getStartStoreUniqueId
)

// Register all column families from the schema
if (stateStoreColFamilySchemas.length > 1) {
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName), stateStore)
stateStoreColFamilySchemas.foreach { cfSchema =>
cfSchema.colFamilyName match {
case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has registered default
case _ =>
val isInternal = cfSchema.colFamilyName.startsWith("$")
val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
require(cfSchema.keyStateEncoderSpec.isDefined,
s"keyStateEncoderSpec must be defined for column family ${cfSchema.colFamilyName}")
stateStore.createColFamilyIfAbsent(
cfSchema.colFamilyName,
cfSchema.keySchema,
cfSchema.valueSchema,
cfSchema.keyStateEncoderSpec.get,
useMultipleValuesPerKey,
isInternal)
}
}
}
stateStore
}

override lazy val iter: Iterator[InternalRow] = {
store
.iterator()
.map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
// Iterate all column families and concatenate results
stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
if (isListType(cfSchema.colFamilyName)) {
store.iterator(cfSchema.colFamilyName).flatMap(
pair =>
store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
value =>
SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), cfSchema.colFamilyName)
}
)
} else {
store.iterator(cfSchema.colFamilyName).map { pair =>
SchemaUtil.unifyStateRowPairAsRawBytes(
(pair.key, pair.value), cfSchema.colFamilyName)
}
}
}
}

override def close(): Unit = {
store.release()
store.abort()
super.close()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ class StateScanBuilder(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String]) extends ScanBuilder {
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt)
joinColFamilyOpt, allColumnFamiliesReaderInfo)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -68,7 +69,8 @@ class StateScan(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo])
extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
Expand Down Expand Up @@ -144,7 +146,7 @@ class StateScan(
case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt, joinColFamilyOpt)
stateSchemaProviderOpt, joinColFamilyOpt, allColumnFamiliesReaderInfo)
}

override def toBatch: Batch = this
Expand Down
Loading