Skip to content

Commit

Permalink
fixing test
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 14, 2023
1 parent 0368dff commit 67ec97d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def observation_fn(_, state):
with tf.name_scope(name or 'infer_trajectories') as name:
pf_seed, resample_seed = samplers.split_seed(
seed, salt='infer_trajectories')

(particles,
log_weights,
parent_indices,
Expand Down Expand Up @@ -377,7 +376,7 @@ def seeded_one_step(seed_state_results, _):
# Return results from just the final step.
traced_results = trace_fn(*final_seed_state_result[1:])

return traced_results
return traced_results


@docstring_util.expand_docstring(
Expand All @@ -393,6 +392,7 @@ def particle_filter(observations,
resample_fn=weighted_resampling.resample_systematic,
resample_criterion_fn=smc_kernel.ess_below_threshold,
unbiased_gradients=True,
rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument
num_transitions_per_observation=1,
trace_fn=_default_trace_fn,
trace_criterion_fn=_always_trace,
Expand Down Expand Up @@ -448,6 +448,7 @@ def particle_filter(observations,
Filtering without Modifying the Forward Pass. _arXiv preprint
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
"""

init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter')
with tf.name_scope(name or 'particle_filter'):
num_observation_steps = ps.size0(tf.nest.flatten(observations)[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def testMarginalLikelihoodGradientIsDefined(self):
WeightedParticles(
particles=samplers.normal([num_particles], seed=seeds[0]),
log_weights=tf.fill([num_particles],
-tf.math.log(float(num_particles)))
))
-tf.math.log(float(num_particles)))))

def propose_and_update_log_weights_fn(_,
weighted_particles,
Expand All @@ -111,8 +110,7 @@ def propose_and_update_log_weights_fn(_,
particles=proposed_particles,
log_weights=(weighted_particles.log_weights +
transition_dist.log_prob(proposed_particles) -
proposal_dist.log_prob(proposed_particles))
)
proposal_dist.log_prob(proposed_particles)))

def marginal_logprob(transition_scale):
kernel = SequentialMonteCarlo(
Expand Down

0 comments on commit 67ec97d

Please sign in to comment.