Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,7 @@ def __init__(
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasSerializer, self).__init__(
Expand All @@ -1592,7 +1593,11 @@ def __init__(
arrow_cast=True,
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
self.key_offsets = None
self.average_arrow_row_size = 0
self.total_bytes = 0
self.total_rows = 0

def load_stream(self, stream):
"""
Expand Down Expand Up @@ -1621,6 +1626,18 @@ def generate_data_batches(batches):

def row_stream():
for batch in batches:
# Short circuit batch size calculation if the batch size is
# unlimited as computing batch size is computationally expensive.
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
batch_bytes = sum(
buf.size
for col in batch.columns
for buf in col.buffers()
if buf is not None
)
self.total_bytes += batch_bytes
self.total_rows += batch.num_rows
self.average_arrow_row_size = self.total_bytes / self.total_rows
data_pandas = [
self.arrow_to_pandas(c, i)
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
Expand All @@ -1630,8 +1647,17 @@ def row_stream():
yield (batch_key, row)

for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]):
df = pd.DataFrame([row for _, row in group_rows])
yield (batch_key, df)
rows = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon @zhengruifeng
What do you think about the code? We limit the size of Arrow RecordBatch in task thread when sending to Python worker, and @zeruibao added this to re-align the size for Pandas DataFrame. Did we do this in other UDF? Is it beneficial or probably over-thinking?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we do this in other UDF?

I don't think so

Is it beneficial or probably over-thinking?

I remember @HyukjinKwon discussed it before, it should be beneficial if the size is properly estimated

for _, row in group_rows:
rows.append(row)
if (
len(rows) >= self.arrow_max_records_per_batch
or len(rows) * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
):
yield (batch_key, pd.DataFrame(rows))
rows = []
if rows:
yield (batch_key, pd.DataFrame(rows))

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)
Expand Down Expand Up @@ -1676,13 +1702,15 @@ def __init__(
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
timezone,
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled,
)
self.init_key_offsets = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ def row(self):
return RowStatefulProcessorCompositeType()


class ChunkCountProcessorFactory(StatefulProcessorFactory):
def pandas(self):
return PandasChunkCountProcessor()


class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory):
def pandas(self):
return PandasChunkCountWithInitialStateProcessor()


# StatefulProcessor implementations


Expand Down Expand Up @@ -1822,3 +1832,36 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:

def close(self) -> None:
pass


class PandasChunkCountProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
pass

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
chunk_count = 0
for _ in rows:
chunk_count += 1
yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})

def close(self) -> None:
pass


class PandasChunkCountWithInitialStateProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.value_state = handle.getValueState("value_state", state_schema)

def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
chunk_count = 0
for _ in rows:
chunk_count += 1
yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})

def handleInitialState(self, key, initialState, timerValues) -> None:
init_val = initialState.at[0, "initVal"]
self.value_state.update((init_val,))

def close(self) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
UpcastProcessorFactory,
MinEventTimeStatefulProcessorFactory,
StatefulProcessorCompositeTypeFactory,
ChunkCountProcessorFactory,
ChunkCountProcessorWithInitialStateFactory,
)


Expand Down Expand Up @@ -1864,6 +1866,99 @@ def close(self):
.collect()
)

def test_transform_with_state_with_bytes_limit(self):
if not self.use_pandas():
return

def make_check_results(expected_per_batch):
def check_results(batch_df, batch_id):
batch_df.collect()
if batch_id == 0:
assert set(batch_df.sort("id").collect()) == expected_per_batch[0]
else:
assert set(batch_df.sort("id").collect()) == expected_per_batch[1]

return check_results

result_with_small_limit = [
{
Row(id="0", chunkCount=2),
Row(id="1", chunkCount=2),
},
{
Row(id="0", chunkCount=3),
Row(id="1", chunkCount=2),
},
]

result_with_large_limit = [
{
Row(id="0", chunkCount=1),
Row(id="1", chunkCount=1),
},
{
Row(id="0", chunkCount=1),
Row(id="1", chunkCount=1),
},
]

data = [("0", 789), ("3", 987)]
initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id")

with self.sql_conf(
# Set it to a very small number so that every row would be a separate pandas df
{"spark.sql.execution.arrow.maxBytesPerBatch": "2"}
):
self._test_transform_with_state_basic(
ChunkCountProcessorFactory(),
make_check_results(result_with_small_limit),
output_schema=StructType(
[
StructField("id", StringType(), True),
StructField("chunkCount", IntegerType(), True),
]
),
)

self._test_transform_with_state_basic(
ChunkCountProcessorWithInitialStateFactory(),
make_check_results(result_with_small_limit),
initial_state=initial_state,
output_schema=StructType(
[
StructField("id", StringType(), True),
StructField("chunkCount", IntegerType(), True),
]
),
)

with self.sql_conf(
# Set it to a very large number so that every row would be in the same pandas df
{"spark.sql.execution.arrow.maxBytesPerBatch": "100000"}
):
self._test_transform_with_state_basic(
ChunkCountProcessorFactory(),
make_check_results(result_with_large_limit),
output_schema=StructType(
[
StructField("id", StringType(), True),
StructField("chunkCount", IntegerType(), True),
]
),
)

self._test_transform_with_state_basic(
ChunkCountProcessorWithInitialStateFactory(),
make_check_results(result_with_large_limit),
initial_state=initial_state,
output_schema=StructType(
[
StructField("id", StringType(), True),
StructField("chunkCount", IntegerType(), True),
]
),
)


@unittest.skipIf(
not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,11 +2577,17 @@ def read_udfs(pickleSer, infile, eval_type):
)
arrow_max_records_per_batch = int(arrow_max_records_per_batch)

arrow_max_bytes_per_batch = runner_conf.get(
"spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
)
arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)

ser = TransformWithStateInPandasSerializer(
timezone,
safecheck,
_assign_cols_by_name,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF need it too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF follows a separate execution path. Each Arrow batch contains rows with the same key, which allows us to directly convert the batch into a Pandas DataFrame and yield it. Since the Arrow batch size is already subject to a byte-size limit, the resulting Pandas DataFrame also inherently respects this constraint.

Expand All @@ -2590,11 +2596,17 @@ def read_udfs(pickleSer, infile, eval_type):
)
arrow_max_records_per_batch = int(arrow_max_records_per_batch)

arrow_max_bytes_per_batch = runner_conf.get(
"spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
)
arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)

ser = TransformWithStateInPandasInitStateSerializer(
timezone,
safecheck,
_assign_cols_by_name,
arrow_max_records_per_batch,
arrow_max_bytes_per_batch,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,14 @@ class ApplyInPandasWithStatePythonRunner(
}

private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
private val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch

// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
// Configurations are both applied to executor and Python worker, set them to the worker conf
// to let Python worker read the config properly.
override protected val workerConf: Map[String, String] = initialWorkerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)

private val stateRowDeserializer = stateEncoder.createDeserializer()

Expand Down Expand Up @@ -142,7 +144,11 @@ class ApplyInPandasWithStatePythonRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[InType]): Boolean = {
if (pandasWriter == null) {
pandasWriter = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch)
pandasWriter = new ApplyInPandasWithStateWriter(
root,
writer,
arrowMaxRecordsPerBatch,
arrowMaxBytesPerBatch)
}
if (inputIterator.hasNext) {
val startData = dataOut.size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ import org.apache.spark.unsafe.types.UTF8String
class ApplyInPandasWithStateWriter(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
arrowMaxRecordsPerBatch: Int)
extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) {
arrowMaxRecordsPerBatch: Int,
arrowMaxBytesPerBatch: Long)
extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch) {

import ApplyInPandasWithStateWriter._

Expand Down Expand Up @@ -144,7 +145,7 @@ class ApplyInPandasWithStateWriter(

// If it exceeds the condition of batch (number of records) once the all data is received for
// same group, finalize and construct a new batch.
if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
if (isBatchSizeLimitReached) {
finalizeCurrentArrowBatch()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class BaseStreamingArrowWriter(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
arrowMaxRecordsPerBatch: Int,
arrowMaxBytesPerBatch: Long,
arrowWriterForTest: ArrowWriter = null) {
protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == null) {
ArrowWriter.create(root)
Expand All @@ -54,7 +55,7 @@ class BaseStreamingArrowWriter(
// If it exceeds the condition of batch (number of records) and there is more data for the
// same group, finalize and construct a new batch.

val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch
val isCurrentBatchFull = isBatchSizeLimitReached
if (isCurrentBatchFull) {
finalizeCurrentChunk(isLastChunkForGroup = false)
finalizeCurrentArrowBatch()
Expand Down Expand Up @@ -84,4 +85,13 @@ class BaseStreamingArrowWriter(
protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
numRowsForCurrentChunk = 0
}

protected def isBatchSizeLimitReached: Boolean = {
// If we have either reached the records or bytes limit
totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
// Short circuit batch size calculation if the batch size is unlimited as computing batch
// size is computationally expensive.
((arrowMaxBytesPerBatch != Int.MaxValue)
&& (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ class TransformWithStateInPySparkPythonRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[InType]): Boolean = {
if (pandasWriter == null) {
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
pandasWriter = new BaseStreamingArrowWriter(
root,
writer,
arrowMaxRecordsPerBatch,
arrowMaxBytesPerBatch
)
}

// If we don't have data left for the current group, move to the next group.
Expand Down Expand Up @@ -145,7 +150,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
dataOut: DataOutputStream,
inputIterator: Iterator[GroupedInType]): Boolean = {
if (pandasWriter == null) {
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
pandasWriter = new BaseStreamingArrowWriter(
root,
writer,
arrowMaxRecordsPerBatch,
arrowMaxBytesPerBatch
)
}

if (inputIterator.hasNext) {
Expand Down Expand Up @@ -200,9 +210,11 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](

protected val sqlConf = SQLConf.get
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch

override protected val workerConf: Map[String, String] = initialWorkerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)

// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
// constructor.
Expand Down
Loading