Skip to content

Commit

Permalink
all flakes
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Dec 4, 2023
1 parent 85ed8e6 commit 774e5b1
Showing 1 changed file with 149 additions and 139 deletions.
288 changes: 149 additions & 139 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,177 +628,187 @@ def _outer_particle_filter_propose_and_update_log_weights_fn(
num_outer_particles,
extra_fn
):
"""Build a function specifying a particle filter update step."""
def _outer_propose_and_update_log_weights_fn(step, state, seed=None):
outside_parameters = state.particles[0]
inner_weighted_particles, log_weights = state.particles[1], state.log_weights
"""Build a function specifying a particle filter update step."""
def _outer_propose_and_update_log_weights_fn(step, state, seed=None):
outside_parameters = state.particles[0]
inner_weighted_particles, log_weights = state.particles[1], state.log_weights

filter_results = smc_kernel.SequentialMonteCarloResults(
filter_results = smc_kernel.SequentialMonteCarloResults(
steps=step,
parent_indices=state.particles[2],
incremental_log_marginal_likelihood=state.particles[3],
accumulated_log_marginal_likelihood=state.particles[4],
seed=state.extra[1])

inner_propose_and_update_log_weights_fn = (
inner_propose_and_update_log_weights_fn = (
_particle_filter_propose_and_update_log_weights_fn(
observations=inner_observations,
transition_fn=inner_transition_fn(outside_parameters),
proposal_fn=(inner_proposal_fn(outside_parameters)
if inner_proposal_fn is not None else None),
observation_fn=inner_observation_fn(outside_parameters),
particles_dim=1,
num_transitions_per_observation=num_transitions_per_observation,
extra_fn=extra_fn
)
)

kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn,
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients
)

inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles,
filter_results,
seed=seed)

updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood

do_rejuvenation = outer_rejuvenation_criterion_fn(step, state)

def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results):
proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed)

rej_params_log_weights = ps.zeros_like(
initial_parameter_prior.log_prob(proposed_parameters)
)
rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0)

rej_inner_weighted_particles = _particle_filter_initial_weighted_particles(
observations=inner_observations,
observation_fn=inner_observation_fn(proposed_parameters),
initial_state_prior=inner_initial_state_prior(0, proposed_parameters),
initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters)
if inner_initial_state_proposal is not None else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seed)

batch_zeros = tf.zeros(ps.shape(log_weights))

rej_filter_results = smc_kernel.SequentialMonteCarloResults(
steps=tf.constant(0, dtype=tf.int32),
parent_indices=smc_kernel._dummy_indices_like(
rej_inner_weighted_particles.log_weights
),
incremental_log_marginal_likelihood=batch_zeros,
accumulated_log_marginal_likelihood=batch_zeros,
seed=samplers.zeros_seed())

rej_inner_particles_weights = rej_inner_weighted_particles.log_weights

rej_inner_propose_and_update_log_weights_fn = (
_particle_filter_propose_and_update_log_weights_fn(
observations=inner_observations,
transition_fn=inner_transition_fn(outside_parameters),
proposal_fn=(inner_proposal_fn(outside_parameters)
transition_fn=inner_transition_fn(proposed_parameters),
proposal_fn=(inner_proposal_fn(proposed_parameters)
if inner_proposal_fn is not None else None),
observation_fn=inner_observation_fn(outside_parameters),
observation_fn=inner_observation_fn(proposed_parameters),
extra_fn=extra_fn,
particles_dim=1,
num_transitions_per_observation=num_transitions_per_observation,
extra_fn=extra_fn
)
num_transitions_per_observation=num_transitions_per_observation)
)

kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=inner_propose_and_update_log_weights_fn,
rej_kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn,
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients)

inner_weighted_particles, filter_results = kernel.one_step(inner_weighted_particles,
filter_results,
seed=seed)

updated_log_weights = log_weights + filter_results.incremental_log_marginal_likelihood

do_rejuvenation = outer_rejuvenation_criterion_fn(step, state)

def rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results):
proposed_parameters = parameter_proposal_kernel(outside_parameters).sample(seed=seed)

rej_params_log_weights = ps.zeros_like(
initial_parameter_prior.log_prob(proposed_parameters)
def condition(i,
rej_inner_weighted_particles,
rej_filter_results,
rej_parameters_weights,
rej_params_log_weights):
return tf.less_equal(i, step)

def body(i,
rej_inner_weighted_particles,
rej_filter_results,
rej_parameters_weights,
rej_params_log_weights):

rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step(
rej_inner_weighted_particles, rej_filter_results, seed=seed
)
rej_params_log_weights = tf.nn.log_softmax(rej_params_log_weights, axis=0)

rej_inner_weighted_particles = _particle_filter_initial_weighted_particles(
observations=inner_observations,
observation_fn=inner_observation_fn(proposed_parameters),
initial_state_prior=inner_initial_state_prior(0, proposed_parameters),
initial_state_proposal=(inner_initial_state_proposal(0, proposed_parameters)
if inner_initial_state_proposal is not None else None),
num_particles=num_inner_particles,
particles_dim=1,
seed=seed)

batch_zeros = tf.zeros(ps.shape(log_weights))

rej_filter_results = smc_kernel.SequentialMonteCarloResults(
steps=tf.constant(0, dtype=tf.int32),
parent_indices=smc_kernel._dummy_indices_like(rej_inner_weighted_particles.log_weights),
incremental_log_marginal_likelihood=batch_zeros,
accumulated_log_marginal_likelihood=batch_zeros,
seed=samplers.zeros_seed())

rej_inner_particles_weights = rej_inner_weighted_particles.log_weights

rej_inner_propose_and_update_log_weights_fn = (
_particle_filter_propose_and_update_log_weights_fn(
observations=inner_observations,
transition_fn=inner_transition_fn(proposed_parameters),
proposal_fn=(inner_proposal_fn(proposed_parameters)
if inner_proposal_fn is not None else None),
observation_fn=inner_observation_fn(proposed_parameters),
extra_fn=extra_fn,
particles_dim=1,
num_transitions_per_observation=num_transitions_per_observation))

rej_kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=rej_inner_propose_and_update_log_weights_fn,
resample_fn=inner_resample_fn,
resample_criterion_fn=inner_resample_criterion_fn,
particles_dim=1,
unbiased_gradients=unbiased_gradients)

