Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50856][SS][PYTHON][CONNECT] Spark Connect Support for TransformWithStateInPandas In Python #49560

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_scalar",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_grouped_agg",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified that this test actually run in CI?

Copy link
Contributor Author

@jingz-db jingz-db Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i think so. I got several failed test case for this suite in previous CI run: https://github.com/jingz-db/spark/actions/runs/13039529632/job/36378113583#step:12:4144 which is now fixed, but this verifies the suite is actually running on CI

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for confirming!

],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
51 changes: 51 additions & 0 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pyspark.sql.column import Column
from pyspark.sql.connect.functions import builtin as F
from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
from pyspark.sql.streaming.stateful_processor import StatefulProcessor

if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
Expand Down Expand Up @@ -361,6 +362,56 @@ def applyInPandasWithState(

applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

def transformWithStateInPandas(
self,
statefulProcessor: StatefulProcessor,
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
eventTimeColumnName: str = "",
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasUdfUtils

udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
if initialState is None:
udf_obj = UserDefinedFunction(
udf_util.transformWithStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
initial_state_plan = None
initial_state_grouping_cols = None
else:
self._df._check_same_session(initialState._df)
udf_obj = UserDefinedFunction(
udf_util.transformWithStateWithInitStateUDF,
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
)
initial_state_plan = initialState._df._plan
initial_state_grouping_cols = initialState._grouping_cols

return DataFrame(
plan.TransformWithStateInPandas(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=outputStructType,
output_mode=outputMode,
time_mode=timeMode,
event_time_col_name=eventTimeColumnName,
cols=self._df.columns,
initial_state_plan=initial_state_plan,
initial_state_grouping_cols=initial_state_grouping_cols,
),
session=self._df._session,
)

transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__

def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
Expand Down
68 changes: 68 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,6 +2546,74 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return self._with_relations(plan, session)


class TransformWithStateInPandas(LogicalPlan):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some comments here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean class docstring? I have one line comment of docstring in line 2550.

"""Logical plan object for a TransformWithStateInPandas."""

def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
output_schema: Union[DataType, str],
output_mode: str,
time_mode: str,
event_time_col_name: str,
cols: List[str],
initial_state_plan: Optional["LogicalPlan"],
initial_state_grouping_cols: Optional[Sequence[Column]],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
if initial_state_plan is not None:
assert isinstance(initial_state_grouping_cols, list) and all(
isinstance(c, Column) for c in initial_state_grouping_cols
)
super().__init__(
child, self._collect_references(grouping_cols + initial_state_grouping_cols)
)
else:
super().__init__(child, self._collect_references(grouping_cols))
self._grouping_cols = grouping_cols
self._output_schema: DataType = (
UnparsedDataType(output_schema) if isinstance(output_schema, str) else output_schema
)
self._output_mode = output_mode
self._time_mode = time_mode
self._event_time_col_name = event_time_col_name
self._function = function._build_common_inline_user_defined_function(*cols)
self._initial_state_plan = initial_state_plan
self._initial_state_grouping_cols = initial_state_grouping_cols

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.group_map.input.CopyFrom(self._child.plan(session))
plan.group_map.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
plan.group_map.output_mode = self._output_mode

# fill in initial state related fields
if self._initial_state_plan is not None:
plan.group_map.initial_input.CopyFrom(self._initial_state_plan.plan(session))
assert self._initial_state_grouping_cols is not None
plan.group_map.initial_grouping_expressions.extend(
[c.to_plan(session) for c in self._initial_state_grouping_cols]
)

# fill in transformWithStateInPandas related fields
tws_info = proto.TransformWithStateInfo()
tws_info.time_mode = self._time_mode
tws_info.event_time_column_name = self._event_time_col_name
tws_info.output_schema.CopyFrom(pyspark_types_to_proto_types(self._output_schema))

plan.group_map.transform_with_state_info.CopyFrom(tws_info)

# wrap transformWithStateInPandasUdf in a function
plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))

return self._with_relations(plan, session)


class PythonUDTF:
"""Represents a Python user-defined table function."""

Expand Down
50 changes: 26 additions & 24 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3613,6 +3613,7 @@ class GroupMap(google.protobuf.message.Message):
OUTPUT_MODE_FIELD_NUMBER: builtins.int
TIMEOUT_CONF_FIELD_NUMBER: builtins.int
STATE_SCHEMA_FIELD_NUMBER: builtins.int
TRANSFORM_WITH_STATE_INFO_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for Group Map API: apply, applyInPandas."""
Expand Down Expand Up @@ -3654,6 +3655,11 @@ class GroupMap(google.protobuf.message.Message):
@property
def state_schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Optional) The schema for the grouped state."""
@property
def transform_with_state_info(self) -> global___TransformWithStateInfo:
"""Below fields are used by TransformWithState and TransformWithStateInPandas
(Optional) TransformWithState related parameters.
"""
def __init__(
self,
*,
Expand All @@ -3677,6 +3683,7 @@ class GroupMap(google.protobuf.message.Message):
output_mode: builtins.str | None = ...,
timeout_conf: builtins.str | None = ...,
state_schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
transform_with_state_info: global___TransformWithStateInfo | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -3689,6 +3696,8 @@ class GroupMap(google.protobuf.message.Message):
b"_state_schema",
"_timeout_conf",
b"_timeout_conf",
"_transform_with_state_info",
b"_transform_with_state_info",
"func",
b"func",
"initial_input",
Expand All @@ -3703,6 +3712,8 @@ class GroupMap(google.protobuf.message.Message):
b"state_schema",
"timeout_conf",
b"timeout_conf",
"transform_with_state_info",
b"transform_with_state_info",
],
) -> builtins.bool: ...
def ClearField(
Expand All @@ -3716,6 +3727,8 @@ class GroupMap(google.protobuf.message.Message):
b"_state_schema",
"_timeout_conf",
b"_timeout_conf",
"_transform_with_state_info",
b"_transform_with_state_info",
"func",
b"func",
"grouping_expressions",
Expand All @@ -3736,6 +3749,8 @@ class GroupMap(google.protobuf.message.Message):
b"state_schema",
"timeout_conf",
b"timeout_conf",
"transform_with_state_info",
b"transform_with_state_info",
],
) -> None: ...
@typing.overload
Expand All @@ -3757,9 +3772,82 @@ class GroupMap(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_timeout_conf", b"_timeout_conf"]
) -> typing_extensions.Literal["timeout_conf"] | None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_transform_with_state_info", b"_transform_with_state_info"
],
) -> typing_extensions.Literal["transform_with_state_info"] | None: ...

