Skip to content

Commit

Permalink
Fixes for push
Browse files Browse the repository at this point in the history
  • Loading branch information
aleslamitz committed Nov 11, 2023
1 parent 5efcb09 commit 37b01c1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 194 deletions.
107 changes: 8 additions & 99 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ def _default_trace_fn(state, kernel_results):
kernel_results.incremental_log_marginal_likelihood)


def _identity_rejuvenation(particles, log_weights, particles_dim, extra, seed):
return particles, log_weights


def _default_extra_fn(step,
state,
seed
):
return state.extra


particle_filter_arg_str = """\
Each latent state is a `Tensor` or nested structure of `Tensor`s, as defined
by the `initial_state_prior`.
Expand Down Expand Up @@ -141,8 +130,6 @@ def infer_trajectories(observations,
resample_fn=weighted_resampling.resample_systematic,
resample_criterion_fn=smc_kernel.ess_below_threshold,
unbiased_gradients=True,
rejuvenation_fn=_identity_rejuvenation,
rejuvenation_criterion_fn=None,
num_transitions_per_observation=1,
seed=None,
name=None): # pylint: disable=g-doc-args
Expand Down Expand Up @@ -253,9 +240,6 @@ def observation_fn(_, state):
pf_seed, resample_seed = samplers.split_seed(
seed, salt='infer_trajectories')

if rejuvenation_criterion_fn is None:
rejuvenation_criterion_fn = lambda *_: False

(particles,
log_weights,
parent_indices,
Expand All @@ -270,8 +254,6 @@ def observation_fn(_, state):
resample_fn=resample_fn,
resample_criterion_fn=resample_criterion_fn,
unbiased_gradients=unbiased_gradients,
rejuvenation_fn=rejuvenation_fn,
rejuvenation_criterion_fn=rejuvenation_criterion_fn,
num_transitions_per_observation=num_transitions_per_observation,
trace_fn=_default_trace_fn,
trace_criterion_fn=lambda *_: True,
Expand Down Expand Up @@ -303,7 +285,6 @@ def sequential_monte_carlo(loop_seed,
resample_criterion_fn,
unbiased_gradients,
trace_fn,
extra_fn=_default_extra_fn,
particles_dim=0,
static_trace_allocation_size=None,
never_trace=lambda *_: False,
Expand Down Expand Up @@ -412,13 +393,10 @@ def particle_filter(observations,
resample_fn=weighted_resampling.resample_systematic,
resample_criterion_fn=smc_kernel.ess_below_threshold,
unbiased_gradients=True,
rejuvenation_fn=_identity_rejuvenation,
rejuvenation_criterion_fn=None,
num_transitions_per_observation=1,
trace_fn=_default_trace_fn,
trace_criterion_fn=_always_trace,
static_trace_allocation_size=None,
extra_fn=_default_extra_fn,
parallel_iterations=1,
seed=None,
name=None): # pylint: disable=g-doc-args
Expand Down Expand Up @@ -476,9 +454,6 @@ def particle_filter(observations,
num_timesteps = (
1 + num_transitions_per_observation * (num_observation_steps - 1))

if rejuvenation_criterion_fn is None:
rejuvenation_criterion_fn = lambda *_: tf.constant(False)

# If trace criterion is `None`, we'll return only the final results.
never_trace = lambda *_: False
if trace_criterion_fn is None:
Expand All @@ -490,7 +465,7 @@ def particle_filter(observations,
observation_fn=observation_fn,
initial_state_prior=initial_state_prior,
initial_state_proposal=initial_state_proposal,
num_inner_particles=num_particles,
num_particles=num_particles,
particles_dim=particles_dim,
seed=init_seed)
propose_and_update_log_weights_fn = (
Expand All @@ -500,9 +475,6 @@ def particle_filter(observations,
proposal_fn=proposal_fn,
observation_fn=observation_fn,
particles_dim=particles_dim,
extra_fn=extra_fn,
rejuvenation_fn=rejuvenation_fn,
rejuvenation_criterion_fn=rejuvenation_criterion_fn,
num_transitions_per_observation=num_transitions_per_observation))

traced_results = sequential_monte_carlo(
Expand All @@ -524,43 +496,21 @@ def particle_filter(observations,
return traced_results


def sample_at_dim(initial_state_prior, dim, num_samples, seed=None):
if type(initial_state_prior.batch_shape) is dict:
model_dict = initial_state_prior.model
sampled_model = {}

for key in model_dict.keys():
d = model_dict[key]
batch_shape = d.batch_shape
d = batch_reshape.BatchReshape(d, batch_shape[:dim] + [1] + batch_shape[dim:])
d = batch_broadcast.BatchBroadcast(d, batch_shape[:dim] + [num_samples] + batch_shape[dim:])
sampled_model[key] = d.sample(seed=seed)

return sampled_model

else:
batch_shape = initial_state_prior.batch_shape
initial_state_prior = batch_reshape.BatchReshape(initial_state_prior, batch_shape[:dim] + [1] + batch_shape[dim:])
initial_state_prior = batch_broadcast.BatchBroadcast(initial_state_prior, batch_shape[:dim] + [num_samples] + batch_shape[dim:])
return initial_state_prior.sample(seed=seed)


def _particle_filter_initial_weighted_particles(observations,
observation_fn,
initial_state_prior,
initial_state_proposal,
num_inner_particles,
num_particles,
particles_dim=0,
extra=np.nan,
seed=None):
"""Initialize a set of weighted particles including the first observation."""
# Propose an initial state.
if initial_state_proposal is None:
if particles_dim == 0:
initial_state = initial_state_prior.sample(num_inner_particles, seed=seed)
initial_state = initial_state_prior.sample(num_particles, seed=seed)
initial_log_weights = ps.zeros_like(initial_state_prior.log_prob(initial_state))
else:
particles_draw = initial_state_prior.sample(num_inner_particles)
particles_draw = initial_state_prior.sample(num_particles)
initial_state = tf.nest.map_structure(
lambda x: dist_util.move_dimension(x, source_idx=0, dest_idx=particles_dim),
particles_draw
Expand All @@ -569,31 +519,22 @@ def _particle_filter_initial_weighted_particles(observations,
dist_util.move_dimension(initial_state_prior.log_prob(particles_draw), source_idx=0, dest_idx=particles_dim)
)
else:
initial_state = initial_state_proposal.sample(num_inner_particles, seed=seed)
initial_state = initial_state_proposal.sample(num_particles, seed=seed)
initial_log_weights = (initial_state_prior.log_prob(initial_state) -
initial_state_proposal.log_prob(initial_state))
# Normalize the initial weights. If we used a proposal, the weights are
# normalized in expectation, but actually normalizing them reduces variance.
initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=particles_dim)

# Return particles weighted by the initial observation.

if extra is np.nan:
if len(ps.shape(initial_log_weights)) == 1:
# initial extra for particle filter
extra = tf.constant(0)
else:
# initial extra for inner particles of smc_squared
extra = tf.constant(0, shape=ps.shape(initial_log_weights))

return smc_kernel.WeightedParticles(
particles=initial_state,
log_weights=initial_log_weights + _compute_observation_log_weights(
step=0,
particles=initial_state,
observations=observations,
observation_fn=observation_fn),
extra=extra)
observation_fn=observation_fn)
)


def _particle_filter_propose_and_update_log_weights_fn(
Expand All @@ -602,9 +543,6 @@ def _particle_filter_propose_and_update_log_weights_fn(
proposal_fn,
observation_fn,
num_transitions_per_observation=1,
extra_fn=_default_extra_fn,
rejuvenation_criterion_fn=None,
rejuvenation_fn=_identity_rejuvenation,
particles_dim=0):
"""Build a function specifying a particle filter update step."""
def propose_and_update_log_weights_fn(step, state, seed=None):
Expand Down Expand Up @@ -634,42 +572,13 @@ def propose_and_update_log_weights_fn(step, state, seed=None):
else:
proposed_particles = transition_dist.sample(seed=seed)

if rejuvenation_criterion_fn == None:
do_rejuvenation = False
else:
do_rejuvenation = rejuvenation_criterion_fn(state, particles_dim)

[
rej_particles,
rej_log_weights
] = rejuvenation_fn(
particles=particles,
log_weights=tf.stop_gradient(state.log_weights),
particles_dim=particles_dim,
extra=state.extra,
seed=seed)

(
proposed_particles,
log_weights
) = tf.nest.map_structure(
lambda r, p: mcmc_util.choose(do_rejuvenation, r, p),
(rej_particles, rej_log_weights),
(proposed_particles, log_weights))

proposed_extra = extra_fn(
step,
state,
seed
)

with tf.control_dependencies(assertions):
return smc_kernel.WeightedParticles(
particles=proposed_particles,
log_weights=log_weights + _compute_observation_log_weights(
step + 1, proposed_particles, observations, observation_fn,
num_transitions_per_observation=num_transitions_per_observation),
extra=proposed_extra)
)
return propose_and_update_log_weights_fn


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,83 +615,6 @@ def marginal_log_likelihood(level_scale, noise_scale):
self.assertAllAssertsNested(self.assertNotAllZero, grads)


def test_extra(self):
def step_hundred(step,
state,
seed
):
return step + 100

results = self.evaluate(
particle_filter.particle_filter(
observations=tf.convert_to_tensor([1., 3., 5., 7., 9.]),
initial_state_prior=normal.Normal(1., 0.2),
transition_fn=lambda _, state: normal.Normal(state + 2, 0.5),
observation_fn=lambda _, state: normal.Normal(state, 0.5),
num_particles=64,
extra_fn=step_hundred,
trace_fn=lambda s, r: s.extra,
seed=test_util.test_seed())
)

self.assertAllEqual(results[1:], [100, 101, 102, 103])

def test_rejuvenation_fn(self):
# A simple HMM with 10 hidden states
stream = test_util.test_seed_stream()
d = hidden_markov_model.HiddenMarkovModel(
initial_distribution=categorical.Categorical(logits=tf.zeros(10)),
transition_distribution=categorical.Categorical(logits=tf.zeros((10, 10))),
observation_distribution=normal.Normal(loc=tf.range(10.), scale=0.3),
num_steps=10
)
observation = categorical.Categorical(
logits=[0] * 10,
dtype=tf.float32).sample(10, seed=test_util.test_seed())

# A dimension for each particle of the particles filters
observations = tf.reshape(tf.tile(observation, [10]),
[10, tf.shape(observation)[0]])

def rejuvenation_fn(particles, log_weights, particles_dim, extra, seed):
posterior = d.posterior_marginals(observation).sample(seed=test_util.test_seed())
initial_weights = ps.zeros_like(posterior, dtype=tf.float32)
initial_log_weights = tf.nn.log_softmax(initial_weights)
return posterior, initial_log_weights

def rejuvenation_criterion_fn(state, particles_dim):
return True

rej_particles, _, _, _ =\
particle_filter.particle_filter(
observations=observation,
initial_state_prior=d.initial_distribution,
transition_fn=lambda _, s: categorical.Categorical(logits=tf.zeros(s.shape + tuple([10]))),
observation_fn=lambda _, s: normal.Normal(loc=tf.cast(s, tf.float32), scale=0.3),
rejuvenation_criterion_fn=rejuvenation_criterion_fn,
rejuvenation_fn=rejuvenation_fn,
num_particles=10,
seed=test_util.test_seed()
)

delta_rej = tf.where(observations - tf.cast(rej_particles, tf.float32) != 0, 1, 0)

nonrej_particles, _, _, _ =\
particle_filter.particle_filter(
observations=observation,
initial_state_prior=d.initial_distribution,
transition_fn=lambda _, s: categorical.Categorical(logits=tf.zeros(s.shape + tuple([10]))),
observation_fn=lambda _, s: normal.Normal(loc=tf.cast(s, tf.float32), scale=0.3),
num_particles=10,
seed=test_util.test_seed()
)
delta_nonrej = tf.where(observations - tf.cast(nonrej_particles, tf.float32) != 0, 1, 0)

delta = tf.reduce_sum(delta_nonrej - delta_rej)

self.assertAllGreaterEqual(self.evaluate(delta), 0)


# TODO(b/186068104): add tests with dynamic shapes.
class ParticleFilterTestFloat32(_ParticleFilterTest):
dtype = np.float32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# SequentialMonteCarlo `state` structure.
class WeightedParticles(collections.namedtuple(
'WeightedParticles', ['particles', 'log_weights', 'extra'])):
'WeightedParticles', ['particles', 'log_weights'])):
"""Particles with corresponding log weights.
This structure serves as the `state` for the `SequentialMonteCarlo` transition
Expand All @@ -50,9 +50,6 @@ class WeightedParticles(collections.namedtuple(
`exp(reduce_logsumexp(log_weights, axis=0)) == 1.`. These must be used in
conjunction with `particles` to compute expectations under the target
distribution.
extra: a (structure of) Tensor(s) each of shape
`concat([[b1, ..., bN], event_shape])`, where `event_shape`
may differ across component `Tensor`s.
In some contexts, particles may be stacked across multiple inference steps,
in which case all `Tensor` shapes will be prefixed by an additional dimension
Expand Down Expand Up @@ -319,8 +316,7 @@ def one_step(self, state, kernel_results, seed=None):
normalized_log_weights))