def condition(i,
rej_inner_weighted_particles,
rej_filter_results,
rej_parameters_weights,
rej_params_log_weights):
return tf.less_equal(i, step)

def body(i,
rej_inner_weighted_particles,
rej_filter_results,
rej_parameters_weights,
rej_params_log_weights):

rej_inner_weighted_particles, rej_filter_results = rej_kernel.one_step(
rej_inner_weighted_particles, rej_filter_results, seed=seed
)

rej_parameters_weights += rej_inner_weighted_particles.log_weights

rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood
return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights

i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop(
condition,
body,
loop_vars=[0, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights,
rej_params_log_weights]
)
rej_parameters_weights += rej_inner_weighted_particles.log_weights

log_a = rej_filter_results.accumulated_log_marginal_likelihood - \
filter_results.accumulated_log_marginal_likelihood + \
parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \
parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters)
rej_params_log_weights = rej_params_log_weights + rej_filter_results.incremental_log_marginal_likelihood
return i + 1, rej_inner_weighted_particles, rej_filter_results, rej_parameters_weights, rej_params_log_weights

acceptance_probs = tf.minimum(1., tf.exp(log_a))
i, rej_inner_weighted_particles, rej_filter_results, rej_inner_particles_weights, rej_params_log_weights = tf.while_loop(
condition,
body,
loop_vars=[0,
rej_inner_weighted_particles,
rej_filter_results,
rej_inner_particles_weights,
rej_params_log_weights
]
)

random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed)
log_a = rej_filter_results.accumulated_log_marginal_likelihood - \
filter_results.accumulated_log_marginal_likelihood + \
parameter_proposal_kernel(proposed_parameters).log_prob(outside_parameters) - \
parameter_proposal_kernel(outside_parameters).log_prob(proposed_parameters)

# Determine if the proposed particle should be accepted or reject
accept = random_numbers > acceptance_probs
acceptance_probs = tf.minimum(1., tf.exp(log_a))

# Update the chosen particles and filter restults based on the acceptance step
outside_parameters = tf.where(accept, outside_parameters, proposed_parameters)
updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights)
random_numbers = uniform.Uniform(0., 1.).sample(num_outer_particles, seed=seed)

inner_weighted_particles_particles = mcmc_util.choose(accept,
inner_weighted_particles.particles,
rej_inner_weighted_particles.particles
)
inner_weighted_particles_log_weights = mcmc_util.choose(accept,
inner_weighted_particles.log_weights,
rej_inner_weighted_particles.log_weights
)
# Determine if the proposed particle should be accepted or reject
accept = random_numbers > acceptance_probs

inner_weighted_particles = smc_kernel.WeightedParticles(
particles=inner_weighted_particles_particles,
log_weights=inner_weighted_particles_log_weights,
extra=inner_weighted_particles.extra
)
# Update the chosen particles and filter restults based on the acceptance step
outside_parameters = tf.where(accept, outside_parameters, proposed_parameters)
updated_log_weights = tf.where(accept, updated_log_weights, rej_params_log_weights)

filter_results = tf.nest.map_structure(
lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles),
filter_results,
rej_filter_results
)
inner_weighted_particles_particles = mcmc_util.choose(
accept,
inner_weighted_particles.particles,
rej_inner_weighted_particles.particles
)
inner_weighted_particles_log_weights = mcmc_util.choose(
accept,
inner_weighted_particles.log_weights,
rej_inner_weighted_particles.log_weights
)

return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results
inner_weighted_particles = smc_kernel.WeightedParticles(
particles=inner_weighted_particles_particles,
log_weights=inner_weighted_particles_log_weights,
extra=inner_weighted_particles.extra
)

outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond(
do_rejuvenation,
lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)),
lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)
filter_results = tf.nest.map_structure(
lambda a, b: where_fn(accept, a, b, num_outer_particles, num_inner_particles),
filter_results,
rej_filter_results
)

return smc_kernel.WeightedParticles(
particles=(outside_parameters,
inner_weighted_particles,
filter_results.parent_indices,
filter_results.incremental_log_marginal_likelihood,
filter_results.accumulated_log_marginal_likelihood),
log_weights=updated_log_weights,
extra=(step,
filter_results.seed))
return _outer_propose_and_update_log_weights_fn
return outside_parameters, updated_log_weights, inner_weighted_particles, filter_results

outside_parameters, updated_log_weights, inner_weighted_particles, filter_results = tf.cond(
do_rejuvenation,
lambda: (rejuvenate_particles(outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)),
lambda: (outside_parameters, updated_log_weights, inner_weighted_particles, filter_results)
)

return smc_kernel.WeightedParticles(
particles=(outside_parameters,
inner_weighted_particles,
filter_results.parent_indices,
filter_results.incremental_log_marginal_likelihood,
filter_results.accumulated_log_marginal_likelihood),
log_weights=updated_log_weights,
extra=(step,
filter_results.seed))
return _outer_propose_and_update_log_weights_fn


@docstring_util.expand_docstring(
Expand Down

0 comments on commit 774e5b1

Please sign in to comment.