From 67ec97d94c69b408b0f6d5d36661100fe45d09c5 Mon Sep 17 00:00:00 2001 From: slamitza Date: Tue, 14 Nov 2023 17:01:49 +0100 Subject: [PATCH] fixing test --- .../python/experimental/mcmc/particle_filter.py | 5 +++-- .../experimental/mcmc/sequential_monte_carlo_kernel_test.py | 6 ++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 4f18f4782d..4bfaf2dd1a 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -237,7 +237,6 @@ def observation_fn(_, state): with tf.name_scope(name or 'infer_trajectories') as name: pf_seed, resample_seed = samplers.split_seed( seed, salt='infer_trajectories') - (particles, log_weights, parent_indices, @@ -377,7 +376,7 @@ def seeded_one_step(seed_state_results, _): # Return results from just the final step. traced_results = trace_fn(*final_seed_state_result[1:]) - return traced_results + return traced_results @docstring_util.expand_docstring( @@ -393,6 +392,7 @@ def particle_filter(observations, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, unbiased_gradients=True, + rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, trace_criterion_fn=_always_trace, @@ -448,6 +448,7 @@ def particle_filter(observations, Filtering without Modifying the Forward Pass. _arXiv preprint arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 """ + init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index b586e44b86..2a9302a420 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -96,8 +96,7 @@ def testMarginalLikelihoodGradientIsDefined(self): WeightedParticles( particles=samplers.normal([num_particles], seed=seeds[0]), log_weights=tf.fill([num_particles], - -tf.math.log(float(num_particles))) - )) + -tf.math.log(float(num_particles))))) def propose_and_update_log_weights_fn(_, weighted_particles, @@ -111,8 +110,7 @@ def propose_and_update_log_weights_fn(_, particles=proposed_particles, log_weights=(weighted_particles.log_weights + transition_dist.log_prob(proposed_particles) - - proposal_dist.log_prob(proposed_particles)) - ) + proposal_dist.log_prob(proposed_particles))) def marginal_logprob(transition_scale): kernel = SequentialMonteCarlo(