Skip to content

Commit

Permalink
added ppcplot-beta
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Jun 25, 2024
1 parent 6f8b750 commit a976e76
Show file tree
Hide file tree
Showing 2 changed files with 387 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from .distplot import plot_dist
from .forestplot import plot_forest
from .ppcplot import plot_ppc
from .tracedistplot import plot_trace_dist
from .traceplot import plot_trace

__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist"]
__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist", "plot_ppc"]
385 changes: 385 additions & 0 deletions src/arviz_plots/plots/ppcplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,385 @@
"""ppc plot code."""

# import warnings
from copy import copy
from numbers import Integral

import arviz_stats # pylint: disable=unused-import
import numpy as np

# import xarray as xr
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller

from arviz_plots.plot_collection import PlotCollection, concat_model_dict
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
from arviz_plots.visuals import line_xy


# define function
def plot_ppc(
dt,
var_names=None,
filter_vars=None,
group="posterior",
observed=None,
coords=None,
sample_dims=None,
kind=None,
data_pairs=None,
# mean=True,
flatten=None,
flatten_pp=None,
num_pp_samples=None,
random_seed=None,
# jitter=None,
animated=False,
plot_collection=None,
backend=None,
labeller=None,
aes_map=None,
plot_kwargs=None,
stats_kwargs=None,
pc_kwargs=None,
):
"""Plot prior/posterior predictive and observed values as kde, cumulative or scatter plots.
Parameters
----------
dt : DataTree or dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to
aesthetics.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
filter_vars : {None, “like”, “regex”}, default=None
If None, interpret var_names as the real variables names.
If “like”, interpret var_names as substrings of the real variables names.
If “regex”, interpret var_names as regular expressions on the real variables names.
group : str, default "posterior"
Group to be plotted. Note: Posterior refers to posterior-predictive, prior refers to
prior-predictive.
observed : boolean, optional
Whether or not to plot the observed data. Defaults to True for ``group = posterior``
and False for ``group = prior``.
coords : dict, optional
Dictionary mapping dimensions to selected coordinates to be plotted.
Dimensions without a mapping specified will include all coordinates for that dimension.
Defaults to including all coordinates for all dimensions if None.
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
kind : {"kde", "cumulative", "scatter"}, optional
How to represent the marginal density. Defaults to ``rcParams["plot.density_kind"]``.
data_pairs : dict, optional
Dictionary containing relations between observed data and posterior/prior predictive data.
Dictionary structure:
* key = observed data var_name
* value = posterior/prior predictive var_name
For example, data_pairs = {'y' : 'y_hat'}.
If None, it will assume that the observed data and the posterior/prior predictive data
have the same variable name
flatten : list of str, optional
Dimensions to flatten in the observed_data.
Only flattens across the coordinates specified in the coords argument.
Defaults to flattening all of the dimensions.
flatten_pp : list of str, optional
Dimensions to flatten in the posterior predictive data.
Only flattens across the coordinates specified in the coords argument.
Defaults to flattening all of the dimensions.
Dimensions should match flatten excluding dimensions for data_pairs parameters.
If flatten is defined and flatten_pp is None, then flatten_pp = flatten.
num_pp_samples : int, optional
Number of prior/posterior predictive samples to plot.
random_seed : int, optional
Random number generator seed passed to numpy.random.seed to allow reproducibility of the
plot.
By default, no seed will be provided and the plot will change each call if a random sample
is specified by num_pp_samples.
jitter : float, optional
If kind is “scatter”, jitter will add random uniform noise to the height of the ppc samples
and observed data.
animated : bool, optional
Create an animation of one posterior/prior predictive sample per frame if true.
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_map : mapping of {str : sequence of str}, optional
Mapping of artists to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `plot_kwargs`.
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* One of "kde", "cumulative", "scatter", matching the `kind` argument
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "cumulative" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "scatter" -> passed to :func: `~arviz_plots.visuals.scatter_x`
* "title" -> passed to :func:`~arviz_plots.visuals.labelled_title`
* "remove_axis" -> not passed anywhere, can only be ``False`` to skip calling this function
stats_kwargs : mapping, optional
Valid keys are:
* density -> passed to kde, cumulative, ...
pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.wrap`
Returns
-------
PlotCollection
Examples
--------
WIP
"""
if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
if kind is None:
kind = rcParams["plot.density_kind"]
if plot_kwargs is None:
plot_kwargs = {}
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()

