Skip to content

Commit

Permalink
updated plot_ppc and modified tests
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Aug 13, 2024
1 parent f7737e4 commit e8a22ae
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 97 deletions.
117 changes: 44 additions & 73 deletions src/arviz_plots/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def plot_ppc(
stats_kwargs=None,
pc_kwargs=None,
):
"""Plot prior/posterior predictive and observed values as kde, cumulative or scatter plots.
"""Plot prior/posterior predictive and observed values as kde, ecdf, hist, or scatter plots.
Parameters
----------
Expand Down Expand Up @@ -69,25 +69,24 @@ def plot_ppc(
Dimensions to loop over in plotting posterior/prior predictive.
Note: Dims not in sample_dims or facet_dims (below) will be reduced by default.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "cumulative", "scatter"}, optional
kind : {"kde", "ecdf", "hist, "scatter"}, optional
How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]``.
facet_dims : list, optional
Dimensions to facet over (for which multiple plots will be generated).
Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define
dims to facet over with the `cols` key and `pc_kwargs` takes precedence.
Defaults to empty list.
data_pairs : dict, optional
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary structure:
* key = observed data var_name
* value = posterior/prior predictive var_name
Dictionary keys are variable names corresponding to observed data and dictionary values
are variable names corresponding to posterior/prior predictive.
For example, data_pairs = {'y' : 'y_hat'}.
If None, it will assume that the observed data and the posterior/prior predictive data
have the same variable name
aggregate: bool, optional
By default, it will assume that the observed data and the posterior/prior predictive data
have the same variable names
aggregate: bool, default False
If True, predictive data will be aggregated over both sample_dims and reduce_dims.
Defaults to False.
num_pp_samples : int, optional
Number of prior/posterior predictive samples to plot.
Defaults to the total sample size (product of sample_dim dimension lengths) or minimum
between total sample size and 5 in case of kind='scatter'
random_seed : int, optional
Random number generator seed passed to numpy.random.seed to allow reproducibility of the
plot.
Expand All @@ -108,15 +107,16 @@ def plot_ppc(
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* "predictive" -> Passed to either of "kde", "cumulative", "scatter" based on `kind`
* "observed" -> passed to either of "kde", "cumulative", "scatter" based on `kind`
* "aggregate" -> passed to either of "kde", "cumulative", "scatter" based on `kind`
* "predictive" -> Passed to either of "kde", "ecdf", "hist", "scatter" based on `kind`
* "observed" -> passed to either of "kde", "ecdf", "hist", "scatter" based on `kind`
* "aggregate" -> passed to either of "kde", "ecdf", "hist", "scatter" based on `kind`
Values of the above plot_kwargs keys are passed to one of "kde", "cumulative", "scatter",
Values of the above plot_kwargs keys are passed to one of "kde", "ecdf", "hist", "scatter",
matching the `kind` argument.
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "cumulative" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func:`~arviz_plots.visuals.hist_line`
* "scatter" -> passed to :func: `~arviz_plots.visuals.scatter_x`
* "observed_rug" -> passed to :func:`arviz_plots.visuals.trace_rug`
Expand All @@ -126,9 +126,9 @@ def plot_ppc(
stats_kwargs : mapping, optional
Valid keys are:
* predictive -> passed to kde, cumulative, ...
* aggregate -> passed to kde, cumulative, ...
* observed -> passed to kde, cumulative, ...
* predictive -> passed to kde, ecdf, ...
* aggregate -> passed to kde, ecdf, ...
* observed -> passed to kde, ecdf, ...
pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.wrap`
Expand Down Expand Up @@ -175,50 +175,22 @@ def plot_ppc(
raise TypeError("`group` argument must be either `posterior` or `prior`")

predictive_data_group = f"{group}_predictive"
default_observed = ""
if observed is None:
default_observed = f"(set as default for chosen group {group})"
observed = group == "posterior" # by default true if posterior, false if prior

# checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with
# top level bool arguments `observed`` and `aggregate`

# observed args logic check:
# observed will be True/False depending on user-input/group- true for posterior, false for prior
# if observed and no plot_kwargs['observed'], no prob
# if observed = True and plot_kwargs['observed'] = True, no prob (observed is plotted)
# if observed = True and plot_kwargs['observed'] = False, prob (error/warning raised)
# if observed = False and plot_kwargs['observed'] = True, prob (error/warning raised)
# if observed = False and plot_kwargs['observed'] = False, no prob (observed is not plotted)
if (
observed
and plot_kwargs.get("observed", True) is False
or not observed
and plot_kwargs.get("observed", False) is not False
):
# raise warning or error
# checking to make sure plot_kwargs["observed"] or plot_kwargs["aggregate"] are not False
observed_kwargs = copy(plot_kwargs.get("observed", {}))
if observed_kwargs is False:
raise ValueError(
f"""
`observed` and `plot_kwargs["observed"]` inconsistency detected.
`observed` = {observed}{default_observed}
`plot_kwargs["observed"]` = {plot_kwargs["observed"]}
Please make sure `observed` and `plot_kwargs["observed"]` have the same value."""
"""plot_kwargs['observed'] can't be False, use observed=False to remove observed
plot element"""
)

# same check for aggregate:
if (
aggregate
and plot_kwargs.get("aggregate", True) is False
or not aggregate
and plot_kwargs.get("aggregate", False) is not False
):
# raise warning or error
aggregate_kwargs = copy(plot_kwargs.get("aggregate", {}))
if aggregate_kwargs is False:
raise ValueError(
f"""
`aggregate` and `plot_kwargs["aggregate"]` inconsistency detected.
`aggregate` = {aggregate}
`plot_kwargs["aggregate"]` = {plot_kwargs["aggregate"]}
Please make sure `aggregate` and `plot_kwargs["aggregate"]` have the same value."""
"""plot_kwargs['aggregate'] can't be False, use aggregate=False to remove
aggregate plot element"""
)

# making sure both posterior/prior predictive group and observed_data group exists in
Expand Down Expand Up @@ -269,7 +241,6 @@ def plot_ppc(
rng = np.random.default_rng()

# picking random pp dataset sample indexes and subsetting pp_distribution accordingly
# total_pp_samples = plot_collection.data.sizes["chain"] * plot_collection.data.sizes["draw"]
total_pp_samples = np.prod(
[pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims]
)
Expand All @@ -286,18 +257,19 @@ def plot_ppc(
):
raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.")

pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)

# print(f"\npp_sample_ix: {pp_sample_ix!r}")
# stacking sample dimensions and selecting randomly if sample_dims length>1 or
# num_pp_samples=total_pp_samples
if num_pp_samples != total_pp_samples or len(sample_dims) == 1:
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)

# stacking sample dims into a new 'ppc_dim' dimension
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims)
# stacking sample dims into a new 'ppc_dim' dimension
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims)

