Skip to content

Commit

Permalink
More flexible member concatenation (#277)
Browse files Browse the repository at this point in the history
* More flexible metric matching + better error for missing match_attrs

* fix type hints for py3.8

* Fixing formatting (not sure why that wasnt checked before)

* More flexible member concat

* trying to rebase

* remove duplicate tests
  • Loading branch information
jbusecke authored Jan 3, 2023
1 parent aee93bf commit 23db8d1
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 8 deletions.
116 changes: 111 additions & 5 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import xesmf

from xmip.postprocessing import (
_construct_and_promote_member_id,
_parse_metric,
_promote_member_id,
combine_datasets,
concat_experiments,
concat_members,
Expand Down Expand Up @@ -249,7 +249,7 @@ def _assert_parsed_ds_dict(ddict_parsed, expected, match_keys, strict=True):
_assert_parsed_ds_dict(ds_dict_parsed, ds_metric[metricname], ["a"])


def test_match_metrics_missing_attr():
def test_match_metrics_missing_non_match_attr():
"""This test ensures that as long as the provided `match_metrics` are
given they will be matched. This is relevant if e.g. the variant label
has been removed due to merging"""
Expand Down Expand Up @@ -559,8 +559,11 @@ def test_concat_members(concat_kwargs):
# Group together the expected 'matches'
# promote the member_id like in concat_members
expected = {
"a.a.a.a": [_promote_member_id(ds_a_temp), _promote_member_id(ds_b_temp)],
"c.a.a.a": [_promote_member_id(ds_c_other)],
"a.a.a.a": [
_construct_and_promote_member_id(ds_a_temp),
_construct_and_promote_member_id(ds_b_temp),
],
"c.a.a.a": [_construct_and_promote_member_id(ds_c_other)],
}

result = concat_members(
Expand All @@ -579,6 +582,109 @@ def test_concat_members(concat_kwargs):
assert member in result["a.a.a.a"].member_id


def test_concat_members_existing_member_dim():
attrs_a = {
"source_id": "a",
"grid_label": "a",
"experiment_id": "a",
"table_id": "a",
"variant_label": "a",
"version": "a",
}

attrs_b = {k: v for k, v in attrs_a.items()}
attrs_b["variant_label"] = "b"

# Create some datasets with a/b attrs
ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"})
ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"})

ds_a_promoted = ds_a.expand_dims({"member_id": [ds_a.attrs["variant_label"]]})
ds_b_promoted = ds_b.expand_dims({"member_id": [ds_b.attrs["variant_label"]]})

# testing mixed case
ds_dict = {"some": ds_a_promoted, "thing": ds_b}

# promote the member_id like in concat_members
expected = xr.concat([ds_a_promoted, ds_b_promoted], "member_id")

result = concat_members(
ds_dict,
)

xr.testing.assert_equal(
result["a.a.a.a"],
expected,
)


def test_concat_members_existing_member_dim_different_warning():
attrs_a = {
"source_id": "a",
"grid_label": "a",
"experiment_id": "a",
"table_id": "a",
"variant_label": "a",
"version": "a",
}

attrs_b = {k: v for k, v in attrs_a.items()}
attrs_b["variant_label"] = "b"

# Create some datasets with a/b attrs
ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"})
ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"})

ds_a_promoted_wrong = ds_a.expand_dims({"member_id": ["something"]})

# testing mixed case
ds_dict = {"some": ds_a_promoted_wrong, "thing": ds_b}
msg = "but this is different from the reconstructed value"
# TODO: Had trouble here when putting in the actual values I expected.
# Probably some regex shit. This should be enough for now
with pytest.warns(UserWarning, match=msg):
concat_members(
ds_dict,
)


def test_concat_members_reconstruct_from_sub_experiment_id():
attrs_a = {
"source_id": "a",
"grid_label": "a",
"experiment_id": "a",
"table_id": "a",
"variant_label": "a",
"version": "a",
}

attrs_b = {k: v for k, v in attrs_a.items()}
attrs_b["variant_label"] = "b"
attrs_b["sub_experiment_id"] = "sub_something"

# Create some datasets with a/b attrs
ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"})
ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"})

ds_a_promoted = ds_a.expand_dims({"member_id": ["a"]})
ds_b_promoted = ds_b.expand_dims({"member_id": ["sub_something-b"]})

# testing mixed case
ds_dict = {"some": ds_a, "thing": ds_b}

# promote the member_id like in concat_members
expected = xr.concat([ds_a_promoted, ds_b_promoted], "member_id")

result = concat_members(
ds_dict,
)

xr.testing.assert_equal(
result["a.a.a.a"],
expected,
)


@pytest.mark.parametrize("concat_kwargs", [{}, {"compat": "override"}])
def test_concat_experiments(concat_kwargs):
concat_kwargs = {}
Expand Down Expand Up @@ -905,7 +1011,7 @@ def test_nested_operations():
)
ddict = {"ds1": ds_1, "ds2": ds_2, "ds3": ds_3, "ds4": ds_4}

ddict = {k: _promote_member_id(ds) for k, ds in ddict.items()}
ddict = {k: _construct_and_promote_member_id(ds) for k, ds in ddict.items()}

ds_expected = xr.Dataset(
{
Expand Down
24 changes: 21 additions & 3 deletions xmip/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,26 @@ def merge_variables(
)


def _promote_member_id(ds):
ds = ds.assign_coords(member_id=ds.attrs["variant_label"])
def _construct_and_promote_member_id(ds):
# construct member_id according to https://docs.google.com/document/d/1h0r8RZr_f3-8egBMMh7aqLwy3snpD6_MrDz1q8n5XUk/edit
sub_experiment_id = ds.attrs.get("sub_experiment_id", "none")
variant_label = ds.attrs[
"variant_label"
] # if this should fail in the future, we could build an error/warning in here
if sub_experiment_id == "none":
member_id = variant_label
else:
member_id = f"{sub_experiment_id}-{variant_label}"

if "member_id" not in ds.dims:
ds = ds.expand_dims({"member_id": [member_id]})
else:
existing_member_id = ds.member_id.data[0]
if not existing_member_id == member_id:
warnings.warn(
f"{cmip6_dataset_id(ds)} already contained a member_id ({existing_member_id}) but this is different from the reconstructed value ({member_id}). The existing value is not modified, but we recommend checking the input.",
UserWarning,
)
return ds


Expand Down Expand Up @@ -234,7 +252,7 @@ def concat_members(

# promote variant_label attr to coordinate, to have member_ids as coordinates

ds_dict = {k: _promote_member_id(ds) for k, ds in ds_dict.items()}
ds_dict = {k: _construct_and_promote_member_id(ds) for k, ds in ds_dict.items()}

return combine_datasets(
ds_dict,
Expand Down

0 comments on commit 23db8d1

Please sign in to comment.