diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py index 5dfff6807a..aa0bfe92e0 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py @@ -79,8 +79,12 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, # Normalize the weights and sample the ancestral indices. log_probs = tf.math.log_softmax(log_weights, axis=particles_dim) - resampled_indices = resample_fn(log_probs, num_particles, (), - particles_dim=particles_dim, seed=seed) + if particles_dim == 0: + # For resample functions that don't yet support the particles_dim argument. + resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + else: + resampled_indices = resample_fn(log_probs, num_particles, (), + particles_dim=particles_dim, seed=seed) gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda mcmc_util.index_remapping_gather(x, @@ -281,6 +285,9 @@ def resample_systematic(log_probs, event_size, sample_shape, The remaining dimensions are batch dimensions. event_size: the dimension of the vector considered a single draw. sample_shape: the `sample_shape` determining the number of draws. + particles_dim: Python `int` axis of each state `Tensor` indexing into the + particles. This is almost always zero, but nonzero values may be necessary + when running SMC in nested contexts. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Default value: None (i.e. no seed). name: Python `str` name for ops created by this method.