# Select the desired samples
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix)
# Select the desired samples
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix)

# renaming sample_dims so that rest of plot will consider this as sample_dims
sample_dims = ["ppc_dim"]
# renaming sample_dims so that rest of plot will consider this as sample_dims
sample_dims = ["ppc_dim"]

# wrap plot collection with pp distribution
if plot_collection is None:
Expand Down Expand Up @@ -379,11 +351,11 @@ def plot_ppc(
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: value for key, value in aes_map.items() if key == "predictive"
kind: aes_map.get("predictive", {})
}, # aes_map[kind] is set to "predictive" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "predictive" plot_kwargs
stats_kwargs={
"density": value for key, value in stats_kwargs.items() if key == "predictive"
"density": stats_kwargs.get("predictive", {})
}, # stats_kwargs["density"] is set to "predictive" stats_kwargs
# via plot_dist
)
Expand Down Expand Up @@ -424,11 +396,11 @@ def plot_ppc(
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: value for key, value in aes_map.items() if key == "aggregate"
kind: aes_map.get("aggregate", {})
}, # aes_map[kind] is set to "aggregate" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "aggregate" plot_kwargs
stats_kwargs={
"density": value for key, value in stats_kwargs.items() if key == "aggregate"
"density": stats_kwargs.get("aggregate", {})
}, # stats_kwargs["density"] is set to "aggregate" stats_kwargs
)

Expand Down Expand Up @@ -465,11 +437,11 @@ def plot_ppc(
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: value for key, value in aes_map.items() if key == "observed"
kind: aes_map.get("observed", {})
}, # aes_map[kind] is set to "observed" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "observed" plot_kwargs
stats_kwargs={
"density": value for key, value in stats_kwargs.items() if key == "observed"
"density": stats_kwargs.get("observed", {})
}, # stats_kwargs["density"] is set to "observed" stats_kwargs
)