return (WeightedParticles(particles=resampled_particles,
log_weights=log_weights,
extra=state.extra),
log_weights=log_weights),
SequentialMonteCarloResults(
steps=kernel_results.steps + 1,
parent_indices=resample_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
]


def resample(particles, log_weights, resample_fn, target_log_weights=None, particles_dim=0,
seed=None):
def resample(particles, log_weights, resample_fn, target_log_weights=None,
particles_dim=0, seed=None):
"""Resamples the current particles according to provided weights.
Args:
Expand Down Expand Up @@ -250,8 +250,8 @@ def resample_independent(log_probs, event_size, sample_shape,


# TODO(b/153689734): rewrite so as not to use `move_dimension`.
def resample_systematic(log_probs, event_size, sample_shape, particles_dim=0,
seed=None, name=None):
def resample_systematic(log_probs, event_size, sample_shape,
particles_dim=0, seed=None, name=None):
"""A systematic resampler for sequential Monte Carlo.
The value returned from this function is similar to sampling with
Expand Down Expand Up @@ -302,7 +302,9 @@ def resample_systematic(log_probs, event_size, sample_shape, particles_dim=0,
"""
with tf.name_scope(name or 'resample_systematic') as name:
log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
log_probs = dist_util.move_dimension(log_probs, source_idx=particles_dim, dest_idx=-1)
log_probs = dist_util.move_dimension(log_probs,
source_idx=particles_dim,
dest_idx=-1)
working_shape = ps.concat([sample_shape,
ps.shape(log_probs)[:-1]], axis=0)
points_shape = ps.concat([working_shape, [event_size]], axis=0)
Expand All @@ -319,7 +321,9 @@ def resample_systematic(log_probs, event_size, sample_shape, particles_dim=0,
log_points = tf.broadcast_to(tf.math.log(even_spacing), points_shape)

resampled = _resample_using_log_points(log_probs, sample_shape, log_points)
return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=particles_dim)
return dist_util.move_dimension(resampled,
source_idx=-1,
dest_idx=particles_dim)


# TODO(b/153689734): rewrite so as not to use `move_dimension`.
Expand Down
Loading

0 comments on commit 37b01c1

Please sign in to comment.