Skip to content

Commit

Permalink
Add CalNRELoss
Browse files Browse the repository at this point in the history
  • Loading branch information
Maciej Falkiewicz committed Dec 30, 2023
1 parent 19d82ee commit 31247d8
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
149 changes: 149 additions & 0 deletions lampe/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'BNRELoss',
'CNRELoss',
'BCNRELoss',
'CalNRELoss',
'AMNRE',
'AMNRELoss',
'NPE',
Expand All @@ -19,6 +20,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import vmap
import torchsort

from itertools import islice
from torch import Tensor, BoolTensor
Expand Down Expand Up @@ -199,6 +202,152 @@ def forward(self, theta: Tensor, x: Tensor) -> Tensor:
return (l1 + l0) / 2 + self.lmbda * lb


class STEhardtanh(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return (input > 0).float()

@staticmethod
def backward(ctx, grad_output):
return F.hardtanh(grad_output)


class CalNRELoss(nn.Module):
r"""Creates a module that calculates the calibration/conservativeness loss for an
NRE network.
Given a batch of :math:`N \geq 2` pairs :math:`(\theta_i, x_i)`, the module returns
.. math::
l & = \frac{1}{2N} \sum_{i = 1}^N
\ell(d_\phi(\theta_i, x_i)) + \ell(1 - d_\phi(\theta_{i+1}, x_i)) \\
& + \lambda \max_j | \text{ECP}(1 - \alpha_j) - (1 - \alpha_j)|
where :math:`\ell(p) = -\log p` is the negative log-likelihood and
:math:`\text{ECP}(1 - \alpha_j)` is the Expected Coverage Probability at
credibility level :math:`1 - \alpha_j`.
References:
| Calibrating Neural Simulation-Based Inference with Differentiable Coverage Probability (Falkiewicz et al., 2023)
| https://arxiv.org/abs/2310.13402
Arguments:
estimator: A log-ratio network :math:`\log r_\phi(\theta, x)`.
prior: Prior distribution module :math:`p(\theta)`.
lmbda: The weight :math:`\lambda` controlling the strength of the regularizer.
n_samples: Number of samples in MC estimate of rank statistic
calibration: Boolean flag of calibration objective (default: False)
sort_kwargs: Arguments of differentiable sorting algorithm, see [doc](https://github.com/teddykoker/torchsort#usage) (default: None)
vmap_chunk_size: Chunk size for vectorization, see [doc](https://pytorch.org/docs/stable/generated/torch.vmap.html) (default: None)
"""

def __init__(
self,
estimator: nn.Module,
prior: nn.Module,
lmbda: float = 5.0,
n_samples: int = 16,
calibration: bool = False,
sort_kwargs: dict = None,
vmap_chunk_size: int = None,
):
super().__init__()

self.estimator = estimator
self.prior = prior
self.lmbda = lmbda
self.n_samples = n_samples
if calibration:
self.activation = lambda input: torch.abs(input)
else:
self.activation = torch.nn.ReLU()
if sort_kwargs is None:
self.sort_kwargs = dict()
else:
self.sort_kwargs = sort_kwargs
self.vmap_chunk_size = vmap_chunk_size

def log_prob(self, parameter, observation):
return self.prior.log_prob(parameter) + self.estimator(parameter, observation)

def forward(self, theta: Tensor, x: Tensor) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The scalar loss :math:`l`.
"""

theta_prime = torch.roll(theta, 1, dims=0)

log_r, log_r_prime = self.estimator(
torch.stack((theta, theta_prime)),
x,
)

l1 = -F.logsigmoid(log_r).mean()
l0 = -F.logsigmoid(-log_r_prime).mean()
lr = self.regularizer(theta, x)

return (l1 + l0) / 2 + self.lmbda * lr

def get_cdfs(self, ranks):
alpha = torchsort.soft_sort(ranks.unsqueeze(0), **self.sort_kwargs).squeeze()
return (
torch.linspace(0.0, 1.0, len(alpha) + 1, device=alpha.device)[1:],
alpha,
)

def get_rank_statistics(
self,
theta,
x,
):
logq = vmap(
self.log_prob,
in_dims=(1, None),
out_dims=1,
randomness="different",
chunk_size=self.vmap_chunk_size,
)(
torch.cat(
[
theta.unsqueeze(1),
self.prior.sample(
(
theta.shape[0],
self.n_samples,
)
),
],
dim=1,
),
x,
)
q = logq.exp()
return (q[:, 1:] * STEhardtanh.apply(q[:, 0].unsqueeze(1) - q[:, 1:])).sum(
dim=1
) / logq[:, 1:].logsumexp(dim=1).exp()

def regularizer(self, theta: Tensor, x: Tensor) -> Tensor:
r"""
Arguments:
theta: The parameters :math:`\theta`, with shape :math:`(N, D)`.
x: The observation :math:`x`, with shape :math:`(N, L)`.
Returns:
The regularizer term for every instance :math:`(N,)`.
"""
ranks = self.get_rank_statistics(
theta,
x,
)
target_cdf, ecdf = self.get_cdfs(ranks)
return self.activation(target_cdf - ecdf).max()


class CNRELoss(nn.Module):
r"""Creates a module that calculates the cross-entropy loss for a contrastive NRE
(CNRE) network.
Expand Down
2 changes: 2 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

def test_NRE():
estimator = NRE(3, 5)
prior = torch.distributions.MultivariateNormal(torch.zeros(3), torch.eye(3))

# Non-batched
theta, x = randn(3), randn(5)
Expand All @@ -36,6 +37,7 @@ def test_NRE():
BNRELoss(estimator),
CNRELoss(estimator),
BCNRELoss(estimator),
CalNRELoss(estimator, prior),
]

for loss in losses:
Expand Down

0 comments on commit 31247d8

Please sign in to comment.