Skip to content

Commit 3863962

Browse files
authored
[GH-2149] Geopandas: Implement to_file, from_file, read_file (#2150)
* Implement to_file, from_file, read_file * Fix ci * Delete __repr__()'s and _process_geometry_columnsi and PR feedback * Make format and extension case insensitive in read_file * Remove sort by GeoHash logic in to_file parquet
1 parent 39f478d commit 3863962

File tree

8 files changed

+736
-145
lines changed

8 files changed

+736
-145
lines changed

python/sedona/geopandas/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@
2424
from sedona.geopandas.geodataframe import GeoDataFrame
2525

2626
from sedona.geopandas.tools import sjoin
27+
28+
from sedona.geopandas.io import read_file

python/sedona/geopandas/geodataframe.py

Lines changed: 123 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pyspark.pandas._typing import Dtype
3333
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
3434
from pyspark.pandas.internal import InternalFrame
35+
from pyspark.pandas.utils import log_advice
3536

3637
from sedona.geopandas._typing import Label
3738
from sedona.geopandas.base import GeoFrame
@@ -351,10 +352,10 @@ def __getitem__(self, key: Any) -> Any:
351352

352353
try:
353354
result = sgpd.GeoSeries(ps_series)
354-
first_idx = ps_series.first_valid_index()
355-
if first_idx is not None:
356-
geom = ps_series.iloc[int(first_idx)]
357-
srid = shapely.get_srid(geom)
355+
not_null = ps_series[ps_series.notnull()]
356+
if len(not_null) > 0:
357+
first_geom = not_null.iloc[0]
358+
srid = shapely.get_srid(first_geom)
358359

359360
# Shapely objects stored in the ps.Series retain their srid
360361
# but the GeoSeries does not, so we manually re-set it here
@@ -425,7 +426,7 @@ def __init__(
425426
# instead of calling e.g assert not dtype ourselves.
426427
# This way, if Spark adds support later, than we inherit those changes naturally
427428
super().__init__(data, index=index, columns=columns, dtype=dtype, copy=copy)
428-
elif isinstance(data, PandasOnSparkDataFrame):
429+
elif isinstance(data, (PandasOnSparkDataFrame, SparkDataFrame)):
429430

430431
super().__init__(data, index=index, columns=columns, dtype=dtype, copy=copy)
431432
elif isinstance(data, PandasOnSparkSeries):
@@ -436,14 +437,6 @@ def __init__(
436437
pass
437438

438439
super().__init__(data, index=index, columns=columns, dtype=dtype, copy=copy)
439-
elif isinstance(data, SparkDataFrame):
440-
assert columns is None
441-
assert dtype is None
442-
assert not copy
443-
444-
if index is None:
445-
internal = InternalFrame(spark_frame=data, index_spark_columns=None)
446-
object.__setattr__(self, "_internal_frame", internal)
447440
else:
448441
# below are not distributed dataframe types
449442
if isinstance(data, pd.DataFrame):
@@ -480,6 +473,9 @@ def __init__(
480473
if crs is not None and data.crs != crs:
481474
raise ValueError(crs_mismatch_error)
482475

476+
if geometry:
477+
self.set_geometry(geometry, inplace=True)
478+
483479
if geometry is None and "geometry" in self.columns:
484480

485481
if (self.columns == "geometry").sum() > 1:
@@ -828,63 +824,13 @@ def active_geometry_name(self) -> Any:
828824
"""
829825
return self._geometry_column_name
830826

831-
def _process_geometry_columns(
832-
self, operation: str, rename_suffix: str = "", *args, **kwargs
833-
) -> GeoDataFrame:
834-
"""
835-
Helper method to process geometry columns with a specified operation.
836-
837-
Parameters
838-
----------
839-
operation : str
840-
The spatial operation to apply (e.g., 'ST_Area', 'ST_Buffer').
841-
rename_suffix : str, default ""
842-
Suffix to append to the resulting column name.
843-
args : tuple
844-
Positional arguments for the operation.
845-
kwargs : dict
846-
Keyword arguments for the operation.
847-
848-
Returns
849-
-------
850-
GeoDataFrame
851-
A new GeoDataFrame with the operation applied to geometry columns.
852-
"""
853-
select_expressions = []
854-
855-
for field in self._internal.spark_frame.schema.fields:
856-
col_name = field.name
857-
858-
# Skip index and order columns
859-
if col_name in ("__index_level_0__", "__natural_order__"):
860-
continue
861-
862-
if field.dataType.typeName() in ("geometrytype", "binary"):
863-
# Prepare arguments for the operation
864-
positional_params = ", ".join([repr(v) for v in args])
865-
keyword_params = ", ".join([repr(v) for v in kwargs.values()])
866-
params = ", ".join(filter(None, [positional_params, keyword_params]))
867-
868-
if field.dataType.typeName() == "binary":
869-
expr = f"{operation}(ST_GeomFromWKB(`{col_name}`){', ' + params if params else ''}) as {col_name}{rename_suffix}"
870-
else:
871-
expr = f"{operation}(`{col_name}`{', ' + params if params else ''}) as {col_name}{rename_suffix}"
872-
select_expressions.append(expr)
873-
else:
874-
# Keep non-geometry columns as they are
875-
select_expressions.append(f"`{col_name}`")
876-
877-
sdf = self._internal.spark_frame.selectExpr(*select_expressions)
878-
return GeoDataFrame(sdf)
879-
880827
def to_geopandas(self) -> gpd.GeoDataFrame:
881828
"""
882829
Note: Unlike in pandas and geopandas, Sedona will always return a general Index.
883830
This differs from pandas and geopandas, which will return a RangeIndex by default.
884831
885832
e.g pd.Index([0, 1, 2]) instead of pd.RangeIndex(start=0, stop=3, step=1)
886833
"""
887-
from pyspark.pandas.utils import log_advice
888834

889835
log_advice(
890836
"`to_geopandas` loads all data into the driver's memory. "
@@ -1007,10 +953,6 @@ def from_dict(
1007953
) -> GeoDataFrame:
1008954
raise NotImplementedError("from_dict() is not implemented yet.")
1009955

1010-
@classmethod
1011-
def from_file(cls, filename: os.PathLike | typing.IO, **kwargs) -> GeoDataFrame:
1012-
raise NotImplementedError("from_file() is not implemented yet.")
1013-
1014956
@classmethod
1015957
def from_features(
1016958
cls, features, crs: Any | None = None, columns: Iterable[str] | None = None
@@ -1290,16 +1232,6 @@ def to_feather(
12901232
):
12911233
raise NotImplementedError("to_feather() is not implemented yet.")
12921234

1293-
def to_file(
1294-
self,
1295-
filename: str,
1296-
driver: str | None = None,
1297-
schema: dict | None = None,
1298-
index: bool | None = None,
1299-
**kwargs,
1300-
):
1301-
raise NotImplementedError("to_file() is not implemented yet.")
1302-
13031235
@property
13041236
def geom_type(self) -> str:
13051237
# Implementation of the abstract method
@@ -1552,9 +1484,9 @@ def buffer(
15521484
mitre_limit=5.0,
15531485
single_sided=False,
15541486
**kwargs,
1555-
) -> GeoDataFrame:
1487+
) -> sgpd.GeoSeries:
15561488
"""
1557-
Returns a GeoDataFrame with all geometries buffered by the specified distance.
1489+
Returns a GeoSeries with all geometries buffered by the specified distance.
15581490
15591491
Parameters
15601492
----------
@@ -1573,8 +1505,8 @@ def buffer(
15731505
15741506
Returns
15751507
-------
1576-
GeoDataFrame
1577-
A new GeoDataFrame with buffered geometries.
1508+
GeoSeries
1509+
A new GeoSeries with buffered geometries.
15781510
15791511
Examples
15801512
--------
@@ -1588,8 +1520,14 @@ def buffer(
15881520
>>> gdf = GeoDataFrame(data)
15891521
>>> buffered = gdf.buffer(0.5)
15901522
"""
1591-
return self._process_geometry_columns(
1592-
"ST_Buffer", rename_suffix="_buffered", distance=distance
1523+
return self.geometry.buffer(
1524+
distance,
1525+
resolution=16,
1526+
cap_style="round",
1527+
join_style="round",
1528+
mitre_limit=5.0,
1529+
single_sided=False,
1530+
**kwargs,
15931531
)
15941532

15951533
def sjoin(
@@ -1666,18 +1604,117 @@ def sjoin(
16661604
# I/O OPERATIONS
16671605
# ============================================================================
16681606

1607+
@classmethod
1608+
def from_file(
1609+
cls, filename: str, format: str | None = None, **kwargs
1610+
) -> GeoDataFrame:
1611+
"""
1612+
Alternate constructor to create a ``GeoDataFrame`` from a file.
1613+
1614+
Parameters
1615+
----------
1616+
filename : str
1617+
File path or file handle to read from. If the path is a directory,
1618+
Sedona will read all files in the directory into a dataframe.
1619+
format : str, default None
1620+
The format of the file to read. If None, Sedona will infer the format
1621+
from the file extension. Note, inferring the format from the file extension
1622+
is not supported for directories.
1623+
Options:
1624+
- "shapefile"
1625+
- "geojson"
1626+
- "geopackage"
1627+
- "geoparquet"
1628+
1629+
table_name : str, default None
1630+
The name of the table to read from a geopackage file. Required if format is geopackage.
1631+
1632+
See also
1633+
--------
1634+
GeoDataFrame.to_file : write GeoDataFrame to file
1635+
"""
1636+
return sgpd.io.read_file(filename, format, **kwargs)
1637+
1638+
def to_file(
1639+
self,
1640+
path: str,
1641+
driver: str | None = None,
1642+
schema: dict | None = None,
1643+
index: bool | None = None,
1644+
**kwargs,
1645+
):
1646+
"""
1647+
Write the ``GeoDataFrame`` to a file.
1648+
1649+
Parameters
1650+
----------
1651+
path : string
1652+
File path or file handle to write to.
1653+
driver : string, default None
1654+
The format driver used to write the file.
1655+
If not specified, it attempts to infer it from the file extension.
1656+
If no extension is specified, Sedona will error.
1657+
Options:
1658+
- "geojson"
1659+
- "geopackage"
1660+
- "geoparquet"
1661+
schema : dict, default None
1662+
Not applicable to Sedona's implementation
1663+
index : bool, default None
1664+
If True, write index into one or more columns (for MultiIndex).
1665+
Default None writes the index into one or more columns only if
1666+
the index is named, is a MultiIndex, or has a non-integer data
1667+
type. If False, no index is written.
1668+
mode : string, default 'w'
1669+
The write mode, 'w' to overwrite the existing file and 'a' to append.
1670+
'overwrite' and 'append' are equivalent to 'w' and 'a' respectively.
1671+
crs : pyproj.CRS, default None
1672+
If specified, the CRS is passed to Fiona to
1673+
better control how the file is written. If None, GeoPandas
1674+
will determine the crs based on crs df attribute.
1675+
The value can be anything accepted
1676+
by :meth:`pyproj.CRS.from_user_input() <pyproj.crs.CRS.from_user_input>`,
1677+
such as an authority string (eg "EPSG:4326") or a WKT string.
1678+
engine : str
1679+
Not applicable to Sedona's implementation
1680+
metadata : dict[str, str], default None
1681+
Optional metadata to be stored in the file. Keys and values must be
1682+
strings. Supported only for "GPKG" driver. Not supported by Sedona
1683+
**kwargs :
1684+
Keyword args to be passed to the engine, and can be used to write
1685+
to multi-layer data, store data within archives (zip files), etc.
1686+
In case of the "pyogrio" engine, the keyword arguments are passed to
1687+
`pyogrio.write_dataframe`. In case of the "fiona" engine, the keyword
1688+
arguments are passed to fiona.open`. For more information on possible
1689+
keywords, type: ``import pyogrio; help(pyogrio.write_dataframe)``.
1690+
1691+
Examples
1692+
--------
1693+
1694+
>>> gdf = GeoDataFrame({"geometry": [Point(0, 0), LineString([(0, 0), (1, 1)])], "int": [1, 2]}
1695+
>>> gdf.to_file(filepath, format="geoparquet")
1696+
1697+
With selected drivers you can also append to a file with `mode="a"`:
1698+
1699+
>>> gdf.to_file(gdf, driver="geojson", mode="a")
1700+
1701+
When the index is of non-integer dtype, index=None (default) is treated as True, writing the index to the file.
1702+
1703+
>>> gdf = GeoDataFrame({"geometry": [Point(0, 0)]}, index=["a", "b"])
1704+
>>> gdf.to_file(gdf, driver="geoparquet")
1705+
"""
1706+
sgpd.io._to_file(self, path, driver, index, **kwargs)
1707+
16691708
def to_parquet(self, path, **kwargs):
16701709
"""
16711710
Write the GeoSeries to a GeoParquet file.
1672-
16731711
Parameters:
16741712
- path: str
16751713
The file path where the GeoParquet file will be written.
16761714
- kwargs: Any
16771715
Additional arguments to pass to the Sedona DataFrame output function.
16781716
"""
1679-
# Use the Spark DataFrame's write method to write to GeoParquet format
1680-
self._internal.spark_frame.write.format("geoparquet").save(path, **kwargs)
1717+
self.to_file(path, driver="geoparquet", **kwargs)
16811718

16821719

16831720
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)