Skip to content

Commit

Permalink
modified tests and added observed_rug check
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Aug 13, 2024
1 parent cfafe71 commit 60aca81
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
9 changes: 9 additions & 0 deletions src/arviz_plots/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ def plot_ppc(
plot element"""
)

observed_rug_kwargs = copy(plot_kwargs.get("observed_rug", {}))
if observed_rug_kwargs is False:
raise ValueError(
"""plot_kwargs['observed_rug'] can't be False, use observed_rug=False to remove
observed_rug plot element"""
)
if observed is False and observed_rug is True:
raise ValueError("""observed_rug=True is only valid when observed=True""")

aggregate_kwargs = copy(plot_kwargs.get("aggregate", {}))
if aggregate_kwargs is False:
raise ValueError(
Expand Down
33 changes: 19 additions & 14 deletions tests/test_hypothesis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def datatree(seed=31):
"posterior_predictive": {"obs": posterior_predictive},
},
dims={
"theta": ["chain", "draw", "hierarchy", "group"],
"tau": ["chain", "draw", "hierarchy"],
"obs": ["chain", "draw", "hierarchy", "group"],
"theta": ["hierarchy", "group"],
"tau": ["hierarchy"],
"obs": ["hierarchy", "group"],
},
)
dt["point_estimate"] = dt.posterior.mean(("chain", "draw"))
Expand Down Expand Up @@ -223,7 +223,7 @@ def draw_num_pp_samples(draw, group, sample_dims):

num_pp_samples = draw(st.integers(min_value=1, max_value=total_num_samples))
# print(f"\nnum_pp_samples = {num_pp_samples}")
return num_pp_samples
return [num_pp_samples, total_num_samples]


@given(
Expand All @@ -232,9 +232,9 @@ def draw_num_pp_samples(draw, group, sample_dims):
optional={
"kind": plot_kwargs_value,
"predictive": plot_kwargs_value,
"observed": plot_kwargs_value,
"aggregate": plot_kwargs_value,
"observed_rug": plot_kwargs_value,
"observed": st.sampled_from(({}, {"color": "red"})),
"aggregate": st.sampled_from(({}, {"color": "red"})),
"observed_rug": st.sampled_from(({}, {"color": "red"})),
"title": plot_kwargs_value,
"remove_axis": st.just(False),
},
Expand All @@ -246,7 +246,7 @@ def draw_num_pp_samples(draw, group, sample_dims):
aggregate=ppc_aggregate,
facet_dims=ppc_facet_dims,
sample_dims=ppc_sample_dims,
num_pp_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims),
drawed_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims),
)
def test_plot_ppc(
kind,
Expand All @@ -256,16 +256,12 @@ def test_plot_ppc(
aggregate,
facet_dims,
sample_dims,
num_pp_samples,
drawed_samples,
plot_kwargs,
):
kind_kwargs = plot_kwargs.pop("kind", None)
if kind_kwargs is not None:
plot_kwargs[kind] = kind_kwargs
if plot_kwargs.get("observed", False) is False:
plot_kwargs["observed"] = True # cannot be False
if plot_kwargs.get("aggregate", False) is False:
plot_kwargs["aggregate"] = True # cannot be False
pc = plot_ppc(
datatree,
backend="none",
Expand All @@ -276,10 +272,19 @@ def test_plot_ppc(
aggregate=aggregate,
facet_dims=facet_dims,
sample_dims=sample_dims,
num_pp_samples=num_pp_samples,
num_pp_samples=drawed_samples[0],
plot_kwargs=plot_kwargs,
)
assert all("plot" in child for child in pc.viz.children.values())
num_pp_samples = drawed_samples[0]
total_num_samples = drawed_samples[1]
if num_pp_samples == total_num_samples:
assert sample_dims in pc.viz["obs"].dims
else:
if len(sample_dims) > 1:
assert "ppc_dim" in pc.viz["obs"].dims
else:
assert sample_dims in pc.viz["obs"].dims
for key, value in plot_kwargs.items():
if value is False:
assert all(key not in child for child in pc.viz.children.values())
Expand Down
3 changes: 2 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,13 @@ def test_plot_ppc(self, datatree, kind, backend):
assert "ecdf" in pc.viz["obs"]
assert "overlay" in pc.aes["obs"].data_vars

# no stacking of sample_dims into ppc_dim
@pytest.mark.parametrize("kind", ("kde", "cumulative"))
def test_plot_ppc_sample(self, datatree_sample, kind, backend):
pc = plot_ppc(datatree_sample, kind=kind, sample_dims="sample", backend=backend)
assert "chart" in pc.viz.data_vars
assert "obs" in pc.viz
# assert "ppc_dim" in pc.viz["obs"].dims
assert "sample" in pc.viz["obs"].dims
if kind == "kde":
assert "kde" in pc.viz["obs"]
elif kind == "cumulative":
Expand Down

0 comments on commit 60aca81

Please sign in to comment.