diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index c0bdabf415..63f139a63e 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -744,14 +744,13 @@ def rejuvenation_criterion(step, state): num_outer_particles = 3 num_inner_particles = 5 - loc = tf.broadcast_to([0., 0.], [num_outer_particles, 2]) - scale_diag = tf.broadcast_to([0.01, 0.01], [num_outer_particles, 2]) - params, _ = self.evaluate(particle_filter.smc_squared( observations=observations, - inner_initial_state_prior=lambda _, params: mvn_diag.MultivariateNormalDiag( + inner_initial_state_prior=lambda _, params: + mvn_diag.MultivariateNormalDiag( loc=tf.broadcast_to([0., 0.], params.shape + [2]), - scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2])), + scale_diag=tf.broadcast_to([0.01, 0.01], params.shape + [2]) + ), initial_parameter_prior=normal.Normal(5., 0.5), num_outer_particles=num_outer_particles, num_inner_particles=num_inner_particles,