Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 13, 2023
1 parent 37b01c1 commit d5ddceb
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None,

# Normalize the weights and sample the ancestral indices.
log_probs = tf.math.log_softmax(log_weights, axis=particles_dim)
resampled_indices = resample_fn(log_probs, num_particles, (),
particles_dim=particles_dim, seed=seed)
if particles_dim == 0:
# For resample functions that don't yet support the particles_dim argument.
resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed)
else:
resampled_indices = resample_fn(log_probs, num_particles, (),
particles_dim=particles_dim, seed=seed)

gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda
mcmc_util.index_remapping_gather(x,
Expand Down Expand Up @@ -281,6 +285,9 @@ def resample_systematic(log_probs, event_size, sample_shape,
The remaining dimensions are batch dimensions.
event_size: the dimension of the vector considered a single draw.
sample_shape: the `sample_shape` determining the number of draws.
particles_dim: Python `int` axis of each state `Tensor` indexing into the
particles. This is almost always zero, but nonzero values may be necessary
when running SMC in nested contexts.
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
Default value: None (i.e. no seed).
name: Python `str` name for ops created by this method.
Expand Down

0 comments on commit d5ddceb

Please sign in to comment.