global___GroupMap = GroupMap

class TransformWithStateInfo(google.protobuf.message.Message):
"""Additional input parameters used for TransformWithState operator."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

TIME_MODE_FIELD_NUMBER: builtins.int
EVENT_TIME_COLUMN_NAME_FIELD_NUMBER: builtins.int
OUTPUT_SCHEMA_FIELD_NUMBER: builtins.int
time_mode: builtins.str
"""(Required) Time mode string for transformWithState."""
event_time_column_name: builtins.str
"""(Optional) Event time column name."""
@property
def output_schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Optional) Schema for the output DataFrame.
Only required used for TransformWithStateInPandas.
"""
def __init__(
self,
*,
time_mode: builtins.str = ...,
event_time_column_name: builtins.str | None = ...,
output_schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_event_time_column_name",
b"_event_time_column_name",
"_output_schema",
b"_output_schema",
"event_time_column_name",
b"event_time_column_name",
"output_schema",
b"output_schema",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_event_time_column_name",
b"_event_time_column_name",
"_output_schema",
b"_output_schema",
"event_time_column_name",
b"event_time_column_name",
"output_schema",
b"output_schema",
"time_mode",
b"time_mode",
],
) -> None: ...
@typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
"_event_time_column_name", b"_event_time_column_name"
],
) -> typing_extensions.Literal["event_time_column_name"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_output_schema", b"_output_schema"]
) -> typing_extensions.Literal["output_schema"] | None: ...

global___TransformWithStateInfo = TransformWithStateInfo

class CoGroupMap(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
Loading