diff --git a/src/arviz_plots/plots/ppcplot.py b/src/arviz_plots/plots/ppcplot.py index 4ad7ca2..b8ca140 100644 --- a/src/arviz_plots/plots/ppcplot.py +++ b/src/arviz_plots/plots/ppcplot.py @@ -186,6 +186,15 @@ def plot_ppc( plot element""" ) + observed_rug_kwargs = copy(plot_kwargs.get("observed_rug", {})) + if observed_rug_kwargs is False: + raise ValueError( + """plot_kwargs['observed_rug'] can't be False, use observed_rug=False to remove + observed_rug plot element""" + ) + if observed is False and observed_rug is True: + raise ValueError("""observed_rug=True is only valid when observed=True""") + aggregate_kwargs = copy(plot_kwargs.get("aggregate", {})) if aggregate_kwargs is False: raise ValueError( diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index 6542e9c..2b128a1 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -33,9 +33,9 @@ def datatree(seed=31): "posterior_predictive": {"obs": posterior_predictive}, }, dims={ - "theta": ["chain", "draw", "hierarchy", "group"], - "tau": ["chain", "draw", "hierarchy"], - "obs": ["chain", "draw", "hierarchy", "group"], + "theta": ["hierarchy", "group"], + "tau": ["hierarchy"], + "obs": ["hierarchy", "group"], }, ) dt["point_estimate"] = dt.posterior.mean(("chain", "draw")) @@ -223,7 +223,7 @@ def draw_num_pp_samples(draw, group, sample_dims): num_pp_samples = draw(st.integers(min_value=1, max_value=total_num_samples)) # print(f"\nnum_pp_samples = {num_pp_samples}") - return num_pp_samples + return [num_pp_samples, total_num_samples] @given( @@ -232,9 +232,9 @@ def draw_num_pp_samples(draw, group, sample_dims): optional={ "kind": plot_kwargs_value, "predictive": plot_kwargs_value, - "observed": plot_kwargs_value, - "aggregate": plot_kwargs_value, - "observed_rug": plot_kwargs_value, + "observed": st.sampled_from(({}, {"color": "red"})), + "aggregate": st.sampled_from(({}, {"color": "red"})), + "observed_rug": st.sampled_from(({}, {"color": "red"})), "title": plot_kwargs_value, "remove_axis": st.just(False), }, @@ -246,7 +246,7 @@ def draw_num_pp_samples(draw, group, sample_dims): aggregate=ppc_aggregate, facet_dims=ppc_facet_dims, sample_dims=ppc_sample_dims, - num_pp_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), + drawed_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), ) def test_plot_ppc( kind, @@ -256,16 +256,12 @@ def test_plot_ppc( aggregate, facet_dims, sample_dims, - num_pp_samples, + drawed_samples, plot_kwargs, ): kind_kwargs = plot_kwargs.pop("kind", None) if kind_kwargs is not None: plot_kwargs[kind] = kind_kwargs - if plot_kwargs.get("observed", False) is False: - plot_kwargs["observed"] = True # cannot be False - if plot_kwargs.get("aggregate", False) is False: - plot_kwargs["aggregate"] = True # cannot be False pc = plot_ppc( datatree, backend="none", @@ -276,10 +272,19 @@ def test_plot_ppc( aggregate=aggregate, facet_dims=facet_dims, sample_dims=sample_dims, - num_pp_samples=num_pp_samples, + num_pp_samples=drawed_samples[0], plot_kwargs=plot_kwargs, ) assert all("plot" in child for child in pc.viz.children.values()) + num_pp_samples = drawed_samples[0] + total_num_samples = drawed_samples[1] + if num_pp_samples == total_num_samples: + assert sample_dims in pc.viz["obs"].dims + else: + if len(sample_dims) > 1: + assert "ppc_dim" in pc.viz["obs"].dims + else: + assert sample_dims in pc.viz["obs"].dims for key, value in plot_kwargs.items(): if value is False: assert all(key not in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index cd45aec..26e2c18 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -285,12 +285,13 @@ def test_plot_ppc(self, datatree, kind, backend): assert "ecdf" in pc.viz["obs"] assert "overlay" in pc.aes["obs"].data_vars + # no stacking of sample_dims into ppc_dim @pytest.mark.parametrize("kind", ("kde", "cumulative")) def test_plot_ppc_sample(self, datatree_sample, kind, backend): pc = plot_ppc(datatree_sample, kind=kind, sample_dims="sample", backend=backend) assert "chart" in pc.viz.data_vars assert "obs" in pc.viz - # assert "ppc_dim" in pc.viz["obs"].dims + assert "sample" in pc.viz["obs"].dims if kind == "kde": assert "kde" in pc.viz["obs"] elif kind == "cumulative":