Skip to content

Commit

Permalink
Merge pull request #286 from Deltares/285-improve-performance-of-_get…
Browse files Browse the repository at this point in the history
…_topology-by-not-accessing-dataarray-each-time

prevent accessing each var again when looping over dataset vars
  • Loading branch information
Huite authored Aug 20, 2024
2 parents 3dee693 + e4dcd3b commit 447ccc2
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion xugrid/core/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def to_geodataframe(
else:
ds = self.obj.to_dataset()

variables = [var for var in ds.data_vars if dim in ds[var].dims]
variables = [var for var in ds.data_vars if dim in ds.variables[var].dims]
# TODO deal with time-dependent data, etc.
# Basically requires checking which variables are static, which aren't.
# For non-static, requires repeating all geometries.
Expand Down
2 changes: 1 addition & 1 deletion xugrid/core/dataset_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def to_geodataframe(self, dim_order=None) -> "geopandas.GeoDataFrame": # type:
else:
raise ValueError("invalid topology dimension on grid")

variables = [var for var in ds.data_vars if dim in ds[var].dims]
variables = [var for var in ds.data_vars if dim in ds.variables[var].dims]
if variables:
data = ds[variables].to_dataframe(dim_order=dim_order)
else:
Expand Down
6 changes: 5 additions & 1 deletion xugrid/ugrid/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ def default_topology_attrs(name: str, topology_dimension: int):


def _get_topology(ds: xr.Dataset) -> List[str]:
return [k for k in ds.data_vars if ds[k].attrs.get("cf_role") == "mesh_topology"]
return [
var
for var in ds.data_vars
if ds.variables[var].attrs.get("cf_role") == "mesh_topology"
]


def _infer_xy_coords(
Expand Down
2 changes: 1 addition & 1 deletion xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def validate_partition_objects(
unique_vars = set(chain(*allvars))
# Check dimensions
dims_per_var = [
{ds[var].dims for ds in data_objects if var in ds.data_vars}
{ds.variables[var].dims for ds in data_objects if var in ds.data_vars}
for var in unique_vars
]
for var, vardims in zip(unique_vars, dims_per_var):
Expand Down

0 comments on commit 447ccc2

Please sign in to comment.