|
3 | 3 | from collections import defaultdict
|
4 | 4 | from collections.abc import Hashable, Iterable, Mapping, Sequence
|
5 | 5 | from collections.abc import Set as AbstractSet
|
6 |
| -from typing import TYPE_CHECKING, Any, NamedTuple, Union |
| 6 | +from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload |
7 | 7 |
|
8 | 8 | import pandas as pd
|
9 | 9 |
|
|
34 | 34 | from xarray.core.coordinates import Coordinates
|
35 | 35 | from xarray.core.dataarray import DataArray
|
36 | 36 | from xarray.core.dataset import Dataset
|
| 37 | + from xarray.core.datatree import DataTree |
37 | 38 | from xarray.core.types import (
|
38 | 39 | CombineAttrsOptions,
|
39 | 40 | CompatOptions,
|
@@ -793,18 +794,101 @@ def merge_core(
|
793 | 794 | return _MergeResult(variables, coord_names, dims, out_indexes, attrs)
|
794 | 795 |
|
795 | 796 |
|
| 797 | +def merge_trees( |
| 798 | + trees: Iterable[DataTree], |
| 799 | + compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, |
| 800 | + join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, |
| 801 | + fill_value: object = dtypes.NA, |
| 802 | + combine_attrs: CombineAttrsOptions = "override", |
| 803 | +) -> DataTree: |
| 804 | + """Merge specialized to DataTree objects.""" |
| 805 | + from xarray.core.dataset import Dataset |
| 806 | + from xarray.core.datatree import DataTree |
| 807 | + from xarray.core.datatree_mapping import add_path_context_to_errors |
| 808 | + |
| 809 | + if fill_value is not dtypes.NA: |
| 810 | + # fill_value support dicts, which probably should be mapped to sub-groups? |
| 811 | + raise NotImplementedError( |
| 812 | + "fill_value is not yet supported for DataTree objects in merge" |
| 813 | + ) |
| 814 | + |
| 815 | + node_lists: defaultdict[str, list[DataTree]] = defaultdict(list) |
| 816 | + for tree in trees: |
| 817 | + for key, node in tree.subtree_with_keys: |
| 818 | + node_lists[key].append(node) |
| 819 | + |
| 820 | + root_datasets = [node.dataset for node in node_lists.pop(".")] |
| 821 | + with add_path_context_to_errors("."): |
| 822 | + root_ds = merge( |
| 823 | + root_datasets, compat=compat, join=join, combine_attrs=combine_attrs |
| 824 | + ) |
| 825 | + result = DataTree(dataset=root_ds) |
| 826 | + |
| 827 | + def level(kv): |
| 828 | + # all trees with the same path have the same level |
| 829 | + _, trees = kv |
| 830 | + return trees[0].level |
| 831 | + |
| 832 | + for key, nodes in sorted(node_lists.items(), key=level): |
| 833 | + # Merge datasets, including inherited indexes to ensure alignment. |
| 834 | + datasets = [node.dataset for node in nodes] |
| 835 | + with add_path_context_to_errors(key): |
| 836 | + merge_result = merge_core( |
| 837 | + datasets, |
| 838 | + compat=compat, |
| 839 | + join=join, |
| 840 | + combine_attrs=combine_attrs, |
| 841 | + ) |
| 842 | + # Remove inherited coordinates/indexes/dimensions. |
| 843 | + for var_name in list(merge_result.coord_names): |
| 844 | + if not any(var_name in node._coord_variables for node in nodes): |
| 845 | + del merge_result.variables[var_name] |
| 846 | + merge_result.coord_names.remove(var_name) |
| 847 | + for index_name in list(merge_result.indexes): |
| 848 | + if not any(index_name in node._node_indexes for node in nodes): |
| 849 | + del merge_result.indexes[index_name] |
| 850 | + for dim in list(merge_result.dims): |
| 851 | + if not any(dim in node._node_dims for node in nodes): |
| 852 | + del merge_result.dims[dim] |
| 853 | + |
| 854 | + merged_ds = Dataset._construct_direct(**merge_result._asdict()) |
| 855 | + result[key] = DataTree(dataset=merged_ds) |
| 856 | + |
| 857 | + return result |
| 858 | + |
| 859 | + |
| 860 | +@overload |
| 861 | +def merge( |
| 862 | + objects: Iterable[DataTree], |
| 863 | + compat: CompatOptions | CombineKwargDefault = ..., |
| 864 | + join: JoinOptions | CombineKwargDefault = ..., |
| 865 | + fill_value: object = ..., |
| 866 | + combine_attrs: CombineAttrsOptions = ..., |
| 867 | +) -> DataTree: ... |
| 868 | + |
| 869 | + |
| 870 | +@overload |
| 871 | +def merge( |
| 872 | + objects: Iterable[DataArray | Dataset | Coordinates | dict], |
| 873 | + compat: CompatOptions | CombineKwargDefault = ..., |
| 874 | + join: JoinOptions | CombineKwargDefault = ..., |
| 875 | + fill_value: object = ..., |
| 876 | + combine_attrs: CombineAttrsOptions = ..., |
| 877 | +) -> Dataset: ... |
| 878 | + |
| 879 | + |
796 | 880 | def merge(
|
797 |
| - objects: Iterable[DataArray | CoercibleMapping], |
| 881 | + objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict], |
798 | 882 | compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT,
|
799 | 883 | join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
|
800 | 884 | fill_value: object = dtypes.NA,
|
801 | 885 | combine_attrs: CombineAttrsOptions = "override",
|
802 |
| -) -> Dataset: |
| 886 | +) -> DataTree | Dataset: |
803 | 887 | """Merge any number of xarray objects into a single Dataset as variables.
|
804 | 888 |
|
805 | 889 | Parameters
|
806 | 890 | ----------
|
807 |
| - objects : iterable of Dataset or iterable of DataArray or iterable of dict-like |
| 891 | + objects : iterable of DataArray, Dataset, DataTree or dict |
808 | 892 | Merge together all variables from these objects. If any of them are
|
809 | 893 | DataArray objects, they must have a name.
|
810 | 894 | compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \
|
@@ -859,8 +943,9 @@ def merge(
|
859 | 943 |
|
860 | 944 | Returns
|
861 | 945 | -------
|
862 |
| - Dataset |
863 |
| - Dataset with combined variables from each object. |
| 946 | + Dataset or DataTree |
| 947 | + Objects with combined variables from the inputs. If any inputs are a |
| 948 | + DataTree, this will also be a DataTree. Otherwise it will be a Dataset. |
864 | 949 |
|
865 | 950 | Examples
|
866 | 951 | --------
|
@@ -1023,13 +1108,31 @@ def merge(
|
1023 | 1108 | from xarray.core.coordinates import Coordinates
|
1024 | 1109 | from xarray.core.dataarray import DataArray
|
1025 | 1110 | from xarray.core.dataset import Dataset
|
| 1111 | + from xarray.core.datatree import DataTree |
| 1112 | + |
| 1113 | + objects = list(objects) |
| 1114 | + |
| 1115 | + if any(isinstance(obj, DataTree) for obj in objects): |
| 1116 | + if not all(isinstance(obj, DataTree) for obj in objects): |
| 1117 | + raise TypeError( |
| 1118 | + "merge does not support mixed type arguments when one argument " |
| 1119 | + f"is a DataTree: {objects}" |
| 1120 | + ) |
| 1121 | + trees = cast(list[DataTree], objects) |
| 1122 | + return merge_trees( |
| 1123 | + trees, |
| 1124 | + compat=compat, |
| 1125 | + join=join, |
| 1126 | + combine_attrs=combine_attrs, |
| 1127 | + fill_value=fill_value, |
| 1128 | + ) |
1026 | 1129 |
|
1027 | 1130 | dict_like_objects = []
|
1028 | 1131 | for obj in objects:
|
1029 | 1132 | if not isinstance(obj, DataArray | Dataset | Coordinates | dict):
|
1030 | 1133 | raise TypeError(
|
1031 |
| - "objects must be an iterable containing only " |
1032 |
| - "Dataset(s), DataArray(s), and dictionaries." |
| 1134 | + "objects must be an iterable containing only DataTree(s), " |
| 1135 | + f"Dataset(s), DataArray(s), and dictionaries: {objects}" |
1033 | 1136 | )
|
1034 | 1137 |
|
1035 | 1138 | if isinstance(obj, DataArray):
|
|
0 commit comments