diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 3ceabcb4..18dc6455 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -7,8 +7,8 @@ import xesmf from xmip.postprocessing import ( + _construct_and_promote_member_id, _parse_metric, - _promote_member_id, combine_datasets, concat_experiments, concat_members, @@ -249,7 +249,7 @@ def _assert_parsed_ds_dict(ddict_parsed, expected, match_keys, strict=True): _assert_parsed_ds_dict(ds_dict_parsed, ds_metric[metricname], ["a"]) -def test_match_metrics_missing_attr(): +def test_match_metrics_missing_non_match_attr(): """This test ensures that as long as the provided `match_metrics` are given they will be matched. This is relevant if e.g. the variant label has been removed due to merging""" @@ -559,8 +559,11 @@ def test_concat_members(concat_kwargs): # Group together the expected 'matches' # promote the member_id like in concat_members expected = { - "a.a.a.a": [_promote_member_id(ds_a_temp), _promote_member_id(ds_b_temp)], - "c.a.a.a": [_promote_member_id(ds_c_other)], + "a.a.a.a": [ + _construct_and_promote_member_id(ds_a_temp), + _construct_and_promote_member_id(ds_b_temp), + ], + "c.a.a.a": [_construct_and_promote_member_id(ds_c_other)], } result = concat_members( @@ -579,6 +582,109 @@ def test_concat_members(concat_kwargs): assert member in result["a.a.a.a"].member_id +def test_concat_members_existing_member_dim(): + attrs_a = { + "source_id": "a", + "grid_label": "a", + "experiment_id": "a", + "table_id": "a", + "variant_label": "a", + "version": "a", + } + + attrs_b = {k: v for k, v in attrs_a.items()} + attrs_b["variant_label"] = "b" + + # Create some datasets with a/b attrs + ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"}) + ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"}) + + ds_a_promoted = ds_a.expand_dims({"member_id": [ds_a.attrs["variant_label"]]}) + ds_b_promoted = ds_b.expand_dims({"member_id": [ds_b.attrs["variant_label"]]}) + + # testing mixed case + ds_dict = {"some": ds_a_promoted, "thing": ds_b} + + # promote the member_id like in concat_members + expected = xr.concat([ds_a_promoted, ds_b_promoted], "member_id") + + result = concat_members( + ds_dict, + ) + + xr.testing.assert_equal( + result["a.a.a.a"], + expected, + ) + + +def test_concat_members_existing_member_dim_different_warning(): + attrs_a = { + "source_id": "a", + "grid_label": "a", + "experiment_id": "a", + "table_id": "a", + "variant_label": "a", + "version": "a", + } + + attrs_b = {k: v for k, v in attrs_a.items()} + attrs_b["variant_label"] = "b" + + # Create some datasets with a/b attrs + ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"}) + ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"}) + + ds_a_promoted_wrong = ds_a.expand_dims({"member_id": ["something"]}) + + # testing mixed case + ds_dict = {"some": ds_a_promoted_wrong, "thing": ds_b} + msg = "but this is different from the reconstructed value" + # TODO: Had trouble here when putting in the actual values I expected. + # Probably some regex shit. This should be enough for now + with pytest.warns(UserWarning, match=msg): + concat_members( + ds_dict, + ) + + +def test_concat_members_reconstruct_from_sub_experiment_id(): + attrs_a = { + "source_id": "a", + "grid_label": "a", + "experiment_id": "a", + "table_id": "a", + "variant_label": "a", + "version": "a", + } + + attrs_b = {k: v for k, v in attrs_a.items()} + attrs_b["variant_label"] = "b" + attrs_b["sub_experiment_id"] = "sub_something" + + # Create some datasets with a/b attrs + ds_a = random_ds(attrs=attrs_a).rename({"data": "temp"}) + ds_b = random_ds(attrs=attrs_b).rename({"data": "temp"}) + + ds_a_promoted = ds_a.expand_dims({"member_id": ["a"]}) + ds_b_promoted = ds_b.expand_dims({"member_id": ["sub_something-b"]}) + + # testing mixed case + ds_dict = {"some": ds_a, "thing": ds_b} + + # promote the member_id like in concat_members + expected = xr.concat([ds_a_promoted, ds_b_promoted], "member_id") + + result = concat_members( + ds_dict, + ) + + xr.testing.assert_equal( + result["a.a.a.a"], + expected, + ) + + @pytest.mark.parametrize("concat_kwargs", [{}, {"compat": "override"}]) def test_concat_experiments(concat_kwargs): concat_kwargs = {} @@ -905,7 +1011,7 @@ def test_nested_operations(): ) ddict = {"ds1": ds_1, "ds2": ds_2, "ds3": ds_3, "ds4": ds_4} - ddict = {k: _promote_member_id(ds) for k, ds in ddict.items()} + ddict = {k: _construct_and_promote_member_id(ds) for k, ds in ddict.items()} ds_expected = xr.Dataset( { diff --git a/xmip/postprocessing.py b/xmip/postprocessing.py index 238b9886..43145b2a 100644 --- a/xmip/postprocessing.py +++ b/xmip/postprocessing.py @@ -197,8 +197,26 @@ def merge_variables( ) -def _promote_member_id(ds): - ds = ds.assign_coords(member_id=ds.attrs["variant_label"]) +def _construct_and_promote_member_id(ds): + # construct member_id according to https://docs.google.com/document/d/1h0r8RZr_f3-8egBMMh7aqLwy3snpD6_MrDz1q8n5XUk/edit + sub_experiment_id = ds.attrs.get("sub_experiment_id", "none") + variant_label = ds.attrs[ + "variant_label" + ] # if this should fail in the future, we could build an error/warning in here + if sub_experiment_id == "none": + member_id = variant_label + else: + member_id = f"{sub_experiment_id}-{variant_label}" + + if "member_id" not in ds.dims: + ds = ds.expand_dims({"member_id": [member_id]}) + else: + existing_member_id = ds.member_id.data[0] + if not existing_member_id == member_id: + warnings.warn( + f"{cmip6_dataset_id(ds)} already contained a member_id ({existing_member_id}) but this is different from the reconstructed value ({member_id}). The existing value is not modified, but we recommend checking the input.", + UserWarning, + ) return ds @@ -234,7 +252,7 @@ def concat_members( # promote variant_label attr to coordinate, to have member_ids as coordinates - ds_dict = {k: _promote_member_id(ds) for k, ds in ds_dict.items()} + ds_dict = {k: _construct_and_promote_member_id(ds) for k, ds in ds_dict.items()} return combine_datasets( ds_dict,