Skip to content

Commit

Permalink
Add CalNPELoss
Browse files Browse the repository at this point in the history
  • Loading branch information
Maciej Falkiewicz committed Dec 31, 2023
1 parent 7be0634 commit 9c472ec
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 7 deletions.
125 changes: 124 additions & 1 deletion lampe/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
'AMNRELoss',
'NPE',
'NPELoss',
'CalNPELoss',
'FMPE',
'FMPELoss',
'MetropolisHastings',
Expand All @@ -24,7 +25,7 @@
import torchsort

from itertools import islice
from torch import Tensor, BoolTensor
from torch import Tensor, BoolTensor, Size
from torch.distributions import Distribution
from typing import *

Expand Down Expand Up @@ -719,6 +720,128 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor:
return -log_p.mean()


class CalNPELoss(nn.Module):
r"""Creates a module that calculates the negative log-likelihood loss and
the calibration/conservativeness loss for a NPE normalizing flow.
Given a batch of :math:`N \geq 2` pairs :math:`(\theta_i, x_i)`, the module returns
.. math::
l & = \frac{1}{N} \sum_{i = 1}^N -\log p_\phi(\theta_i | x_i) \\
& + \lambda \max_j | \text{ECP}(1 - \alpha_j) - (1 - \alpha_j)|
where :math:`\ell(p) = -\log p` is the negative log-likelihood and
:math:`\text{ECP}(1 - \alpha_j)` is the Expected Coverage Probability at
credibility level :math:`1 - \alpha_j`.
References:
| Calibrating Neural Simulation-Based Inference with Differentiable Coverage Probability (Falkiewicz et al., 2023)
| https://arxiv.org/abs/2310.13402
Arguments:
estimator: A normalizing flow :math:`p_\phi(\theta | x)`.
lmbda: The weight :math:`\lambda` controlling the strength of the regularizer.
n_samples: Number of samples in MC estimate of rank statistic
calibration: Boolean flag of calibration objective (default: False)
sort_kwargs: Arguments of differentiable sorting algorithm, see [doc](https://github.com/teddykoker/torchsort#usage) (default: None)
vmap_chunk_size: Chunk size for vectorization, see [doc](https://pytorch.org/docs/stable/generated/torch.vmap.html) (default: None)
"""

def __init__(
self,
estimator: nn.Module,
lmbda: float = 1.0,
n_samples: int = 16,
calibration: bool = False,
sort_kwargs: dict = None,
vmap_chunk_size: int = None,
):
super().__init__()

self.estimator = estimator
self.lmbda = lmbda
self.n_samples = n_samples
if calibration:
self.activation = lambda input: torch.abs(input)
else:
self.activation = torch.nn.ReLU()
if sort_kwargs is None:
self.sort_kwargs = dict()
else:
self.sort_kwargs = sort_kwargs
self.vmap_chunk_size = vmap_chunk_size

def rsample_and_log_prob(
self, x: Tensor, shape: Size = ()
) -> Tuple[Tensor, Tensor]:
r"""
Arguments:
x: The observation :math:`x`, with shape :math:`(*, L)`.
Returns:
A tuple containing samples from the model
:math:`\log p_\phi(\theta | x)`, with shape :math:`(*, D, *shape)`
and their log-density :math:`\log p_\phi(\theta | x)`,
with shape :math:`(*, *shape)`.
"""

return tuple(
map(
lambda t: t.movedim(1, 0),
self.estimator.flow(x).rsample_and_log_prob(shape),
)
)

def forward(self, theta: Tensor, x: Tensor) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The scalar loss :math:`l`.
"""
log_p = self.estimator(theta, x)
lr = self.regularizer(theta, x, log_p)

return -log_p.mean() + self.lmbda * lr

def get_cdfs(self, ranks):
alpha = torchsort.soft_sort(ranks.unsqueeze(0), **self.sort_kwargs).squeeze()
return (
torch.linspace(0.0, 1.0, len(alpha) + 1, device=alpha.device)[1:],
alpha,
)

def get_rank_statistics(self, x, nominal_logq):
q = torch.cat(
[
nominal_logq.unsqueeze(-1),
self.rsample_and_log_prob(
x,
(self.n_samples,),
)[1],
],
dim=1,
).exp()
return STEhardtanh.apply(q[:, 0].unsqueeze(1) - q[:, 1:]).mean(dim=1)

def regularizer(self, theta: Tensor, x: Tensor, nominal_logq=None) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The scalar regularizer term :math:`r`.
"""
ranks = self.get_rank_statistics(
x, nominal_logq if nominal_logq is not None else self.estimator(theta, x)
)
target_cdf, ecdf = self.get_cdfs(ranks)
return self.activation(target_cdf - ecdf).max()


class FMPE(nn.Module):
r"""Creates a flow matching posterior estimation (FMPE) network.
Expand Down
18 changes: 12 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def test_AMNRE():

assert log_r.shape == (256,)

grad = torch.autograd.functional.jacobian(lambda theta: estimator(theta, x, b).sum(), theta)
grad = torch.autograd.functional.jacobian(
lambda theta: estimator(theta, x, b).sum(), theta
)

assert (grad[~b] == 0).all()

Expand Down Expand Up @@ -128,14 +130,18 @@ def test_NPE():

def test_NPELoss():
estimator = NPE(3, 5)
loss = NPELoss(estimator)
losses = [
NPELoss(estimator),
CalNPELoss(estimator),
]

theta, x = randn(256, 3), randn(256, 5)
for loss in losses:
theta, x = randn(256, 3), randn(256, 5)

l = loss(theta, x)
l = loss(theta, x)

assert l.shape == ()
assert l.requires_grad
assert l.shape == ()
assert l.requires_grad


def test_FMPE():
Expand Down

0 comments on commit 9c472ec

Please sign in to comment.