if stats_kwargs is None:
stats_kwargs = {}

# making sure both posterior/prior predictive group and observed_data group exists in
# datatree provided
if group not in ("posterior", "prior"):
raise TypeError("`group` argument must be either `posterior` or `prior`")

for groups in (f"{group}_predictive", "observed_data"):
if not hasattr(dt, groups):
raise TypeError(f'`data` argument must have the group "{groups}" for ppcplot')

# making sure kde type is one of these three
if kind.lower() not in ("kde", "cumulative", "scatter"):
raise TypeError("`kind` argument must be either `kde`, `cumulative`, or `scatter`")

# initializaing data_pairs as empty dict in case pp and observed data var names are same
if data_pairs is None:
data_pairs = {}

if group == "posterior":
predictive_data_group = "posterior_predictive"
if observed is None:
observed = True
elif group == "prior":
predictive_data_group = "prior_predictive"
if observed is None:
observed = False

if observed:
observed_data_group = "observed_data"

# process data to plot (select specified groups/variables/coords)
# two distributions are created, one to hold pp observed data and one to hold actual
# observed data

obs_distribution = process_group_variables_coords(
dt, group=observed_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords
)

if flatten is None:
flatten = list(
obs_distribution.dims
) # assigning all dims to flatten in case user does not provide specific ones

pp_distribution = process_group_variables_coords(
dt, group=predictive_data_group, var_names=var_names, filter_vars=filter_vars, coords=coords
)

if flatten_pp is None:
flatten_pp = flatten

# concatenate both distributions into one or just put them into a dict to pass to .wrap() which
# will add
# them along the new model dimension
distribution = {"posterior_predictive": pp_distribution, "observed_data": obs_distribution}
distribution = concat_model_dict(
distribution
) # converts into a single dataset along new 'model' dim
print(
f"plot_ppc merged distribution= {distribution!r}"
) # after concatenating, only one variable "obs" exists
print(f"plot_ppc distribution obs variable values= {distribution.obs.values}")

# an advantage of having one variable is that .wrap() will not cause process_facet_dims to
# create separate plots for each variable, but if multiple subplots are wanted (for example
# for each coord of a dimension not to be flattened) then that can still be done by adjusting
# pc_kwargs

# facetting overall isnt very important for plot_ppc though since usually by default there'll
# be only one plot multiple plots would only matter in the case of non sample dims that
# have coords

# wip: dims selected to be flattened should be dealt with before wrapping or maybe have
# aesthetics set for them automatically/as-a-requirement if not to be flattened?

# wrap plot collection
if plot_collection is None:
if backend is None:
backend = rcParams["plot.backend"]
pc_kwargs.setdefault("col_wrap", 5)
pc_kwargs.setdefault(
"cols",
["__variable__"] # special variable to create one plot per variable
+ [
dim for dim in distribution.dims if dim not in {"model"}.union(distribution.dims)
], # zero dims are selected here
) # for plot_ppc(), selecting all dims to reduce by default, and not just sample dims
# (chain,draw) like plot_dist
# ^this is because plot_ppc() default is just one plot, though users can separate by a
# dim's coords should they wish explicitly
if "model" in distribution:
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("color", ["model"])
pc_kwargs["aes"].setdefault("y", ["model"])
# process_facet_dims is called within wrap() to create the subplotting areas- just 1 by
# default
plot_collection = PlotCollection.wrap(
distribution,
backend=backend,
**pc_kwargs,
)

if aes_map is None:
aes_map = {}
else:
aes_map = aes_map.copy()
# aes_map.setdefault(kind, plot_collection.aes_set.difference("y"))
if labeller is None:
labeller = BaseLabeller()

if random_seed is not None:
np.random.seed(random_seed)

# checking plot_collection wrapped dataset and viz/aes datatrees
print(f"\nplot_collection.data = {plot_collection.data}")
print(f"\nplot_collection.aes = {plot_collection.aes}")
print(f"\nplot_collection.viz = {plot_collection.viz}")

# picking random pp dataset sample indexes
total_pp_samples = plot_collection.data.sizes["chain"] * plot_collection.data.sizes["draw"]
if num_pp_samples is None:
if kind == "scatter" and not animated:
num_pp_samples = min(5, total_pp_samples)
else:
num_pp_samples = total_pp_samples

if (
not isinstance(num_pp_samples, Integral)
or num_pp_samples < 1
or num_pp_samples > total_pp_samples
):
raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.")

