Skip to content

Commit 20d3773

Browse files
authored
Add support for DataTree to xarray.merge() (#10790)
* Add support for DataTree to xarray.merge() * Add path context to errors * add re.escape * use level instead of counting / * fix whats new
1 parent ad51404 commit 20d3773

File tree

8 files changed

+229
-43
lines changed

8 files changed

+229
-43
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ v2025.10.2 (unreleased)
1313
New Features
1414
~~~~~~~~~~~~
1515

16+
- :py:func:`merge` now supports merging :py:class:`DataTree` objects
17+
(:issue:`9790`).
18+
By `Stephan Hoyer <https://github.com/shoyer>`_.
1619

1720
Breaking Changes
1821
~~~~~~~~~~~~~~~~

xarray/core/datatree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from xarray.core.dataset import Dataset
3838
from xarray.core.dataset_variables import DataVariables
3939
from xarray.core.datatree_mapping import (
40-
_handle_errors_with_path_context,
40+
add_path_context_to_errors,
4141
map_over_datasets,
4242
)
4343
from xarray.core.formatting import (
@@ -2213,8 +2213,8 @@ def _selective_indexing(
22132213
result = {}
22142214
for path, node in self.subtree_with_keys:
22152215
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
2216-
func_with_error_context = _handle_errors_with_path_context(path)(func)
2217-
node_result = func_with_error_context(node.dataset, node_indexers)
2216+
with add_path_context_to_errors(path):
2217+
node_result = func(node.dataset, node_indexers)
22182218
# Indexing datasets corresponding to each node results in redundant
22192219
# coordinates when indexes from a parent node are inherited.
22202220
# Ideally, we would avoid creating such coordinates in the first

xarray/core/datatree_mapping.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable, Mapping
4+
from contextlib import contextmanager
45
from typing import TYPE_CHECKING, Any, cast, overload
56

67
from xarray.core.dataset import Dataset
@@ -112,9 +113,8 @@ def map_over_datasets(
112113
for i, arg in enumerate(args):
113114
if not isinstance(arg, DataTree):
114115
node_dataset_args.insert(i, arg)
115-
116-
func_with_error_context = _handle_errors_with_path_context(path)(func)
117-
results = func_with_error_context(*node_dataset_args, **kwargs)
116+
with add_path_context_to_errors(path):
117+
results = func(*node_dataset_args, **kwargs)
118118
out_data_objects[path] = results
119119

120120
num_return_values = _check_all_return_values(out_data_objects)
@@ -138,27 +138,14 @@ def map_over_datasets(
138138
)
139139

140140

141-
def _handle_errors_with_path_context(path: str):
142-
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
143-
144-
def decorator(func):
145-
def wrapper(*args, **kwargs):
146-
try:
147-
return func(*args, **kwargs)
148-
except Exception as e:
149-
# Add the context information to the error message
150-
add_note(
151-
e, f"Raised whilst mapping function over node with path {path!r}"
152-
)
153-
raise
154-
155-
return wrapper
156-
157-
return decorator
158-
159-
160-
def add_note(err: BaseException, msg: str) -> None:
161-
err.add_note(msg)
141+
@contextmanager
142+
def add_path_context_to_errors(path: str):
143+
"""Add path context to any errors."""
144+
try:
145+
yield
146+
except Exception as e:
147+
e.add_note(f"Raised whilst mapping function over node(s) with path {path!r}")
148+
raise
162149

163150

164151
def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:

xarray/core/treenode.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
import sys
55
from collections.abc import Iterator, Mapping
66
from pathlib import PurePosixPath
7-
from typing import (
8-
TYPE_CHECKING,
9-
Any,
10-
TypeVar,
11-
)
7+
from typing import TYPE_CHECKING, Any, TypeVar
128

139
from xarray.core.types import Self
1410
from xarray.core.utils import Frozen, is_dict_like

xarray/structure/merge.py

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from collections.abc import Hashable, Iterable, Mapping, Sequence
55
from collections.abc import Set as AbstractSet
6-
from typing import TYPE_CHECKING, Any, NamedTuple, Union
6+
from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload
77

88
import pandas as pd
99

@@ -34,6 +34,7 @@
3434
from xarray.core.coordinates import Coordinates
3535
from xarray.core.dataarray import DataArray
3636
from xarray.core.dataset import Dataset
37+
from xarray.core.datatree import DataTree
3738
from xarray.core.types import (
3839
CombineAttrsOptions,
3940
CompatOptions,
@@ -793,18 +794,101 @@ def merge_core(
793794
return _MergeResult(variables, coord_names, dims, out_indexes, attrs)
794795

795796

797+
def merge_trees(
798+
trees: Iterable[DataTree],
799+
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
800+
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
801+
fill_value: object = dtypes.NA,
802+
combine_attrs: CombineAttrsOptions = "override",
803+
) -> DataTree:
804+
"""Merge specialized to DataTree objects."""
805+
from xarray.core.dataset import Dataset
806+
from xarray.core.datatree import DataTree
807+
from xarray.core.datatree_mapping import add_path_context_to_errors
808+
809+
if fill_value is not dtypes.NA:
810+
# fill_value support dicts, which probably should be mapped to sub-groups?
811+
raise NotImplementedError(
812+
"fill_value is not yet supported for DataTree objects in merge"
813+
)
814+
815+
node_lists: defaultdict[str, list[DataTree]] = defaultdict(list)
816+
for tree in trees:
817+
for key, node in tree.subtree_with_keys:
818+
node_lists[key].append(node)
819+
820+
root_datasets = [node.dataset for node in node_lists.pop(".")]
821+
with add_path_context_to_errors("."):
822+
root_ds = merge(
823+
root_datasets, compat=compat, join=join, combine_attrs=combine_attrs
824+
)
825+
result = DataTree(dataset=root_ds)
826+
827+
def level(kv):
828+
# all trees with the same path have the same level
829+
_, trees = kv
830+
return trees[0].level
831+
832+
for key, nodes in sorted(node_lists.items(), key=level):
833+
# Merge datasets, including inherited indexes to ensure alignment.
834+
datasets = [node.dataset for node in nodes]
835+
with add_path_context_to_errors(key):
836+
merge_result = merge_core(
837+
datasets,
838+
compat=compat,
839+
join=join,
840+
combine_attrs=combine_attrs,
841+
)
842+
# Remove inherited coordinates/indexes/dimensions.
843+
for var_name in list(merge_result.coord_names):
844+
if not any(var_name in node._coord_variables for node in nodes):
845+
del merge_result.variables[var_name]
846+
merge_result.coord_names.remove(var_name)
847+
for index_name in list(merge_result.indexes):
848+
if not any(index_name in node._node_indexes for node in nodes):
849+
del merge_result.indexes[index_name]
850+
for dim in list(merge_result.dims):
851+
if not any(dim in node._node_dims for node in nodes):
852+
del merge_result.dims[dim]
853+
854+
merged_ds = Dataset._construct_direct(**merge_result._asdict())
855+
result[key] = DataTree(dataset=merged_ds)
856+
857+
return result
858+
859+
860+
@overload
861+
def merge(
862+
objects: Iterable[DataTree],
863+
compat: CompatOptions | CombineKwargDefault = ...,
864+
join: JoinOptions | CombineKwargDefault = ...,
865+
fill_value: object = ...,
866+
combine_attrs: CombineAttrsOptions = ...,
867+
) -> DataTree: ...
868+
869+
870+
@overload
871+
def merge(
872+
objects: Iterable[DataArray | Dataset | Coordinates | dict],
873+
compat: CompatOptions | CombineKwargDefault = ...,
874+
join: JoinOptions | CombineKwargDefault = ...,
875+
fill_value: object = ...,
876+
combine_attrs: CombineAttrsOptions = ...,
877+
) -> Dataset: ...
878+
879+
796880
def merge(
797-
objects: Iterable[DataArray | CoercibleMapping],
881+
objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict],
798882
compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
799883
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
800884
fill_value: object = dtypes.NA,
801885
combine_attrs: CombineAttrsOptions = "override",
802-
) -> Dataset:
886+
) -> DataTree | Dataset:
803887
"""Merge any number of xarray objects into a single Dataset as variables.
804888
805889
Parameters
806890
----------
807-
objects : iterable of Dataset or iterable of DataArray or iterable of dict-like
891+
objects : iterable of DataArray, Dataset, DataTree or dict
808892
Merge together all variables from these objects. If any of them are
809893
DataArray objects, they must have a name.
810894
compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \
@@ -859,8 +943,9 @@ def merge(
859943
860944
Returns
861945
-------
862-
Dataset
863-
Dataset with combined variables from each object.
946+
Dataset or DataTree
947+
Objects with combined variables from the inputs. If any inputs are a
948+
DataTree, this will also be a DataTree. Otherwise it will be a Dataset.
864949
865950
Examples
866951
--------
@@ -1023,13 +1108,31 @@ def merge(
10231108
from xarray.core.coordinates import Coordinates
10241109
from xarray.core.dataarray import DataArray
10251110
from xarray.core.dataset import Dataset
1111+
from xarray.core.datatree import DataTree
1112+
1113+
objects = list(objects)
1114+
1115+
if any(isinstance(obj, DataTree) for obj in objects):
1116+
if not all(isinstance(obj, DataTree) for obj in objects):
1117+
raise TypeError(
1118+
"merge does not support mixed type arguments when one argument "
1119+
f"is a DataTree: {objects}"
1120+
)
1121+
trees = cast(list[DataTree], objects)
1122+
return merge_trees(
1123+
trees,
1124+
compat=compat,
1125+
join=join,
1126+
combine_attrs=combine_attrs,
1127+
fill_value=fill_value,
1128+
)
10261129

10271130
dict_like_objects = []
10281131
for obj in objects:
10291132
if not isinstance(obj, DataArray | Dataset | Coordinates | dict):
10301133
raise TypeError(
1031-
"objects must be an iterable containing only "
1032-
"Dataset(s), DataArray(s), and dictionaries."
1134+
"objects must be an iterable containing only DataTree(s), "
1135+
f"Dataset(s), DataArray(s), and dictionaries: {objects}"
10331136
)
10341137

10351138
if isinstance(obj, DataArray):

xarray/tests/test_datatree.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,13 +2219,17 @@ def test_sel_isel_error_has_node_info(self) -> None:
22192219

22202220
with pytest.raises(
22212221
KeyError,
2222-
match="Raised whilst mapping function over node with path 'second'",
2222+
match=re.escape(
2223+
"Raised whilst mapping function over node(s) with path 'second'"
2224+
),
22232225
):
22242226
tree.sel(x=1)
22252227

22262228
with pytest.raises(
22272229
IndexError,
2228-
match="Raised whilst mapping function over node with path 'first'",
2230+
match=re.escape(
2231+
"Raised whilst mapping function over node(s) with path 'first'"
2232+
),
22292233
):
22302234
tree.isel(x=4)
22312235

xarray/tests/test_datatree_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def fail_on_specific_node(ds):
192192
with pytest.raises(
193193
ValueError,
194194
match=re.escape(
195-
r"Raised whilst mapping function over node with path 'set1'"
195+
r"Raised whilst mapping function over node(s) with path 'set1'"
196196
),
197197
):
198198
dt.map_over_datasets(fail_on_specific_node)

xarray/tests/test_merge.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
import warnings
45

56
import numpy as np
@@ -867,3 +868,95 @@ def test_merge_auto_align(self):
867868
with set_options(use_new_combine_kwarg_defaults=True):
868869
with pytest.raises(ValueError, match="might be related to new default"):
869870
expected.identical(ds2.merge(ds1))
871+
872+
873+
class TestMergeDataTree:
874+
def test_mixed(self) -> None:
875+
tree = xr.DataTree()
876+
ds = xr.Dataset()
877+
with pytest.raises(
878+
TypeError,
879+
match="merge does not support mixed type arguments when one argument is a DataTree",
880+
):
881+
xr.merge([tree, ds]) # type: ignore[list-item]
882+
883+
def test_distinct(self) -> None:
884+
tree1 = xr.DataTree.from_dict({"/a/b/c": 1})
885+
tree2 = xr.DataTree.from_dict({"/a/d/e": 2})
886+
expected = xr.DataTree.from_dict({"/a/b/c": 1, "/a/d/e": 2})
887+
merged = xr.merge([tree1, tree2])
888+
assert_equal(merged, expected)
889+
890+
def test_overlap(self) -> None:
891+
tree1 = xr.DataTree.from_dict({"/a/b": 1})
892+
tree2 = xr.DataTree.from_dict({"/a/c": 2})
893+
tree3 = xr.DataTree.from_dict({"/a/d": 3})
894+
expected = xr.DataTree.from_dict({"/a/b": 1, "/a/c": 2, "/a/d": 3})
895+
merged = xr.merge([tree1, tree2, tree3])
896+
assert_equal(merged, expected)
897+
898+
def test_inherited(self) -> None:
899+
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [1])}, coords={"x": [0]})
900+
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [2])})
901+
expected = xr.DataTree.from_dict(
902+
{"/a/b": ("x", [1]), "a/c": ("x", [2])}, coords={"x": [0]}
903+
)
904+
merged = xr.merge([tree1, tree2])
905+
assert_equal(merged, expected)
906+
907+
def test_inherited_join(self) -> None:
908+
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])}, coords={"x": [0, 1]})
909+
tree2 = xr.DataTree.from_dict({"/a/c": ("x", [1, 2])}, coords={"x": [1, 2]})
910+
911+
expected = xr.DataTree.from_dict(
912+
{"/a/b": ("x", [0, 1]), "a/c": ("x", [np.nan, 1])}, coords={"x": [0, 1]}
913+
)
914+
merged = xr.merge([tree1, tree2], join="left")
915+
assert_equal(merged, expected)
916+
917+
expected = xr.DataTree.from_dict(
918+
{"/a/b": ("x", [1, np.nan]), "a/c": ("x", [1, 2])}, coords={"x": [1, 2]}
919+
)
920+
merged = xr.merge([tree1, tree2], join="right")
921+
assert_equal(merged, expected)
922+
923+
expected = xr.DataTree.from_dict(
924+
{"/a/b": ("x", [1]), "a/c": ("x", [1])}, coords={"x": [1]}
925+
)
926+
merged = xr.merge([tree1, tree2], join="inner")
927+
assert_equal(merged, expected)
928+
929+
expected = xr.DataTree.from_dict(
930+
{"/a/b": ("x", [0, 1, np.nan]), "a/c": ("x", [np.nan, 1, 2])},
931+
coords={"x": [0, 1, 2]},
932+
)
933+
merged = xr.merge([tree1, tree2], join="outer")
934+
assert_equal(merged, expected)
935+
936+
with pytest.raises(
937+
xr.AlignmentError,
938+
match=re.escape("cannot align objects with join='exact'"),
939+
):
940+
xr.merge([tree1, tree2], join="exact")
941+
942+
def test_merge_error_includes_path(self) -> None:
943+
tree1 = xr.DataTree.from_dict({"/a/b": ("x", [0, 1])})
944+
tree2 = xr.DataTree.from_dict({"/a/b": ("x", [1, 2])})
945+
with pytest.raises(
946+
xr.MergeError,
947+
match=re.escape(
948+
"Raised whilst mapping function over node(s) with path 'a'"
949+
),
950+
):
951+
xr.merge([tree1, tree2], join="exact")
952+
953+
def test_fill_value_errors(self) -> None:
954+
trees = [xr.DataTree(), xr.DataTree()]
955+
956+
with pytest.raises(
957+
NotImplementedError,
958+
match=re.escape(
959+
"fill_value is not yet supported for DataTree objects in merge"
960+
),
961+
):
962+
xr.merge(trees, fill_value=None)

0 commit comments

Comments
 (0)