Skip to content

Commit

Permalink
Adds util to compute binary predictive posterior variance
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 339414502
  • Loading branch information
Edward2 Team authored and edward-bot committed Oct 28, 2020
1 parent 7321555 commit 3368fa3
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion edward2/tensorflow/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,27 @@ def smart_constant_value(pred):
return pred_value


def mean_field_binary_predictive_variance(logits, covmat, mean_field_factor=1.):
"""Compute predictive variance for Laplace-approximated logit posterior, assuming sigmoid link.
Arguments:
logits: A float tensor of shape (batch_size, num_classes).
covmat: A float tensor of shape (batch_size, batch_size).
mean_field_factor: The scale factor for mean-field approximation, used to
adjust the influence of posterior variance in posterior mean
approximation.
Returns:
Mean-field posterior variance.
"""
logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor)
logits = logits / tf.expand_dims(logits_scale, axis=-1)
posterior_mean = tf.sigmoid(tf.squeeze(logits, axis=(1,)))

return posterior_mean * (1 - posterior_mean) * (1 / logits_scale)


def mean_field_logits(logits, covmat, mean_field_factor=1.):
"""Adjust the SNGP logits so its softmax approximates posterior mean [1].
Expand All @@ -356,7 +377,7 @@ def mean_field_logits(logits, covmat, mean_field_factor=1.):
approximation.
Returns:
True or False if `pred` has a constant boolean value, None otherwise.
Calibrated logits.
"""
logits_scale = tf.sqrt(1. + tf.linalg.diag_part(covmat) * mean_field_factor)
Expand Down

0 comments on commit 3368fa3

Please sign in to comment.