Skip to content

Commit

Permalink
pylinted
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 13, 2023
1 parent ca98cd5 commit 62b7701
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 94 deletions.
177 changes: 92 additions & 85 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import batch_reshape
from tensorflow_probability.python.distributions import batch_broadcast
from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel
from tensorflow_probability.python.experimental.mcmc import weighted_resampling
from tensorflow_probability.python.internal import assert_util
Expand Down Expand Up @@ -290,92 +288,94 @@ def sequential_monte_carlo(loop_seed,
never_trace=lambda *_: False,
):

"""Samples a series of particles representing filtered latent states.
The particle filter samples from the sequence of "filtering" distributions
`p(state[t] | observations[:t])` over latent
states: at each point in time, this is a sample from the distribution conditioned
on all observations *up to that time*. Because particles may be resampled, a particle
at time `t` may be different from the particle with the same index at time
`t + 1`. To reconstruct trajectories by tracing back through the resampling
process, see `tfp.mcmc.experimental.reconstruct_trajectories`.
${particle_filter_arg_str}
trace_fn: Python `callable` defining the values to be traced at each step,
with signature `traced_values = trace_fn(weighted_particles, results)`
in which the first argument is an instance of
`tfp.experimental.mcmc.WeightedParticles` and the second an instance of
`SequentialMonteCarloResults` tuple, and the return value is a structure
of `Tensor`s.
Default value: `lambda s, r: (s.particles, s.log_weights,
r.parent_indices, r.incremental_log_marginal_likelihood)`
trace_criterion_fn: optional Python `callable` with signature
`trace_this_step = trace_criterion_fn(weighted_particles, results)` taking
the same arguments as `trace_fn` and returning a boolean `Tensor`. If
`None`, only values from the final step are returned.
Default value: `lambda *_: True` (trace every step).
static_trace_allocation_size: Optional Python `int` size of trace to
allocate statically. This should be an upper bound on the number of steps
traced and is used only when the length cannot be
statically inferred (for example, if a `trace_criterion_fn` is specified).
It is primarily intended for contexts where static shapes are required,
such as in XLA-compiled code.
Default value: `None`.
parallel_iterations: Passed to the internal `tf.while_loop`.
Default value: `1`.
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
name: Python `str` name for ops created by this method.
Default value: `None` (i.e., `'particle_filter'`).
Returns:
traced_results: A structure of Tensors as returned by `trace_fn`. If
`trace_criterion_fn==None`, this is computed from the final step;
otherwise, each Tensor will have initial dimension `num_steps_traced`
and stacks the traced results across all steps.
#### References
[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
Filtering without Modifying the Forward Pass. _arXiv preprint
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
"""
kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
resample_fn=resample_fn,
resample_criterion_fn=resample_criterion_fn,
particles_dim=particles_dim,
unbiased_gradients=unbiased_gradients
)
"""Samples a series of particles representing filtered latent states.
The particle filter samples from the sequence of "filtering" distributions
`p(state[t] | observations[:t])` over latent
states: at each point in time, this is a sample from the distribution
conditioned on all observations *up to that time*. Because particles may be
resampled, a particle at time `t` may be different from the particle with
the same index at time `t + 1`. To reconstruct trajectories by tracing back
through the resampling process,
see `tfp.mcmc.experimental.reconstruct_trajectories`.
${particle_filter_arg_str}
trace_fn: Python `callable` defining the values to be traced at each step,
with signature `traced_values = trace_fn(weighted_particles, results)`
in which the first argument is an instance of
`tfp.experimental.mcmc.WeightedParticles` and the second an instance of
`SequentialMonteCarloResults` tuple, and the return value is a structure
of `Tensor`s.
Default value: `lambda s, r: (s.particles, s.log_weights,
r.parent_indices, r.incremental_log_marginal_likelihood)`
trace_criterion_fn: optional Python `callable` with signature
`trace_this_step = trace_criterion_fn(weighted_particles, results)`
taking the same arguments as `trace_fn` and returning a boolean `Tensor`.
If `None`, only values from the final step are returned.
Default value: `lambda *_: True` (trace every step).
static_trace_allocation_size: Optional Python `int` size of trace to
allocate statically. This should be an upper bound on the number of steps
traced and is used only when the length cannot be
statically inferred (for example, if a `trace_criterion_fn` is
specified).
It is primarily intended for contexts where static shapes are required,
such as in XLA-compiled code.
Default value: `None`.
parallel_iterations: Passed to the internal `tf.while_loop`.
Default value: `1`.
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
name: Python `str` name for ops created by this method.
Default value: `None` (i.e., `'particle_filter'`).
Returns:
traced_results: A structure of Tensors as returned by `trace_fn`. If
`trace_criterion_fn==None`, this is computed from the final step;
otherwise, each Tensor will have initial dimension `num_steps_traced`
and stacks the traced results across all steps.
# Use `trace_scan` rather than `sample_chain` directly because the latter
# would force us to trace the state history (with or without thinning),
# which is not always appropriate.
def seeded_one_step(seed_state_results, _):
#### References
seed, state, results = seed_state_results
[1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle
Filtering without Modifying the Forward Pass. _arXiv preprint
arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314
"""
kernel = smc_kernel.SequentialMonteCarlo(
propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
resample_fn=resample_fn,
resample_criterion_fn=resample_criterion_fn,
particles_dim=particles_dim,
unbiased_gradients=unbiased_gradients
)

one_step_seed, next_seed = samplers.split_seed(seed)
# Use `trace_scan` rather than `sample_chain` directly because the latter
# would force us to trace the state history (with or without thinning),
# which is not always appropriate.
def seeded_one_step(seed_state_results, _):

