Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48755] State V2 base implementation and ValueState support #47133

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

bogao007
Copy link
Contributor

@bogao007 bogao007 commented Jun 27, 2024

What changes were proposed in this pull request?

  • Base implementation for Python State V2
  • Implemented ValueState

Why are the changes needed?

Support Python State V2 API

Does this PR introduce any user-facing change?

Yes

How was this patch tested?

Did local integration test with below command

import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
spark.conf.set("spark.sql.shuffle.partitions","1")
output_schema = StructType([
    StructField("value", LongType(), True)
])
state_schema = StructType([
    StructField("value", StringType(), True)
])

class SimpleStatefulProcessor(StatefulProcessor):
  def init(self, handle: StatefulProcessorHandle) -> None:
    self.value_state = handle.getValueState("testValueState", state_schema)
  def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
    self.value_state.update("test_value")
    exists = self.value_state.exists()
    print(f"value state exists: {exists}")
    value = self.value_state.get()
    print(f"get value: {value}")
    print("clearing value state")
    self.value_state.clear()
    print("value state cleared")
    return rows
  def close(self) -> None:
    pass

q = spark.readStream.format("rate").option("rowsPerSecond", "1").option("numPartitions", "1").load().groupBy("value").transformWithStateInPandas(stateful_processor = SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", timeMode="None").writeStream.format("console").option("checkpointLocation", "/tmp/streaming/temp_ckp").outputMode("update").start()

Verified from the logs that value state methods work as expected for key 11

handling input rows for key: 11
setting implicit key: 11
sending message -- len = 8 b'"\x06\n\x04\n\x0211'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 8 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(34, 6, 10, 4, 10, 2, 49, 49)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = implicitGroupingKeyRequest {
  setImplicitKey {
    key: "11"
  }
}

24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: setting implicit key to 11 with type class java.lang.Long
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: reading the version
setImplicitKey status= 0
updating value state: testValueState
sending message -- len = 127 b'\x1a}\n{\x1ay\n\x0etestValueState\x12[{"fields":[{"metadata":{},"name":"value","nullable":true,"type":"string"}],"type":"struct"}\x1a\ntest_value'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 127 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(26, 125, 10, 123, 26, 121, 10, 14, 116, 101, 115, 116, 86, 97, 108, 117, 101, 83, 116, 97, 116, 101, 18, 91, 123, 34, 102, 105, 101, 108, 100, 115, 34, 58, 91, 123, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, 125, 44, 34, 110, 97, 109, 101, 34, 58, 34, 118, 97, 108, 117, 101, 34, 44, 34, 110, 117, 108, 108, 97, 98, 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 125, 93, 44, 34, 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 117, 99, 116, 34, 125, 26, 10, 116, 101, 115, 116, 95, 118, 97, 108, 117, 101)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = stateVariableRequest {
  valueStateCall {
    update {
      stateName: "testValueState"
      schema: "{\"fields\":[{\"metadata\":{},\"name\":\"value\",\"nullable\":true,\"type\":\"string\"}],\"type\":\"struct\"}"
      value: "test_value"
    }
  }
}

24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: updating state testValueState with value test_value and type class java.lang.String
24/06/27 15:21:40 WARN StateTypesEncoder: Serializing grouping key: [11]
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: reading the version
valueStateUpdate status= 0
checking value state exists: testValueState
sending message -- len = 22 b'\x1a\x14\n\x12\n\x10\n\x0etestValueState'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 22 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(26, 20, 10, 18, 10, 16, 10, 14, 116, 101, 115, 116, 86, 97, 108, 117, 101, 83, 116, 97, 116, 101)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = stateVariableRequest {
  valueStateCall {
    exists {
      stateName: "testValueState"
    }
  }
}

24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: state testValueState exists
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: reading the version
valueStateExists status= 0
value state exists: True
getting value state: testValueState
sending message -- len = 22 b'\x1a\x14\n\x12\x12\x10\n\x0etestValueState'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 22 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(26, 20, 10, 18, 18, 16, 10, 14, 116, 101, 115, 116, 86, 97, 108, 117, 101, 83, 116, 97, 116, 101)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = stateVariableRequest {
  valueStateCall {
    get {
      stateName: "testValueState"
    }
  }
}

24/06/27 15:21:40 WARN StateTypesEncoder: Serializing grouping key: [11]
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: got state value test_value
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: writing value bytes of length 10
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: writing value bytes: Array(116, 101, 115, 116, 95, 118, 97, 108, 117, 101)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: reading the version
valueStateGet status= 0
get value: test_value
clearing value state
clearing value state: testValueState
sending message -- len = 22 b'\x1a\x14\n\x12"\x10\n\x0etestValueState'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 22 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(26, 20, 10, 18, 34, 16, 10, 14, 116, 101, 115, 116, 86, 97, 108, 117, 101, 83, 116, 97, 116, 101)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = stateVariableRequest {
  valueStateCall {
    clear {
      stateName: "testValueState"
    }
  }
}

24/06/27 15:21:40 WARN StateTypesEncoder: Serializing grouping key: [11]
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: reading the version
valueStateClear status= 0
value state cleared
removing implicit key
sending message -- len = 4 b'"\x02\x12\x00'
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: version = 0
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: parsing a message of 4 bytes
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read bytes = Array(34, 2, 18, 0)
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: read message = implicitGroupingKeyRequest {
  removeImplicitKey {
  }
}

24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: removing implicit key
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: removed implicit key
24/06/27 15:21:40 WARN TransformWithStateInPandasStateServer: flush output stream
removeImplicitKey status= 0

Will add unit test

Was this patch authored or co-authored using generative AI tooling?

No

val field = structType.fields(0)
val encoder = getEncoder(field.dataType)
val state = statefulProcessorHandle.getValueState[String](stateName, Encoders.STRING)
// val state = statefulProcessorHandle.getValueState(stateName, encoder)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the correct way to build value state with a schema, but it cannot support class as a state type right now, might need to find a different way to do so. cc @sahnib

val groupingKeyType = groupingKeySchema.fields(0).dataType
val castedData = castToType(key, groupingKeyType)
logWarning(s"setting implicit key to $castedData with type ${castedData.getClass}")
val row = Row(castedData)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setting grouping key to row here since only this way works in my local, otherwise it would throw a mismatch error like below if I just set a String key

24/06/26 13:13:13 WARN ExpressionEncoder$Serializer: inputRow: [test_key]
Exception in thread "stateConnectionListenerThread" org.apache.spark.SparkRuntimeException: [EXPRESSION_ENCODING_FAILED] Failed to encode a value of the expressions: if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, value), LongType, ObjectType(class java.lang.Long)).longValue AS value#129L to a row. SQLSTATE: 42846
	at org.apache.spark.sql.errors.QueryExecutionErrors$.expressionEncodingError(QueryExecutionErrors.scala:1309)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:227)
	at org.apache.spark.sql.execution.streaming.StateTypesEncoder.serializeGroupingKey(StateTypesEncoderUtils.scala:101)
	at org.apache.spark.sql.execution.streaming.StateTypesEncoder.encodeGroupingKey(StateTypesEncoderUtils.scala:80)
	at org.apache.spark.sql.execution.streaming.ValueStateImpl.get(ValueStateImpl.scala:64)
	at org.apache.spark.sql.execution.python.TransformWithStateInPandasStateServer.handleRequest(TransformWithStateInPandasStateServer.scala:133)
	at org.apache.spark.sql.execution.python.TransformWithStateInPandasStateServer.run(TransformWithStateInPandasStateServer.scala:78)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: java.lang.ClassCastException: class java.lang.String cannot be cast to class org.apache.spark.sql.Row (java.lang.String is in module java.base of loader 'bootstrap'; org.apache.spark.sql.Row is in unnamed module of loader 'app')
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.If_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:222)
	... 8 more

import org.apache.spark.tags.SlowSQLTest

@SlowSQLTest
class TransformWithStateInPandasSuite extends StreamTest {
Copy link
Contributor Author

@bogao007 bogao007 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add more unit tests

inputRows: Iterator["PandasDataFrameLike"]) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(state_api_client)

print(f"checking handle state: {state_api_client.handle_state}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm keeping all the prints and logWarnings for testing purpose, will remove them in final revision

@HyukjinKwon
Copy link
Member

Mind filing a JIRA?

@bogao007
Copy link
Contributor Author

Mind filing a JIRA?

Yeah, will do, thanks!

@bogao007 bogao007 changed the title State V2 base implementation and ValueState support [SPARK-48755] State V2 base implementation and ValueState support Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants