diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py index 353b4ff7f87..6960538a403 100644 --- a/python/sedona/utils/geoarrow.py +++ b/python/sedona/utils/geoarrow.py @@ -26,6 +26,11 @@ from pyspark.sql import DataFrame from pyspark.sql.types import StructType, StructField, DataType, ArrayType, MapType +try: + from pyspark.util import _load_from_socket +except ImportError: + from pyspark.rdd import _load_from_socket + from sedona.sql.types import GeometryType import geopandas as gpd from pyspark.sql.pandas.types import ( @@ -50,22 +55,138 @@ def dataframe_to_arrow(df, crs=None): the output if exactly one CRS is present in the output. :return: """ - import pyarrow as pa - col_is_geometry = [isinstance(f.dataType, GeometryType) for f in df.schema.fields] if not any(col_is_geometry): return dataframe_to_arrow_raw(df) + df_projected = project_dataframe_geoarrow(df, col_is_geometry) + table = dataframe_to_arrow_raw(df_projected) + return wrap_table_or_batch(table, col_is_geometry, crs) + + +class GeoArrowDataFrameReader: + def __init__(self, df, crs=None): + from pyspark.sql.pandas.types import to_arrow_schema + from pyspark.traceback_utils import SCCallSiteSync + from pyspark.sql.pandas.serializers import ArrowCollectSerializer + + self._crs = crs + self._batch_order = None + self._col_is_geometry = [ + isinstance(f.dataType, GeometryType) for f in df.schema.fields + ] + + if any(self._col_is_geometry): + df = project_dataframe_geoarrow(df, self._col_is_geometry) + raw_schema = to_arrow_schema(df.schema) + self._schema = raw_schema_to_geoarrow_schema( + raw_schema, self._col_is_geometry, self._crs + ) + else: + self._schema = to_arrow_schema(df.schema) + + with SCCallSiteSync(df._sc): + ( + port, + auth_secret, + self._jsocket_auth_server, + ) = df._jdf.collectAsArrowToPython() + + self._batch_stream = _load_from_socket( + (port, auth_secret), ArrowCollectSerializer() + ) + + @property + def schema(self): + return self._schema + + @property + def batch_order(self): + return self._batch_order + + def to_table(self): + import pyarrow as pa + + batches = list(self) + if not batches: + return pa.Table.from_batches([], schema=self.schema) + + batches_in_order = [batches[i] for i in self.batch_order] + return pa.Table.from_batches(batches_in_order) + + def __iter__(self): + import pyarrow as pa + + try: + for batch_or_indices in self._batch_stream: + if isinstance(batch_or_indices, pa.RecordBatch): + yield wrap_table_or_batch( + batch_or_indices, self._col_is_geometry, self._crs + ) + else: + self._batch_order = batch_or_indices + finally: + self._finish_stream() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self._finish_stream() + + def __del__(self): + self._finish_stream() + + def _finish_stream(self): + from pyspark.errors.exceptions.captured import unwrap_spark_exception + + if self._jsocket_auth_server is None: + return + + with unwrap_spark_exception(): + # Join serving thread and raise any exceptions from collectAsArrowToPython + auth_server = self._jsocket_auth_server + self._jsocket_auth_server = None + auth_server.getResult() + + +def project_dataframe_geoarrow(df, col_is_geometry): df_columns = list(df) df_column_names = df.schema.fieldNames() for i, is_geom in enumerate(col_is_geometry): if is_geom: df_columns[i] = ST_AsEWKB(df_columns[i]).alias(df_column_names[i]) - df_projected = df.select(*df_columns) - table = dataframe_to_arrow_raw(df_projected) + return df.select(*df_columns) + + +def raw_schema_to_geoarrow_schema(raw_schema, col_is_geometry, crs, columns=None): + import pyarrow as pa + try: + import geoarrow.types as gat + + spec = gat.wkb() + except ImportError: + spec = None + + if columns is None: + columns = [None] * len(col_is_geometry) + + new_fields = [ + ( + wrap_geoarrow_field(raw_schema.field(i), columns[i], crs, spec) + if is_geom + else raw_schema.field(i) + ) + for i, is_geom in enumerate(col_is_geometry) + ] + + return pa.schema(new_fields) + + +def wrap_table_or_batch(table_or_batch, col_is_geometry, crs): try: # Using geoarrow-types is the preferred mechanism for Arrow output. # Using the extension type ensures that the type and its metadata will @@ -78,25 +199,19 @@ def dataframe_to_arrow(df, crs=None): new_cols = [ wrap_geoarrow_extension(col, spec, crs) if is_geom else col - for is_geom, col in zip(col_is_geometry, table.columns) + for is_geom, col in zip(col_is_geometry, table_or_batch.columns) ] - return pa.table(new_cols, table.column_names) + return table_or_batch.from_arrays(new_cols, table_or_batch.column_names) except ImportError: # In the event that we don't have access to GeoArrow extension types, # we can still add field metadata that will propagate through some types # of operations (e.g., writing this table to a file or passing it to # DuckDB as long as no intermediate transformations were applied). - new_fields = [ - ( - wrap_geoarrow_field(table.schema.field(i), table[i], crs) - if is_geom - else table.schema.field(i) - ) - for i, is_geom in enumerate(col_is_geometry) - ] - - return table.from_arrays(table.columns, schema=pa.schema(new_fields)) + schema = raw_schema_to_geoarrow_schema( + table_or_batch.schema, col_is_geometry, crs, table_or_batch.columns + ) + return table_or_batch.from_arrays(table_or_batch.columns, schema=schema) def dataframe_to_arrow_raw(df): @@ -125,7 +240,7 @@ def dataframe_to_arrow_raw(df): def wrap_geoarrow_extension(col, spec, crs): - if crs is None: + if crs is None and col is not None: crs = unique_srid_from_ewkb(col) elif not hasattr(crs, "to_json"): import pyproj @@ -135,8 +250,8 @@ def wrap_geoarrow_extension(col, spec, crs): return spec.override(crs=crs).to_pyarrow().wrap_array(col) -def wrap_geoarrow_field(field, col, crs): - if crs is None: +def wrap_geoarrow_field(field, col, crs, spec=None): + if crs is None and col is not None: crs = unique_srid_from_ewkb(col) if crs is not None: @@ -144,12 +259,16 @@ def wrap_geoarrow_field(field, col, crs): else: metadata = "" - return field.with_metadata( - { - "ARROW:extension:name": "geoarrow.wkb", - "ARROW:extension:metadata": "{" + metadata + "}", - } - ) + if spec is None: + return field.with_metadata( + { + "ARROW:extension:name": "geoarrow.wkb", + "ARROW:extension:metadata": "{" + metadata + "}", + } + ) + else: + spec_metadata = spec.from_extension_metadata("{" + metadata + "}") + return field.with_type(spec_metadata.coalesce(spec).to_pyarrow()) def crs_to_json(crs):