Skip to content
Open
1 change: 1 addition & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pyyaml>=3.11
# PySpark test dependencies
unittest-xml-reporting
openpyxl
duckdb

# PySpark test dependencies (optional)
coverage
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-311/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*


ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 duckdb"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def __hash__(self):
"pyspark.sql.tests.test_connect_compatibility",
"pyspark.sql.tests.udf_type_tests.test_udf_input_types",
"pyspark.sql.tests.udf_type_tests.test_udf_return_types",
"pyspark.sql.tests.test_interchange",
],
)

Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13824,6 +13824,25 @@ def __class_getitem__(cls, params: Any) -> object:
# we always wraps the given type hints by a tuple to mimic the variadic generic.
return create_tuple_for_frame_type(params)

def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object:
"""
Export to a C PyCapsule stream object.
Parameters
----------
requested_schema : PyCapsule, optional
The schema to attempt to use for the output stream. This is a best effort request,
Returns
-------
A C PyCapsule stream object.
"""
from pyspark.sql.interchange import SparkArrowCStreamer

return SparkArrowCStreamer(self._internal.to_internal_spark_frame).__arrow_c_stream__(
requested_schema
)


def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any:
"""
Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7064,6 +7064,23 @@ def replace(

replace.__doc__ = DataFrame.replace.__doc__

def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object:
"""
Export to a C PyCapsule stream object.

Parameters
----------
requested_schema : PyCapsule, optional
The schema to attempt to use for the output stream. This is a best effort request,

Returns
-------
A C PyCapsule stream object.
"""
from pyspark.sql.interchange import SparkArrowCStreamer

return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema)


class DataFrameStatFunctions:
"""Functionality for statistic functions with :class:`DataFrame`.
Expand Down
89 changes: 89 additions & 0 deletions python/pyspark/sql/interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Iterator

import pyarrow as pa

import pyspark.sql
from pyspark.sql.types import StructType, StructField, BinaryType
from pyspark.sql.pandas.types import to_arrow_schema


def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]:
"""Return all the partitions as Arrow arrays in an Iterator."""
# We will be using mapInArrow to convert each partition to Arrow RecordBatches.
# The return type of the function will be a single binary column containing
# the serialized RecordBatch in Arrow IPC format.
binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)])

def batch_to_bytes_iter(batch_iter):
"""
A generator function that converts RecordBatches to serialized Arrow IPC format.

Spark sends each partition as an iterator of RecordBatches. In order to return
the entire partition as a stream of Arrow RecordBatches, we need to serialize
each RecordBatch to Arrow IPC format and yield it as a single binary blob.
"""
# The size of the batch can be controlled by the Spark config
# `spark.sql.execution.arrow.maxRecordsPerBatch`.
for arrow_batch in batch_iter:
# We create an in-memory byte stream to hold the serialized batch
sink = pa.BufferOutputStream()
# Write the batch to the stream using Arrow IPC format
with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:
writer.write_batch(arrow_batch)
buf = sink.getvalue()
# The second buffer contains the offsets we are manually creating.
offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]
null_bitmap = None
# Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return
# signature. This serializes the whole batch into a single pyarrow serialized cell.
storage_arr = pa.Array.from_buffers(
type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf]
)
yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"])

# Convert all partitions to Arrow RecordBatches and map to binary blobs.
byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)
# A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one.
for row in byte_df.toLocalIterator():
with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:
for batch in reader:
# Each batch corresponds to a chunk of data in the partition.
yield batch


class SparkArrowCStreamer:
"""
A class that implements that __arrow_c_stream__ protocol for Spark partitions.

This class is implemented in a way that allows consumers to consume each partition
one at a time without materializing all partitions at once on the driver side.
"""

def __init__(self, df: pyspark.sql.DataFrame):
self._df = df
self._schema = to_arrow_schema(df.schema)

def __arrow_c_stream__(self, requested_schema=None):
"""
Return the Arrow C stream for the dataframe partitions.
"""
reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(
self._schema, _get_arrow_array_partition_stream(self._df)
)
return reader.__arrow_c_stream__(requested_schema=requested_schema)
57 changes: 57 additions & 0 deletions python/pyspark/sql/tests/test_interchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import pyarrow as pa
import pandas as pd
import pyspark.pandas as ps

try:
import duckdb

DUCKDB_TESTS = True
except ImportError:
DUCKDB_TESTS = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DUCKDB_TESTS should be in pyspark/testing/utils.py and conform to the other library checkers.



@unittest.skipIf(not DUCKDB_TESTS, " ... ")
class TestSparkArrowCStreamer(unittest.TestCase):
def test_spark_arrow_c_streamer(self):
if not DUCKDB_TESTS:
self.skipTest("duckdb is not installed")

pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], columns=["id", "value"])
psdf = ps.from_pandas(pdf)
# Use Spark Arrow C Streamer to convert PyArrow Table to DuckDB relation
stream = pa.RecordBatchReader.from_stream(psdf)
assert isinstance(stream, pa.RecordBatchReader)

# Verify the contents of the DuckDB relation
result = duckdb.execute("SELECT id, value from stream").fetchall()
expected = [(1, "a"), (2, "b"), (3, "c"), (4, "d")]
self.assertEqual(result, expected)


if __name__ == "__main__":
from pyspark.sql.tests.test_interchange import * # noqa: F401

try:
import xmlrunner # type: ignore

test_runner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
test_runner = None
unittest.main(testRunner=test_runner, verbosity=2)