From 62b7701a20dfa17c4f840b1f161539ef53adb3d8 Mon Sep 17 00:00:00 2001 From: slamitza Date: Mon, 13 Nov 2023 23:54:24 +0100 Subject: [PATCH] pylinted --- .../experimental/mcmc/particle_filter.py | 177 +++++++++--------- .../experimental/mcmc/particle_filter_test.py | 2 - .../mcmc/sequential_monte_carlo_kernel.py | 6 +- .../experimental/mcmc/weighted_resampling.py | 9 +- .../mcmc/weighted_resampling_test.py | 1 - 5 files changed, 101 insertions(+), 94 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index b996ed6bec..4f18f4782d 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -16,8 +16,6 @@ import numpy as np import tensorflow.compat.v2 as tf -from tensorflow_probability.python.distributions import batch_reshape -from tensorflow_probability.python.distributions import batch_broadcast from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel from tensorflow_probability.python.experimental.mcmc import weighted_resampling from tensorflow_probability.python.internal import assert_util @@ -290,92 +288,94 @@ def sequential_monte_carlo(loop_seed, never_trace=lambda *_: False, ): - """Samples a series of particles representing filtered latent states. - - The particle filter samples from the sequence of "filtering" distributions - `p(state[t] | observations[:t])` over latent - states: at each point in time, this is a sample from the distribution conditioned - on all observations *up to that time*. Because particles may be resampled, a particle - at time `t` may be different from the particle with the same index at time - `t + 1`. To reconstruct trajectories by tracing back through the resampling - process, see `tfp.mcmc.experimental.reconstruct_trajectories`. - - ${particle_filter_arg_str} - trace_fn: Python `callable` defining the values to be traced at each step, - with signature `traced_values = trace_fn(weighted_particles, results)` - in which the first argument is an instance of - `tfp.experimental.mcmc.WeightedParticles` and the second an instance of - `SequentialMonteCarloResults` tuple, and the return value is a structure - of `Tensor`s. - Default value: `lambda s, r: (s.particles, s.log_weights, - r.parent_indices, r.incremental_log_marginal_likelihood)` - trace_criterion_fn: optional Python `callable` with signature - `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking - the same arguments as `trace_fn` and returning a boolean `Tensor`. If - `None`, only values from the final step are returned. - Default value: `lambda *_: True` (trace every step). - static_trace_allocation_size: Optional Python `int` size of trace to - allocate statically. This should be an upper bound on the number of steps - traced and is used only when the length cannot be - statically inferred (for example, if a `trace_criterion_fn` is specified). - It is primarily intended for contexts where static shapes are required, - such as in XLA-compiled code. - Default value: `None`. - parallel_iterations: Passed to the internal `tf.while_loop`. - Default value: `1`. - seed: PRNG seed; see `tfp.random.sanitize_seed` for details. - name: Python `str` name for ops created by this method. - Default value: `None` (i.e., `'particle_filter'`). - Returns: - traced_results: A structure of Tensors as returned by `trace_fn`. If - `trace_criterion_fn==None`, this is computed from the final step; - otherwise, each Tensor will have initial dimension `num_steps_traced` - and stacks the traced results across all steps. - - #### References - - [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle - Filtering without Modifying the Forward Pass. _arXiv preprint - arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 - """ - kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - particles_dim=particles_dim, - unbiased_gradients=unbiased_gradients - ) + """Samples a series of particles representing filtered latent states. + + The particle filter samples from the sequence of "filtering" distributions + `p(state[t] | observations[:t])` over latent + states: at each point in time, this is a sample from the distribution + conditioned on all observations *up to that time*. Because particles may be + resampled, a particle at time `t` may be different from the particle with + the same index at time `t + 1`. To reconstruct trajectories by tracing back + through the resampling process, + see `tfp.mcmc.experimental.reconstruct_trajectories`. + + ${particle_filter_arg_str} + trace_fn: Python `callable` defining the values to be traced at each step, + with signature `traced_values = trace_fn(weighted_particles, results)` + in which the first argument is an instance of + `tfp.experimental.mcmc.WeightedParticles` and the second an instance of + `SequentialMonteCarloResults` tuple, and the return value is a structure + of `Tensor`s. + Default value: `lambda s, r: (s.particles, s.log_weights, + r.parent_indices, r.incremental_log_marginal_likelihood)` + trace_criterion_fn: optional Python `callable` with signature + `trace_this_step = trace_criterion_fn(weighted_particles, results)` + taking the same arguments as `trace_fn` and returning a boolean `Tensor`. + If `None`, only values from the final step are returned. + Default value: `lambda *_: True` (trace every step). + static_trace_allocation_size: Optional Python `int` size of trace to + allocate statically. This should be an upper bound on the number of steps + traced and is used only when the length cannot be + statically inferred (for example, if a `trace_criterion_fn` is + specified). + It is primarily intended for contexts where static shapes are required, + such as in XLA-compiled code. + Default value: `None`. + parallel_iterations: Passed to the internal `tf.while_loop`. + Default value: `1`. + seed: PRNG seed; see `tfp.random.sanitize_seed` for details. + name: Python `str` name for ops created by this method. + Default value: `None` (i.e., `'particle_filter'`). + Returns: + traced_results: A structure of Tensors as returned by `trace_fn`. If + `trace_criterion_fn==None`, this is computed from the final step; + otherwise, each Tensor will have initial dimension `num_steps_traced` + and stacks the traced results across all steps. - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): + #### References - seed, state, results = seed_state_results + [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle + Filtering without Modifying the Forward Pass. _arXiv preprint + arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 + """ + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + particles_dim=particles_dim, + unbiased_gradients=unbiased_gradients + ) - one_step_seed, next_seed = samplers.split_seed(seed) + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) + seed, state, results = seed_state_results - return next_seed, next_state, next_results + one_step_seed, next_seed = samplers.split_seed(seed) - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_weighted_particles, - kernel.bootstrap_results(initial_weighted_particles)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + + return next_seed, next_state, next_results + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_weighted_particles, + kernel.bootstrap_results(initial_weighted_particles)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations) - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) return traced_results @@ -508,15 +508,21 @@ def _particle_filter_initial_weighted_particles(observations, if initial_state_proposal is None: if particles_dim == 0: initial_state = initial_state_prior.sample(num_particles, seed=seed) - initial_log_weights = ps.zeros_like(initial_state_prior.log_prob(initial_state)) + initial_log_weights = ps.zeros_like( + initial_state_prior.log_prob(initial_state) + ) else: particles_draw = initial_state_prior.sample(num_particles) initial_state = tf.nest.map_structure( - lambda x: dist_util.move_dimension(x, source_idx=0, dest_idx=particles_dim), + lambda x: dist_util.move_dimension(x, + source_idx=0, + dest_idx=particles_dim), particles_draw ) initial_log_weights = ps.zeros_like( - dist_util.move_dimension(initial_state_prior.log_prob(particles_draw), source_idx=0, dest_idx=particles_dim) + dist_util.move_dimension(initial_state_prior.log_prob(particles_draw), + source_idx=0, + dest_idx=particles_dim) ) else: initial_state = initial_state_proposal.sample(num_particles, seed=seed) @@ -524,7 +530,8 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_proposal.log_prob(initial_state)) # Normalize the initial weights. If we used a proposal, the weights are # normalized in expectation, but actually normalizing them reduces variance. - initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=particles_dim) + initial_log_weights = tf.nn.log_softmax(initial_log_weights, + axis=particles_dim) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index b7761c3895..5e0628dd02 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -20,9 +20,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import shift from tensorflow_probability.python.distributions import bernoulli -from tensorflow_probability.python.distributions import categorical from tensorflow_probability.python.distributions import deterministic -from tensorflow_probability.python.distributions import hidden_markov_model from tensorflow_probability.python.distributions import joint_distribution_auto_batched as jdab from tensorflow_probability.python.distributions import joint_distribution_named as jdn from tensorflow_probability.python.distributions import linear_gaussian_ssm as lgssm diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index dc4d294b7e..7a5f5be045 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -127,7 +127,8 @@ def ess_below_threshold(weighted_particles, particles_dim=0, threshold=0.5): """Determines if the effective sample size is much less than num_particles.""" with tf.name_scope('ess_below_threshold'): num_particles = ps.size0(weighted_particles.log_weights) - log_ess = log_ess_from_log_weights(weighted_particles.log_weights, particles_dim) + log_ess = log_ess_from_log_weights(weighted_particles.log_weights, + particles_dim) return tf.expand_dims(log_ess < (ps.log(num_particles) + ps.log(threshold)), axis=particles_dim) @@ -273,7 +274,8 @@ def one_step(self, state, kernel_results, seed=None): state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) - normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=self._particles_dim) + normalized_log_weights = tf.nn.log_softmax(state.log_weights, + axis=self._particles_dim) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py index aa0bfe92e0..613419cf46 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py @@ -80,11 +80,12 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, # Normalize the weights and sample the ancestral indices. log_probs = tf.math.log_softmax(log_weights, axis=particles_dim) if particles_dim == 0: - # For resample functions that don't yet support the particles_dim argument. - resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + # For resample functions that don't yet support the + # particles_dim argument. + resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) else: - resampled_indices = resample_fn(log_probs, num_particles, (), - particles_dim=particles_dim, seed=seed) + resampled_indices = resample_fn(log_probs, num_particles, (), + particles_dim=particles_dim, seed=seed) gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda mcmc_util.index_remapping_gather(x, diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py index 00faa676db..ace87de1e1 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py @@ -300,7 +300,6 @@ def resample_with_target_distribution(self): 30., atol=1.) def test_with_target_distribution_dim_one(self): - particles = np.linspace(0., 500., num=2500, dtype=np.float32) stacked_particles = np.stack([ np.linspace(0., 500., num=2500, dtype=np.float32), np.linspace(0.17, 433., num=2500, dtype=np.float32),