next_state, next_results = kernel.one_step(
state, results, seed=one_step_seed)
seed, state, results = seed_state_results

return next_seed, next_state, next_results
one_step_seed, next_seed = samplers.split_seed(seed)

final_seed_state_result, traced_results = loop_util.trace_scan(
loop_fn=seeded_one_step,
initial_state=(loop_seed,
initial_weighted_particles,
kernel.bootstrap_results(initial_weighted_particles)),
elems=tf.ones([num_timesteps]),
trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]),
trace_criterion_fn=(
lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda
*seed_state_results[1:])),
static_trace_allocation_size=static_trace_allocation_size,
parallel_iterations=parallel_iterations)
next_state, next_results = kernel.one_step(
state, results, seed=one_step_seed)

return next_seed, next_state, next_results

final_seed_state_result, traced_results = loop_util.trace_scan(
loop_fn=seeded_one_step,
initial_state=(loop_seed,
initial_weighted_particles,
kernel.bootstrap_results(initial_weighted_particles)),
elems=tf.ones([num_timesteps]),
trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]),
trace_criterion_fn=(
lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda
*seed_state_results[1:])),
static_trace_allocation_size=static_trace_allocation_size,
parallel_iterations=parallel_iterations)

if trace_criterion_fn is never_trace:
# Return results from just the final step.
traced_results = trace_fn(*final_seed_state_result[1:])
if trace_criterion_fn is never_trace:
# Return results from just the final step.
traced_results = trace_fn(*final_seed_state_result[1:])

return traced_results

Expand Down Expand Up @@ -508,23 +508,30 @@ def _particle_filter_initial_weighted_particles(observations,
if initial_state_proposal is None:
if particles_dim == 0:
initial_state = initial_state_prior.sample(num_particles, seed=seed)
initial_log_weights = ps.zeros_like(initial_state_prior.log_prob(initial_state))
initial_log_weights = ps.zeros_like(
initial_state_prior.log_prob(initial_state)
)
else:
particles_draw = initial_state_prior.sample(num_particles)
initial_state = tf.nest.map_structure(
lambda x: dist_util.move_dimension(x, source_idx=0, dest_idx=particles_dim),
lambda x: dist_util.move_dimension(x,
source_idx=0,
dest_idx=particles_dim),
particles_draw
)
initial_log_weights = ps.zeros_like(
dist_util.move_dimension(initial_state_prior.log_prob(particles_draw), source_idx=0, dest_idx=particles_dim)
dist_util.move_dimension(initial_state_prior.log_prob(particles_draw),
source_idx=0,
dest_idx=particles_dim)
)
else:
initial_state = initial_state_proposal.sample(num_particles, seed=seed)
initial_log_weights = (initial_state_prior.log_prob(initial_state) -
initial_state_proposal.log_prob(initial_state))
# Normalize the initial weights. If we used a proposal, the weights are
# normalized in expectation, but actually normalizing them reduces variance.
initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=particles_dim)
initial_log_weights = tf.nn.log_softmax(initial_log_weights,
axis=particles_dim)

# Return particles weighted by the initial observation.
return smc_kernel.WeightedParticles(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import shift
from tensorflow_probability.python.distributions import bernoulli
from tensorflow_probability.python.distributions import categorical
from tensorflow_probability.python.distributions import deterministic
from tensorflow_probability.python.distributions import hidden_markov_model
from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab
from tensorflow_probability.python.distributions import joint_distribution_named as jdn
from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def ess_below_threshold(weighted_particles, particles_dim=0, threshold=0.5):
"""Determines if the effective sample size is much less than num_particles."""
with tf.name_scope('ess_below_threshold'):
num_particles = ps.size0(weighted_particles.log_weights)
log_ess = log_ess_from_log_weights(weighted_particles.log_weights, particles_dim)
log_ess = log_ess_from_log_weights(weighted_particles.log_weights,
particles_dim)
return tf.expand_dims(log_ess < (ps.log(num_particles) +
ps.log(threshold)), axis=particles_dim)

Expand Down Expand Up @@ -273,7 +274,8 @@ def one_step(self, state, kernel_results, seed=None):
state = tf.nest.map_structure(
lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state)

normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=self._particles_dim)
normalized_log_weights = tf.nn.log_softmax(state.log_weights,
axis=self._particles_dim)
# Every entry of `log_weights` differs from `normalized_log_weights`
# by the same normalizing constant. We extract that constant by
# examining an arbitrary entry.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None,
# Normalize the weights and sample the ancestral indices.
log_probs = tf.math.log_softmax(log_weights, axis=particles_dim)
if particles_dim == 0:
# For resample functions that don't yet support the particles_dim argument.
resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed)
# For resample functions that don't yet support the
# particles_dim argument.
resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed)
else:
resampled_indices = resample_fn(log_probs, num_particles, (),
particles_dim=particles_dim, seed=seed)
resampled_indices = resample_fn(log_probs, num_particles, (),
particles_dim=particles_dim, seed=seed)

gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda
mcmc_util.index_remapping_gather(x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def resample_with_target_distribution(self):
30., atol=1.)

def test_with_target_distribution_dim_one(self):
particles = np.linspace(0., 500., num=2500, dtype=np.float32)
stacked_particles = np.stack([
np.linspace(0., 500., num=2500, dtype=np.float32),
np.linspace(0.17, 433., num=2500, dtype=np.float32),
Expand Down

0 comments on commit 62b7701

Please sign in to comment.