Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ Deprecations
Bug Fixes
~~~~~~~~~

- :py:meth:`Dataset.map` now merges attrs from the function result and the original
using the ``drop_conflicts`` strategy when ``keep_attrs=True``, preserving attrs
set by the function (:issue:`11019`, :pull:`11020`).
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Ensure that ``keep_attrs='drop'`` and ``keep_attrs=False`` remove attrs from result, even when there is
only one xarray object given to ``apply_ufunc`` (:issue:`10982` :pull:`10997`).
By `Julia Signell <https://github.com/jsignell>`_.
Expand Down
6 changes: 6 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ sparse = "0.15.*"
toolz = "0.12.*"
zarr = "2.18.*"

# TODO: Remove `platforms` restriction once pandas nightly has win-64 wheels again.
# Without this, `pixi lock` fails because it can't solve the nightly feature for win-64,
# which breaks RTD builds (RTD has no lock file cache, unlike GitHub Actions CI).
[feature.nightly]
platforms = ["linux-64", "osx-arm64"]

[feature.nightly.dependencies]
python = "*"

Expand Down
20 changes: 18 additions & 2 deletions xarray/computation/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,29 @@ def _implementation(self, func, dim, **kwargs) -> DataArray:

dataset = self.obj._to_temp_dataset()
dataset = dataset.map(func, dim=dim, **kwargs)
return self.obj._from_temp_dataset(dataset)
result = self.obj._from_temp_dataset(dataset)

# Clear attrs when keep_attrs is explicitly False
# (weighted operations can propagate attrs from weights through internal computations)
if kwargs.get("keep_attrs") is False:
result.attrs = {}

return result


class DatasetWeighted(Weighted["Dataset"]):
def _implementation(self, func, dim, **kwargs) -> Dataset:
self._check_dim(dim)
return self.obj.map(func, dim=dim, **kwargs)
result = self.obj.map(func, dim=dim, **kwargs)

# Clear attrs when keep_attrs is explicitly False
# (weighted operations can propagate attrs from weights through internal computations)
if kwargs.get("keep_attrs") is False:
result.attrs = {}
for var in result.data_vars.values():
var.attrs = {}

return result


def _inject_docstring(cls, cls_name):
Expand Down
23 changes: 14 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6910,8 +6910,10 @@ def map(
DataArray.
keep_attrs : bool or None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
combined from the original objects and the function results using the
``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs
are dropped. If False, the new dataset and variables will have only
the attributes set by the function.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Expand Down Expand Up @@ -6960,16 +6962,19 @@ def map(
coords = Coordinates._construct_direct(coords=coord_vars, indexes=indexes)

if keep_attrs:
# Merge attrs from function result and original, dropping conflicts
from xarray.structure.merge import merge_attrs

for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
v.attrs = merge_attrs(
[v.attrs, self.data_vars[k].attrs], "drop_conflicts"
)
for k, v in coords.items():
if k in self.coords:
v._copy_attrs_from(self.coords[k])
else:
for v in variables.values():
v.attrs = {}
for v in coords.values():
v.attrs = {}
v.attrs = merge_attrs(
[v.attrs, self.coords[k].attrs], "drop_conflicts"
)
# When keep_attrs=False, leave attrs as the function returned them
Comment on lines 6964 to +6977
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with interpreting keep_attrs=False as "leave the attrs as returned by the function" is that that means we don't have keep_attrs="drop" anymore.

I'd argue that keep_attrs=True should be closer to what you're proposing for keep_attrs=False, which I do think would be more intuitive.

So instead we may need to consider supporting keep_attrs with strategy names / a strategy function, like apply_ufunc does. That would still allow you to choose "drop_conflicts" if preferred (or maybe as the default? Not sure), while not changing behavior too drastically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the proposed code treats keep_attrs=False as "remove all the input attrs". but not "remove all the output attrs".

@keewis can you see a reasonable change to fix the immediate issue without adding a whole strategy to keep_attrs? I don't have a particularly strong view on this specific implementation, but it does seem reasonable / logical, and it does let us solve this immediate bug...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(zooming out — as I mentioned before, for me the best "blank-slate" implementation for keep_attrs is to mostly not have a the option at all, and folks can drop attrs if they want. though I agree with you that merging is case that neither approach handles well...)


attrs = self.attrs if keep_attrs else None
return type(self)(variables, coords=coords, attrs=attrs)
Expand Down
13 changes: 10 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ def map( # type: ignore[override]
DataArray.
keep_attrs : bool | None, optional
If True, both the dataset's and variables' attributes (`attrs`) will be
copied from the original objects to the new ones. If False, the new dataset
and variables will be returned without copying the attributes.
combined from the original objects and the function results using the
``drop_conflicts`` strategy: matching attrs are kept, conflicting attrs
are dropped. If False, the new dataset and variables will have only
the attributes set by the function.
args : iterable, optional
Positional arguments passed on to `func`.
**kwargs : Any
Expand Down Expand Up @@ -438,8 +440,13 @@ def map( # type: ignore[override]
for k, v in self.data_vars.items()
}
if keep_attrs:
# Merge attrs from function result and original, dropping conflicts
from xarray.structure.merge import merge_attrs

for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
v.attrs = merge_attrs(
[v.attrs, self.data_vars[k].attrs], "drop_conflicts"
)
attrs = self.attrs if keep_attrs else None
# return type(self)(variables, attrs=attrs)
return Dataset(variables, attrs=attrs)
Expand Down
29 changes: 29 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6452,6 +6452,35 @@ def mixed_func(x):
expected = xr.Dataset({"foo": 42, "bar": ("y", [4, 5])})
assert_identical(result, expected)

def test_map_preserves_function_attrs(self) -> None:
# Regression test for GH11019
# Attrs added by function should be preserved in result
ds = xr.Dataset({"test": ("x", [1, 2, 3], {"original": "value"})})

def add_attr(da):
return da.assign_attrs(new_attr="foobar")

# With keep_attrs=True: merge using drop_conflicts (no conflict here)
result = ds.map(add_attr, keep_attrs=True)
assert result["test"].attrs == {"original": "value", "new_attr": "foobar"}

# With keep_attrs=False: function's attrs preserved
result = ds.map(add_attr, keep_attrs=False)
assert result["test"].attrs == {"original": "value", "new_attr": "foobar"}

# When function modifies existing attr with keep_attrs=True, conflict is dropped
def modify_attr(da):
return da.assign_attrs(original="modified", extra="added")

result = ds.map(modify_attr, keep_attrs=True)
assert result["test"].attrs == {
"extra": "added"
} # "original" dropped due to conflict

# When function modifies existing attr with keep_attrs=False, function wins
result = ds.map(modify_attr, keep_attrs=False)
assert result["test"].attrs == {"original": "modified", "extra": "added"}

def test_apply_pending_deprecated_map(self) -> None:
data = create_test_data()
data.attrs["foo"] = "bar"
Expand Down
Loading