diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 8bb2a36f34..8e23166adf 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -663,7 +663,7 @@ def _compute_observation_log_weights(step, observations) observation = tf.nest.map_structure(lambda x: - tf.expand_dims(x, axis=particles_dim), + tf.expand_dims(x, axis=particles_dim), observation) log_weights = observation_fn(step, particles).log_prob(observation)