From cfafe715c31439cad9312eac862d7540014666c3 Mon Sep 17 00:00:00 2001 From: Ratish Panda Date: Tue, 13 Aug 2024 14:13:51 +0530 Subject: [PATCH] predictive values dim-stacking/sample-subselecting logic --- src/arviz_plots/plots/ppcplot.py | 47 +++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/arviz_plots/plots/ppcplot.py b/src/arviz_plots/plots/ppcplot.py index d76080e..4ad7ca2 100644 --- a/src/arviz_plots/plots/ppcplot.py +++ b/src/arviz_plots/plots/ppcplot.py @@ -234,12 +234,6 @@ def plot_ppc( coords=coords, ) - # creating random number generator - if random_seed is not None: - rng = np.random.default_rng(random_seed) - else: - rng = np.random.default_rng() - # picking random pp dataset sample indexes and subsetting pp_distribution accordingly total_pp_samples = np.prod( [pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims] @@ -257,19 +251,40 @@ def plot_ppc( ): raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.") - # stacking sample dimensions and selecting randomly if sample_dims length>1 or - # num_pp_samples=total_pp_samples - if num_pp_samples != total_pp_samples or len(sample_dims) == 1: - pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) + # if num_pp_samples==1, no stacking or subselecting required + + # stacking sample dimensions and subselecting from + # if len(sample_dims)>1 and num_pp_samples!=total_pp_samples + + # subselecting from without stacking if len(sample_dims)==1 and + # num_pp_samples!=total_pp_samples + + if num_pp_samples != total_pp_samples: + # creating random number generator for random subselecting + if random_seed is not None: + rng = np.random.default_rng(random_seed) + else: + rng = np.random.default_rng() + + if len(sample_dims) > 1: + # stacking into one dim (ppc_dim) required before subselecting + pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) + + # stacking sample dims into a new 'ppc_dim' dimension + pp_distribution = pp_distribution.stack(ppc_dim=sample_dims) - # stacking sample dims into a new 'ppc_dim' dimension - pp_distribution = pp_distribution.stack(ppc_dim=sample_dims) + # Select the desired samples + pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix) - # Select the desired samples - pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix) + # renaming sample_dims so that rest of plot will consider this as sample_dims + sample_dims = ["ppc_dim"] + elif len(sample_dims) == 1: + # subselecting without stacking + pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False) - # renaming sample_dims so that rest of plot will consider this as sample_dims - sample_dims = ["ppc_dim"] + # Select the desired samples + sample_dim = sample_dims[0] # provided dimension used for subsetting + pp_distribution = pp_distribution.isel({sample_dim: pp_sample_ix}) # wrap plot collection with pp distribution if plot_collection is None: