diff --git a/src/arviz_plots/plots/ppcplot.py b/src/arviz_plots/plots/ppcplot.py index b3cca40..9b32050 100644 --- a/src/arviz_plots/plots/ppcplot.py +++ b/src/arviz_plots/plots/ppcplot.py @@ -103,13 +103,9 @@ def plot_ppc( aes_map : mapping of {str : sequence of str}, optional Mapping of artists to aesthetics that should use their mapping in `plot_collection` when plotted. Valid keys are the same as for `plot_kwargs`. - (Note: like in `plot_kwargs` below, aes_map values are passed with key as `kind` - when passed to `plot_dist`) plot_kwargs : mapping of {str : mapping or False}, optional Valid keys are: - (Note: This function internally calls `plot_dist` so the first three artists here get - mapped to one of `plot_dist`'s visual element types.) * "predictive" -> Passed to either of "kde", "cumulative", "scatter" based on `kind` * "observed" -> passed to either of "kde", "cumulative", "scatter" based on `kind` @@ -117,7 +113,6 @@ def plot_ppc( Values of the above plot_kwargs keys are passed to one of "kde", "cumulative", "scatter", matching the `kind` argument. - These are passed to :func:`~arviz_plots.plots.distplot.plot_dist`. * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` * "cumulative" -> passed to :func:`~arviz_plots.visuals.ecdf_line` @@ -129,7 +124,9 @@ def plot_ppc( stats_kwargs : mapping, optional Valid keys are: - * density -> passed to kde, cumulative, ... + * predictive -> passed to kde, cumulative, ... + * aggregate -> passed to kde, cumulative, ... + * observed -> passed to kde, cumulative, ... pc_kwargs : mapping Passed to :class:`arviz_plots.PlotCollection.wrap` @@ -150,6 +147,9 @@ def plot_ppc( kind = rcParams["plot.density_kind"] if plot_kwargs is None: plot_kwargs = {} + # Note: This function internally calls `plot_dist` so the 3 relevant artists "predictive", + # "aggregate", "observed" get mapped to one of `plot_dist`'s density artists- "kde", + # "ecdf", "scatter" based on the value of the top level arg `kind` if pc_kwargs is None: pc_kwargs = {} else: @@ -371,7 +371,9 @@ def plot_ppc( kind: value for key, value in aes_map.items() if key == "predictive" }, # aes_map[kind] is set to "predictive" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "predictive" plot_kwargs - stats_kwargs=stats_kwargs, # common "density" key used for all artists generated + stats_kwargs={ + "density": value for key, value in stats_kwargs.items() if key == "predictive" + }, # stats_kwargs["density"] is set to "predictive" stats_kwargs # via plot_dist ) @@ -414,7 +416,9 @@ def plot_ppc( kind: value for key, value in aes_map.items() if key == "aggregate" }, # aes_map[kind] is set to "aggregate" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "aggregate" plot_kwargs - stats_kwargs=stats_kwargs, + stats_kwargs={ + "density": value for key, value in stats_kwargs.items() if key == "aggregate" + }, # stats_kwargs["density"] is set to "aggregate" stats_kwargs ) # ---------STEP 3 (observed data)----------- @@ -453,7 +457,9 @@ def plot_ppc( kind: value for key, value in aes_map.items() if key == "observed" }, # aes_map[kind] is set to "observed" aes_map plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "observed" plot_kwargs - stats_kwargs=stats_kwargs, + stats_kwargs={ + "density": value for key, value in stats_kwargs.items() if key == "observed" + }, # stats_kwargs["density"] is set to "observed" stats_kwargs ) # adding plot title/s