From 66b2af49110b8d85dbd1825b781d404fd7a2ea89 Mon Sep 17 00:00:00 2001 From: aleslamitz <109731102+aleslamitz@users.noreply.github.com> Date: Sun, 24 Jul 2022 13:13:41 +0200 Subject: [PATCH 1/7] Added sequential_monte_carlo in particle_filter --- .../experimental/mcmc/particle_filter.py | 207 ++++++++---------- 1 file changed, 90 insertions(+), 117 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 83e7b99f77..8aba1af5a0 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -354,130 +354,103 @@ def particle_filter(observations, static_trace_allocation_size = 0 trace_criterion_fn = never_trace - initial_weighted_particles = _particle_filter_initial_weighted_particles( - observations=observations, - observation_fn=observation_fn, - initial_state_prior=initial_state_prior, - initial_state_proposal=initial_state_proposal, - num_particles=num_particles, - seed=init_seed) - propose_and_update_log_weights_fn = ( - _particle_filter_propose_and_update_log_weights_fn( - observations=observations, - transition_fn=transition_fn, - proposal_fn=proposal_fn, - observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) - - kernel = smc_kernel.SequentialMonteCarlo( + if initial_state_proposal is None: + initial_state_proposal = initial_state_prior + + # Initializing particles and weights + initial_particles, initial_log_weights = \ + initial_state_proposal.experimental_sample_and_log_prob( + [num_particles], seed=seed + ) + initial_log_weights = ( + initial_state_prior.log_prob(initial_particles) - + initial_state_proposal.log_prob(initial_particles) + ) + + if proposal_fn is None: + proposal_fn = transition_fn + + def propose_and_update_log_weights_fn(step, state, seed): + """Particle filter propose and update for single steps""" + particles, log_weights = ( + proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) + ) + + log_weights = ( + observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) + + transition_fn(step, state.particles).log_prob(particles) + - log_weights) + return smc_kernel.WeightedParticles(particles, log_weights) + + + def sequential_monte_carlo( + initial_state, + propose_and_update_log_weights_fn, + condition_fn, + resample_fn, + resample_criterion_fn + ): + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + unbiased_gradients=unbiased_gradients) + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + seed, state, results = seed_state_results + one_step_seed, next_seed = samplers.split_seed(seed) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + return next_seed, next_state, next_results + + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_state, + kernel.bootstrap_results(initial_state)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + condition_fn=condition_fn + ) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) + + return traced_results + + + traced_results = sequential_monte_carlo( + # Weighted particles weighted also considering the observations. + initial_state=smc_kernel.WeightedParticles( + particles=initial_particles, + log_weights=initial_log_weights + _compute_observation_log_weights( + step=0, + particles=initial_particles, + observations=observations, + observation_fn=observation_fn) + ), propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + condition_fn=lambda step, state, num_traced, trace: step < len(observations), resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) - return next_seed, next_state, next_results - - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_weighted_particles, - kernel.bootstrap_results(initial_weighted_particles)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) + # rejuvenation_fn=..., + # rejuvenation_criterion_fn=..., + # trace_fn=..., + ) return traced_results -def _particle_filter_initial_weighted_particles(observations, - observation_fn, - initial_state_prior, - initial_state_proposal, - num_particles, - seed=None): - """Initialize a set of weighted particles including the first observation.""" - # Propose an initial state. - if initial_state_proposal is None: - initial_state = initial_state_prior.sample(num_particles, seed=seed) - initial_log_weights = ps.zeros_like( - initial_state_prior.log_prob(initial_state)) - else: - initial_state = initial_state_proposal.sample(num_particles, seed=seed) - initial_log_weights = (initial_state_prior.log_prob(initial_state) - - initial_state_proposal.log_prob(initial_state)) - # Normalize the initial weights. If we used a proposal, the weights are - # normalized in expectation, but actually normalizing them reduces variance. - initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) - - # Return particles weighted by the initial observation. - return smc_kernel.WeightedParticles( - particles=initial_state, - log_weights=initial_log_weights + _compute_observation_log_weights( - step=0, - particles=initial_state, - observations=observations, - observation_fn=observation_fn)) - - -def _particle_filter_propose_and_update_log_weights_fn( - observations, - transition_fn, - proposal_fn, - observation_fn, - num_transitions_per_observation=1): - """Build a function specifying a particle filter update step.""" - def propose_and_update_log_weights_fn(step, state, seed=None): - particles, log_weights = state.particles, state.log_weights - transition_dist = transition_fn(step, particles) - assertions = _assert_batch_shape_matches_weights( - distribution=transition_dist, - weights_shape=ps.shape(log_weights), - diststr='transition') - - if proposal_fn: - proposal_dist = proposal_fn(step, particles) - assertions += _assert_batch_shape_matches_weights( - distribution=proposal_dist, - weights_shape=ps.shape(log_weights), - diststr='proposal') - proposed_particles = proposal_dist.sample(seed=seed) - - log_weights += (transition_dist.log_prob(proposed_particles) - - proposal_dist.log_prob(proposed_particles)) - # The normalizing constant E~q[p(x)/q(x)] is 1 in expectation, - # so we reduce variance by dividing it out. Intuitively: the marginal - # likelihood of a model with no observations is constant - # (equal to 1.), so the transition and proposal distributions shouldn't - # affect it. - log_weights = tf.nn.log_softmax(log_weights, axis=0) - else: - proposed_particles = transition_dist.sample(seed=seed) - - with tf.control_dependencies(assertions): - return smc_kernel.WeightedParticles( - particles=proposed_particles, - log_weights=log_weights + _compute_observation_log_weights( - step + 1, proposed_particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) - return propose_and_update_log_weights_fn - - def _compute_observation_log_weights(step, particles, observations, From 491e770f934a143a80d4f5c1b1166017db574f95 Mon Sep 17 00:00:00 2001 From: aleslamitz <109731102+aleslamitz@users.noreply.github.com> Date: Thu, 8 Sep 2022 15:37:54 +0200 Subject: [PATCH 2/7] Refactor of Particle Filter --- .../experimental/mcmc/particle_filter.py | 269 ++++++++++++------ 1 file changed, 189 insertions(+), 80 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 8aba1af5a0..8136fc4072 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -273,6 +273,62 @@ def observation_fn(_, state): return trajectories, incremental_log_marginal_likelihoods +def sequential_monte_carlo( + initial_state, + propose_and_update_log_weights_fn, + condition_fn, + resample_fn, + resample_criterion_fn, + num_timesteps, + unbiased_gradients=True, + trace_fn=_default_trace_fn, + trace_criterion_fn=_always_trace, + static_trace_allocation_size=None, + parallel_iterations=1, + seed=None, + name=None +): + init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') + # If trace criterion is `None`, we'll return only the final results. + never_trace = lambda *_: False + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + unbiased_gradients=unbiased_gradients) + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + seed, state, results = seed_state_results + one_step_seed, next_seed = samplers.split_seed(seed) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + return next_seed, next_state, next_results + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_state, + kernel.bootstrap_results(initial_state)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + condition_fn=condition_fn + ) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) + + return traced_results + + @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -341,8 +397,6 @@ def particle_filter(observations, Filtering without Modifying the Forward Pass. _arXiv preprint arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 """ - - init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = ( @@ -354,19 +408,121 @@ def particle_filter(observations, static_trace_allocation_size = 0 trace_criterion_fn = never_trace - if initial_state_proposal is None: - initial_state_proposal = initial_state_prior + initial_particles = _particle_filter_initial_weighted_particles( + observations, + observation_fn, + initial_state_prior, + initial_state_proposal, + num_particles, + seed=None + ) + + propose_and_update_log_weights_fn = ( + _particle_filter_propose_and_update_log_weights_fn( + observations=observations, + transition_fn=transition_fn, + proposal_fn=proposal_fn, + observation_fn=observation_fn, + num_transitions_per_observation=num_transitions_per_observation)) - # Initializing particles and weights - initial_particles, initial_log_weights = \ - initial_state_proposal.experimental_sample_and_log_prob( - [num_particles], seed=seed - ) - initial_log_weights = ( - initial_state_prior.log_prob(initial_particles) - - initial_state_proposal.log_prob(initial_particles) + traced_results = sequential_monte_carlo( + # Weighted particles weighted also considering the observations. + initial_state=initial_particles, + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + # rejuvenation_fn=..., + # rejuvenation_criterion_fn=..., + trace_fn=trace_fn, + num_timesteps=num_timesteps ) + return traced_results + + +def _particle_filter_initial_weighted_particles(observations, + observation_fn, + initial_state_prior, + initial_state_proposal, + num_particles, + seed=None): + """Initialize a set of weighted particles including the first observation.""" + # Propose an initial state. + if initial_state_proposal is None: + initial_particles = initial_state_prior.sample(num_particles, seed=seed) + initial_log_weights = ps.zeros_like( + initial_state_prior.log_prob(initial_particles) + ) + else: + initial_particles = initial_state_proposal.sample(num_partiles, seed=seed) + initial_log_weights = ( + initial_state_prior.log_prob(initial_particles) - + initial_state_proposal.log_prob(initial_particles) + ) + + # Normalize the initial weights. If we used a proposal, the weights are + # normalized in expectation, but actually normalizing them reduces variance. + initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) + + # Return particles weighted by the initial observation. + return smc_kernel.WeightedParticles( + particles=initial_particles, + log_weights=initial_log_weights + _compute_observation_log_weights( + step=0, + particles=initial_particles, + observations=observations, + observation_fn=observation_fn)) + + +def _propose_and_update_log_weights_fn( + observations, + transition_fn, + proposal_fn, + observation_fn, + step, + state, + seed, + num_transitions_per_observation=1): + """Particle filter propose and update for single steps""" + particles, log_weights = ( + proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) + ) + + assertions = _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='transition' + ) + if proposal_fn(step, particles) != transition_fn(step, particles): + assertions += _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='proposal' + ) + + log_weights = ( + observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) + + transition_fn(step, state.particles).log_prob(particles) + - log_weights) + + log_weights = tf.nn.log_softmax(log_weights, axis=0) + with tf.control_dependencies(assertions): + return smc_kernel.WeightedParticles( + particles=particles, + log_weights=log_weights + _compute_observation_log_weights( + step + 1, particles, observations, observation_fn, + num_transitions_per_observation=num_transitions_per_observation) + ) + + +def _particle_filter_propose_and_update_log_weights_fn( + observations, + transition_fn, + proposal_fn, + observation_fn, + num_transitions_per_observation=1): + """Build a function specifying a particle filter update step.""" if proposal_fn is None: proposal_fn = transition_fn @@ -376,79 +532,32 @@ def propose_and_update_log_weights_fn(step, state, seed): proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) ) + assertions = _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='transition' + ) + if proposal_fn(step, particles) != transition_fn(step, particles): + assertions += _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='proposal' + ) + log_weights = ( observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) + transition_fn(step, state.particles).log_prob(particles) - log_weights) - return smc_kernel.WeightedParticles(particles, log_weights) - - - def sequential_monte_carlo( - initial_state, - propose_and_update_log_weights_fn, - condition_fn, - resample_fn, - resample_criterion_fn - ): - kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) - return next_seed, next_state, next_results - - - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_state, - kernel.bootstrap_results(initial_state)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations, - condition_fn=condition_fn - ) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) - return traced_results - - - traced_results = sequential_monte_carlo( - # Weighted particles weighted also considering the observations. - initial_state=smc_kernel.WeightedParticles( - particles=initial_particles, - log_weights=initial_log_weights + _compute_observation_log_weights( - step=0, - particles=initial_particles, - observations=observations, - observation_fn=observation_fn) - ), - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - condition_fn=lambda step, state, num_traced, trace: step < len(observations), - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - # rejuvenation_fn=..., - # rejuvenation_criterion_fn=..., - # trace_fn=..., - ) - - return traced_results + log_weights = tf.nn.log_softmax(log_weights, axis=0) + with tf.control_dependencies(assertions): + return smc_kernel.WeightedParticles( + particles=particles, + log_weights=log_weights + _compute_observation_log_weights( + step + 1, particles, observations, observation_fn, + num_transitions_per_observation=num_transitions_per_observation) + ) + return propose_and_update_log_weights_fn def _compute_observation_log_weights(step, From 317173d5cb318618b2c68646f4838c73aa0b5ffe Mon Sep 17 00:00:00 2001 From: aleslamitz <109731102+aleslamitz@users.noreply.github.com> Date: Thu, 8 Sep 2022 18:26:19 +0200 Subject: [PATCH 3/7] Update particle_filter.py --- .../experimental/mcmc/particle_filter.py | 265 +++++++++--------- 1 file changed, 136 insertions(+), 129 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 8136fc4072..f237fb55b5 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -273,60 +273,61 @@ def observation_fn(_, state): return trajectories, incremental_log_marginal_likelihoods -def sequential_monte_carlo( - initial_state, - propose_and_update_log_weights_fn, - condition_fn, - resample_fn, - resample_criterion_fn, - num_timesteps, - unbiased_gradients=True, - trace_fn=_default_trace_fn, - trace_criterion_fn=_always_trace, - static_trace_allocation_size=None, - parallel_iterations=1, - seed=None, - name=None -): - init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') - # If trace criterion is `None`, we'll return only the final results. - never_trace = lambda *_: False - kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) - return next_seed, next_state, next_results - - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_state, - kernel.bootstrap_results(initial_state)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations, - condition_fn=condition_fn +def sequential_monte_carlo(initial_state, + propose_and_update_log_weights_fn, + condition_fn, + resample_fn, + resample_criterion_fn, + num_timesteps, + unbiased_gradients=True, + trace_fn=_default_trace_fn, + trace_criterion_fn=_always_trace, + static_trace_allocation_size=None, + parallel_iterations=1, + seed=None, + name=None + ): + init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') + # If trace criterion is `None`, we'll return only the final results. + never_trace = lambda *_: False + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + unbiased_gradients=unbiased_gradients + ) + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + seed, state, results = seed_state_results + one_step_seed, next_seed = samplers.split_seed(seed) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed ) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) - - return traced_results + return next_seed, next_state, next_results + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_state, + kernel.bootstrap_results(initial_state)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:]) + ), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations, + condition_fn=condition_fn + ) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) + + return traced_results @docstring_util.expand_docstring( @@ -423,19 +424,20 @@ def particle_filter(observations, transition_fn=transition_fn, proposal_fn=proposal_fn, observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) + num_transitions_per_observation=num_transitions_per_observation) + ) traced_results = sequential_monte_carlo( - # Weighted particles weighted also considering the observations. - initial_state=initial_particles, - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - # rejuvenation_fn=..., - # rejuvenation_criterion_fn=..., - trace_fn=trace_fn, - num_timesteps=num_timesteps + # Weighted particles weighted also considering the observations. + initial_state=initial_particles, + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + # rejuvenation_fn=..., + # rejuvenation_criterion_fn=..., + trace_fn=trace_fn, + num_timesteps=num_timesteps ) return traced_results @@ -450,16 +452,16 @@ def _particle_filter_initial_weighted_particles(observations, """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. if initial_state_proposal is None: - initial_particles = initial_state_prior.sample(num_particles, seed=seed) - initial_log_weights = ps.zeros_like( - initial_state_prior.log_prob(initial_particles) - ) + initial_particles = initial_state_prior.sample(num_particles, seed=seed) + initial_log_weights = ps.zeros_like( + initial_state_prior.log_prob(initial_particles) + ) else: - initial_particles = initial_state_proposal.sample(num_partiles, seed=seed) - initial_log_weights = ( - initial_state_prior.log_prob(initial_particles) - - initial_state_proposal.log_prob(initial_particles) - ) + initial_particles = initial_state_proposal.sample(num_partiles, seed=seed) + initial_log_weights = ( + initial_state_prior.log_prob(initial_particles) - + initial_state_proposal.log_prob(initial_particles) + ) # Normalize the initial weights. If we used a proposal, the weights are # normalized in expectation, but actually normalizing them reduces variance. @@ -469,10 +471,12 @@ def _particle_filter_initial_weighted_particles(observations, return smc_kernel.WeightedParticles( particles=initial_particles, log_weights=initial_log_weights + _compute_observation_log_weights( - step=0, - particles=initial_particles, - observations=observations, - observation_fn=observation_fn)) + step=0, + particles=initial_particles, + observations=observations, + observation_fn=observation_fn + ) + ) def _propose_and_update_log_weights_fn( @@ -484,6 +488,50 @@ def _propose_and_update_log_weights_fn( state, seed, num_transitions_per_observation=1): + """Particle filter propose and update for single steps""" + particles, log_weights = ( + proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) + ) + + assertions = _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='transition' + ) + if proposal_fn(step, particles) != transition_fn(step, particles): + assertions += _assert_batch_shape_matches_weights( + distribution=transition_fn(step, particles), + weights_shape=ps.shape(log_weights), + diststr='proposal' + ) + + log_weights = ( + observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) + + transition_fn(step, state.particles).log_prob(particles) + - log_weights + ) + + log_weights = tf.nn.log_softmax(log_weights, axis=0) + with tf.control_dependencies(assertions): + return smc_kernel.WeightedParticles( + particles=particles, + log_weights=log_weights + _compute_observation_log_weights( + step + 1, particles, observations, observation_fn, + num_transitions_per_observation=num_transitions_per_observation) + ) + + +def _particle_filter_propose_and_update_log_weights_fn( + observations, + transition_fn, + proposal_fn, + observation_fn, + num_transitions_per_observation=1): + """Build a function specifying a particle filter update step.""" + if proposal_fn is None: + proposal_fn = transition_fn + + def propose_and_update_log_weights_fn(step, state, seed): """Particle filter propose and update for single steps""" particles, log_weights = ( proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) @@ -502,62 +550,21 @@ def _propose_and_update_log_weights_fn( ) log_weights = ( - observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) - + transition_fn(step, state.particles).log_prob(particles) - - log_weights) + observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) + + transition_fn(step, state.particles).log_prob(particles) + - log_weights + ) log_weights = tf.nn.log_softmax(log_weights, axis=0) with tf.control_dependencies(assertions): - return smc_kernel.WeightedParticles( - particles=particles, - log_weights=log_weights + _compute_observation_log_weights( - step + 1, particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation) - ) - - -def _particle_filter_propose_and_update_log_weights_fn( - observations, - transition_fn, - proposal_fn, - observation_fn, - num_transitions_per_observation=1): - """Build a function specifying a particle filter update step.""" - if proposal_fn is None: - proposal_fn = transition_fn - - def propose_and_update_log_weights_fn(step, state, seed): - """Particle filter propose and update for single steps""" - particles, log_weights = ( - proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) + return smc_kernel.WeightedParticles( + particles=particles, + log_weights=log_weights + _compute_observation_log_weights( + step + 1, particles, observations, observation_fn, + num_transitions_per_observation=num_transitions_per_observation ) - - assertions = _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), - weights_shape=ps.shape(log_weights), - diststr='transition' - ) - if proposal_fn(step, particles) != transition_fn(step, particles): - assertions += _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), - weights_shape=ps.shape(log_weights), - diststr='proposal' - ) - - log_weights = ( - observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) - + transition_fn(step, state.particles).log_prob(particles) - - log_weights) - - log_weights = tf.nn.log_softmax(log_weights, axis=0) - with tf.control_dependencies(assertions): - return smc_kernel.WeightedParticles( - particles=particles, - log_weights=log_weights + _compute_observation_log_weights( - step + 1, particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation) - ) - return propose_and_update_log_weights_fn + ) + return propose_and_update_log_weights_fn def _compute_observation_log_weights(step, From f38d7d16310bfa9d916a0f846050c3bf6fcfa484 Mon Sep 17 00:00:00 2001 From: aleslamitz <109731102+aleslamitz@users.noreply.github.com> Date: Thu, 8 Sep 2022 18:34:42 +0200 Subject: [PATCH 4/7] Pylint Fixed formatting errors --- .../experimental/mcmc/particle_filter.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index f237fb55b5..8cd19130ba 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -428,16 +428,16 @@ def particle_filter(observations, ) traced_results = sequential_monte_carlo( - # Weighted particles weighted also considering the observations. - initial_state=initial_particles, - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - # rejuvenation_fn=..., - # rejuvenation_criterion_fn=..., - trace_fn=trace_fn, - num_timesteps=num_timesteps + # Weighted particles weighted also considering the observations. + initial_state=initial_particles, + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + # rejuvenation_fn=..., + # rejuvenation_criterion_fn=..., + trace_fn=trace_fn, + num_timesteps=num_timesteps ) return traced_results @@ -487,7 +487,8 @@ def _propose_and_update_log_weights_fn( step, state, seed, - num_transitions_per_observation=1): + num_transitions_per_observation=1 +): """Particle filter propose and update for single steps""" particles, log_weights = ( proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) From b0c21f2f545e286a52b1e82aa5423eb555aafd3f Mon Sep 17 00:00:00 2001 From: aleslamitz <109731102+aleslamitz@users.noreply.github.com> Date: Mon, 7 Nov 2022 14:52:21 +0100 Subject: [PATCH 5/7] reset --- .../experimental/mcmc/particle_filter.py | 254 ++++++------------ 1 file changed, 82 insertions(+), 172 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 8cd19130ba..83e7b99f77 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -273,63 +273,6 @@ def observation_fn(_, state): return trajectories, incremental_log_marginal_likelihoods -def sequential_monte_carlo(initial_state, - propose_and_update_log_weights_fn, - condition_fn, - resample_fn, - resample_criterion_fn, - num_timesteps, - unbiased_gradients=True, - trace_fn=_default_trace_fn, - trace_criterion_fn=_always_trace, - static_trace_allocation_size=None, - parallel_iterations=1, - seed=None, - name=None - ): - init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') - # If trace criterion is `None`, we'll return only the final results. - never_trace = lambda *_: False - kernel = smc_kernel.SequentialMonteCarlo( - propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - resample_fn=resample_fn, - resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients - ) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed - ) - return next_seed, next_state, next_results - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_state, - kernel.bootstrap_results(initial_state)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:]) - ), - static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations, - condition_fn=condition_fn - ) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) - - return traced_results - - @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -398,6 +341,8 @@ def particle_filter(observations, Filtering without Modifying the Forward Pass. _arXiv preprint arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 """ + + init_seed, loop_seed = samplers.split_seed(seed, salt='particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = ( @@ -409,36 +354,53 @@ def particle_filter(observations, static_trace_allocation_size = 0 trace_criterion_fn = never_trace - initial_particles = _particle_filter_initial_weighted_particles( - observations, - observation_fn, - initial_state_prior, - initial_state_proposal, - num_particles, - seed=None - ) - + initial_weighted_particles = _particle_filter_initial_weighted_particles( + observations=observations, + observation_fn=observation_fn, + initial_state_prior=initial_state_prior, + initial_state_proposal=initial_state_proposal, + num_particles=num_particles, + seed=init_seed) propose_and_update_log_weights_fn = ( _particle_filter_propose_and_update_log_weights_fn( observations=observations, transition_fn=transition_fn, proposal_fn=proposal_fn, observation_fn=observation_fn, - num_transitions_per_observation=num_transitions_per_observation) - ) + num_transitions_per_observation=num_transitions_per_observation)) - traced_results = sequential_monte_carlo( - # Weighted particles weighted also considering the observations. - initial_state=initial_particles, + kernel = smc_kernel.SequentialMonteCarlo( propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, - condition_fn=lambda step, state, num_traced, trace: step < ps.size0(observations), resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, - # rejuvenation_fn=..., - # rejuvenation_criterion_fn=..., - trace_fn=trace_fn, - num_timesteps=num_timesteps - ) + unbiased_gradients=unbiased_gradients) + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + seed, state, results = seed_state_results + one_step_seed, next_seed = samplers.split_seed(seed) + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + return next_seed, next_state, next_results + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_weighted_particles, + kernel.bootstrap_results(initial_weighted_particles)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) return traced_results @@ -452,119 +414,67 @@ def _particle_filter_initial_weighted_particles(observations, """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. if initial_state_proposal is None: - initial_particles = initial_state_prior.sample(num_particles, seed=seed) + initial_state = initial_state_prior.sample(num_particles, seed=seed) initial_log_weights = ps.zeros_like( - initial_state_prior.log_prob(initial_particles) - ) + initial_state_prior.log_prob(initial_state)) else: - initial_particles = initial_state_proposal.sample(num_partiles, seed=seed) - initial_log_weights = ( - initial_state_prior.log_prob(initial_particles) - - initial_state_proposal.log_prob(initial_particles) - ) - + initial_state = initial_state_proposal.sample(num_particles, seed=seed) + initial_log_weights = (initial_state_prior.log_prob(initial_state) - + initial_state_proposal.log_prob(initial_state)) # Normalize the initial weights. If we used a proposal, the weights are # normalized in expectation, but actually normalizing them reduces variance. initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( - particles=initial_particles, + particles=initial_state, log_weights=initial_log_weights + _compute_observation_log_weights( - step=0, - particles=initial_particles, - observations=observations, - observation_fn=observation_fn - ) - ) - - -def _propose_and_update_log_weights_fn( - observations, - transition_fn, - proposal_fn, - observation_fn, - step, - state, - seed, - num_transitions_per_observation=1 -): - """Particle filter propose and update for single steps""" - particles, log_weights = ( - proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) - ) - - assertions = _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), - weights_shape=ps.shape(log_weights), - diststr='transition' - ) - if proposal_fn(step, particles) != transition_fn(step, particles): - assertions += _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), - weights_shape=ps.shape(log_weights), - diststr='proposal' - ) - - log_weights = ( - observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) - + transition_fn(step, state.particles).log_prob(particles) - - log_weights - ) - - log_weights = tf.nn.log_softmax(log_weights, axis=0) - with tf.control_dependencies(assertions): - return smc_kernel.WeightedParticles( - particles=particles, - log_weights=log_weights + _compute_observation_log_weights( - step + 1, particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation) - ) + step=0, + particles=initial_state, + observations=observations, + observation_fn=observation_fn)) def _particle_filter_propose_and_update_log_weights_fn( - observations, - transition_fn, - proposal_fn, - observation_fn, - num_transitions_per_observation=1): + observations, + transition_fn, + proposal_fn, + observation_fn, + num_transitions_per_observation=1): """Build a function specifying a particle filter update step.""" - if proposal_fn is None: - proposal_fn = transition_fn - - def propose_and_update_log_weights_fn(step, state, seed): - """Particle filter propose and update for single steps""" - particles, log_weights = ( - proposal_fn(step, state.particles).experimental_sample_and_log_prob(seed=seed) - ) - + def propose_and_update_log_weights_fn(step, state, seed=None): + particles, log_weights = state.particles, state.log_weights + transition_dist = transition_fn(step, particles) assertions = _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), + distribution=transition_dist, weights_shape=ps.shape(log_weights), - diststr='transition' - ) - if proposal_fn(step, particles) != transition_fn(step, particles): - assertions += _assert_batch_shape_matches_weights( - distribution=transition_fn(step, particles), - weights_shape=ps.shape(log_weights), - diststr='proposal' - ) - - log_weights = ( - observation_fn(step, state.particles).log_prob(tf.gather(observations, step)) - + transition_fn(step, state.particles).log_prob(particles) - - log_weights - ) - - log_weights = tf.nn.log_softmax(log_weights, axis=0) + diststr='transition') + + if proposal_fn: + proposal_dist = proposal_fn(step, particles) + assertions += _assert_batch_shape_matches_weights( + distribution=proposal_dist, + weights_shape=ps.shape(log_weights), + diststr='proposal') + proposed_particles = proposal_dist.sample(seed=seed) + + log_weights += (transition_dist.log_prob(proposed_particles) - + proposal_dist.log_prob(proposed_particles)) + # The normalizing constant E~q[p(x)/q(x)] is 1 in expectation, + # so we reduce variance by dividing it out. Intuitively: the marginal + # likelihood of a model with no observations is constant + # (equal to 1.), so the transition and proposal distributions shouldn't + # affect it. + log_weights = tf.nn.log_softmax(log_weights, axis=0) + else: + proposed_particles = transition_dist.sample(seed=seed) + with tf.control_dependencies(assertions): return smc_kernel.WeightedParticles( - particles=particles, + particles=proposed_particles, log_weights=log_weights + _compute_observation_log_weights( - step + 1, particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation - ) - ) + step + 1, proposed_particles, observations, observation_fn, + num_transitions_per_observation=num_transitions_per_observation)) return propose_and_update_log_weights_fn From 771c987b7139b83a36bdd5576669934d189a03bc Mon Sep 17 00:00:00 2001 From: slamitza Date: Tue, 21 Nov 2023 21:34:26 +0100 Subject: [PATCH 6/7] pf refactor and update --- .../experimental/mcmc/particle_filter.py | 266 ++++++++++++++---- .../experimental/mcmc/particle_filter_test.py | 122 ++++++++ .../mcmc/sequential_monte_carlo_kernel.py | 40 ++- .../sequential_monte_carlo_kernel_test.py | 3 +- .../experimental/mcmc/weighted_resampling.py | 37 ++- .../mcmc/weighted_resampling_test.py | 42 +++ 6 files changed, 432 insertions(+), 78 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 83e7b99f77..8bb2a36f34 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -24,6 +24,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc.internal import util as mcmc_util +from tensorflow_probability.python.internal import distribution_util as dist_util __all__ = [ 'infer_trajectories', @@ -123,11 +124,11 @@ def infer_trajectories(observations, observation_fn, num_particles, initial_state_proposal=None, + particles_dim=0, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, unbiased_gradients=True, - rejuvenation_kernel_fn=None, num_transitions_per_observation=1, seed=None, name=None): # pylint: disable=g-doc-args @@ -247,32 +248,145 @@ def observation_fn(_, state): observation_fn=observation_fn, num_particles=num_particles, initial_state_proposal=initial_state_proposal, + particles_dim=particles_dim, proposal_fn=proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, unbiased_gradients=unbiased_gradients, - rejuvenation_kernel_fn=rejuvenation_kernel_fn, num_transitions_per_observation=num_transitions_per_observation, trace_fn=_default_trace_fn, trace_criterion_fn=lambda *_: True, seed=pf_seed, name=name) - weighted_trajectories = reconstruct_trajectories(particles, parent_indices) + weighted_trajectories = reconstruct_trajectories( + particles, + parent_indices, + particles_dim=particles_dim + ) # Resample all steps of the trajectories using the final weights. resample_indices = resample_fn(log_probs=log_weights[-1], event_size=num_particles, + particles_dim=particles_dim, sample_shape=(), seed=resample_seed) trajectories = tf.nest.map_structure( lambda x: mcmc_util.index_remapping_gather(x, # pylint: disable=g-long-lambda resample_indices, - axis=1), + axis=particles_dim+1, + indices_axis=particles_dim), weighted_trajectories) return trajectories, incremental_log_marginal_likelihoods +def sequential_monte_carlo(loop_seed, + initial_weighted_particles, + num_timesteps, + parallel_iterations, + trace_criterion_fn, + propose_and_update_log_weights_fn, + resample_fn, + resample_criterion_fn, + unbiased_gradients, + trace_fn, + particles_dim=0, + static_trace_allocation_size=None, + never_trace=lambda *_: False, + ): + + """Samples a series of particles representing filtered latent states. + + The particle filter samples from the sequence of "filtering" distributions + `p(state[t] | observations[:t])` over latent + states: at each point in time, this is a sample from the distribution + conditioned on all observations *up to that time*. Because particles may be + resampled, a particle at time `t` may be different from the particle with + the same index at time `t + 1`. To reconstruct trajectories by tracing back + through the resampling process, + see `tfp.mcmc.experimental.reconstruct_trajectories`. + + ${particle_filter_arg_str} + trace_fn: Python `callable` defining the values to be traced at each step, + with signature `traced_values = trace_fn(weighted_particles, results)` + in which the first argument is an instance of + `tfp.experimental.mcmc.WeightedParticles` and the second an instance of + `SequentialMonteCarloResults` tuple, and the return value is a structure + of `Tensor`s. + Default value: `lambda s, r: (s.particles, s.log_weights, + r.parent_indices, r.incremental_log_marginal_likelihood)` + trace_criterion_fn: optional Python `callable` with signature + `trace_this_step = trace_criterion_fn(weighted_particles, results)` + taking the same arguments as `trace_fn` and returning a boolean `Tensor`. + If `None`, only values from the final step are returned. + Default value: `lambda *_: True` (trace every step). + static_trace_allocation_size: Optional Python `int` size of trace to + allocate statically. This should be an upper bound on the number of steps + traced and is used only when the length cannot be + statically inferred (for example, if a `trace_criterion_fn` is + specified). + It is primarily intended for contexts where static shapes are required, + such as in XLA-compiled code. + Default value: `None`. + parallel_iterations: Passed to the internal `tf.while_loop`. + Default value: `1`. + seed: PRNG seed; see `tfp.random.sanitize_seed` for details. + name: Python `str` name for ops created by this method. + Default value: `None` (i.e., `'particle_filter'`). + Returns: + traced_results: A structure of Tensors as returned by `trace_fn`. If + `trace_criterion_fn==None`, this is computed from the final step; + otherwise, each Tensor will have initial dimension `num_steps_traced` + and stacks the traced results across all steps. + + #### References + + [1] Adam Scibior, Vaden Masrani, and Frank Wood. Differentiable Particle + Filtering without Modifying the Forward Pass. _arXiv preprint + arXiv:2106.10314_, 2021. https://arxiv.org/abs/2106.10314 + """ + kernel = smc_kernel.SequentialMonteCarlo( + propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, + resample_fn=resample_fn, + resample_criterion_fn=resample_criterion_fn, + particles_dim=particles_dim, + unbiased_gradients=unbiased_gradients + ) + + # Use `trace_scan` rather than `sample_chain` directly because the latter + # would force us to trace the state history (with or without thinning), + # which is not always appropriate. + def seeded_one_step(seed_state_results, _): + + seed, state, results = seed_state_results + + one_step_seed, next_seed = samplers.split_seed(seed) + + next_state, next_results = kernel.one_step( + state, results, seed=one_step_seed) + + return next_seed, next_state, next_results + + final_seed_state_result, traced_results = loop_util.trace_scan( + loop_fn=seeded_one_step, + initial_state=(loop_seed, + initial_weighted_particles, + kernel.bootstrap_results(initial_weighted_particles)), + elems=tf.ones([num_timesteps]), + trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), + trace_criterion_fn=( + lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda + *seed_state_results[1:])), + static_trace_allocation_size=static_trace_allocation_size, + parallel_iterations=parallel_iterations) + + if trace_criterion_fn is never_trace: + # Return results from just the final step. + traced_results = trace_fn(*final_seed_state_result[1:]) + + return traced_results + + @docstring_util.expand_docstring( particle_filter_arg_str=particle_filter_arg_str.format(scibor_ref_idx=1)) def particle_filter(observations, @@ -281,6 +395,7 @@ def particle_filter(observations, observation_fn, num_particles, initial_state_proposal=None, + particles_dim=0, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=smc_kernel.ess_below_threshold, @@ -298,10 +413,10 @@ def particle_filter(observations, The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all - observations *up to that time*. Because particles may be resampled, a particle - at time `t` may be different from the particle with the same index at time - `t + 1`. To reconstruct trajectories by tracing back through the resampling - process, see `tfp.mcmc.experimental.reconstruct_trajectories`. + observations *up to that time*. Because particles may be resampled, a + particle at time `t` may be different from the particle with the same index + at time `t + 1`. To reconstruct trajectories by tracing back through the + resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step, @@ -313,9 +428,9 @@ def particle_filter(observations, Default value: `lambda s, r: (s.particles, s.log_weights, r.parent_indices, r.incremental_log_marginal_likelihood)` trace_criterion_fn: optional Python `callable` with signature - `trace_this_step = trace_criterion_fn(weighted_particles, results)` taking - the same arguments as `trace_fn` and returning a boolean `Tensor`. If - `None`, only values from the final step are returned. + `trace_this_step = trace_criterion_fn(weighted_particles, results)` + taking the same arguments as `trace_fn` and returning a boolean `Tensor`. + If `None`, only values from the final step are returned. Default value: `lambda *_: True` (trace every step). static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps @@ -329,6 +444,8 @@ def particle_filter(observations, seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). + particles_dim: int `Tensor` specifying the dimension in `observations` + that corresponds to the particles dimension. Returns: traced_results: A structure of Tensors as returned by `trace_fn`. If `trace_criterion_fn==None`, this is computed from the final step; @@ -360,47 +477,32 @@ def particle_filter(observations, initial_state_prior=initial_state_prior, initial_state_proposal=initial_state_proposal, num_particles=num_particles, + particles_dim=particles_dim, seed=init_seed) propose_and_update_log_weights_fn = ( _particle_filter_propose_and_update_log_weights_fn( observations=observations, transition_fn=transition_fn, + particles_dim=particles_dim, proposal_fn=proposal_fn, observation_fn=observation_fn, num_transitions_per_observation=num_transitions_per_observation)) - kernel = smc_kernel.SequentialMonteCarlo( + traced_results = sequential_monte_carlo( + initial_weighted_particles=initial_weighted_particles, + num_timesteps=num_timesteps, + parallel_iterations=parallel_iterations, + particles_dim=particles_dim, propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, - unbiased_gradients=unbiased_gradients) - - # Use `trace_scan` rather than `sample_chain` directly because the latter - # would force us to trace the state history (with or without thinning), - # which is not always appropriate. - def seeded_one_step(seed_state_results, _): - seed, state, results = seed_state_results - one_step_seed, next_seed = samplers.split_seed(seed) - next_state, next_results = kernel.one_step( - state, results, seed=one_step_seed) - return next_seed, next_state, next_results - - final_seed_state_result, traced_results = loop_util.trace_scan( - loop_fn=seeded_one_step, - initial_state=(loop_seed, - initial_weighted_particles, - kernel.bootstrap_results(initial_weighted_particles)), - elems=tf.ones([num_timesteps]), - trace_fn=lambda seed_state_results: trace_fn(*seed_state_results[1:]), - trace_criterion_fn=( - lambda seed_state_results: trace_criterion_fn( # pylint: disable=g-long-lambda - *seed_state_results[1:])), static_trace_allocation_size=static_trace_allocation_size, - parallel_iterations=parallel_iterations) - - if trace_criterion_fn is never_trace: - # Return results from just the final step. - traced_results = trace_fn(*final_seed_state_result[1:]) + trace_criterion_fn=trace_criterion_fn, + trace_fn=trace_fn, + unbiased_gradients=unbiased_gradients, + never_trace=never_trace, + loop_seed=loop_seed + ) return traced_results @@ -410,20 +512,55 @@ def _particle_filter_initial_weighted_particles(observations, initial_state_prior, initial_state_proposal, num_particles, + particles_dim=0, seed=None): """Initialize a set of weighted particles including the first observation.""" # Propose an initial state. if initial_state_proposal is None: - initial_state = initial_state_prior.sample(num_particles, seed=seed) - initial_log_weights = ps.zeros_like( - initial_state_prior.log_prob(initial_state)) + if particles_dim == 0: + initial_state = initial_state_prior.sample(num_particles, seed=seed) + initial_log_weights = ps.zeros_like( + initial_state_prior.log_prob(initial_state) + ) + else: + particles_draw = initial_state_prior.sample(num_particles) + initial_state = tf.nest.map_structure( + lambda x: dist_util.move_dimension(x, + source_idx=0, + dest_idx=particles_dim), + particles_draw + ) + initial_log_weights = ps.zeros_like( + dist_util.move_dimension( + initial_state_prior.log_prob(particles_draw), + source_idx=0, + dest_idx=particles_dim) + ) else: - initial_state = initial_state_proposal.sample(num_particles, seed=seed) - initial_log_weights = (initial_state_prior.log_prob(initial_state) - - initial_state_proposal.log_prob(initial_state)) + if particles_dim == 0: + initial_state = initial_state_proposal.sample(num_particles, seed=seed) + initial_log_weights = (initial_state_prior.log_prob(initial_state) - + initial_state_proposal.log_prob(initial_state)) + else: + particles_draw = initial_state_proposal.sample(num_particles, seed=seed) + initial_state = tf.nest.map_structure( + lambda x: dist_util.move_dimension(x, + source_idx=0, + dest_idx=particles_dim), + particles_draw + ) + initial_log_weights = ps.zeros_like( + dist_util.move_dimension( + (initial_state_prior.log_prob(particles_draw) - + initial_state_proposal.log_prob(particles_draw)), + source_idx=0, + dest_idx=particles_dim) + ) + # Normalize the initial weights. If we used a proposal, the weights are # normalized in expectation, but actually normalizing them reduces variance. - initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) + initial_log_weights = tf.nn.log_softmax(initial_log_weights, + axis=particles_dim) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( @@ -432,7 +569,9 @@ def _particle_filter_initial_weighted_particles(observations, step=0, particles=initial_state, observations=observations, - observation_fn=observation_fn)) + observation_fn=observation_fn, + particles_dim=particles_dim) + ) def _particle_filter_propose_and_update_log_weights_fn( @@ -440,7 +579,8 @@ def _particle_filter_propose_and_update_log_weights_fn( transition_fn, proposal_fn, observation_fn, - num_transitions_per_observation=1): + num_transitions_per_observation=1, + particles_dim=0): """Build a function specifying a particle filter update step.""" def propose_and_update_log_weights_fn(step, state, seed=None): particles, log_weights = state.particles, state.log_weights @@ -465,7 +605,7 @@ def propose_and_update_log_weights_fn(step, state, seed=None): # likelihood of a model with no observations is constant # (equal to 1.), so the transition and proposal distributions shouldn't # affect it. - log_weights = tf.nn.log_softmax(log_weights, axis=0) + log_weights = tf.nn.log_softmax(log_weights, axis=particles_dim) else: proposed_particles = transition_dist.sample(seed=seed) @@ -474,7 +614,9 @@ def propose_and_update_log_weights_fn(step, state, seed=None): particles=proposed_particles, log_weights=log_weights + _compute_observation_log_weights( step + 1, proposed_particles, observations, observation_fn, - num_transitions_per_observation=num_transitions_per_observation)) + num_transitions_per_observation=num_transitions_per_observation, + particles_dim=particles_dim), + ) return propose_and_update_log_weights_fn @@ -482,7 +624,8 @@ def _compute_observation_log_weights(step, particles, observations, observation_fn, - num_transitions_per_observation=1): + num_transitions_per_observation=1, + particles_dim=0): """Computes particle importance weights from an observation step. Args: @@ -502,6 +645,8 @@ def _compute_observation_log_weights(step, num_transitions_per_observation: optional int `Tensor` number of times to apply the transition model between successive observation steps. Default value: `1`. + particles_dim: int `Tensor` specifying the dimension in `observations` + that corresponds to the particles dimension. Returns: log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`. """ @@ -514,7 +659,12 @@ def _compute_observation_log_weights(step, ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation observation = tf.nest.map_structure( - lambda x, step=step: tf.gather(x, observation_idx), observations) + lambda x, step=step: tf.gather(x, observation_idx), + observations) + + observation = tf.nest.map_structure(lambda x: + tf.expand_dims(x, axis=particles_dim), + observation) log_weights = observation_fn(step, particles).log_prob(observation) return tf.where(step_has_observation, @@ -522,14 +672,17 @@ def _compute_observation_log_weights(step, tf.zeros_like(log_weights)) -def reconstruct_trajectories(particles, parent_indices, name=None): +def reconstruct_trajectories(particles, + parent_indices, + particles_dim=0, + name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): # Walk backwards to compute the ancestor of each final particle at time t. final_indices = smc_kernel._dummy_indices_like(parent_indices[-1]) # pylint: disable=protected-access ancestor_indices = tf.scan( fn=lambda ancestor, parent: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda - parent, ancestor, axis=0), + parent, ancestor, axis=particles_dim, indices_axis=particles_dim), elems=parent_indices[1:], initializer=final_indices, reverse=True) @@ -537,7 +690,10 @@ def reconstruct_trajectories(particles, parent_indices, name=None): return tf.nest.map_structure( lambda part: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda - part, ancestor_indices, axis=1, indices_axis=1), + part, + ancestor_indices, + axis=particles_dim + 1, + indices_axis=particles_dim + 1), particles) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index 5e0628dd02..6508eb6231 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -177,6 +177,128 @@ def observation_fn(_, state): self.assertAllEqual(incremental_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) + def test_batch_of_filters_particles_dim_1(self): + + batch_shape = [3, 2] + num_particles = 1000 + num_timesteps = 40 + + # Batch of priors on object 1D positions and velocities. + initial_state_prior = jdn.JointDistributionNamed({ + 'position': normal.Normal(loc=0., scale=tf.ones(batch_shape)), + 'velocity': normal.Normal(loc=0., scale=tf.ones(batch_shape) * 0.1) + }) + + def transition_fn(_, previous_state): + return jdn.JointDistributionNamed({ + 'position': + normal.Normal( + loc=previous_state['position'] + previous_state['velocity'], + scale=0.1), + 'velocity': + normal.Normal(loc=previous_state['velocity'], scale=0.01) + }) + + def observation_fn(_, state): + return normal.Normal(loc=state['position'], scale=0.1) + + # Batch of synthetic observations, . + true_initial_positions = np.random.randn(*batch_shape).astype(self.dtype) + true_velocities = 0.1 * np.random.randn( + *batch_shape).astype(self.dtype) + observed_positions = ( + true_velocities * + np.arange(num_timesteps).astype( + self.dtype)[..., tf.newaxis, tf.newaxis] + + true_initial_positions) + + (particles, log_weights, parent_indices, + incremental_log_marginal_likelihoods) = self.evaluate( + particle_filter.particle_filter( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + seed=test_util.test_seed(), + particles_dim=1)) + + self.assertAllEqual(particles['position'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(particles['velocity'].shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(parent_indices.shape, + [num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]]) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + + self.assertAllClose( + self.evaluate( + tf.reduce_sum(tf.exp(log_weights) * + particles['position'], axis=2)), + observed_positions, + atol=0.3) + + velocity_means = tf.reduce_sum(tf.exp(log_weights) * + particles['velocity'], axis=2) + + self.assertAllClose( + self.evaluate(tf.reduce_mean(velocity_means, axis=0)), + true_velocities, atol=0.05) + + # Uncertainty in velocity should decrease over time. + velocity_stddev = self.evaluate( + tf.math.reduce_std(particles['velocity'], axis=2)) + self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) + + trajectories = self.evaluate( + particle_filter.reconstruct_trajectories(particles, + parent_indices, + particles_dim=1)) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + + # Verify that `infer_trajectories` also works on batches. + trajectories, incremental_log_marginal_likelihoods = self.evaluate( + particle_filter.infer_trajectories( + observations=observed_positions, + initial_state_prior=initial_state_prior, + transition_fn=transition_fn, + observation_fn=observation_fn, + num_particles=num_particles, + particles_dim=1, + seed=test_util.test_seed())) + + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['position'].shape) + self.assertAllEqual([num_timesteps, + batch_shape[0], + num_particles, + batch_shape[1]], + trajectories['velocity'].shape) + self.assertAllEqual(incremental_log_marginal_likelihoods.shape, + [num_timesteps] + batch_shape) + def test_reconstruct_trajectories_toy_example(self): particles = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6,], [7, 8, 9]]) # 1 -- 4 -- 7 diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py index 6cb9b65003..7304622489 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel.py @@ -21,6 +21,7 @@ from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.mcmc import kernel as kernel_base +from tensorflow_probability.python.mcmc.internal import util as mcmc_util __all__ = [ 'SequentialMonteCarlo', @@ -115,19 +116,21 @@ def _dummy_indices_like(indices): indices_shape) -def log_ess_from_log_weights(log_weights): - """Computes log-ESS estimate from log-weights along axis=0.""" +def log_ess_from_log_weights(log_weights, particles_dim=0): + """Computes log-ESS estimate from log-weights along axis=particles_dim.""" with tf.name_scope('ess_from_log_weights'): - log_weights = tf.math.log_softmax(log_weights, axis=0) - return -tf.math.reduce_logsumexp(2 * log_weights, axis=0) + log_weights = tf.math.log_softmax(log_weights, axis=particles_dim) + return -tf.math.reduce_logsumexp(2 * log_weights, axis=particles_dim) -def ess_below_threshold(weighted_particles, threshold=0.5): +def ess_below_threshold(weighted_particles, particles_dim=0, threshold=0.5): """Determines if the effective sample size is much less than num_particles.""" with tf.name_scope('ess_below_threshold'): num_particles = ps.size0(weighted_particles.log_weights) - log_ess = log_ess_from_log_weights(weighted_particles.log_weights) - return log_ess < (ps.log(num_particles) + ps.log(threshold)) + log_ess = log_ess_from_log_weights(weighted_particles.log_weights, + particles_dim) + return tf.expand_dims(log_ess < (ps.log(num_particles) + + ps.log(threshold)), axis=particles_dim) class SequentialMonteCarlo(kernel_base.TransitionKernel): @@ -145,6 +148,7 @@ def __init__(self, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, unbiased_gradients=True, + particles_dim=0, name=None): """Initializes a sequential Monte Carlo transition kernel. @@ -202,6 +206,7 @@ def __init__(self, self._resample_fn = resample_fn self._resample_criterion_fn = resample_criterion_fn self._unbiased_gradients = unbiased_gradients + self._particles_dim = particles_dim self._name = name or 'SequentialMonteCarlo' @property @@ -269,15 +274,17 @@ def one_step(self, state, kernel_results, seed=None): state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) - normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) + normalized_log_weights = tf.nn.log_softmax(state.log_weights, + axis=self._particles_dim) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. - incremental_log_marginal_likelihood = (state.log_weights[0] - - normalized_log_weights[0]) - do_resample = self.resample_criterion_fn(state) + incremental_log_marginal_likelihood = ( + tf.gather(state.log_weights, 0, axis=self._particles_dim) - + tf.gather(normalized_log_weights, 0, axis=self._particles_dim)) + do_resample = self.resample_criterion_fn(state, self._particles_dim) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to @@ -300,11 +307,12 @@ def one_step(self, state, kernel_results, seed=None): resample_fn=self.resample_fn, target_log_weights=(normalized_log_weights if self.unbiased_gradients else None), + particles_dim=self._particles_dim, seed=resample_seed) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( - lambda r, p: tf.where(do_resample, r, p), + lambda r, p: mcmc_util.choose(do_resample, r, p), (resampled_particles, resample_indices, weights_after_resampling), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) @@ -326,9 +334,13 @@ def bootstrap_results(self, init_state): with tf.name_scope('bootstrap_results'): init_state = WeightedParticles(*init_state) + particles_shape = ps.shape(init_state.log_weights) + weights_shape = ps.concat([ + particles_shape[:self._particles_dim], + particles_shape[self._particles_dim+1:] + ], axis=0) batch_zeros = tf.zeros( - ps.shape(init_state.log_weights)[1:], - dtype=init_state.log_weights.dtype) + weights_shape, dtype=init_state.log_weights.dtype) return SequentialMonteCarloResults( steps=0, diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 2a9302a420..8dd11dd69b 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -62,7 +62,8 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): state, results = kernel.one_step( state=initial_state, kernel_results=kernel.bootstrap_results(initial_state), - seed=seeds[0]) + seed=seeds[0], + ) state, results = kernel.one_step(state=state, kernel_results=results, seed=seeds[1]) state, results = self.evaluate( diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py index 9632d916f5..613419cf46 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling.py @@ -34,7 +34,7 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, - seed=None): + particles_dim=0, seed=None): """Resamples the current particles according to provided weights. Args: @@ -54,6 +54,10 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, `None`, the target measure is implicitly taken to be the normalized log weights (`log_weights - tf.reduce_logsumexp(log_weights, axis=0)`). Default value: `None`. + particles_dim: Python `int` axis of each state `Tensor` indexing into the + particles. This is almost always zero, but nonzero values may be necessary + when running SMC in nested contexts. + Default value: `0`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: @@ -69,15 +73,25 @@ def resample(particles, log_weights, resample_fn, target_log_weights=None, resampling are uniformly equal to `-log(num_particles)`. """ with tf.name_scope('resample'): - num_particles = ps.size0(log_weights) + num_particles = ps.dimension_size(log_weights, particles_dim) + log_num_particles = tf.math.log(tf.cast(num_particles, log_weights.dtype)) # Normalize the weights and sample the ancestral indices. - log_probs = tf.math.log_softmax(log_weights, axis=0) - resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + log_probs = tf.math.log_softmax(log_weights, axis=particles_dim) + if particles_dim == 0: + # For resample functions that don't yet support the + # particles_dim argument. + resampled_indices = resample_fn(log_probs, num_particles, (), seed=seed) + else: + resampled_indices = resample_fn(log_probs, num_particles, (), + particles_dim=particles_dim, seed=seed) gather_ancestors = lambda x: ( # pylint: disable=g-long-lambda - mcmc_util.index_remapping_gather(x, resampled_indices, axis=0)) + mcmc_util.index_remapping_gather(x, + resampled_indices, + axis=particles_dim, + indices_axis=particles_dim)) resampled_particles = tf.nest.map_structure(gather_ancestors, particles) if target_log_weights is None: log_weights_after_resampling = tf.fill(ps.shape(log_weights), @@ -242,7 +256,7 @@ def resample_independent(log_probs, event_size, sample_shape, # TODO(b/153689734): rewrite so as not to use `move_dimension`. def resample_systematic(log_probs, event_size, sample_shape, - seed=None, name=None): + particles_dim=0, seed=None, name=None): """A systematic resampler for sequential Monte Carlo. The value returned from this function is similar to sampling with @@ -272,6 +286,9 @@ def resample_systematic(log_probs, event_size, sample_shape, The remaining dimensions are batch dimensions. event_size: the dimension of the vector considered a single draw. sample_shape: the `sample_shape` determining the number of draws. + particles_dim: Python `int` axis of each state `Tensor` indexing into the + particles. This is almost always zero, but nonzero values may be necessary + when running SMC in nested contexts. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Default value: None (i.e. no seed). name: Python `str` name for ops created by this method. @@ -293,7 +310,9 @@ def resample_systematic(log_probs, event_size, sample_shape, """ with tf.name_scope(name or 'resample_systematic') as name: log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32) - log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1) + log_probs = dist_util.move_dimension(log_probs, + source_idx=particles_dim, + dest_idx=-1) working_shape = ps.concat([sample_shape, ps.shape(log_probs)[:-1]], axis=0) points_shape = ps.concat([working_shape, [event_size]], axis=0) @@ -310,7 +329,9 @@ def resample_systematic(log_probs, event_size, sample_shape, log_points = tf.broadcast_to(tf.math.log(even_spacing), points_shape) resampled = _resample_using_log_points(log_probs, sample_shape, log_points) - return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0) + return dist_util.move_dimension(resampled, + source_idx=-1, + dest_idx=particles_dim) # TODO(b/153689734): rewrite so as not to use `move_dimension`. diff --git a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py index e415b4c99e..ace87de1e1 100644 --- a/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py +++ b/tensorflow_probability/python/experimental/mcmc/weighted_resampling_test.py @@ -299,6 +299,48 @@ def resample_with_target_distribution(self): tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles), 30., atol=1.) + def test_with_target_distribution_dim_one(self): + stacked_particles = np.stack([ + np.linspace(0., 500., num=2500, dtype=np.float32), + np.linspace(0.17, 433., num=2500, dtype=np.float32), + ], axis=0) + + stacked_log_weights = poisson.Poisson(20.).log_prob(stacked_particles) + + # Resample particles to target a Poisson(20.) distribution. + new_particles, _, new_log_weights = resample( + stacked_particles, stacked_log_weights, + resample_fn=resample_systematic, + particles_dim=1, + seed=test_util.test_seed(sampler_type='stateless')) + + self.assertAllMeansClose(new_particles, + [20., 20.], + axis=1, + atol=1e-2) + self.assertAllClose( + tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1), + [20., 20.], + atol=1e-2) + + # Reweight the resampled particles to target a Poisson(30.) distribution. + new_particles, _, new_log_weights = resample( + stacked_particles, + stacked_log_weights, + resample_fn=resample_systematic, + particles_dim=1, + target_log_weights=poisson.Poisson(30).log_prob(stacked_particles), + seed=test_util.test_seed(sampler_type='stateless')) + self.assertAllMeansClose(new_particles, + [20., 20.], + axis=1, + atol=1e-2) + + self.assertAllClose( + tf.reduce_sum(tf.nn.softmax(new_log_weights) * new_particles, axis=1), + [30., 30.], + atol=1.5) + def maybe_compiler(self, f): if self.use_xla: return tf.function(f, autograph=False, jit_compile=True) From fc1e466d7d24cbc96cb249eff55a73b164dc9189 Mon Sep 17 00:00:00 2001 From: slamitza Date: Tue, 21 Nov 2023 23:04:35 +0100 Subject: [PATCH 7/7] whitespace --- .../python/experimental/mcmc/particle_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 8bb2a36f34..8e23166adf 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -663,7 +663,7 @@ def _compute_observation_log_weights(step, observations) observation = tf.nest.map_structure(lambda x: - tf.expand_dims(x, axis=particles_dim), + tf.expand_dims(x, axis=particles_dim), observation) log_weights = observation_fn(step, particles).log_prob(observation)