From 88a9ed47895bc837d3d5217c9f8efaa4bc852053 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 10:52:13 -0700 Subject: [PATCH 01/12] save --- python/pyspark/sql/pandas/serializers.py | 19 +++++++++++--- python/pyspark/worker.py | 6 +++++ .../streaming/BaseStreamingArrowWriter.scala | 12 ++++++++- ...nsformWithStateInPySparkPythonRunner.scala | 18 ++++++++++--- .../BaseStreamingArrowWriterSuite.scala | 26 +++++++++++++++++-- 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d1bdfa9e8d01e..010fb2d50de9a 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,7 +20,7 @@ """ from decimal import Decimal -from itertools import groupby +from itertools import groupby, accumulate from typing import TYPE_CHECKING, Optional import pyspark @@ -1582,6 +1582,7 @@ def __init__( safecheck, assign_cols_by_name, arrow_max_records_per_batch, + arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled, ): super(TransformWithStateInPandasSerializer, self).__init__( @@ -1592,6 +1593,7 @@ def __init__( arrow_cast=True, ) self.arrow_max_records_per_batch = arrow_max_records_per_batch + self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch self.key_offsets = None def load_stream(self, stream): @@ -1607,6 +1609,7 @@ def load_stream(self, stream): from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) + import sys def generate_data_batches(batches): """ @@ -1630,8 +1633,18 @@ def row_stream(): yield (batch_key, row) for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): - df = pd.DataFrame([row for _, row in group_rows]) - yield (batch_key, df) + rows = [] + accumulate_size = 0 + for _, row in group_rows: + rows.append(row) + accumulate_size += sum(sys.getsizeof(x) for x in row) + if (len(rows) >= self.arrow_max_records_per_batch or + accumulate_size >= self.arrow_max_bytes_per_batch): + yield (batch_key, pd.DataFrame(rows)) + rows = [] + accumulate_size = 0 + if rows: + yield (batch_key, pd.DataFrame(rows)) _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) data_batches = generate_data_batches(_batches) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 466bb79246587..cf0fd76a95054 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2577,11 +2577,17 @@ def read_udfs(pickleSer, infile, eval_type): ) arrow_max_records_per_batch = int(arrow_max_records_per_batch) + arrow_max_bytes_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxBytesPerBatch", 64*1024*1024 + ) + arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch) + ser = TransformWithStateInPandasSerializer( timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch, + arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala index ba8b2c3ac7daa..f0371cafb72a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala @@ -32,6 +32,7 @@ class BaseStreamingArrowWriter( root: VectorSchemaRoot, writer: ArrowStreamWriter, arrowMaxRecordsPerBatch: Int, + arrowMaxBytesPerBatch: Long, arrowWriterForTest: ArrowWriter = null) { protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == null) { ArrowWriter.create(root) @@ -54,7 +55,7 @@ class BaseStreamingArrowWriter( // If it exceeds the condition of batch (number of records) and there is more data for the // same group, finalize and construct a new batch. - val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch + val isCurrentBatchFull = isBatchSizeLimitReached if (isCurrentBatchFull) { finalizeCurrentChunk(isLastChunkForGroup = false) finalizeCurrentArrowBatch() @@ -84,4 +85,13 @@ class BaseStreamingArrowWriter( protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { numRowsForCurrentChunk = 0 } + + protected def isBatchSizeLimitReached: Boolean = { + // If we have either reached the records or bytes limit + totalNumRowsForBatch >= arrowMaxRecordsPerBatch || + // Short circuit batch size calculation if the batch size is unlimited as computing batch + // size is computationally expensive. + ((arrowMaxBytesPerBatch != Int.MaxValue) + && (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index f0df3e1f7d15c..42d4ad68c29a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -75,7 +75,12 @@ class TransformWithStateInPySparkPythonRunner( dataOut: DataOutputStream, inputIterator: Iterator[InType]): Boolean = { if (pandasWriter == null) { - pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) + pandasWriter = new BaseStreamingArrowWriter( + root, + writer, + arrowMaxRecordsPerBatch, + arrowMaxBytesPerBatch + ) } // If we don't have data left for the current group, move to the next group. @@ -145,7 +150,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner( dataOut: DataOutputStream, inputIterator: Iterator[GroupedInType]): Boolean = { if (pandasWriter == null) { - pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) + pandasWriter = new BaseStreamingArrowWriter( + root, + writer, + arrowMaxRecordsPerBatch, + arrowMaxBytesPerBatch + ) } if (inputIterator.hasNext) { @@ -200,9 +210,11 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( protected val sqlConf = SQLConf.get protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch override protected val workerConf: Map[String, String] = initialWorkerConf + - (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s // constructor. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala index f0fee2b9b0d9c..188376b9fe26d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python.streaming import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter -import org.mockito.Mockito.{mock, never, times, verify} +import org.mockito.Mockito.{mock, never, times, verify, when} import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite @@ -28,6 +28,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEach { // Setting the maximum number of records per batch to 2 to make test easier. val arrowMaxRecordsPerBatch = 2 + val arrowMaxBytesPerBatch = Int.MaxValue var transformWithStateInPySparkWriter: BaseStreamingArrowWriter = _ var arrowWriter: ArrowWriter = _ var writer: ArrowStreamWriter = _ @@ -37,7 +38,7 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEac writer = mock(classOf[ArrowStreamWriter]) arrowWriter = mock(classOf[ArrowWriter]) transformWithStateInPySparkWriter = new BaseStreamingArrowWriter( - root, writer, arrowMaxRecordsPerBatch, arrowWriter) + root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch, arrowWriter) } test("test writeRow") { @@ -64,4 +65,25 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEac verify(writer).writeBatch() verify(arrowWriter).reset() } + + test("test maxBytesPerBatch can work") { + val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot]) + when(arrowWriter.sizeInBytes()).thenReturn(2) + // Set arrowMaxBytesPerBatch to 1 + transformWithStateInPySparkWriter = new BaseStreamingArrowWriter( + root, writer, arrowMaxRecordsPerBatch, 1, arrowWriter) + val dataRow = mock(classOf[InternalRow]) + transformWithStateInPySparkWriter.writeRow(dataRow) + verify(arrowWriter).write(dataRow) + verify(writer, never()).writeBatch() + transformWithStateInPySparkWriter.writeRow(dataRow) + verify(arrowWriter, times(2)).write(dataRow) + // Write batch is called since we reach arrowMaxBytesPerBatch + verify(writer).writeBatch() + transformWithStateInPySparkWriter.finalizeCurrentArrowBatch() + verify(arrowWriter, times(2)).finish() + // The second record would be written + verify(writer, times(2)).writeBatch() + verify(arrowWriter, times(2)).reset() + } } From c526618674554d6a72853c752cb9492a8e79c431 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 10:53:42 -0700 Subject: [PATCH 02/12] save --- python/pyspark/sql/pandas/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 010fb2d50de9a..9c7cfaca6143c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -20,7 +20,7 @@ """ from decimal import Decimal -from itertools import groupby, accumulate +from itertools import groupby from typing import TYPE_CHECKING, Optional import pyspark From d6846af4dfd073746df001560f24b0ffa03daf3a Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 11:48:39 -0700 Subject: [PATCH 03/12] save --- .../helper_pandas_transform_with_state.py | 17 +++++ .../test_pandas_transform_with_state.py | 62 +++++++++++++++++++ .../ApplyInPandasWithStatePythonRunner.scala | 10 ++- .../ApplyInPandasWithStateWriter.scala | 5 +- 4 files changed, 90 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py index d258f693ccb88..490dd96766646 100644 --- a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py @@ -236,6 +236,9 @@ def pandas(self): def row(self): return RowStatefulProcessorCompositeType() +class ChunkCountProcessorFactory(StatefulProcessorFactory): + def pandas(self): + return PandasChunkCountProcessor() # StatefulProcessor implementations @@ -1822,3 +1825,17 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]: def close(self) -> None: pass + +class PandasChunkCountProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + pass + + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: + chunk_count = 0 + for _ in rows: + chunk_count += 1 + yield pd.DataFrame({'id': [key[0]], 'chunkCount': [chunk_count]}) + + + def close(self) -> None: + pass \ No newline at end of file diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 6d79a8c267531..ca433c96aeddf 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -70,6 +70,7 @@ UpcastProcessorFactory, MinEventTimeStatefulProcessorFactory, StatefulProcessorCompositeTypeFactory, + ChunkCountProcessorFactory ) @@ -1864,6 +1865,67 @@ def close(self): .collect() ) + def test_transform_with_state_with_bytes_limit(self): + if not self.use_pandas(): + return + + def make_check_results(expected_per_batch): + def check_results(batch_df, batch_id): + batch_df.collect() + if batch_id == 0: + assert set(batch_df.sort("id").collect()) == expected_per_batch[0] + else: + assert set(batch_df.sort("id").collect()) == expected_per_batch[1] + return check_results + + with self.sql_conf( + # Set it to a very small number so that every row would be a separate pandas df + {"spark.sql.execution.arrow.maxBytesPerBatch": "2"} + ): + self._test_transform_with_state_basic( + ChunkCountProcessorFactory(), + make_check_results( + [ + { + Row(id="0", chunkCount=2), + Row(id="1", chunkCount=2), + }, + { + Row(id="0", chunkCount=3), + Row(id="1", chunkCount=2), + } + ] + ), + output_schema=StructType([ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True) + ]) + ) + + with self.sql_conf( + # Set it to a very large number so that every row would be in the same pandas df + {"spark.sql.execution.arrow.maxBytesPerBatch": "100000"} + ): + self._test_transform_with_state_basic( + ChunkCountProcessorFactory(), + make_check_results( + [ + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + }, + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + } + ] + ), + output_schema=StructType([ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True) + ]) + ) + @unittest.skipIf( not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala index b6f6a4cbc30b6..477592ae31536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala @@ -106,12 +106,14 @@ class ApplyInPandasWithStatePythonRunner( } private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + private val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. // Configurations are both applied to executor and Python worker, set them to the worker conf // to let Python worker read the config properly. override protected val workerConf: Map[String, String] = initialWorkerConf + - (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) private val stateRowDeserializer = stateEncoder.createDeserializer() @@ -142,7 +144,11 @@ class ApplyInPandasWithStatePythonRunner( dataOut: DataOutputStream, inputIterator: Iterator[InType]): Boolean = { if (pandasWriter == null) { - pandasWriter = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) + pandasWriter = new ApplyInPandasWithStateWriter( + root, + writer, + arrowMaxRecordsPerBatch, + arrowMaxBytesPerBatch) } if (inputIterator.hasNext) { val startData = dataOut.size() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala index f55ca749112fb..3b8fdfe910d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala @@ -50,8 +50,9 @@ import org.apache.spark.unsafe.types.UTF8String class ApplyInPandasWithStateWriter( root: VectorSchemaRoot, writer: ArrowStreamWriter, - arrowMaxRecordsPerBatch: Int) - extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) { + arrowMaxRecordsPerBatch: Int, + arrowMaxBytesPerBatch: Long) + extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, arrowMaxRecordsPerBatch) { import ApplyInPandasWithStateWriter._ From 294bda53fdba3678e4d7e7ed594af923a841d7f2 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 12:19:21 -0700 Subject: [PATCH 04/12] save --- python/pyspark/sql/pandas/serializers.py | 2 + .../helper_pandas_transform_with_state.py | 22 ++++++ .../test_pandas_transform_with_state.py | 73 ++++++++++++------- python/pyspark/worker.py | 6 ++ 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9c7cfaca6143c..520c0b9f03b59 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1689,6 +1689,7 @@ def __init__( safecheck, assign_cols_by_name, arrow_max_records_per_batch, + arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled, ): super(TransformWithStateInPandasInitStateSerializer, self).__init__( @@ -1696,6 +1697,7 @@ def __init__( safecheck, assign_cols_by_name, arrow_max_records_per_batch, + arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled, ) self.init_key_offsets = None diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py index 490dd96766646..0243c7bbc7165 100644 --- a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py @@ -240,6 +240,10 @@ class ChunkCountProcessorFactory(StatefulProcessorFactory): def pandas(self): return PandasChunkCountProcessor() +class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory): + def pandas(self): + return PandasChunkCountWithInitialStateProcessor() + # StatefulProcessor implementations @@ -1837,5 +1841,23 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: yield pd.DataFrame({'id': [key[0]], 'chunkCount': [chunk_count]}) + def close(self) -> None: + pass + +class PandasChunkCountWithInitialStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("value", IntegerType(), True)]) + self.value_state = handle.getValueState("value_state", state_schema) + + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: + chunk_count = 0 + for _ in rows: + chunk_count += 1 + yield pd.DataFrame({'id': [key[0]], 'chunkCount': [chunk_count]}) + + def handleInitialState(self, key, initialState, timerValues) -> None: + init_val = initialState.at[0, "initVal"] + self.value_state.update((init_val,)) + def close(self) -> None: pass \ No newline at end of file diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index ca433c96aeddf..58de0fa0102f2 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -70,7 +70,8 @@ UpcastProcessorFactory, MinEventTimeStatefulProcessorFactory, StatefulProcessorCompositeTypeFactory, - ChunkCountProcessorFactory + ChunkCountProcessorFactory, + ChunkCountProcessorWithInitialStateFactory ) @@ -1878,24 +1879,48 @@ def check_results(batch_df, batch_id): assert set(batch_df.sort("id").collect()) == expected_per_batch[1] return check_results + result_with_small_limit = [ + { + Row(id="0", chunkCount=2), + Row(id="1", chunkCount=2), + }, + { + Row(id="0", chunkCount=3), + Row(id="1", chunkCount=2), + } + ] + + result_with_large_limit = [ + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + }, + { + Row(id="0", chunkCount=1), + Row(id="1", chunkCount=1), + } + ] + + data = [("0", 789), ("3", 987)] + initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id") + with self.sql_conf( # Set it to a very small number so that every row would be a separate pandas df {"spark.sql.execution.arrow.maxBytesPerBatch": "2"} ): self._test_transform_with_state_basic( ChunkCountProcessorFactory(), - make_check_results( - [ - { - Row(id="0", chunkCount=2), - Row(id="1", chunkCount=2), - }, - { - Row(id="0", chunkCount=3), - Row(id="1", chunkCount=2), - } - ] - ), + make_check_results(result_with_small_limit), + output_schema=StructType([ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True) + ]) + ) + + self._test_transform_with_state_basic( + ChunkCountProcessorWithInitialStateFactory(), + make_check_results(result_with_small_limit), + initial_state=initial_state, output_schema=StructType([ StructField("id", StringType(), True), StructField("chunkCount", IntegerType(), True) @@ -1908,24 +1933,22 @@ def check_results(batch_df, batch_id): ): self._test_transform_with_state_basic( ChunkCountProcessorFactory(), - make_check_results( - [ - { - Row(id="0", chunkCount=1), - Row(id="1", chunkCount=1), - }, - { - Row(id="0", chunkCount=1), - Row(id="1", chunkCount=1), - } - ] - ), + make_check_results(result_with_large_limit), output_schema=StructType([ StructField("id", StringType(), True), StructField("chunkCount", IntegerType(), True) ]) ) + self._test_transform_with_state_basic( + ChunkCountProcessorWithInitialStateFactory(), + make_check_results(result_with_large_limit), + initial_state=initial_state, + output_schema=StructType([ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True) + ]) + ) @unittest.skipIf( not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0", diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index cf0fd76a95054..20a3f72961a35 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2596,11 +2596,17 @@ def read_udfs(pickleSer, infile, eval_type): ) arrow_max_records_per_batch = int(arrow_max_records_per_batch) + arrow_max_bytes_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxBytesPerBatch", 64*1024*1024 + ) + arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch) + ser = TransformWithStateInPandasInitStateSerializer( timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch, + arrow_max_bytes_per_batch, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, ) elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF: From 4ddeec6ee13e7da86ad9a5a02dce5f4fbfc9c55d Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 13:06:47 -0700 Subject: [PATCH 05/12] save --- python/pyspark/sql/pandas/serializers.py | 5 ++++- python/pyspark/worker.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 520c0b9f03b59..d02f1c5717d54 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1637,7 +1637,10 @@ def row_stream(): accumulate_size = 0 for _, row in group_rows: rows.append(row) - accumulate_size += sum(sys.getsizeof(x) for x in row) + # Short circuit batch size calculation if the batch size is + # unlimited as computing batch size is computationally expensive. + if self.arrow_max_bytes_per_batch != 2**31-1: + accumulate_size += sum(sys.getsizeof(x) for x in row) if (len(rows) >= self.arrow_max_records_per_batch or accumulate_size >= self.arrow_max_bytes_per_batch): yield (batch_key, pd.DataFrame(rows)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 20a3f72961a35..b1cf3bf978212 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2578,7 +2578,7 @@ def read_udfs(pickleSer, infile, eval_type): arrow_max_records_per_batch = int(arrow_max_records_per_batch) arrow_max_bytes_per_batch = runner_conf.get( - "spark.sql.execution.arrow.maxBytesPerBatch", 64*1024*1024 + "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1 ) arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch) @@ -2597,7 +2597,7 @@ def read_udfs(pickleSer, infile, eval_type): arrow_max_records_per_batch = int(arrow_max_records_per_batch) arrow_max_bytes_per_batch = runner_conf.get( - "spark.sql.execution.arrow.maxBytesPerBatch", 64*1024*1024 + "spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1 ) arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch) From 88255859f0b007c1edf9093c2f5ed3d7691b1fd1 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 17:06:53 -0700 Subject: [PATCH 06/12] save --- python/pyspark/sql/pandas/serializers.py | 8 ++- .../helper_pandas_transform_with_state.py | 12 ++-- .../test_pandas_transform_with_state.py | 56 +++++++++++-------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d02f1c5717d54..221509987001c 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1639,10 +1639,12 @@ def row_stream(): rows.append(row) # Short circuit batch size calculation if the batch size is # unlimited as computing batch size is computationally expensive. - if self.arrow_max_bytes_per_batch != 2**31-1: + if self.arrow_max_bytes_per_batch != 2**31 - 1: accumulate_size += sum(sys.getsizeof(x) for x in row) - if (len(rows) >= self.arrow_max_records_per_batch or - accumulate_size >= self.arrow_max_bytes_per_batch): + if ( + len(rows) >= self.arrow_max_records_per_batch + or accumulate_size >= self.arrow_max_bytes_per_batch + ): yield (batch_key, pd.DataFrame(rows)) rows = [] accumulate_size = 0 diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py index 0243c7bbc7165..5b65a6c3b98ae 100644 --- a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py @@ -236,14 +236,17 @@ def pandas(self): def row(self): return RowStatefulProcessorCompositeType() + class ChunkCountProcessorFactory(StatefulProcessorFactory): def pandas(self): return PandasChunkCountProcessor() + class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory): def pandas(self): return PandasChunkCountWithInitialStateProcessor() + # StatefulProcessor implementations @@ -1830,6 +1833,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]: def close(self) -> None: pass + class PandasChunkCountProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: pass @@ -1838,12 +1842,12 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: chunk_count = 0 for _ in rows: chunk_count += 1 - yield pd.DataFrame({'id': [key[0]], 'chunkCount': [chunk_count]}) - + yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]}) def close(self) -> None: pass + class PandasChunkCountWithInitialStateProcessor(StatefulProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("value", IntegerType(), True)]) @@ -1853,11 +1857,11 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: chunk_count = 0 for _ in rows: chunk_count += 1 - yield pd.DataFrame({'id': [key[0]], 'chunkCount': [chunk_count]}) + yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]}) def handleInitialState(self, key, initialState, timerValues) -> None: init_val = initialState.at[0, "initVal"] self.value_state.update((init_val,)) def close(self) -> None: - pass \ No newline at end of file + pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 58de0fa0102f2..679ced1c134d1 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -71,7 +71,7 @@ MinEventTimeStatefulProcessorFactory, StatefulProcessorCompositeTypeFactory, ChunkCountProcessorFactory, - ChunkCountProcessorWithInitialStateFactory + ChunkCountProcessorWithInitialStateFactory, ) @@ -1877,6 +1877,7 @@ def check_results(batch_df, batch_id): assert set(batch_df.sort("id").collect()) == expected_per_batch[0] else: assert set(batch_df.sort("id").collect()) == expected_per_batch[1] + return check_results result_with_small_limit = [ @@ -1887,7 +1888,7 @@ def check_results(batch_df, batch_id): { Row(id="0", chunkCount=3), Row(id="1", chunkCount=2), - } + }, ] result_with_large_limit = [ @@ -1898,58 +1899,67 @@ def check_results(batch_df, batch_id): { Row(id="0", chunkCount=1), Row(id="1", chunkCount=1), - } + }, ] data = [("0", 789), ("3", 987)] initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id") with self.sql_conf( - # Set it to a very small number so that every row would be a separate pandas df - {"spark.sql.execution.arrow.maxBytesPerBatch": "2"} + # Set it to a very small number so that every row would be a separate pandas df + {"spark.sql.execution.arrow.maxBytesPerBatch": "2"} ): self._test_transform_with_state_basic( ChunkCountProcessorFactory(), make_check_results(result_with_small_limit), - output_schema=StructType([ - StructField("id", StringType(), True), - StructField("chunkCount", IntegerType(), True) - ]) + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), ) self._test_transform_with_state_basic( ChunkCountProcessorWithInitialStateFactory(), make_check_results(result_with_small_limit), initial_state=initial_state, - output_schema=StructType([ - StructField("id", StringType(), True), - StructField("chunkCount", IntegerType(), True) - ]) + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), ) with self.sql_conf( - # Set it to a very large number so that every row would be in the same pandas df - {"spark.sql.execution.arrow.maxBytesPerBatch": "100000"} + # Set it to a very large number so that every row would be in the same pandas df + {"spark.sql.execution.arrow.maxBytesPerBatch": "100000"} ): self._test_transform_with_state_basic( ChunkCountProcessorFactory(), make_check_results(result_with_large_limit), - output_schema=StructType([ - StructField("id", StringType(), True), - StructField("chunkCount", IntegerType(), True) - ]) + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), ) self._test_transform_with_state_basic( ChunkCountProcessorWithInitialStateFactory(), make_check_results(result_with_large_limit), initial_state=initial_state, - output_schema=StructType([ - StructField("id", StringType(), True), - StructField("chunkCount", IntegerType(), True) - ]) + output_schema=StructType( + [ + StructField("id", StringType(), True), + StructField("chunkCount", IntegerType(), True), + ] + ), ) + @unittest.skipIf( not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0", cast(str, pyarrow_requirement_message or "Not supported in no-GIL mode"), From deeb98892475b7b5c1551bae4b4f86299e323329 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 18 Sep 2025 17:27:49 -0700 Subject: [PATCH 07/12] save --- .../streaming/BaseStreamingArrowWriterSuite.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala index 188376b9fe26d..fc10a102b4f55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python.streaming import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{mock, never, times, verify, when} import org.scalatest.BeforeAndAfterEach @@ -68,7 +69,15 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite with BeforeAndAfterEac test("test maxBytesPerBatch can work") { val root: VectorSchemaRoot = mock(classOf[VectorSchemaRoot]) - when(arrowWriter.sizeInBytes()).thenReturn(2) + + var sizeCounter = 0 + when(arrowWriter.write(any[InternalRow])).thenAnswer { _ => + sizeCounter += 1 + () + } + + when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter } + // Set arrowMaxBytesPerBatch to 1 transformWithStateInPySparkWriter = new BaseStreamingArrowWriter( root, writer, arrowMaxRecordsPerBatch, 1, arrowWriter) From 7b11d9051cc7272cb6837add1b69254dbe7c3fd8 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Fri, 19 Sep 2025 11:42:59 -0700 Subject: [PATCH 08/12] save --- .../python/streaming/ApplyInPandasWithStateWriter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala index 3b8fdfe910d51..cd83270bb4c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala @@ -52,7 +52,7 @@ class ApplyInPandasWithStateWriter( writer: ArrowStreamWriter, arrowMaxRecordsPerBatch: Int, arrowMaxBytesPerBatch: Long) - extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, arrowMaxRecordsPerBatch) { + extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch) { import ApplyInPandasWithStateWriter._ @@ -145,7 +145,7 @@ class ApplyInPandasWithStateWriter( // If it exceeds the condition of batch (number of records) once the all data is received for // same group, finalize and construct a new batch. - if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + if (isBatchSizeLimitReached) { finalizeCurrentArrowBatch() } } From a1768a30e1e7d6841a5e6d165d33590227473b3d Mon Sep 17 00:00:00 2001 From: Zerui Bao <125398515+zeruibao@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:32:51 -0700 Subject: [PATCH 09/12] Update ApplyInPandasWithStatePythonRunner.scala --- .../python/streaming/ApplyInPandasWithStatePythonRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala index 477592ae31536..51d9f6f523a23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala @@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner( // to let Python worker read the config properly. override protected val workerConf: Map[String, String] = initialWorkerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + - (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) private val stateRowDeserializer = stateEncoder.createDeserializer() From 3962eee5677f34f98d7dd04b2acfc56ba44bd126 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Wed, 1 Oct 2025 20:36:55 -0700 Subject: [PATCH 10/12] save --- python/pyspark/sql/pandas/serializers.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 221509987001c..ec00e7fb68b8d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1611,6 +1611,14 @@ def load_stream(self, stream): ) import sys + def average_row_size(record_batch: pa.RecordBatch) -> float: + total_bytes = 0 + for col in record_batch.columns: + for buf in col.buffers(): + if buf is not None: + total_bytes += buf.size + return total_bytes / record_batch.num_rows if record_batch.num_rows > 0 else 0.0 + def generate_data_batches(batches): """ Deserialize ArrowRecordBatches and return a generator of Rows. @@ -1622,8 +1630,13 @@ def generate_data_batches(batches): same time. And data chunks from the same grouping key should appear sequentially. """ + average_arrow_row_size = 0 def row_stream(): for batch in batches: + # Short circuit batch size calculation if the batch size is + # unlimited as computing batch size is computationally expensive. + if self.arrow_max_bytes_per_batch != 2**31 - 1: + average_arrow_row_size = average_row_size(batch) data_pandas = [ self.arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1634,20 +1647,14 @@ def row_stream(): for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]): rows = [] - accumulate_size = 0 for _, row in group_rows: rows.append(row) - # Short circuit batch size calculation if the batch size is - # unlimited as computing batch size is computationally expensive. - if self.arrow_max_bytes_per_batch != 2**31 - 1: - accumulate_size += sum(sys.getsizeof(x) for x in row) if ( len(rows) >= self.arrow_max_records_per_batch - or accumulate_size >= self.arrow_max_bytes_per_batch + or len(rows) * average_arrow_row_size >= self.arrow_max_bytes_per_batch ): yield (batch_key, pd.DataFrame(rows)) rows = [] - accumulate_size = 0 if rows: yield (batch_key, pd.DataFrame(rows)) From 6105de8fc799d3cf02677248536a04894bf5708c Mon Sep 17 00:00:00 2001 From: zeruibao Date: Wed, 1 Oct 2025 22:25:16 -0700 Subject: [PATCH 11/12] save --- python/pyspark/sql/pandas/serializers.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ec00e7fb68b8d..ec71d8d8b43f5 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1595,6 +1595,9 @@ def __init__( self.arrow_max_records_per_batch = arrow_max_records_per_batch self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch self.key_offsets = None + self.average_arrow_row_size = 0 + self.total_bytes = 0 + self.total_rows = 0 def load_stream(self, stream): """ @@ -1611,14 +1614,6 @@ def load_stream(self, stream): ) import sys - def average_row_size(record_batch: pa.RecordBatch) -> float: - total_bytes = 0 - for col in record_batch.columns: - for buf in col.buffers(): - if buf is not None: - total_bytes += buf.size - return total_bytes / record_batch.num_rows if record_batch.num_rows > 0 else 0.0 - def generate_data_batches(batches): """ Deserialize ArrowRecordBatches and return a generator of Rows. @@ -1630,13 +1625,20 @@ def generate_data_batches(batches): same time. And data chunks from the same grouping key should appear sequentially. """ - average_arrow_row_size = 0 def row_stream(): for batch in batches: # Short circuit batch size calculation if the batch size is # unlimited as computing batch size is computationally expensive. - if self.arrow_max_bytes_per_batch != 2**31 - 1: - average_arrow_row_size = average_row_size(batch) + if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0: + batch_bytes = sum( + buf.size + for col in batch.columns + for buf in col.buffers() + if buf is not None + ) + self.total_bytes += batch_bytes + self.total_rows += batch.num_rows + self.average_arrow_row_size = self.total_bytes / self.total_rows data_pandas = [ self.arrow_to_pandas(c, i) for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns()) @@ -1651,7 +1653,7 @@ def row_stream(): rows.append(row) if ( len(rows) >= self.arrow_max_records_per_batch - or len(rows) * average_arrow_row_size >= self.arrow_max_bytes_per_batch + or len(rows) * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch ): yield (batch_key, pd.DataFrame(rows)) rows = [] From 266207c2c19422dc2ddd34239b9ca607e0fcd586 Mon Sep 17 00:00:00 2001 From: zeruibao Date: Thu, 2 Oct 2025 00:06:26 -0700 Subject: [PATCH 12/12] save --- python/pyspark/sql/pandas/serializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index ec71d8d8b43f5..565621509fc3d 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1612,7 +1612,6 @@ def load_stream(self, stream): from pyspark.sql.streaming.stateful_processor_util import ( TransformWithStateInPandasFuncMode, ) - import sys def generate_data_batches(batches): """