diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 83e7b99f77..1bcbc870f4 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -19,6 +19,7 @@ from tensorflow_probability.python.experimental.mcmc import sequential_monte_carlo_kernel as smc_kernel from tensorflow_probability.python.experimental.mcmc import weighted_resampling from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import distribution_util as dist_util from tensorflow_probability.python.internal import docstring_util from tensorflow_probability.python.internal import loop_util from tensorflow_probability.python.internal import prefer_static as ps @@ -112,6 +113,9 @@ def _default_trace_fn(state, kernel_results): approximate continuous-time dynamics. The initial and final steps (steps `0` and `num_timesteps - 1`) are always observed. Default value: `None`. + particles_dim: `int` dimension that indexes the particles in the state of + this particle filter. + Default value: `0`. """ @@ -127,8 +131,8 @@ def infer_trajectories(observations, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, unbiased_gradients=True, - rejuvenation_kernel_fn=None, num_transitions_per_observation=1, + particles_dim=0, seed=None, name=None): # pylint: disable=g-doc-args """Use particle filtering to sample from the posterior over trajectories. @@ -247,32 +251,190 @@ def observation_fn(_, state): observation_fn=observation_fn, num_particles=num_particles, initial_state_proposal=initial_state_proposal, + particles_dim=particles_dim, proposal_fn=proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, unbiased_gradients=unbiased_gradients, - rejuvenation_kernel_fn=rejuvenation_kernel_fn, num_transitions_per_observation=num_transitions_per_observation, trace_fn=_default_trace_fn, trace_criterion_fn=lambda *_: True, seed=pf_seed, name=name) - weighted_trajectories = reconstruct_trajectories(particles, parent_indices) + weighted_trajectories = reconstruct_trajectories( + particles, + parent_indices, + particles_dim=particles_dim) # Resample all steps of the trajectories using the final weights. resample_indices = resample_fn(log_probs=log_weights[-1], event_size=num_particles, + particles_dim=particles_dim, sample_shape=(), seed=resample_seed) trajectories = tf.nest.map_structure( lambda x: mcmc_util.index_remapping_gather(x, # pylint: disable=g-long-lambda resample_indices, - axis=1), + axis=particles_dim + 1, + indices_axis=particles_dim), weighted_trajectories) return trajectories, incremental_log_marginal_likelihoods +def sequential_monte_carlo( + initial_weighted_particles, + propose_and_update_log_weights_fn, + num_steps, + resample_fn, + resample_criterion_fn, + trace_fn, + trace_criterion_fn, + particles_dim=0, + parallel_iterations=1, + unbiased_gradients=True, + static_trace_allocation_size=None, + seed=None, + name=None, +): + """Run Sequential Monte Carlo. + + Sequential Monte Carlo maintains a population of weighted particles + representing samples from a sequence of target distributions. + + Args: + initial_weighted_particles: The initial + `tfp.experimental.mcmc.WeightedParticles`. + propose_and_update_log_weights_fn: Python `callable` with signature + `new_weighted_particles = propose_and_update_log_weights_fn(step, + weighted_particles, seed=None)`. Its input is a + `tfp.experimental.mcmc.WeightedParticles` structure representing + weighted samples (with normalized weights) from the `step`th + target distribution, and it returns another such structure representing + unnormalized weighted samples from the next (`step + 1`th) target + distribution. This will typically include particles + sampled from a proposal distribution `q(x[step + 1] | x[step])`, and + weights that account for some or all of: the proposal density, + a transition density `p(x[step + 1] | x[step]), + observation weights `p(y[step + 1] | x[step + 1])`, and/or a backwards + or 'L'-kernel `L(x[step] | x[step + 1])`. The (log) normalization + constant of the weights is interpreted as the incremental (log) marginal + likelihood. + num_steps: Number of steps to run Sequential Monte Carlo. + resample_fn: Resampling scheme specified as a `callable` with signature + `indices = resample_fn(log_probs, event_size, sample_shape, seed)`, + where `log_probs` is a `Tensor` of the same shape as `state.log_weights` + containing a normalized log-probability for every current + particle, `event_size` is the number of new particle indices to + generate, `sample_shape` is the number of independent index sets to + return, and the return value `indices` is an `int` Tensor of shape + `concat([sample_shape, [event_size, B1, ..., BN])`. Typically one of + `tfp.experimental.mcmc.resample_deterministic_minimum_error`, + `tfp.experimental.mcmc.resample_independent`, + `tfp.experimental.mcmc.resample_stratified`, or + `tfp.experimental.mcmc.resample_systematic`. + Default value: `tfp.experimental.mcmc.resample_systematic`. + resample_criterion_fn: optional Python `callable` with signature + `do_resample = resample_criterion_fn(weighted_particles)`, + passed an instance of `tfp.experimental.mcmc.WeightedParticles`. The + return value `do_resample` + determines whether particles are resampled at the current step. The + default behavior is to resample particles when the effective + sample size falls below half of the total number of particles. + Default value: `tfp.experimental.mcmc.ess_below_threshold`. + trace_fn: Python `callable` defining the values to be traced at each step, + with signature `traced_values = trace_fn(weighted_particles, results)` + in which the first argument is an instance of + `tfp.experimental.mcmc.WeightedParticles` and the second an instance of + `SequentialMonteCarloResults` tuple, and the return value is a structure + of `Tensor`s. + Default value: `lambda s, r: (s.particles, s.log_weights, + r.parent_indices, r.incremental_log_marginal_likelihood)` + trace_criterion_fn: optional Python `callable` with signature + `trace_this_step = trace_criterion_fn(weighted_particles, results)` + taking the same arguments as `trace_fn` and returning a boolean `Tensor`. + If `None`, only values from the final step are returned. + Default value: `lambda *_: True` (trace every step). + particles_dim: `int` dimension that indexes the particles in the + `tfp.experimental.mcmc.WeightedParticles` structures on which this + function operates. + Default value: `0`. + parallel_iterations: Passed to the internal `tf.while_loop`. + Default value: `1`. + unbiased_gradients: If `True`, use the stop-gradient + resampling trick of Scibior, Masrani, and Wood [1] to correct for + gradient bias introduced by the discrete resampling step. This will + generally increase the variance of stochastic gradients. + Default value: `True`. + static_trace_allocation_size: Optional Python `int` size of trace to + allocate statically. This should be an upper bound on the number of steps + traced and is used only when the length cannot be + statically inferred (for example, if a `trace_criterion_fn` is + specified). + It is primarily intended for contexts where static shapes are required, + such as in XLA-compiled code. + Default value: `None`. + seed: PRNG seed; see `tfp.random.sanitize_seed` for details. + name: Python `str` name for ops created by this method. + Default value: `None` (i.e., `'particle_filter'`). + Returns: + traced_results: A structure of Tensors as returned by `trace_fn`. If + `trace_criterion_fn==None`, this is computed from the final step; + otherwise, each Tensor will have initial dimension `num_steps_traced` + and stacks the traced results across all steps. + + #### References + + [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle + Filtering without Modifying the Forward Pass. _arXiv preprint + arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 + """ + with tf.name_scope(name or 'sequential_monte_carlo'): + seed = samplers.sanitize_seed(seed) + + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + particles_dim=particles_dim, + unbiased_gradients=unbiased_gradients) + + # If trace criterion is `None`, we'll return only the final results. + never_trace = lambda *_: False + if trace_criterion_fn is None: + static_trace_allocation_size = 0 + trace_criterion_fn = never_trace + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + seed, state, results = seed_state_results + one_step_seed, next_seed = samplers.split_seed(seed) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + return next_seed, next_state, next_results + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(seed, + initial_weighted_particles, + kernel.bootstrap_results(initial_weighted_particles)), + elems=tf.ones([num_steps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) + + return traced_results + + @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -287,6 +449,7 @@ def particle_filter(observations, unbiased_gradients=True, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, + particles_dim=0, trace_fn=_default_trace_fn, trace_criterion_fn=_always_trace, static_trace_allocation_size=None, @@ -298,10 +461,10 @@ def particle_filter(observations, The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all - observations *up to that time*. Because particles may be resampled, a particle - at time `t` may be different from the particle with the same index at time - `t + 1`. To reconstruct trajectories by tracing back through the resampling - process, see `tfp.mcmc.experimental.reconstruct_trajectories`. + observations *up to that time*. Because particles may be resampled, a + particle at time `t` may be different from the particle with the same index + at time `t + 1`. To reconstruct trajectories by tracing back through the + resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step, @@ -348,61 +511,36 @@ def particle_filter(observations, num_timesteps = ( 1 + num_transitions_per_observation * (num_observation_steps - 1)) - # If trace criterion is `None`, we'll return only the final results. - never_trace = lambda *_: False - if trace_criterion_fn is None: - static_trace_allocation_size = 0 - trace_criterion_fn = never_trace - initial_weighted_particles = _particle_filter_initial_weighted_particles( observations=observations, observation_fn=observation_fn, initial_state_prior=initial_state_prior, initial_state_proposal=initial_state_proposal, num_particles=num_particles, + particles_dim=particles_dim, seed=init_seed) propose_and_update_log_weights_fn = ( _particle_filter_propose_and_update_log_weights_fn( observations=observations, transition_fn=transition_fn, + particles_dim=particles_dim, proposal_fn=proposal_fn, observation_fn=observation_fn, num_transitions_per_observation=num_transitions_per_observation)) - kernel = smc_kernel.SequentialMonteCarlo( + return sequential_monte_carlo( + initial_weighted_particles=initial_weighted_particles, + num_steps=num_timesteps, + parallel_iterations=parallel_iterations, + particles_dim=particles_dim, propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) - return next_seed, next_state, next_results - - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_weighted_particles, - kernel.bootstrap_results(initial_weighted_particles)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) - - return traced_results + trace_criterion_fn=trace_criterion_fn, + trace_fn=trace_fn, + unbiased_gradients=unbiased_gradients, + seed=loop_seed) def _particle_filter_initial_weighted_particles(observations, @@ -410,6 +548,7 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_prior, initial_state_proposal, num_particles, + particles_dim=0, seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. @@ -421,9 +560,18 @@ def _particle_filter_initial_weighted_particles(observations, initial_state = initial_state_proposal.sample(num_particles, seed=seed) initial_log_weights = (initial_state_prior.log_prob(initial_state) - initial_state_proposal.log_prob(initial_state)) + if particles_dim != 0: + initial_state = tf.nest.map_structure( + lambda x: dist_util.move_dimension( + x, source_idx=0, dest_idx=particles_dim), + initial_state) + initial_log_weights = dist_util.move_dimension( + initial_log_weights, source_idx=0, dest_idx=particles_dim) + # Normalize the initial weights. If we used a proposal, the weights are # normalized in expectation, but actually normalizing them reduces variance. - initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) + initial_log_weights = tf.nn.log_softmax(initial_log_weights, + axis=particles_dim) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( @@ -432,7 +580,8 @@ def _particle_filter_initial_weighted_particles(observations, step=0, particles=initial_state, observations=observations, - observation_fn=observation_fn)) + observation_fn=observation_fn, + particles_dim=particles_dim)) def _particle_filter_propose_and_update_log_weights_fn( @@ -440,7 +589,8 @@ def _particle_filter_propose_and_update_log_weights_fn( transition_fn, proposal_fn, observation_fn, - num_transitions_per_observation=1): + num_transitions_per_observation=1, + particles_dim=0): """Build a function specifying a particle filter update step.""" def propose_and_update_log_weights_fn(step, state, seed=None): particles, log_weights = state.particles, state.log_weights @@ -465,7 +615,7 @@ def propose_and_update_log_weights_fn(step, state, seed=None): # likelihood of a model with no observations is constant # (equal to 1.), so the transition and proposal distributions shouldn't # affect it. - log_weights = tf.nn.log_softmax(log_weights, axis=0) + log_weights = tf.nn.log_softmax(log_weights, axis=particles_dim) else: proposed_particles = transition_dist.sample(seed=seed) @@ -474,7 +624,8 @@ def propose_and_update_log_weights_fn(step, state, seed=None): particles=proposed_particles, log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) + num_transitions_per_observation=num_transitions_per_observation, + particles_dim=particles_dim)) return propose_and_update_log_weights_fn @@ -482,7 +633,8 @@ def _compute_observation_log_weights(step, particles, observations, observation_fn, - num_transitions_per_observation=1): + num_transitions_per_observation=1, + particles_dim=0): """Computes particle importance weights from an observation step. Args: @@ -502,6 +654,8 @@ def _compute_observation_log_weights(step, num_transitions_per_observation: optional int `Tensor` number of times to apply the transition model between successive observation steps. Default value: `1`. + particles_dim: `int` dimension that indexes the particles in `particles`. + Default value: `0`. Returns: log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`. """ @@ -516,20 +670,26 @@ def _compute_observation_log_weights(step, observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) + observation = tf.nest.map_structure( + lambda x: tf.expand_dims(x, axis=particles_dim), observation) + log_weights = observation_fn(step, particles).log_prob(observation) return tf.where(step_has_observation, log_weights, tf.zeros_like(log_weights)) -def reconstruct_trajectories(particles, parent_indices, name=None): +def reconstruct_trajectories(particles, + parent_indices, + particles_dim=0, + name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): # Walk backwards to compute the ancestor of each final particle at time t. final_indices = smc_kernel._dummy_indices_like(parent_indices[-1]) # pylint: disable=protected-access ancestor_indices = tf.scan( fn=lambda ancestor, parent: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda - parent, ancestor, axis=0), + parent, ancestor, axis=particles_dim, indices_axis=particles_dim), elems=parent_indices[1:], initializer=final_indices, reverse=True) @@ -537,7 +697,10 @@ def reconstruct_trajectories(particles, parent_indices, name=None): return tf.nest.map_structure( lambda part: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda - part, ancestor_indices, axis=1, indices_axis=1), + part, + ancestor_indices, + axis=particles_dim + 1, + indices_axis=particles_dim + 1), particles) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 5e0628dd02..6508eb6231 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -177,6 +177,128 @@ def observation_fn(_, state): self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) + def test_batch_of_filters_particles_dim_1(self): + + batch_shape = [3, 2] + num_particles = 1000 + num_timesteps = 40 + + # Batch of priors on object 1D positions and velocities. + initial_state_prior = jdn.JointDistributionNamed({ + 'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)), + 'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1) + }) + + def transition_fn(_, previous_state): + return jdn.JointDistributionNamed({ + 'position': + normal.Normal( + loc=previous_state['position'] + previous_state['velocity'], + scale=0.1), + 'velocity': + normal.Normal(loc=previous_state['velocity'], scale=0.01) + }) + + def observation_fn(_, state): + return normal.Normal(loc=state['position'], scale=0.1) + + # Batch of synthetic observations, . + true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) + true_velocities = 0.1 * np.random.randn( + *batch_shape).astype(self.dtype) + observed_positions = ( + true_velocities * + np.arange(num_timesteps).astype( + self.dtype)[..., tf.newaxis, tf.newaxis] + + true_initial_positions) + + (particles, log_weights, parent_indices, + incremental_log_marginal_likelihoods) = self.evaluate( + particle_filter.particle_filter( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + seed=test_util.test_seed(), + particles_dim=1)) + + self.assertAllEqual(particles['position'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(particles['velocity'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(parent_indices.shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + + self.assertAllClose( + self.evaluate( + tf.reduce_sum(tf.exp(log_weights) * + particles['position'], axis=2)), + observed_positions, + atol=0.3) + + velocity_means = tf.reduce_sum(tf.exp(log_weights) * + particles['velocity'], axis=2) + + self.assertAllClose( + self.evaluate(tf.reduce_mean(velocity_means, axis=0)), + true_velocities, atol=0.05) + + # Uncertainty in velocity should decrease over time. + velocity_stddev = self.evaluate( + tf.math.reduce_std(particles['velocity'], axis=2)) + self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) + + trajectories = self.evaluate( + particle_filter.reconstruct_trajectories(particles, + parent_indices, + particles_dim=1)) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + + # Verify that `infer_trajectories` also works on batches. + trajectories, incremental_log_marginal_likelihoods = self.evaluate( + particle_filter.infer_trajectories( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + particles_dim=1, + seed=test_util.test_seed())) + + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + def test_reconstruct_trajectories_toy_example(self): particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]]) # 1 -- 4 -- 7 diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 6cb9b65003..73cb0f8414 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc import kernel as kernel_base +from tensorflow_probability.python.mcmc.internal import util as mcmc_util __all__ = [ 'SequentialMonteCarlo', @@ -115,19 +116,21 @@ def _dummy_indices_like(indices): indices_shape) -def log_ess_from_log_weights(log_weights): - """Computes log-ESS estimate from log-weights along axis=0.""" +def log_ess_from_log_weights(log_weights, particles_dim=0): + """Computes log-ESS estimate from log-weights along axis=particles_dim.""" with tf.name_scope('ess_from_log_weights'): - log_weights = tf.math.log_softmax(log_weights, axis=0) - return -tf.math.reduce_logsumexp(2 * log_weights, axis=0) + log_weights = tf.math.log_softmax(log_weights, axis=particles_dim) + return -tf.math.reduce_logsumexp(2 * log_weights, axis=particles_dim) -def ess_below_threshold(weighted_particles, threshold=0.5): +def ess_below_threshold(weighted_particles, particles_dim=0, threshold=0.5): """Determines if the effective sample size is much less than num_particles.""" with tf.name_scope('ess_below_threshold'): num_particles = ps.size0(weighted_particles.log_weights) - log_ess = log_ess_from_log_weights(weighted_particles.log_weights) - return log_ess < (ps.log(num_particles) + ps.log(threshold)) + log_ess = log_ess_from_log_weights( + weighted_particles.log_weights, particles_dim=particles_dim) + return tf.expand_dims(log_ess < (ps.log(num_particles) + ps.log(threshold)), + axis=particles_dim) class SequentialMonteCarlo(kernel_base.TransitionKernel): @@ -145,6 +148,7 @@ def __init__(self, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, unbiased_gradients=True, + particles_dim=0, name=None): """Initializes a sequential Monte Carlo transition kernel. @@ -190,6 +194,10 @@ def __init__(self, correct for gradient bias introduced by the discrete resampling step. This will generally increase the variance of stochastic gradients. Default value: `True`. + particles_dim: `int` dimension that indexes the particles in the + `tfp.experimental.mcmc.WeightedParticles` structures on which this + kernel operates. + Default value: `0`. name: Python `str` name for ops created by this kernel. #### References @@ -202,6 +210,7 @@ def __init__(self, self._resample_fn = resample_fn self._resample_criterion_fn = resample_criterion_fn self._unbiased_gradients = unbiased_gradients + self._particles_dim = particles_dim self._name = name or 'SequentialMonteCarlo' @property @@ -220,13 +229,17 @@ def propose_and_update_log_weights_fn(self): def resample_criterion_fn(self): return self._resample_criterion_fn + @property + def resample_fn(self): + return self._resample_fn + @property def unbiased_gradients(self): return self._unbiased_gradients @property - def resample_fn(self): - return self._resample_fn + def particles_dim(self): + return self._particles_dim def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. @@ -269,15 +282,17 @@ def one_step(self, state, kernel_results, seed=None): state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) - normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) + normalized_log_weights = tf.nn.log_softmax(state.log_weights, + axis=self.particles_dim) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. - incremental_log_marginal_likelihood = (state.log_weights[0] - - normalized_log_weights[0]) - - do_resample = self.resample_criterion_fn(state) + incremental_log_marginal_likelihood = ( + tf.gather(state.log_weights, 0, axis=self.particles_dim) + - tf.gather(normalized_log_weights, 0, axis=self.particles_dim)) + do_resample = self.resample_criterion_fn( + state, particles_dim=self.particles_dim) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to @@ -300,11 +315,12 @@ def one_step(self, state, kernel_results, seed=None): resample_fn=self.resample_fn, target_log_weights=(normalized_log_weights if self.unbiased_gradients else None), + particles_dim=self.particles_dim, seed=resample_seed) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( - lambda r, p: tf.where(do_resample, r, p), + lambda r, p: mcmc_util.choose(do_resample, r, p), (resampled_particles, resample_indices, weights_after_resampling), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) @@ -326,9 +342,13 @@ def bootstrap_results(self, init_state): with tf.name_scope('bootstrap_results'): init_state = WeightedParticles(*init_state) + particles_shape = ps.shape(init_state.log_weights) + weights_shape = ps.concat([ + particles_shape[:self.particles_dim], + particles_shape[self.particles_dim+1:] + ], axis=0) batch_zeros = tf.zeros( - ps.shape(init_state.log_weights)[1:], - dtype=init_state.log_weights.dtype) + weights_shape, dtype=init_state.log_weights.dtype) return SequentialMonteCarloResults( steps=0, diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py index 9632d916f5..613419cf46 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py @@ -34,7 +34,7 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, - seed=None): + particles_dim=0, seed=None): """Resamples the current particles according to provided weights. Args: @@ -54,6 +54,10 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, `None`, the target measure is implicitly taken to be the normalized log weights (`log_weights - tf.reduce_logsumexp(log_weights, axis=0)`). Default value: `None`. + particles_dim: Python `int` axis of each state `Tensor` indexing into the + particles. This is almost always zero, but nonzero values may be necessary + when running SMC in nested contexts. + Default value: `0`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: @@ -69,15 +73,25 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, resampling are uniformly equal to `-log(num_particles)`. """ with tf.name_scope('resample'): - num_particles = ps.size0(log_weights) + num_particles = ps.dimension_size(log_weights, particles_dim) + log_num_particles = tf.math.log(tf.cast(num_particles, log_weights.dtype)) # Normalize the weights and sample the ancestral indices. - log_probs = tf.math.log_softmax(log_weights, axis=0) - resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + log_probs = tf.math.log_softmax(log_weights, axis=particles_dim) + if particles_dim == 0: + # For resample functions that don't yet support the + # particles_dim argument. + resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + else: + resampled_indices = resample_fn(log_probs, num_particles, (), + particles_dim=particles_dim, seed=seed) gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda - mcmc_util.index_remapping_gather(x, resampled_indices, axis=0)) + mcmc_util.index_remapping_gather(x, + resampled_indices, + axis=particles_dim, + indices_axis=particles_dim)) resampled_particles = tf.nest.map_structure(gather_ancestors, particles) if target_log_weights is None: log_weights_after_resampling = tf.fill(ps.shape(log_weights), @@ -242,7 +256,7 @@ def resample_independent(log_probs, event_size, sample_shape, # TODO(b/153689734): rewrite so as not to use `move_dimension`. def resample_systematic(log_probs, event_size, sample_shape, - seed=None, name=None): + particles_dim=0, seed=None, name=None): """A systematic resampler for sequential Monte Carlo. The value returned from this function is similar to sampling with @@ -272,6 +286,9 @@ def resample_systematic(log_probs, event_size, sample_shape, The remaining dimensions are batch dimensions. event_size: the dimension of the vector considered a single draw. sample_shape: the `sample_shape` determining the number of draws. + particles_dim: Python `int` axis of each state `Tensor` indexing into the + particles. This is almost always zero, but nonzero values may be necessary + when running SMC in nested contexts. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Default value: None (i.e. no seed). name: Python `str` name for ops created by this method. @@ -293,7 +310,9 @@ def resample_systematic(log_probs, event_size, sample_shape, """ with tf.name_scope(name or 'resample_systematic') as name: log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32) - log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1) + log_probs = dist_util.move_dimension(log_probs, + source_idx=particles_dim, + dest_idx=-1) working_shape = ps.concat([sample_shape, ps.shape(log_probs)[:-1]], axis=0) points_shape = ps.concat([working_shape, [event_size]], axis=0) @@ -310,7 +329,9 @@ def resample_systematic(log_probs, event_size, sample_shape, log_points = tf.broadcast_to(tf.math.log(even_spacing), points_shape) resampled = _resample_using_log_points(log_probs, sample_shape, log_points) - return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0) + return dist_util.move_dimension(resampled, + source_idx=-1, + dest_idx=particles_dim) # TODO(b/153689734): rewrite so as not to use `move_dimension`. diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py index e415b4c99e..ace87de1e1 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py @@ -299,6 +299,48 @@ def resample_with_target_distribution(self): tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles), 30., atol=1.) + def test_with_target_distribution_dim_one(self): + stacked_particles = np.stack([ + np.linspace(0., 500., num=2500, dtype=np.float32), + np.linspace(0.17, 433., num=2500, dtype=np.float32), + ], axis=0) + + stacked_log_weights = poisson.Poisson(20.).log_prob(stacked_particles) + + # Resample particles to target a Poisson(20.) distribution. + new_particles, _, new_log_weights = resample( + stacked_particles, stacked_log_weights, + resample_fn=resample_systematic, + particles_dim=1, + seed=test_util.test_seed(sampler_type='stateless')) + + self.assertAllMeansClose(new_particles, + [20., 20.], + axis=1, + atol=1e-2) + self.assertAllClose( + tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1), + [20., 20.], + atol=1e-2) + + # Reweight the resampled particles to target a Poisson(30.) distribution. + new_particles, _, new_log_weights = resample( + stacked_particles, + stacked_log_weights, + resample_fn=resample_systematic, + particles_dim=1, + target_log_weights=poisson.Poisson(30).log_prob(stacked_particles), + seed=test_util.test_seed(sampler_type='stateless')) + self.assertAllMeansClose(new_particles, + [20., 20.], + axis=1, + atol=1e-2) + + self.assertAllClose( + tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1), + [30., 30.], + atol=1.5) + def maybe_compiler(self, f): if self.use_xla: return tf.function(f, autograph=False, jit_compile=True)