pp_sample_ix = np.random.choice(total_pp_samples, size=num_pp_samples, replace=False)

print(f"\npp_sample_ix: {pp_sample_ix!r}")

# iterate over ppc/observed data, subsetting and doing statistical computation of kde as
# required

# data structure implementation: could divide the ppc obs variable into coords along a
# new 'ppc_dim' dimension for each pp_sample and make sure mapping of each is done

# ---------STEP 1 (observed data)-----------
# subset the distribution data for observed data model, calculate density and call map()
observed_distribution = distribution.sel(model="observed_data")
print(
f"\nobserved distri = {observed_distribution.obs!r}"
) # the obs data was auto broadcasted all over the chain and draw dims
# but we can ignore the other dims and just subset it to 1 chain and 1 draw and then use the
# resulting subsetted variable data
observed_distribution = observed_distribution.sel(chain=0, draw=0)
print(f"\nobserved distri = {observed_distribution.obs!r}")

# density calculation for observed variables
density_kwargs = copy(plot_kwargs.get(kind, {}))

if density_kwargs is not False:
density_dims, _, density_ignore = filter_aes(plot_collection, aes_map, "kde", sample_dims)
print(f"\ndensity_dims = {density_dims}\ndensity_ignore= {density_ignore}")

for dim in flatten: # flatten is the list of user defined dims to flatten or all dims
obs_density_dims = {dim}.union(
density_dims
) # dims to be reduced now includes the flatten ones and not just sample dims

obs_density = observed_distribution.azstats.kde(
dims=obs_density_dims, **stats_kwargs.get("density", {})
)
print(f"\nobserved data density = {obs_density}")

plot_collection.map(
line_xy, "kde", data=obs_density, ignore_aes=density_ignore, **density_kwargs
)

# ---------STEP 2 (PPC data)-------------
# subset distribution for predictive data model, reshape and pick samples and flatten and
# then call kde to get density info then call map() again with this density info
# (algorithm of legacy plot ppc followed for this, but implemented in refactored way)
predictive_distribution = distribution.sel(model="posterior_predictive")
print(f"\nposterior predictive distri = {predictive_distribution.obs!r}")

predictive_distribution = predictive_distribution.stack(ppc_dim=("chain", "draw")) # reshaping
predictive_distribution = predictive_distribution.assign_coords(
ppc_dim=np.arange(predictive_distribution.sizes["ppc_dim"])
)
print(f"\nposterior predictive distri reshaped = {predictive_distribution!r}")
print(f"\nposterior predictive distri stacked variable = {predictive_distribution.obs!r}")

# selecting sampled values from predictive_distribution only
predictive_distribution = predictive_distribution.isel(ppc_dim=pp_sample_ix)
print(f"\nposterior predictive distri selected samples = {predictive_distribution.obs!r}")

for dim in flatten_pp: # flatten is the list of user defined dims to flatten or all dims
pp_density_dims = {dim}.union(
density_dims
) # dims to be reduced now includes the flatten ones and not just sample dims
# ppc_dim should not be flattened- this is because subselections of this correspond to
# pp samples

pp_densities = []
pp_xs = []

for i in range(predictive_distribution.sizes["ppc_dim"]):
# select the i-th subselection along 'ppc_dim'
subselection = predictive_distribution.isel(ppc_dim=i)

# compute the density of the subselection
pp_density = subselection.azstats.kde(
dims=pp_density_dims, **stats_kwargs.get("density", {})
)
print(f"\npredictive data density for subselection {i} = {pp_density}")
pp_densities.append(pp_density.sel(plot_axis="y").values)
pp_xs.append(pp_density.sel(plot_axis="x").values) # storing these for later mean calc

plot_collection.map(
line_xy, "kde", data=pp_density, ignore_aes=density_ignore, **density_kwargs
)

print(f"\npp_densities= {pp_densities}")
print(f"\npp_xs = {pp_xs}")

# ---------STEP 3 (PPC MEAN)------------- (WIP)

# checking plot_collection wrapped dataset and viz/aes datatrees
print("\nAfter .map() of density as kde artist")
print(f"\nplot_collection.data = {plot_collection.data}")
print(f"\nplot_collection.aes = {plot_collection.aes}")
print(f"\nplot_collection.viz = {plot_collection.viz}")

print("End of plot_ppc()")
return plot_collection

0 comments on commit a976e76

Please sign in to comment.