Skip to content

Commit

Permalink
predictive values dim-stacking/sample-subselecting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Aug 13, 2024
1 parent e8a22ae commit cfafe71
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions src/arviz_plots/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,6 @@ def plot_ppc(
coords=coords,
)

# creating random number generator
if random_seed is not None:
rng = np.random.default_rng(random_seed)
else:
rng = np.random.default_rng()

# picking random pp dataset sample indexes and subsetting pp_distribution accordingly
total_pp_samples = np.prod(
[pp_distribution.sizes[dim] for dim in sample_dims if dim in pp_distribution.dims]
Expand All @@ -257,19 +251,40 @@ def plot_ppc(
):
raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.")

# stacking sample dimensions and selecting randomly if sample_dims length>1 or
# num_pp_samples=total_pp_samples
if num_pp_samples != total_pp_samples or len(sample_dims) == 1:
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)
# if num_pp_samples==1, no stacking or subselecting required

# stacking sample dimensions and subselecting from
# if len(sample_dims)>1 and num_pp_samples!=total_pp_samples

# subselecting from without stacking if len(sample_dims)==1 and
# num_pp_samples!=total_pp_samples

if num_pp_samples != total_pp_samples:
# creating random number generator for random subselecting
if random_seed is not None:
rng = np.random.default_rng(random_seed)
else:
rng = np.random.default_rng()

if len(sample_dims) > 1:
# stacking into one dim (ppc_dim) required before subselecting
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)

# stacking sample dims into a new 'ppc_dim' dimension
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims)

# stacking sample dims into a new 'ppc_dim' dimension
pp_distribution = pp_distribution.stack(ppc_dim=sample_dims)
# Select the desired samples
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix)

# Select the desired samples
pp_distribution = pp_distribution.isel(ppc_dim=pp_sample_ix)
# renaming sample_dims so that rest of plot will consider this as sample_dims
sample_dims = ["ppc_dim"]
elif len(sample_dims) == 1:
# subselecting without stacking
pp_sample_ix = rng.choice(total_pp_samples, size=num_pp_samples, replace=False)

# renaming sample_dims so that rest of plot will consider this as sample_dims
sample_dims = ["ppc_dim"]
# Select the desired samples
sample_dim = sample_dims[0] # provided dimension used for subsetting
pp_distribution = pp_distribution.isel({sample_dim: pp_sample_ix})

# wrap plot collection with pp distribution
if plot_collection is None:
Expand Down

0 comments on commit cfafe71

Please sign in to comment.