diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a544c7a37b1..177c2de7041 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -100,27 +100,22 @@ def _build_discrete_cmap(cmap, levels, extend, filled): # copy colors to use for bad, under, and over values in case they have been # set to non-default values - try: - # matplotlib<3.2 only uses bad color for masked values - bad = cmap(np.ma.masked_invalid([np.nan]))[0] - except TypeError: - # cmap was a str or list rather than a color-map object, so there are - # no bad, under or over values to check or copy - pass - else: - under = cmap(-np.inf) - over = cmap(np.inf) - - new_cmap.set_bad(bad) + if isinstance(cmap, mpl.colors.Colormap): + bad = cmap(np.nan) # Only update under and over if they were explicitly changed by the user # (i.e. are different from the lowest or highest values in cmap). Otherwise # leave unchanged so new_cmap uses its default values (its own lowest and # highest values). - if under != cmap(0): - new_cmap.set_under(under) - if over != cmap(cmap.N - 1): - new_cmap.set_over(over) + under = cmap(-np.inf) + if under == cmap(0): + under = None + + over = cmap(np.inf) + if over == cmap(cmap.N - 1): + over = None + + new_cmap = new_cmap.with_extremes(bad=bad, under=under, over=over) return new_cmap, cnorm diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 64054bad8ab..40f950d246d 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -482,21 +482,20 @@ def test_contourf_cmap_set(self) -> None: def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) - # make a copy here because we want a local cmap that we will modify. - cmap_expected = copy(mpl.colormaps["viridis"]) + # make a copy using with_extremes because we want a local cmap: + cmap_expected = mpl.colormaps["viridis"].with_extremes( + bad="w", under="r", over="g" + ) - cmap_expected.set_bad("w") # check we actually changed the set_bad color assert np.all( cmap_expected(np.ma.masked_invalid([np.nan]))[0] != mpl.colormaps["viridis"](np.ma.masked_invalid([np.nan]))[0] ) - cmap_expected.set_under("r") # check we actually changed the set_under color assert cmap_expected(-np.inf) != mpl.colormaps["viridis"](-np.inf) - cmap_expected.set_over("g") # check we actually changed the set_over color assert cmap_expected(np.inf) != mpl.colormaps["viridis"](-np.inf)