Skip to content

Commit

Permalink
monte carlo kernel tests fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 13, 2023
1 parent d5ddceb commit ca98cd5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None):
return WeightedParticles(
particles=proposed_particles,
log_weights=weighted_particles.log_weights +
normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles),
extra=tf.constant(np.nan)
normal.Normal(loc=-2.6, scale=0.1).log_prob(proposed_particles)
)

num_particles = 16
Expand All @@ -52,8 +51,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None):
particles=tf.random.normal([num_particles],
seed=test_util.test_seed()),
log_weights=tf.fill([num_particles],
-tf.math.log(float(num_particles))),
extra=tf.constant(np.nan)
-tf.math.log(float(num_particles)))
))

# Run a couple of steps.
Expand Down Expand Up @@ -100,8 +98,7 @@ def testMarginalLikelihoodGradientIsDefined(self):
WeightedParticles(
particles=samplers.normal([num_particles], seed=seeds[0]),
log_weights=tf.fill([num_particles],
-tf.math.log(float(num_particles))),
extra=tf.constant(np.nan)
-tf.math.log(float(num_particles)))
))

def propose_and_update_log_weights_fn(_,
Expand All @@ -116,8 +113,7 @@ def propose_and_update_log_weights_fn(_,
particles=proposed_particles,
log_weights=(weighted_particles.log_weights +
transition_dist.log_prob(proposed_particles) -
proposal_dist.log_prob(proposed_particles)),
extra=tf.constant(np.nan)
proposal_dist.log_prob(proposed_particles))
)

def marginal_logprob(transition_scale):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_with_target_distribution_dim_one(self):
self.assertAllClose(
tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1),
[30., 30.],
atol=1.)
atol=1.5)

def maybe_compiler(self, f):
if self.use_xla:
Expand Down

0 comments on commit ca98cd5

Please sign in to comment.