From 6602eed449de75ddba9c18a698fdd19ccfd5c5c4 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 13 Aug 2024 17:21:46 +0530 Subject: [PATCH] modified tests and added observed_rug check --- src/arviz_plots/plots/ppcplot.py | 9 +++++++++ tests/test_hypothesis_plots.py | 33 ++++++++++++++++++-------------- tests/test_plots.py | 3 ++- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/arviz_plots/plots/ppcplot.py b/src/arviz_plots/plots/ppcplot.py index 4ad7ca20..b8ca140e 100644 --- a/src/arviz_plots/plots/ppcplot.py +++ b/src/arviz_plots/plots/ppcplot.py @@ -186,6 +186,15 @@ def plot_ppc( plot element""" ) + observed_rug_kwargs = copy(plot_kwargs.get("observed_rug", {})) + if observed_rug_kwargs is False: + raise ValueError( + """plot_kwargs['observed_rug'] can't be False, use observed_rug=False to remove + observed_rug plot element""" + ) + if observed is False and observed_rug is True: + raise ValueError("""observed_rug=True is only valid when observed=True""") + aggregate_kwargs = copy(plot_kwargs.get("aggregate", {})) if aggregate_kwargs is False: raise ValueError( diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 1c924a48..0eebaa94 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -31,9 +31,9 @@ def datatree(seed=31): "posterior_predictive": {"obs": posterior_predictive}, }, dims={ - "theta": ["chain", "draw", "hierarchy", "group"], - "tau": ["chain", "draw", "hierarchy"], - "obs": ["chain", "draw", "hierarchy", "group"], + "theta": ["hierarchy", "group"], + "tau": ["hierarchy"], + "obs": ["hierarchy", "group"], }, ) @@ -203,7 +203,7 @@ def draw_num_pp_samples(draw, group, sample_dims): num_pp_samples = draw(st.integers(min_value=1, max_value=total_num_samples)) # print(f"\nnum_pp_samples = {num_pp_samples}") - return num_pp_samples + return [num_pp_samples, total_num_samples] @given( @@ -212,9 +212,9 @@ def draw_num_pp_samples(draw, group, sample_dims): optional={ "kind": plot_kwargs_value, "predictive": plot_kwargs_value, - "observed": plot_kwargs_value, - "aggregate": plot_kwargs_value, - "observed_rug": plot_kwargs_value, + "observed": st.sampled_from(({}, {"color": "red"})), + "aggregate": st.sampled_from(({}, {"color": "red"})), + "observed_rug": st.sampled_from(({}, {"color": "red"})), "title": plot_kwargs_value, "remove_axis": st.just(False), }, @@ -226,7 +226,7 @@ def draw_num_pp_samples(draw, group, sample_dims): aggregate=ppc_aggregate, facet_dims=ppc_facet_dims, sample_dims=ppc_sample_dims, - num_pp_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), + drawed_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), ) def test_plot_ppc( kind, @@ -236,16 +236,12 @@ def test_plot_ppc( aggregate, facet_dims, sample_dims, - num_pp_samples, + drawed_samples, plot_kwargs, ): kind_kwargs = plot_kwargs.pop("kind", None) if kind_kwargs is not None: plot_kwargs[kind] = kind_kwargs - if plot_kwargs.get("observed", False) is False: - plot_kwargs["observed"] = True # cannot be False - if plot_kwargs.get("aggregate", False) is False: - plot_kwargs["aggregate"] = True # cannot be False pc = plot_ppc( datatree, backend="none", @@ -256,10 +252,19 @@ def test_plot_ppc( aggregate=aggregate, facet_dims=facet_dims, sample_dims=sample_dims, - num_pp_samples=num_pp_samples, + num_pp_samples=drawed_samples[0], plot_kwargs=plot_kwargs, ) assert all("plot" in child for child in pc.viz.children.values()) + num_pp_samples = drawed_samples[0] + total_num_samples = drawed_samples[1] + if num_pp_samples == total_num_samples: + assert sample_dims in pc.viz["obs"].dims + else: + if len(sample_dims) > 1: + assert "ppc_dim" in pc.viz["obs"].dims + else: + assert sample_dims in pc.viz["obs"].dims for key, value in plot_kwargs.items(): if value is False: assert all(key not in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index 3a857276..910578da 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -285,12 +285,13 @@ def test_plot_ppc(self, datatree, kind, backend): assert "ecdf" in pc.viz["obs"] assert "overlay" in pc.aes["obs"].data_vars + # no stacking of sample_dims into ppc_dim @pytest.mark.parametrize("kind", ("kde", "cumulative")) def test_plot_ppc_sample(self, datatree_sample, kind, backend): pc = plot_ppc(datatree_sample, kind=kind, sample_dims="sample", backend=backend) assert "chart" in pc.viz.data_vars assert "obs" in pc.viz - # assert "ppc_dim" in pc.viz["obs"].dims + assert "sample" in pc.viz["obs"].dims if kind == "kde": assert "kde" in pc.viz["obs"] elif kind == "cumulative":