From 8bd06bfe66dae467fe8da4debfa33aae2795b3d1 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Mon, 8 Dec 2025 14:39:45 -0600 Subject: [PATCH 01/12] [SPARK-54337][PS] Add support for PyCapsule to Pyspark Signed-off-by: Devin Petersohn Co-authored-by: Devin Petersohn --- python/pyspark/interchange.py | 248 ++++++++++++++++++++++++++++++++ python/pyspark/pandas/frame.py | 38 +++++ python/pyspark/sql/dataframe.py | 35 +++++ 3 files changed, 321 insertions(+) create mode 100644 python/pyspark/interchange.py diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py new file mode 100644 index 000000000000..406e19184dfb --- /dev/null +++ b/python/pyspark/interchange.py @@ -0,0 +1,248 @@ +import dataclasses +from typing import Iterable, Optional, Iterator, Any, Tuple +import pyarrow +from pyarrow.interchange.column import DtypeKind, _PyArrowColumn, ColumnBuffers, ColumnNullType, CategoricalDescription +from pyarrow.interchange.dataframe import _PyArrowDataFrame + +import pyspark.sql +from pyspark.sql.types import StructType, StructField, BinaryType +from pyspark.sql.pandas.types import to_arrow_schema + + +def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pyarrow.RecordBatch]: + """Return all the partitions as Arrow arrays in an Iterator.""" + # We will be using mapInArrow to convert each partition to Arrow RecordBatches. + # The return type of the function will be a single binary column containing + # the serialized RecordBatch in Arrow IPC format. + binary_schema = StructType([StructField("interchange_arrow_bytes", BinaryType(), nullable=False)]) + + def batch_to_bytes_iter(batch_iter): + """ + A generator function that converts RecordBatches to serialized Arrow IPC format. + + Spark sends each partition as an iterator of RecordBatches. In order to return + the entire partition as a stream of Arrow RecordBatches, we need to serialize + each RecordBatch to Arrow IPC format and yield it as a single binary blob. + """ + # The size of the batch can be controlled by the Spark config + # `spark.sql.execution.arrow.maxRecordsPerBatch`. + for arrow_batch in batch_iter: + # We create an in-memory byte stream to hold the serialized batch + sink = pyarrow.BufferOutputStream() + # Write the batch to the stream using Arrow IPC format + with pyarrow.ipc.new_stream(sink, arrow_batch.schema) as writer: + writer.write_batch(arrow_batch) + buf = sink.getvalue().to_pybytes() + # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return signature + # This serializes the whole batch into a single pyarrow serialized cell. + storage_arr = pyarrow.array([buf]) + yield pyarrow.RecordBatch.from_arrays([storage_arr], names=["interchange_arrow_bytes"]) + + # Convert all partitions to Arrow RecordBatches and map to binary blobs. + byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) + + # A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one. + for row in byte_df.toLocalIterator(): + with pyarrow.ipc.open_stream(row.interchange_arrow_bytes) as reader: + for batch in reader: + # Each batch corresponds to a chunk of data in the partition. + yield batch + + +class SparkArrowCStreamer: + """ + A class that implements that __arrow_c_stream__ protocol for Spark partitions. + + This class is implemented in a way that allows consumers to consume each partition + one at a time without materializing all partitions at once on the driver side. + """ + def __init__(self, df: pyspark.sql.DataFrame): + self._df = df + self._schema = to_arrow_schema(df.schema) + + def __arrow_c_stream__(self, requested_schema=None): + """ + Return the Arrow C stream for the dataframe partitions. + """ + reader: pyarrow.RecordBatchReader = pyarrow.RecordBatchReader.from_batches( + self._schema, _get_arrow_array_partition_stream(self._df) + ) + return reader.__arrow_c_stream__(requested_schema=requested_schema) + + +@dataclasses.dataclass(frozen=True) +class SparkInterchangeColumn(_PyArrowColumn): + """ + A class that conforms to the dataframe interchange protocol column interface. + + This class leverages the Arrow-based dataframe interchange protocol by returning + Spark partitions (chunks) in Arrow's dataframe interchange format. + """ + _spark_dataframe: "pyspark.sql.DataFrame" + _spark_column: "pyspark.sql.Column" + _allow_copy: bool + + def size(self) -> Optional[int]: + """ + The number of values in the column. + + This would trigger computation to get the size, so we return None. + """ + return None + + def offset(self) -> int: + """ + Return the offset of the first element, which is always 0 in Spark. + + The only case where the offset would not be 0 would be when this object + represents a chunk. Since we have a separate class for the column chunks, + we can safely return 0 here. + """ + return 0 + + @property + def dtype(self) -> Tuple[DtypeKind, int, str, str]: + """Return the Dtype of the column.""" + return self._dtype_from_arrowdtype( + to_arrow_schema(self._spark_dataframe.select(self._spark_column).schema).field(0).type, + bit_width=8, + ) + + @property + def describe_categorical(self) -> CategoricalDescription: + """Return the categorical description of the column, if applicable.""" + raise NotImplementedError("Categorical description is not implemented for Spark columns.") + + @property + def describe_null(self) -> Tuple[ColumnNullType, Any]: + """Return the null description of the column.""" + raise NotImplementedError("Null description is not implemented for Spark columns.") + + @property + def null_count(self) -> Optional[int]: + """Return the number of nulls in the column, or None if not known.""" + # Always return None to avoid triggering computation + return None + + @property + def num_chunks(self) -> int: + """Return the number of chunks in the column (partitions in this case).""" + return self._spark_dataframe.rdd.getNumPartitions() + + def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowColumn]: + """ + Return an iterator yielding the chunks of the column. See + SparkInterchangeDataframe.get_chunks for details. + """ + if n_chunks is not None: + raise NotImplementedError("n_chunks would require repartitioning, which is not implemented.") + arrow_array_partitions = _get_arrow_array_partition_stream(self._spark_dataframe.select(self._spark_column)) + for part in arrow_array_partitions: + yield _PyArrowColumn( + column=part.column(0), + allow_copy=self._allow_copy, + ) + + def get_buffers(self) -> ColumnBuffers: + """Return a dictionary of buffers for the column.""" + raise NotImplementedError("get_buffers would force materialization, so it is not implemented.") + + +@dataclasses.dataclass(frozen=True) +class SparkInterchangeDataframe(_PyArrowDataFrame): + """ + A class that conforms to the dataframe interchange protocol. + + This class leverages the Arrow-based dataframe interchange protocol by returning + Spark partitions (chunks) in Arrow's dataframe interchange format. This + implementation attempts to avoid materializing all the data on the driver side at + once. + """ + _spark_dataframe: "pyspark.sql.DataFrame" + _allow_copy: bool + _nan_as_null: bool + + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True) -> "SparkInterchangeDataframe": + """Construct a new interchange dataframe, potentially changing the options.""" + return SparkInterchangeDataframe( + _spark_dataframe=self._spark_dataframe, + _allow_copy=allow_copy, + _nan_as_null=nan_as_null, + ) + + @property + def metadata(self) -> dict[str, Any]: + """ + The metadata for the dataframe. + + In Spark's case, there is no additional metadata to provide. + """ + return {} + + def num_columns(self) -> int: + return len(self._spark_dataframe.columns) + + def num_chunks(self) -> int: + """Return the number of chunks in the dataframe (partitions in this case).""" + return self._spark_dataframe.rdd.getNumPartitions() + + def column_names(self) -> Iterable[str]: + return self._spark_dataframe.columns + + def get_column(self, i: int) -> "SparkInterchangeColumn": + """Get a column by the 0-based index.""" + col_name = self._spark_dataframe.columns[i] + return SparkInterchangeColumn( + _spark_dataframe=self._spark_dataframe, + _spark_column=self._spark_dataframe[col_name], + _allow_copy=self._allow_copy, + ) + + def get_column_by_name(self, name: str) -> "SparkInterchangeColumn": + """Get a column by name.""" + return SparkInterchangeColumn( + _spark_dataframe=self._spark_dataframe, + _spark_column=self._spark_dataframe[name], + _allow_copy=self._allow_copy, + ) + + def get_columns(self) -> Iterable[_PyArrowColumn]: + """Return an iterator yielding the columns.""" + for col_name in self._spark_dataframe.columns: + yield SparkInterchangeColumn( + _spark_dataframe=self._spark_dataframe, + _spark_column=self._spark_dataframe[col_name], + _allow_copy=self._allow_copy, + ) + + def select_columns(self, indices: Iterable[int]) -> "SparkInterchangeDataframe": + """Create a new DataFrame by selecting a subset of columns by index.""" + selected_column_names = [self._spark_dataframe.columns[i] for i in indices] + new_spark_df = self._spark_dataframe.select(selected_column_names) + return SparkInterchangeDataframe( + _spark_dataframe=new_spark_df, + _allow_copy=self._allow_copy, + _nan_as_null=self._nan_as_null, + ) + + def select_columns_by_name(self, names: Iterable[str]) -> "SparkInterchangeDataframe": + """Create a new DataFrame by selecting a subset of columns by name.""" + new_spark_df = self._spark_dataframe.select(list(names)) + return SparkInterchangeDataframe( + _spark_dataframe=new_spark_df, + _allow_copy=self._allow_copy, + _nan_as_null=self._nan_as_null, + ) + + def get_chunks( + self, n_chunks: Optional[int] = None + ) -> Iterable[_PyArrowDataFrame]: + """Return an iterator yielding the chunks of the dataframe.""" + if n_chunks is not None: + raise NotImplementedError("n_chunks would require repartitioning, which is not implemented.") + arrow_array_partitions = _get_arrow_array_partition_stream(self._spark_dataframe) + for part in arrow_array_partitions: + yield _PyArrowDataFrame( + part, + allow_copy=self._allow_copy, + ) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index e5aaecbb64fd..38bed135b1b2 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13824,6 +13824,44 @@ def __class_getitem__(cls, params: Any) -> object: # we always wraps the given type hints by a tuple to mimic the variadic generic. return create_tuple_for_frame_type(params) + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Return a DataFrame interchange protocol object. + + Parameters + ---------- + nan_as_null : bool, default False + Whether to treat NaN values as nulls. + allow_copy : bool, default True + Whether the implementation is allowed to return a copy of the data. + + Returns + ------- + SparkInterChangeDataFrame object. + """ + from pyspark.interchange import SparkInterchangeDataFrame + + return SparkInterchangeDataFrame(self._internal.spark_frame, nan_as_null, allow_copy) + + def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None) -> object: + """ + Export to a C PyCapsule stream object. + + Parameters + ---------- + requested_schema : pyarrow.Schema, optional + The schema to attempt to use for the output stream. This is a best effort request, + + Returns + ------- + A C PyCapsule stream object. + """ + from pyspark.interchange import SparkArrowCStreamer + + return SparkArrowCStreamer( + self._internal.to_internal_spark_frame + ).__arrow_c_stream__(requested_schema) + def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any: """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bf81c13a7bac..b64c26c0182a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -7057,6 +7057,41 @@ def replace( replace.__doc__ = DataFrame.replace.__doc__ + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Return a DataFrame interchange protocol object. + + Parameters + ---------- + nan_as_null : bool, default False + Whether to treat NaN values as nulls. + allow_copy : bool, default True + Whether the implementation is allowed to return a copy of the data. + + Returns + ------- + SparkInterChangeDataFrame object. + """ + from pyspark.interchange import SparkInterchangeDataFrame + + return SparkInterchangeDataFrame(self, nan_as_null, allow_copy) + + def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None) -> object: + """ + Export to a C PyCapsule stream object. + + Parameters + ---------- + requested_schema : pyarrow.Schema, optional + The schema to attempt to use for the output stream. This is a best effort request, + + Returns + ------- + A C PyCapsule stream object. + """ + from pyspark.interchange import SparkArrowCStreamer + + return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema) class DataFrameStatFunctions: """Functionality for statistic functions with :class:`DataFrame`. From 51c341fbbb9100330a8441f0a02c1524b9540df9 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Mon, 8 Dec 2025 19:06:09 -0600 Subject: [PATCH 02/12] Lint Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 39 ++++++++++++++++++++++++--------- python/pyspark/pandas/frame.py | 6 ++--- python/pyspark/sql/dataframe.py | 1 + 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 406e19184dfb..7f8faed380ea 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -1,7 +1,13 @@ import dataclasses from typing import Iterable, Optional, Iterator, Any, Tuple import pyarrow -from pyarrow.interchange.column import DtypeKind, _PyArrowColumn, ColumnBuffers, ColumnNullType, CategoricalDescription +from pyarrow.interchange.column import ( + DtypeKind, + _PyArrowColumn, + ColumnBuffers, + ColumnNullType, + CategoricalDescription, +) from pyarrow.interchange.dataframe import _PyArrowDataFrame import pyspark.sql @@ -14,7 +20,9 @@ def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pya # We will be using mapInArrow to convert each partition to Arrow RecordBatches. # The return type of the function will be a single binary column containing # the serialized RecordBatch in Arrow IPC format. - binary_schema = StructType([StructField("interchange_arrow_bytes", BinaryType(), nullable=False)]) + binary_schema = StructType( + [StructField("interchange_arrow_bytes", BinaryType(), nullable=False)] + ) def batch_to_bytes_iter(batch_iter): """ @@ -56,6 +64,7 @@ class SparkArrowCStreamer: This class is implemented in a way that allows consumers to consume each partition one at a time without materializing all partitions at once on the driver side. """ + def __init__(self, df: pyspark.sql.DataFrame): self._df = df self._schema = to_arrow_schema(df.schema) @@ -78,6 +87,7 @@ class SparkInterchangeColumn(_PyArrowColumn): This class leverages the Arrow-based dataframe interchange protocol by returning Spark partitions (chunks) in Arrow's dataframe interchange format. """ + _spark_dataframe: "pyspark.sql.DataFrame" _spark_column: "pyspark.sql.Column" _allow_copy: bool @@ -135,8 +145,12 @@ def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowColumn] SparkInterchangeDataframe.get_chunks for details. """ if n_chunks is not None: - raise NotImplementedError("n_chunks would require repartitioning, which is not implemented.") - arrow_array_partitions = _get_arrow_array_partition_stream(self._spark_dataframe.select(self._spark_column)) + raise NotImplementedError( + "n_chunks would require repartitioning, which is not implemented." + ) + arrow_array_partitions = _get_arrow_array_partition_stream( + self._spark_dataframe.select(self._spark_column) + ) for part in arrow_array_partitions: yield _PyArrowColumn( column=part.column(0), @@ -145,7 +159,9 @@ def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowColumn] def get_buffers(self) -> ColumnBuffers: """Return a dictionary of buffers for the column.""" - raise NotImplementedError("get_buffers would force materialization, so it is not implemented.") + raise NotImplementedError( + "get_buffers would force materialization, so it is not implemented." + ) @dataclasses.dataclass(frozen=True) @@ -158,11 +174,14 @@ class SparkInterchangeDataframe(_PyArrowDataFrame): implementation attempts to avoid materializing all the data on the driver side at once. """ + _spark_dataframe: "pyspark.sql.DataFrame" _allow_copy: bool _nan_as_null: bool - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True) -> "SparkInterchangeDataframe": + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> "SparkInterchangeDataframe": """Construct a new interchange dataframe, potentially changing the options.""" return SparkInterchangeDataframe( _spark_dataframe=self._spark_dataframe, @@ -234,12 +253,12 @@ def select_columns_by_name(self, names: Iterable[str]) -> "SparkInterchangeDataf _nan_as_null=self._nan_as_null, ) - def get_chunks( - self, n_chunks: Optional[int] = None - ) -> Iterable[_PyArrowDataFrame]: + def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowDataFrame]: """Return an iterator yielding the chunks of the dataframe.""" if n_chunks is not None: - raise NotImplementedError("n_chunks would require repartitioning, which is not implemented.") + raise NotImplementedError( + "n_chunks would require repartitioning, which is not implemented." + ) arrow_array_partitions = _get_arrow_array_partition_stream(self._spark_dataframe) for part in arrow_array_partitions: yield _PyArrowDataFrame( diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 38bed135b1b2..263c6166a556 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13858,9 +13858,9 @@ def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None """ from pyspark.interchange import SparkArrowCStreamer - return SparkArrowCStreamer( - self._internal.to_internal_spark_frame - ).__arrow_c_stream__(requested_schema) + return SparkArrowCStreamer(self._internal.to_internal_spark_frame).__arrow_c_stream__( + requested_schema + ) def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any: diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b64c26c0182a..6fb120cb76ad 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -7093,6 +7093,7 @@ def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema) + class DataFrameStatFunctions: """Functionality for statistic functions with :class:`DataFrame`. From 0de26e36931665378f29676e114ef98a93b793d4 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Tue, 9 Dec 2025 11:28:22 -0600 Subject: [PATCH 03/12] Add test Signed-off-by: Devin Petersohn --- python/pyspark/sql/tests/test_interchange.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 python/pyspark/sql/tests/test_interchange.py diff --git a/python/pyspark/sql/tests/test_interchange.py b/python/pyspark/sql/tests/test_interchange.py new file mode 100644 index 000000000000..7395018ad86d --- /dev/null +++ b/python/pyspark/sql/tests/test_interchange.py @@ -0,0 +1,28 @@ +import unittest +import pyarrow as pa +import pandas as pd +import pyspark.pandas as ps + +try: + import duckdb + + DUCKDB_TESTS = True +except ImportError: + DUCKDB_TESTS = False + + +class TestSparkArrowCStreamer(unittest.TestCase): + def test_spark_arrow_c_streamer(self): + if not DUCKDB_TESTS: + self.skipTest("duckdb is not installed") + + pdf = pd.DataFrame({"A": [1, "a"], "B": [2, "b"], "C": [3, "c"], "D": [4, "d"]}) + psdf = ps.from_pandas(pdf) + # Use Spark Arrow C Streamer to convert PyArrow Table to DuckDB relation + stream = pa.RecordBatchReader.from_stream(psdf) + assert isinstance(stream, pa.RecordBatchReader) + + # Verify the contents of the DuckDB relation + result = duckdb.execute("SELECT * from stream").fetchall() + expected = [(1, "a"), (2, "b"), (3, "c"), (4, "d")] + self.assertEqual(result, expected) From 6ab12e35c236f32cf1f5ca8ea6837497bcdd7035 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Tue, 9 Dec 2025 14:07:34 -0600 Subject: [PATCH 04/12] License on new files Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 16 ++++++++++++++++ python/pyspark/sql/tests/test_interchange.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 7f8faed380ea..892929b924e7 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import dataclasses from typing import Iterable, Optional, Iterator, Any, Tuple import pyarrow diff --git a/python/pyspark/sql/tests/test_interchange.py b/python/pyspark/sql/tests/test_interchange.py index 7395018ad86d..2d8747f6713f 100644 --- a/python/pyspark/sql/tests/test_interchange.py +++ b/python/pyspark/sql/tests/test_interchange.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import unittest import pyarrow as pa import pandas as pd From db30a25961b3ae47abdede20c69e48b709b5bf96 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 10 Dec 2025 08:51:48 -0600 Subject: [PATCH 05/12] Remove __dataframe__ Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 203 +------------------------------- python/pyspark/pandas/frame.py | 23 +--- python/pyspark/sql/dataframe.py | 23 +--- 3 files changed, 7 insertions(+), 242 deletions(-) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 892929b924e7..0e398028c66d 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -14,17 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import dataclasses -from typing import Iterable, Optional, Iterator, Any, Tuple +from typing import Iterator import pyarrow -from pyarrow.interchange.column import ( - DtypeKind, - _PyArrowColumn, - ColumnBuffers, - ColumnNullType, - CategoricalDescription, -) -from pyarrow.interchange.dataframe import _PyArrowDataFrame import pyspark.sql from pyspark.sql.types import StructType, StructField, BinaryType @@ -57,8 +48,8 @@ def batch_to_bytes_iter(batch_iter): with pyarrow.ipc.new_stream(sink, arrow_batch.schema) as writer: writer.write_batch(arrow_batch) buf = sink.getvalue().to_pybytes() - # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return signature - # This serializes the whole batch into a single pyarrow serialized cell. + # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return + # signature. This serializes the whole batch into a single pyarrow serialized cell. storage_arr = pyarrow.array([buf]) yield pyarrow.RecordBatch.from_arrays([storage_arr], names=["interchange_arrow_bytes"]) @@ -93,191 +84,3 @@ def __arrow_c_stream__(self, requested_schema=None): self._schema, _get_arrow_array_partition_stream(self._df) ) return reader.__arrow_c_stream__(requested_schema=requested_schema) - - -@dataclasses.dataclass(frozen=True) -class SparkInterchangeColumn(_PyArrowColumn): - """ - A class that conforms to the dataframe interchange protocol column interface. - - This class leverages the Arrow-based dataframe interchange protocol by returning - Spark partitions (chunks) in Arrow's dataframe interchange format. - """ - - _spark_dataframe: "pyspark.sql.DataFrame" - _spark_column: "pyspark.sql.Column" - _allow_copy: bool - - def size(self) -> Optional[int]: - """ - The number of values in the column. - - This would trigger computation to get the size, so we return None. - """ - return None - - def offset(self) -> int: - """ - Return the offset of the first element, which is always 0 in Spark. - - The only case where the offset would not be 0 would be when this object - represents a chunk. Since we have a separate class for the column chunks, - we can safely return 0 here. - """ - return 0 - - @property - def dtype(self) -> Tuple[DtypeKind, int, str, str]: - """Return the Dtype of the column.""" - return self._dtype_from_arrowdtype( - to_arrow_schema(self._spark_dataframe.select(self._spark_column).schema).field(0).type, - bit_width=8, - ) - - @property - def describe_categorical(self) -> CategoricalDescription: - """Return the categorical description of the column, if applicable.""" - raise NotImplementedError("Categorical description is not implemented for Spark columns.") - - @property - def describe_null(self) -> Tuple[ColumnNullType, Any]: - """Return the null description of the column.""" - raise NotImplementedError("Null description is not implemented for Spark columns.") - - @property - def null_count(self) -> Optional[int]: - """Return the number of nulls in the column, or None if not known.""" - # Always return None to avoid triggering computation - return None - - @property - def num_chunks(self) -> int: - """Return the number of chunks in the column (partitions in this case).""" - return self._spark_dataframe.rdd.getNumPartitions() - - def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowColumn]: - """ - Return an iterator yielding the chunks of the column. See - SparkInterchangeDataframe.get_chunks for details. - """ - if n_chunks is not None: - raise NotImplementedError( - "n_chunks would require repartitioning, which is not implemented." - ) - arrow_array_partitions = _get_arrow_array_partition_stream( - self._spark_dataframe.select(self._spark_column) - ) - for part in arrow_array_partitions: - yield _PyArrowColumn( - column=part.column(0), - allow_copy=self._allow_copy, - ) - - def get_buffers(self) -> ColumnBuffers: - """Return a dictionary of buffers for the column.""" - raise NotImplementedError( - "get_buffers would force materialization, so it is not implemented." - ) - - -@dataclasses.dataclass(frozen=True) -class SparkInterchangeDataframe(_PyArrowDataFrame): - """ - A class that conforms to the dataframe interchange protocol. - - This class leverages the Arrow-based dataframe interchange protocol by returning - Spark partitions (chunks) in Arrow's dataframe interchange format. This - implementation attempts to avoid materializing all the data on the driver side at - once. - """ - - _spark_dataframe: "pyspark.sql.DataFrame" - _allow_copy: bool - _nan_as_null: bool - - def __dataframe__( - self, nan_as_null: bool = False, allow_copy: bool = True - ) -> "SparkInterchangeDataframe": - """Construct a new interchange dataframe, potentially changing the options.""" - return SparkInterchangeDataframe( - _spark_dataframe=self._spark_dataframe, - _allow_copy=allow_copy, - _nan_as_null=nan_as_null, - ) - - @property - def metadata(self) -> dict[str, Any]: - """ - The metadata for the dataframe. - - In Spark's case, there is no additional metadata to provide. - """ - return {} - - def num_columns(self) -> int: - return len(self._spark_dataframe.columns) - - def num_chunks(self) -> int: - """Return the number of chunks in the dataframe (partitions in this case).""" - return self._spark_dataframe.rdd.getNumPartitions() - - def column_names(self) -> Iterable[str]: - return self._spark_dataframe.columns - - def get_column(self, i: int) -> "SparkInterchangeColumn": - """Get a column by the 0-based index.""" - col_name = self._spark_dataframe.columns[i] - return SparkInterchangeColumn( - _spark_dataframe=self._spark_dataframe, - _spark_column=self._spark_dataframe[col_name], - _allow_copy=self._allow_copy, - ) - - def get_column_by_name(self, name: str) -> "SparkInterchangeColumn": - """Get a column by name.""" - return SparkInterchangeColumn( - _spark_dataframe=self._spark_dataframe, - _spark_column=self._spark_dataframe[name], - _allow_copy=self._allow_copy, - ) - - def get_columns(self) -> Iterable[_PyArrowColumn]: - """Return an iterator yielding the columns.""" - for col_name in self._spark_dataframe.columns: - yield SparkInterchangeColumn( - _spark_dataframe=self._spark_dataframe, - _spark_column=self._spark_dataframe[col_name], - _allow_copy=self._allow_copy, - ) - - def select_columns(self, indices: Iterable[int]) -> "SparkInterchangeDataframe": - """Create a new DataFrame by selecting a subset of columns by index.""" - selected_column_names = [self._spark_dataframe.columns[i] for i in indices] - new_spark_df = self._spark_dataframe.select(selected_column_names) - return SparkInterchangeDataframe( - _spark_dataframe=new_spark_df, - _allow_copy=self._allow_copy, - _nan_as_null=self._nan_as_null, - ) - - def select_columns_by_name(self, names: Iterable[str]) -> "SparkInterchangeDataframe": - """Create a new DataFrame by selecting a subset of columns by name.""" - new_spark_df = self._spark_dataframe.select(list(names)) - return SparkInterchangeDataframe( - _spark_dataframe=new_spark_df, - _allow_copy=self._allow_copy, - _nan_as_null=self._nan_as_null, - ) - - def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable[_PyArrowDataFrame]: - """Return an iterator yielding the chunks of the dataframe.""" - if n_chunks is not None: - raise NotImplementedError( - "n_chunks would require repartitioning, which is not implemented." - ) - arrow_array_partitions = _get_arrow_array_partition_stream(self._spark_dataframe) - for part in arrow_array_partitions: - yield _PyArrowDataFrame( - part, - allow_copy=self._allow_copy, - ) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 263c6166a556..c6ea125215c0 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13824,32 +13824,13 @@ def __class_getitem__(cls, params: Any) -> object: # we always wraps the given type hints by a tuple to mimic the variadic generic. return create_tuple_for_frame_type(params) - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Return a DataFrame interchange protocol object. - - Parameters - ---------- - nan_as_null : bool, default False - Whether to treat NaN values as nulls. - allow_copy : bool, default True - Whether the implementation is allowed to return a copy of the data. - - Returns - ------- - SparkInterChangeDataFrame object. - """ - from pyspark.interchange import SparkInterchangeDataFrame - - return SparkInterchangeDataFrame(self._internal.spark_frame, nan_as_null, allow_copy) - - def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None) -> object: + def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: """ Export to a C PyCapsule stream object. Parameters ---------- - requested_schema : pyarrow.Schema, optional + requested_schema : PyCapsule, optional The schema to attempt to use for the output stream. This is a best effort request, Returns diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6fb120cb76ad..a4e792adc782 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -7057,32 +7057,13 @@ def replace( replace.__doc__ = DataFrame.replace.__doc__ - def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): - """ - Return a DataFrame interchange protocol object. - - Parameters - ---------- - nan_as_null : bool, default False - Whether to treat NaN values as nulls. - allow_copy : bool, default True - Whether the implementation is allowed to return a copy of the data. - - Returns - ------- - SparkInterChangeDataFrame object. - """ - from pyspark.interchange import SparkInterchangeDataFrame - - return SparkInterchangeDataFrame(self, nan_as_null, allow_copy) - - def __arrow_c_stream__(self, requested_schema: Optional["pyarrow.Schema"] = None) -> object: + def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: """ Export to a C PyCapsule stream object. Parameters ---------- - requested_schema : pyarrow.Schema, optional + requested_schema : PyCapsule, optional The schema to attempt to use for the output stream. This is a best effort request, Returns From b3c0cd95bd958b39493583ecfb7d581db2337f51 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 10 Dec 2025 09:51:37 -0600 Subject: [PATCH 06/12] Avoid copy to python bytes, manually build the pyarrow array from buffer. Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 0e398028c66d..6d3712c9a67a 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -28,7 +28,7 @@ def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pya # The return type of the function will be a single binary column containing # the serialized RecordBatch in Arrow IPC format. binary_schema = StructType( - [StructField("interchange_arrow_bytes", BinaryType(), nullable=False)] + [StructField("arrow_ipc_bytes", BinaryType(), nullable=False)] ) def batch_to_bytes_iter(batch_iter): @@ -47,18 +47,27 @@ def batch_to_bytes_iter(batch_iter): # Write the batch to the stream using Arrow IPC format with pyarrow.ipc.new_stream(sink, arrow_batch.schema) as writer: writer.write_batch(arrow_batch) - buf = sink.getvalue().to_pybytes() + buf = sink.getvalue() + # The second buffer contains the offsets we are manually creating. + offset_buf = pyarrow.array([0, len(buf)], type=pyarrow.int32()).buffers()[1] + null_bitmap = None # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return # signature. This serializes the whole batch into a single pyarrow serialized cell. - storage_arr = pyarrow.array([buf]) - yield pyarrow.RecordBatch.from_arrays([storage_arr], names=["interchange_arrow_bytes"]) + storage_arr = pyarrow.Array.from_buffers( + type=pyarrow.binary(), + length=1, + buffers=[null_bitmap, offset_buf, buf] + ) + yield pyarrow.RecordBatch.from_arrays( + [storage_arr], + names=["arrow_ipc_bytes"] + ) # Convert all partitions to Arrow RecordBatches and map to binary blobs. byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) - # A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one. for row in byte_df.toLocalIterator(): - with pyarrow.ipc.open_stream(row.interchange_arrow_bytes) as reader: + with pyarrow.ipc.open_stream(row.arrow_ipc_bytes) as reader: for batch in reader: # Each batch corresponds to a chunk of data in the partition. yield batch From b380413471c48febea84c40067239c109f6af7cd Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 10 Dec 2025 13:27:26 -0600 Subject: [PATCH 07/12] lint Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 6d3712c9a67a..42ab0a57f469 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -27,9 +27,7 @@ def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pya # We will be using mapInArrow to convert each partition to Arrow RecordBatches. # The return type of the function will be a single binary column containing # the serialized RecordBatch in Arrow IPC format. - binary_schema = StructType( - [StructField("arrow_ipc_bytes", BinaryType(), nullable=False)] - ) + binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)]) def batch_to_bytes_iter(batch_iter): """ @@ -54,14 +52,9 @@ def batch_to_bytes_iter(batch_iter): # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return # signature. This serializes the whole batch into a single pyarrow serialized cell. storage_arr = pyarrow.Array.from_buffers( - type=pyarrow.binary(), - length=1, - buffers=[null_bitmap, offset_buf, buf] - ) - yield pyarrow.RecordBatch.from_arrays( - [storage_arr], - names=["arrow_ipc_bytes"] + type=pyarrow.binary(), length=1, buffers=[null_bitmap, offset_buf, buf] ) + yield pyarrow.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"]) # Convert all partitions to Arrow RecordBatches and map to binary blobs. byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) From a2009f17b8d8a44762f6471f3f817312555d822a Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 11 Dec 2025 10:38:22 -0600 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Takuya UESHIN --- python/pyspark/interchange.py | 2 +- python/pyspark/sql/tests/test_interchange.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py index 42ab0a57f469..84464fbe4e90 100644 --- a/python/pyspark/interchange.py +++ b/python/pyspark/interchange.py @@ -15,7 +15,7 @@ # limitations under the License. # from typing import Iterator -import pyarrow +import pyarrow as pa import pyspark.sql from pyspark.sql.types import StructType, StructField, BinaryType diff --git a/python/pyspark/sql/tests/test_interchange.py b/python/pyspark/sql/tests/test_interchange.py index 2d8747f6713f..eea639b4970e 100644 --- a/python/pyspark/sql/tests/test_interchange.py +++ b/python/pyspark/sql/tests/test_interchange.py @@ -27,6 +27,7 @@ DUCKDB_TESTS = False +@unittest.skipIf(not DUCKDB_TESTS, " ... ") class TestSparkArrowCStreamer(unittest.TestCase): def test_spark_arrow_c_streamer(self): if not DUCKDB_TESTS: From 9a56f09fbc3571dbb541ef216360414f687dee3b Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 11 Dec 2025 10:51:19 -0600 Subject: [PATCH 09/12] Address comments Signed-off-by: Devin Petersohn --- dev/requirements.txt | 1 + dev/spark-test-image/python-311/Dockerfile | 2 +- dev/sparktestsupport/modules.py | 1 + python/pyspark/pandas/frame.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/interchange.py | 89 ++++++++++++++++++++ python/pyspark/sql/tests/test_interchange.py | 16 +++- 7 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 python/pyspark/sql/interchange.py diff --git a/dev/requirements.txt b/dev/requirements.txt index cde0957715bf..d46df7e18be7 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -17,6 +17,7 @@ pyyaml>=3.11 # PySpark test dependencies unittest-xml-reporting openpyxl +duckdb # PySpark test dependencies (optional) coverage diff --git a/dev/spark-test-image/python-311/Dockerfile b/dev/spark-test-image/python-311/Dockerfile index f8a9df5842ce..0fde61763a8e 100644 --- a/dev/spark-test-image/python-311/Dockerfile +++ b/dev/spark-test-image/python-311/Dockerfile @@ -68,7 +68,7 @@ RUN apt-get update && apt-get install -y \ && rm -rf /var/lib/apt/lists/* -ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 duckdb" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 306a3b69223f..cd2702ad7168 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -589,6 +589,7 @@ def __hash__(self): "pyspark.sql.tests.test_connect_compatibility", "pyspark.sql.tests.udf_type_tests.test_udf_input_types", "pyspark.sql.tests.udf_type_tests.test_udf_return_types", + "pyspark.sql.tests.test_interchange", ], ) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index c6ea125215c0..e23828ef20f3 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -13837,7 +13837,7 @@ def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> objec ------- A C PyCapsule stream object. """ - from pyspark.interchange import SparkArrowCStreamer + from pyspark.sql.interchange import SparkArrowCStreamer return SparkArrowCStreamer(self._internal.to_internal_spark_frame).__arrow_c_stream__( requested_schema diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a4e792adc782..2648f054e026 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -7070,7 +7070,7 @@ def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> objec ------- A C PyCapsule stream object. """ - from pyspark.interchange import SparkArrowCStreamer + from pyspark.sql.interchange import SparkArrowCStreamer return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema) diff --git a/python/pyspark/sql/interchange.py b/python/pyspark/sql/interchange.py new file mode 100644 index 000000000000..e687fa71b011 --- /dev/null +++ b/python/pyspark/sql/interchange.py @@ -0,0 +1,89 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Iterator + +import pyarrow as pa + +import pyspark.sql +from pyspark.sql.types import StructType, StructField, BinaryType +from pyspark.sql.pandas.types import to_arrow_schema + + +def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa.RecordBatch]: + """Return all the partitions as Arrow arrays in an Iterator.""" + # We will be using mapInArrow to convert each partition to Arrow RecordBatches. + # The return type of the function will be a single binary column containing + # the serialized RecordBatch in Arrow IPC format. + binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)]) + + def batch_to_bytes_iter(batch_iter): + """ + A generator function that converts RecordBatches to serialized Arrow IPC format. + + Spark sends each partition as an iterator of RecordBatches. In order to return + the entire partition as a stream of Arrow RecordBatches, we need to serialize + each RecordBatch to Arrow IPC format and yield it as a single binary blob. + """ + # The size of the batch can be controlled by the Spark config + # `spark.sql.execution.arrow.maxRecordsPerBatch`. + for arrow_batch in batch_iter: + # We create an in-memory byte stream to hold the serialized batch + sink = pa.BufferOutputStream() + # Write the batch to the stream using Arrow IPC format + with pa.ipc.new_stream(sink, arrow_batch.schema) as writer: + writer.write_batch(arrow_batch) + buf = sink.getvalue() + # The second buffer contains the offsets we are manually creating. + offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1] + null_bitmap = None + # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return + # signature. This serializes the whole batch into a single pyarrow serialized cell. + storage_arr = pa.Array.from_buffers( + type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, buf] + ) + yield pa.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"]) + + # Convert all partitions to Arrow RecordBatches and map to binary blobs. + byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) + # A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one. + for row in byte_df.toLocalIterator(): + with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader: + for batch in reader: + # Each batch corresponds to a chunk of data in the partition. + yield batch + + +class SparkArrowCStreamer: + """ + A class that implements that __arrow_c_stream__ protocol for Spark partitions. + + This class is implemented in a way that allows consumers to consume each partition + one at a time without materializing all partitions at once on the driver side. + """ + + def __init__(self, df: pyspark.sql.DataFrame): + self._df = df + self._schema = to_arrow_schema(df.schema) + + def __arrow_c_stream__(self, requested_schema=None): + """ + Return the Arrow C stream for the dataframe partitions. + """ + reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches( + self._schema, _get_arrow_array_partition_stream(self._df) + ) + return reader.__arrow_c_stream__(requested_schema=requested_schema) diff --git a/python/pyspark/sql/tests/test_interchange.py b/python/pyspark/sql/tests/test_interchange.py index eea639b4970e..8df207b1a9df 100644 --- a/python/pyspark/sql/tests/test_interchange.py +++ b/python/pyspark/sql/tests/test_interchange.py @@ -33,13 +33,25 @@ def test_spark_arrow_c_streamer(self): if not DUCKDB_TESTS: self.skipTest("duckdb is not installed") - pdf = pd.DataFrame({"A": [1, "a"], "B": [2, "b"], "C": [3, "c"], "D": [4, "d"]}) + pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], columns=["id", "value"]) psdf = ps.from_pandas(pdf) # Use Spark Arrow C Streamer to convert PyArrow Table to DuckDB relation stream = pa.RecordBatchReader.from_stream(psdf) assert isinstance(stream, pa.RecordBatchReader) # Verify the contents of the DuckDB relation - result = duckdb.execute("SELECT * from stream").fetchall() + result = duckdb.execute("SELECT id, value from stream").fetchall() expected = [(1, "a"), (2, "b"), (3, "c"), (4, "d")] self.assertEqual(result, expected) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_interchange import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + test_runner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + test_runner = None + unittest.main(testRunner=test_runner, verbosity=2) From 101636f3c169e60b9f602a6b530ff097deadb521 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 11 Dec 2025 12:22:54 -0600 Subject: [PATCH 10/12] Move file Signed-off-by: Devin Petersohn --- python/pyspark/interchange.py | 88 ----------------------------------- 1 file changed, 88 deletions(-) delete mode 100644 python/pyspark/interchange.py diff --git a/python/pyspark/interchange.py b/python/pyspark/interchange.py deleted file mode 100644 index 84464fbe4e90..000000000000 --- a/python/pyspark/interchange.py +++ /dev/null @@ -1,88 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Iterator -import pyarrow as pa - -import pyspark.sql -from pyspark.sql.types import StructType, StructField, BinaryType -from pyspark.sql.pandas.types import to_arrow_schema - - -def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pyarrow.RecordBatch]: - """Return all the partitions as Arrow arrays in an Iterator.""" - # We will be using mapInArrow to convert each partition to Arrow RecordBatches. - # The return type of the function will be a single binary column containing - # the serialized RecordBatch in Arrow IPC format. - binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)]) - - def batch_to_bytes_iter(batch_iter): - """ - A generator function that converts RecordBatches to serialized Arrow IPC format. - - Spark sends each partition as an iterator of RecordBatches. In order to return - the entire partition as a stream of Arrow RecordBatches, we need to serialize - each RecordBatch to Arrow IPC format and yield it as a single binary blob. - """ - # The size of the batch can be controlled by the Spark config - # `spark.sql.execution.arrow.maxRecordsPerBatch`. - for arrow_batch in batch_iter: - # We create an in-memory byte stream to hold the serialized batch - sink = pyarrow.BufferOutputStream() - # Write the batch to the stream using Arrow IPC format - with pyarrow.ipc.new_stream(sink, arrow_batch.schema) as writer: - writer.write_batch(arrow_batch) - buf = sink.getvalue() - # The second buffer contains the offsets we are manually creating. - offset_buf = pyarrow.array([0, len(buf)], type=pyarrow.int32()).buffers()[1] - null_bitmap = None - # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy mapInArrow return - # signature. This serializes the whole batch into a single pyarrow serialized cell. - storage_arr = pyarrow.Array.from_buffers( - type=pyarrow.binary(), length=1, buffers=[null_bitmap, offset_buf, buf] - ) - yield pyarrow.RecordBatch.from_arrays([storage_arr], names=["arrow_ipc_bytes"]) - - # Convert all partitions to Arrow RecordBatches and map to binary blobs. - byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema) - # A row is actually a batch of data in Arrow IPC format. Fetch the batches one by one. - for row in byte_df.toLocalIterator(): - with pyarrow.ipc.open_stream(row.arrow_ipc_bytes) as reader: - for batch in reader: - # Each batch corresponds to a chunk of data in the partition. - yield batch - - -class SparkArrowCStreamer: - """ - A class that implements that __arrow_c_stream__ protocol for Spark partitions. - - This class is implemented in a way that allows consumers to consume each partition - one at a time without materializing all partitions at once on the driver side. - """ - - def __init__(self, df: pyspark.sql.DataFrame): - self._df = df - self._schema = to_arrow_schema(df.schema) - - def __arrow_c_stream__(self, requested_schema=None): - """ - Return the Arrow C stream for the dataframe partitions. - """ - reader: pyarrow.RecordBatchReader = pyarrow.RecordBatchReader.from_batches( - self._schema, _get_arrow_array_partition_stream(self._df) - ) - return reader.__arrow_c_stream__(requested_schema=requested_schema) From 0af8766273a02f3cf742d3cc4ea2e350b3d79c67 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 11 Dec 2025 13:41:58 -0600 Subject: [PATCH 11/12] Update skip comment Signed-off-by: Devin Petersohn --- python/pyspark/sql/tests/test_interchange.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_interchange.py b/python/pyspark/sql/tests/test_interchange.py index 8df207b1a9df..5ea73ab787a9 100644 --- a/python/pyspark/sql/tests/test_interchange.py +++ b/python/pyspark/sql/tests/test_interchange.py @@ -27,12 +27,9 @@ DUCKDB_TESTS = False -@unittest.skipIf(not DUCKDB_TESTS, " ... ") +@unittest.skipIf(not DUCKDB_TESTS, "duckdb is not installed") class TestSparkArrowCStreamer(unittest.TestCase): def test_spark_arrow_c_streamer(self): - if not DUCKDB_TESTS: - self.skipTest("duckdb is not installed") - pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], columns=["id", "value"]) psdf = ps.from_pandas(pdf) # Use Spark Arrow C Streamer to convert PyArrow Table to DuckDB relation From deb4df5d1fb9a1e67b45967e1fa8b8c8fc667b7e Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 11 Dec 2025 16:17:26 -0600 Subject: [PATCH 12/12] Lint Signed-off-by: Devin Petersohn --- python/pyspark/sql/dataframe.py | 34 +++++++++++++++---------------- python/pyspark/sql/interchange.py | 6 +++--- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b12c3056c3a0..c5b060b9f59e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -6983,6 +6983,23 @@ def plot(self) -> "PySparkPlotAccessor": """ ... + def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: + """ + Export to a C PyCapsule stream object. + + Parameters + ---------- + requested_schema : PyCapsule, optional + The schema to attempt to use for the output stream. This is a best effort request, + + Returns + ------- + A C PyCapsule stream object. + """ + from pyspark.sql.interchange import SparkArrowCStreamer + + return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema) + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. @@ -7064,23 +7081,6 @@ def replace( replace.__doc__ = DataFrame.replace.__doc__ - def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: - """ - Export to a C PyCapsule stream object. - - Parameters - ---------- - requested_schema : PyCapsule, optional - The schema to attempt to use for the output stream. This is a best effort request, - - Returns - ------- - A C PyCapsule stream object. - """ - from pyspark.sql.interchange import SparkArrowCStreamer - - return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema) - class DataFrameStatFunctions: """Functionality for statistic functions with :class:`DataFrame`. diff --git a/python/pyspark/sql/interchange.py b/python/pyspark/sql/interchange.py index e687fa71b011..141d9f37148e 100644 --- a/python/pyspark/sql/interchange.py +++ b/python/pyspark/sql/interchange.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Iterator +from typing import Iterator, Optional import pyarrow as pa @@ -30,7 +30,7 @@ def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> Iterator[pa. # the serialized RecordBatch in Arrow IPC format. binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), nullable=False)]) - def batch_to_bytes_iter(batch_iter): + def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: """ A generator function that converts RecordBatches to serialized Arrow IPC format. @@ -79,7 +79,7 @@ def __init__(self, df: pyspark.sql.DataFrame): self._df = df self._schema = to_arrow_schema(df.schema) - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: """ Return the Arrow C stream for the dataframe partitions. """