diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a718ef5e911..71ccd1c12df 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -130,7 +130,7 @@ def infer_template( "Please supply the 'template' kwarg to map_blocks." ) from e - if not isinstance(template, Dataset | DataArray): + if not isinstance(template, (Dataset, DataArray)): raise TypeError( "Function must return an xarray DataArray or Dataset. Instead it returned " f"{type(template)}" @@ -351,7 +351,7 @@ def _wrapper( result = func(*converted_args, **kwargs) merged_coordinates = merge( - [arg.coords for arg in args if isinstance(arg, Dataset | DataArray)], + [arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))], join="exact", compat="override", ).coords diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 2d103994410..f6c2ec43b67 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1211,6 +1211,17 @@ def really_bad_func(darray): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_accepts_da_ds_template(obj): + def identity(x): + return x + + with raise_if_dask_computes(): + result = xr.map_blocks(identity, obj, template=obj) + + assert_identical(result, obj) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks(obj): def func(obj):