diff --git a/src/arviz_plots/plots/ppcplot.py b/src/arviz_plots/plots/ppcplot.py index 9eac1db..d76080e 100644 --- a/src/arviz_plots/plots/ppcplot.py +++ b/src/arviz_plots/plots/ppcplot.py @@ -40,7 +40,7 @@ def plot_ppc( stats_kwargs=None, pc_kwargs=None, ): - """Plot prior/posterior predictive and observed values as kde, cumulative or scatter plots. + """Plot prior/posterior predictive and observed values as kde, ecdf, hist, or scatter plots. Parameters ---------- @@ -69,25 +69,24 @@ def plot_ppc( Dimensions to loop over in plotting posterior/prior predictive. Note: Dims not in sample_dims or facet_dims (below) will be reduced by default. Defaults to ``rcParams["data.sample_dims"]`` - kind : {"kde", "cumulative", "scatter"}, optional + kind : {"kde", "ecdf", "hist, "scatter"}, optional How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]``. facet_dims : list, optional Dimensions to facet over (for which multiple plots will be generated). - Defaults to empty list. A warning is raised if `pc_kwargs` is also used to define - dims to facet over with the `cols` key and `pc_kwargs` takes precedence. + Defaults to empty list. data_pairs : dict, optional Dictionary containing relations between observed data and posterior/prior predictive data. - Dictionary structure: - * key = observed data var_name - * value = posterior/prior predictive var_name + Dictionary keys are variable names corresponding to observed data and dictionary values + are variable names corresponding to posterior/prior predictive. For example, data_pairs = {'y' : 'y_hat'}. - If None, it will assume that the observed data and the posterior/prior predictive data - have the same variable name - aggregate: bool, optional + By default, it will assume that the observed data and the posterior/prior predictive data + have the same variable names + aggregate: bool, default False If True, predictive data will be aggregated over both sample_dims and reduce_dims. - Defaults to False. num_pp_samples : int, optional Number of prior/posterior predictive samples to plot. + Defaults to the total sample size (product of sample_dim dimension lengths) or minimum + between total sample size and 5 in case of kind='scatter' random_seed : int, optional Random number generator seed passed to numpy.random.seed to allow reproducibility of the plot. @@ -108,15 +107,16 @@ def plot_ppc( plot_kwargs : mapping of {str : mapping or False}, optional Valid keys are: - * "predictive" -> Passed to either of "kde", "cumulative", "scatter" based on `kind` - * "observed" -> passed to either of "kde", "cumulative", "scatter" based on `kind` - * "aggregate" -> passed to either of "kde", "cumulative", "scatter" based on `kind` + * "predictive" -> Passed to either of "kde", "ecdf", "hist", "scatter" based on `kind` + * "observed" -> passed to either of "kde", "ecdf", "hist", "scatter" based on `kind` + * "aggregate" -> passed to either of "kde", "ecdf", "hist", "scatter" based on `kind` - Values of the above plot_kwargs keys are passed to one of "kde", "cumulative", "scatter", + Values of the above plot_kwargs keys are passed to one of "kde", "ecdf", "hist", "scatter", matching the `kind` argument. * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` - * "cumulative" -> passed to :func:`~arviz_plots.visuals.ecdf_line` + * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` + * "hist" -> passed to :func:`~arviz_plots.visuals.hist_line` * "scatter" -> passed to :func: `~arviz_plots.visuals.scatter_x` * "observed_rug" -> passed to :func:`arviz_plots.visuals.trace_rug` @@ -126,9 +126,9 @@ def plot_ppc( stats_kwargs : mapping, optional Valid keys are: - * predictive -> passed to kde, cumulative, ... - * aggregate -> passed to kde, cumulative, ... - * observed -> passed to kde, cumulative, ... + * predictive -> passed to kde, ecdf, ... + * aggregate -> passed to kde, ecdf, ... + * observed -> passed to kde, ecdf, ... pc_kwargs : mapping Passed to :class:`arviz_plots.PlotCollection.wrap` @@ -175,50 +175,22 @@ def plot_ppc( raise TypeError("`group` argument must be either `posterior` or `prior`") predictive_data_group = f"{group}_predictive" - default_observed = "" if observed is None: - default_observed = f"(set as default for chosen group {group})" observed = group == "posterior" # by default true if posterior, false if prior - # checking if plot_kwargs["observed"] or plot_kwargs["aggregate"] is not inconsistent with - # top level bool arguments `observed`` and `aggregate` - - # observed args logic check: - # observed will be True/False depending on user-input/group- true for posterior, false for prior - # if observed and no plot_kwargs['observed'], no prob - # if observed = True and plot_kwargs['observed'] = True, no prob (observed is plotted) - # if observed = True and plot_kwargs['observed'] = False, prob (error/warning raised) - # if observed = False and plot_kwargs['observed'] = True, prob (error/warning raised) - # if observed = False and plot_kwargs['observed'] = False, no prob (observed is not plotted) - if ( - observed - and plot_kwargs.get("observed", True) is False - or not observed - and plot_kwargs.get("observed", False) is not False - ): - # raise warning or error + # checking to make sure plot_kwargs["observed"] or plot_kwargs["aggregate"] are not False + observed_kwargs = copy(plot_kwargs.get("observed", {})) + if observed_kwargs is False: raise ValueError( - f""" - `observed` and `plot_kwargs["observed"]` inconsistency detected. - `observed` = {observed}{default_observed} - `plot_kwargs["observed"]` = {plot_kwargs["observed"]} - Please make sure `observed` and `plot_kwargs["observed"]` have the same value.""" + """plot_kwargs['observed'] can't be False, use observed=False to remove observed + plot element""" ) - # same check for aggregate: - if ( - aggregate - and plot_kwargs.get("aggregate", True) is False - or not aggregate - and plot_kwargs.get("aggregate", False) is not False - ): - # raise warning or error + aggregate_kwargs = copy(plot_kwargs.get("aggregate", {})) + if aggregate_kwargs is False: raise ValueError( - f""" - `aggregate` and `plot_kwargs["aggregate"]` inconsistency detected. - `aggregate` = {aggregate} - `plot_kwargs["aggregate"]` = {plot_kwargs["aggregate"]} - Please make sure `aggregate` and `plot_kwargs["aggregate"]` have the same value.""" + """plot_kwargs['aggregate'] can't be False, use aggregate=False to remove + aggregate plot element""" ) # making sure both posterior/prior predictive group and observed_data group exists in @@ -269,7 +241,6 @@ def plot_ppc( rng = np.random.default_rng() # picking random pp dataset sample indexes and subsetting pp_distribution accordingly - # total_pp_samples = plot_collection.data.sizes["chain"] * plot_collection.data.sizes["draw"] total_pp_samples = np.prod( [pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims] ) @@ -286,18 +257,19 @@ def plot_ppc( ): raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.") - pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) - - # print(f"\npp_sample_ix: {pp_sample_ix!r}") + # stacking sample dimensions and selecting randomly if sample_dims length>1 or + # num_pp_samples=total_pp_samples + if num_pp_samples != total_pp_samples or len(sample_dims) == 1: + pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) - # stacking sample dims into a new 'ppc_dim' dimension - pp_distribution = pp_distribution.stack(ppc_dim=sample_dims) + # stacking sample dims into a new 'ppc_dim' dimension + pp_distribution = pp_distribution.stack(ppc_dim=sample_dims) - # Select the desired samples - pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix) + # Select the desired samples + pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix) - # renaming sample_dims so that rest of plot will consider this as sample_dims - sample_dims = ["ppc_dim"] + # renaming sample_dims so that rest of plot will consider this as sample_dims + sample_dims = ["ppc_dim"] # wrap plot collection with pp distribution if plot_collection is None: @@ -379,11 +351,11 @@ def plot_ppc( plot_collection=plot_collection, labeller=labeller, aes_map={ - kind: value for key, value in aes_map.items() if key == "predictive" + kind: aes_map.get("predictive", {}) }, # aes_map[kind] is set to "predictive" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "predictive" plot_kwargs stats_kwargs={ - "density": value for key, value in stats_kwargs.items() if key == "predictive" + "density": stats_kwargs.get("predictive", {}) }, # stats_kwargs["density"] is set to "predictive" stats_kwargs # via plot_dist ) @@ -424,11 +396,11 @@ def plot_ppc( plot_collection=plot_collection, labeller=labeller, aes_map={ - kind: value for key, value in aes_map.items() if key == "aggregate" + kind: aes_map.get("aggregate", {}) }, # aes_map[kind] is set to "aggregate" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "aggregate" plot_kwargs stats_kwargs={ - "density": value for key, value in stats_kwargs.items() if key == "aggregate" + "density": stats_kwargs.get("aggregate", {}) }, # stats_kwargs["density"] is set to "aggregate" stats_kwargs ) @@ -465,11 +437,11 @@ def plot_ppc( plot_collection=plot_collection, labeller=labeller, aes_map={ - kind: value for key, value in aes_map.items() if key == "observed" + kind: aes_map.get("observed", {}) }, # aes_map[kind] is set to "observed" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "observed" plot_kwargs stats_kwargs={ - "density": value for key, value in stats_kwargs.items() if key == "observed" + "density": stats_kwargs.get("observed", {}) }, # stats_kwargs["density"] is set to "observed" stats_kwargs ) @@ -496,7 +468,6 @@ def plot_ppc( data=obs_distribution, ignore_aes=rug_ignore, xname=False, - flatten=True, y=0, **rug_kwargs, ) diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 37f81d9..e99ab94 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -46,24 +46,19 @@ def line(da, target, backend, xname=None, **kwargs): return plot_backend.line(xvalues, yvalues, target, **kwargs) -def trace_rug(da, target, backend, mask=None, flatten=False, xname=None, y=None, **kwargs): +def trace_rug(da, target, backend, mask=None, xname=None, y=None, **kwargs): """Create a rug plot with the subset of `da` indicated by `mask`.""" xname = xname.item() if hasattr(xname, "item") else xname if xname is False: xvalues = da else: if xname is None: - if len(da.shape) != 1: - raise ValueError(f"Expected unidimensional data but got {da.sizes}") xvalues = np.arange(len(da)) else: xvalues = da[xname] if y is None: y = da.min().item() - if flatten is not False: - xvalues = xvalues.values.flatten() # flatten xvalues - if len(xvalues.shape) != 1: - raise ValueError(f"Expected unidimensional data but got {xvalues.sizes}") + xvalues = xvalues.values.flatten() # flatten xvalues by default if mask is not None: xvalues = xvalues[mask] return scatter_x(xvalues, target=target, backend=backend, y=y, **kwargs) diff --git a/tests/test_hypothesis_plots.py b/tests/test_hypothesis_plots.py index b9bb804..6542e9c 100644 --- a/tests/test_hypothesis_plots.py +++ b/tests/test_hypothesis_plots.py @@ -32,7 +32,11 @@ def datatree(seed=31): "prior_predictive": {"obs": prior_predictive}, "posterior_predictive": {"obs": posterior_predictive}, }, - dims={"theta": ["hierarchy", "group"], "tau": ["hierarchy"], "obs": ["hierarchy", "group"]}, + dims={ + "theta": ["chain", "draw", "hierarchy", "group"], + "tau": ["chain", "draw", "hierarchy"], + "obs": ["chain", "draw", "hierarchy", "group"], + }, ) dt["point_estimate"] = dt.posterior.mean(("chain", "draw")) # TODO: should become dt.azstats.eti() after fix in arviz-stats @@ -200,9 +204,26 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): assert all(key in child for child in pc.viz.children.values()) -# plot_ppc tests ppc_kind_value = st.sampled_from(("kde", "cumulative")) ppc_group = st.sampled_from(("prior", "posterior")) +ppc_observed = st.booleans() +ppc_aggregate = st.booleans() +ppc_sample_dims = st.sampled_from((["chain"], ["chain", "draw"])) +ppc_facet_dims = st.sampled_from((["group"], ["hierarchy"], None)) + + +@st.composite # composite func to determine num_pp_samples based on draws of group, sample_dims +def draw_num_pp_samples(draw, group, sample_dims): + group = draw(group) + sample_dims = draw(sample_dims) + # print(f"\n sample_dims = {sample_dims}\ngroup = {group}") + chain_dim_length = 1 if group == "prior" else 3 + draw_dim_length = 50 if sample_dims == ["chain", "draw"] else 1 + total_num_samples = np.prod([chain_dim_length, draw_dim_length]) + + num_pp_samples = draw(st.integers(min_value=1, max_value=total_num_samples)) + # print(f"\nnum_pp_samples = {num_pp_samples}") + return num_pp_samples @given( @@ -220,16 +241,42 @@ def test_plot_ridge(datatree, combined, plot_kwargs, labels_shade_label): ), kind=ppc_kind_value, group=ppc_group, + observed=ppc_observed, + observed_rug=ppc_observed, + aggregate=ppc_aggregate, + facet_dims=ppc_facet_dims, + sample_dims=ppc_sample_dims, + num_pp_samples=draw_num_pp_samples(ppc_group, ppc_sample_dims), ) -def test_plot_ppc(datatree, kind, group, plot_kwargs): +def test_plot_ppc( + kind, + group, + observed, + observed_rug, + aggregate, + facet_dims, + sample_dims, + num_pp_samples, + plot_kwargs, +): kind_kwargs = plot_kwargs.pop("kind", None) if kind_kwargs is not None: plot_kwargs[kind] = kind_kwargs + if plot_kwargs.get("observed", False) is False: + plot_kwargs["observed"] = True # cannot be False + if plot_kwargs.get("aggregate", False) is False: + plot_kwargs["aggregate"] = True # cannot be False pc = plot_ppc( datatree, backend="none", kind=kind, group=group, + observed=observed, + observed_rug=observed_rug, + aggregate=aggregate, + facet_dims=facet_dims, + sample_dims=sample_dims, + num_pp_samples=num_pp_samples, plot_kwargs=plot_kwargs, ) assert all("plot" in child for child in pc.viz.children.values()) diff --git a/tests/test_plots.py b/tests/test_plots.py index 4db3e8c..cd45aec 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -273,51 +273,48 @@ def test_plot_ridge_aes_labels_shading(self, backend, datatree_4d, pseudo_dim): assert all(0 in child["alpha"] for child in pc.aes.children.values()) assert any(pseudo_dim in child["shade"].dims for child in pc.viz.children.values()) - @pytest.mark.parametrize("group", ("prior", "posterior")) @pytest.mark.parametrize("kind", ("kde", "cumulative")) - def test_plot_ppc(self, datatree, kind, group, backend): - pc = plot_ppc(datatree, kind=kind, group=group, backend=backend) + def test_plot_ppc(self, datatree, kind, backend): + pc = plot_ppc(datatree, kind=kind, backend=backend) assert "chart" in pc.viz.data_vars assert "obs" in pc.viz - assert "ppc_dim" in pc.viz["obs"].dims + # assert "ppc_dim" in pc.viz["obs"].dims if kind == "kde": assert "kde" in pc.viz["obs"] elif kind == "cumulative": assert "ecdf" in pc.viz["obs"] assert "overlay" in pc.aes["obs"].data_vars - @pytest.mark.parametrize("group", ("prior", "posterior")) @pytest.mark.parametrize("kind", ("kde", "cumulative")) - def test_plot_ppc_sample(self, datatree_sample, kind, group, backend): - pc = plot_ppc( - datatree_sample, kind=kind, group=group, sample_dims="sample", backend=backend - ) + def test_plot_ppc_sample(self, datatree_sample, kind, backend): + pc = plot_ppc(datatree_sample, kind=kind, sample_dims="sample", backend=backend) assert "chart" in pc.viz.data_vars assert "obs" in pc.viz - assert "ppc_dim" in pc.viz["obs"].dims + # assert "ppc_dim" in pc.viz["obs"].dims if kind == "kde": assert "kde" in pc.viz["obs"] elif kind == "cumulative": assert "ecdf" in pc.viz["obs"] assert "overlay" in pc.aes["obs"].data_vars - @pytest.mark.parametrize("group", ("prior", "posterior")) @pytest.mark.parametrize("kind", ("kde", "cumulative")) @pytest.mark.parametrize("facet_dims", (["group"], ["hierarchy"], None)) - def test_plot_ppc_4d(self, datatree_4d, facet_dims, kind, group, backend): + def test_plot_ppc_4d(self, datatree_4d, facet_dims, kind, backend): pc = plot_ppc( datatree_4d, facet_dims=facet_dims, kind=kind, - group=group, observed_rug=True, backend=backend, ) assert "chart" in pc.viz.data_vars assert "obs" in pc.viz - assert "ppc_dim" in pc.viz["obs"].dims + # assert "ppc_dim" in pc.viz["obs"].dims if kind == "kde": assert "kde" in pc.viz["obs"] elif kind == "cumulative": assert "ecdf" in pc.viz["obs"] assert "overlay" in pc.aes["obs"].data_vars + if facet_dims is not None: + for dim in facet_dims: + assert dim in pc.viz["obs"]["plot"].dims