Expand All @@ -496,7 +468,6 @@ def plot_ppc(
data=obs_distribution,
ignore_aes=rug_ignore,
xname=False,
flatten=True,
y=0,
**rug_kwargs,
)
Expand Down
9 changes: 2 additions & 7 deletions src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,19 @@ def line(da, target, backend, xname=None, **kwargs):
return plot_backend.line(xvalues, yvalues, target, **kwargs)


def trace_rug(da, target, backend, mask=None, flatten=False, xname=None, y=None, **kwargs):
def trace_rug(da, target, backend, mask=None, xname=None, y=None, **kwargs):
"""Create a rug plot with the subset of `da` indicated by `mask`."""
xname = xname.item() if hasattr(xname, "item") else xname
if xname is False:
xvalues = da
else:
if xname is None:
if len(da.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {da.sizes}")
xvalues = np.arange(len(da))
else:
xvalues = da[xname]
if y is None:
y = da.min().item()
if flatten is not False:
xvalues = xvalues.values.flatten() # flatten xvalues
if len(xvalues.shape) != 1:
raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}")
xvalues = xvalues.values.flatten() # flatten xvalues by default
if mask is not None:
xvalues = xvalues[mask]
return scatter_x(xvalues, target=target, backend=backend, y=y, **kwargs)
Expand Down
53 changes: 50 additions & 3 deletions tests/test_hypothesis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def datatree(seed=31):
"prior_predictive": {"obs": prior_predictive},
"posterior_predictive": {"obs": posterior_predictive},
},
dims={"theta": ["hierarchy", "group"], "tau": ["hierarchy"], "obs": ["hierarchy", "group"]},
dims={
"theta": ["chain", "draw", "hierarchy", "group"],
"tau": ["chain", "draw", "hierarchy"],
"obs": ["chain", "draw", "hierarchy", "group"],
},
)
dt["point_estimate"] = dt.posterior.mean(("chain", "draw"))
# TODO: should become dt.azstats.eti() after fix in arviz-stats
Expand Down Expand Up @@ -200,9 +204,26 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label):
assert all(key in child for child in pc.viz.children.values())


# plot_ppc tests
ppc_kind_value = st.sampled_from(("kde", "cumulative"))
ppc_group = st.sampled_from(("prior", "posterior"))
ppc_observed = st.booleans()
ppc_aggregate = st.booleans()
ppc_sample_dims = st.sampled_from((["chain"], ["chain", "draw"]))
ppc_facet_dims = st.sampled_from((["group"], ["hierarchy"], None))


@st.composite # composite func to determine num_pp_samples based on draws of group, sample_dims
def draw_num_pp_samples(draw, group, sample_dims):
group = draw(group)
sample_dims = draw(sample_dims)
# print(f"\n sample_dims = {sample_dims}\ngroup = {group}")
chain_dim_length = 1 if group == "prior" else 3
draw_dim_length = 50 if sample_dims == ["chain", "draw"] else 1
total_num_samples = np.prod([chain_dim_length, draw_dim_length])

num_pp_samples = draw(st.integers(min_value=1, max_value=total_num_samples))
# print(f"\nnum_pp_samples = {num_pp_samples}")
return num_pp_samples


@given(
Expand All @@ -220,16 +241,42 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label):
),
kind=ppc_kind_value,
group=ppc_group,
observed=ppc_observed,
observed_rug=ppc_observed,
aggregate=ppc_aggregate,
facet_dims=ppc_facet_dims,
sample_dims=ppc_sample_dims,
num_pp_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims),
)
def test_plot_ppc(datatree, kind, group, plot_kwargs):
def test_plot_ppc(
kind,
group,
observed,
observed_rug,
aggregate,
facet_dims,
sample_dims,
num_pp_samples,
plot_kwargs,
):
kind_kwargs = plot_kwargs.pop("kind", None)
if kind_kwargs is not None:
plot_kwargs[kind] = kind_kwargs
if plot_kwargs.get("observed", False) is False:
plot_kwargs["observed"] = True # cannot be False
if plot_kwargs.get("aggregate", False) is False:
plot_kwargs["aggregate"] = True # cannot be False
pc = plot_ppc(
datatree,
backend="none",
kind=kind,
group=group,
observed=observed,
observed_rug=observed_rug,
aggregate=aggregate,
facet_dims=facet_dims,
sample_dims=sample_dims,
num_pp_samples=num_pp_samples,
plot_kwargs=plot_kwargs,
)
assert all("plot" in child for child in pc.viz.children.values())
Expand Down
Loading

0 comments on commit e8a22ae

Please sign in to comment.