Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Noise Contrastive Estimation Loss #29

Merged
merged 9 commits into from
Feb 18, 2024
80 changes: 80 additions & 0 deletions src/modalities/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,83 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
# Flatten the tokens
loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss


def nce_loss(
embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float
) -> torch.Tensor:
"""
This implementation calculates the noise contrastive estimation loss between embeddings of two different modalities
Implementation slightly adapted from https://arxiv.org/pdf/1912.06430.pdf, https://github.com/antoine77340/MIL-NCE_HowTo100M
changes include adding a temperature value and the choice of calculating asymmetric loss w.r.t. one modality
This implementation is adapted to contrastive loss from CoCa model https://arxiv.org/pdf/2205.01917.pdf

Args:
embedding1 (torch.Tensor): embeddings from modality 1 of size batch_size x embed_dim.
embedding2 (torch.Tensor): embeddings from modality 2 of size batch_size x embed_dim.
device (torch.device): torch device for calculating loss.
is_asymmetric (bool): boolean value to specify if the loss is calculated in one direction or both directions.
temperature (float): temperature value for regulating loss.

Returns:
torch.Tensor: loss tensor.
"""
# calculating the similarity matrix of size (batch_size x batch_size)
sim_matrix = torch.matmul(embedding1, embedding2.t()) / temperature
# numerator of loss: using similarity scores for all positive pairs (e.g., image and its caption)
numerator = sim_matrix * torch.eye(sim_matrix.shape[0], device=device)
numerator = numerator.sum(dim=0).view(sim_matrix.shape[0], -1)
numerator = torch.logsumexp(numerator, dim=1)
if is_asymmetric:
# denominator of loss: using all similarity scores for all pairs (positive and negative)
denominator = torch.logsumexp(sim_matrix, dim=1)
else:
# calculate bidirectional loss
numerator *= 2
denominator = torch.logsumexp(sim_matrix, dim=1) + torch.logsumexp(sim_matrix.t(), dim=1)
return torch.mean(denominator - numerator) # calculated in log space


class NCELoss(Loss):
def __init__(
self,
prediction_key1: str,
prediction_key2: str,
is_asymmetric: bool = True,
temperature: float = 1.0,
tag: str = "NCELoss",
):
"""
Noise Contrastive Estimation Loss

Args:
prediction_key1 (str): key to access embedding 1.
prediction_key2 (str): key to access embedding 2.
is_asymmetric (bool, optional): specifies symmetric or asymmetric calculation of NCEloss. Defaults to True.
temperature (float, optional): temperature. Defaults to 1.0.
tag (str, optional): Defaults to "NCELoss".
"""
super().__init__(tag)
self.prediction_key1 = prediction_key1
self.prediction_key2 = prediction_key2
self.is_asymmetric = is_asymmetric
self.temperature = temperature

def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
"""
Args:
forward_batch (InferenceResultBatch): data batch.

Returns:
torch.Tensor: loss tensor.
"""
embedding1 = forward_batch.get_predictions(self.prediction_key1)
embedding2 = forward_batch.get_predictions(self.prediction_key2)

contiguous_embedding1 = embedding1.contiguous()
contiguous_embedding2 = embedding2.contiguous()

loss = nce_loss(
contiguous_embedding1, contiguous_embedding2, embedding1.device, self.is_asymmetric, self.temperature
)
return loss
38 changes: 38 additions & 0 deletions tests/test_loss_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch

from modalities.batch import InferenceResultBatch
from modalities.loss_functions import NCELoss, nce_loss


@pytest.fixture
def dummy_result_batch() -> InferenceResultBatch:
predictions = {"embedding": torch.rand(1024, 512)}
targets = {"target": torch.zeros(1024, 512)}
batch_dim = 1024
result_batch = InferenceResultBatch(targets, predictions, batch_dim)
return result_batch


# calculating asymmetric NCELoss between a batch of embeddings and itself --> zero
@pytest.mark.parametrize("key", ["embedding"])
def test_asymm_NCELoss_is_zero(dummy_result_batch, key):
loss_func = NCELoss(prediction_key1=key, prediction_key2=key)
assert loss_func(dummy_result_batch) <= 10e-6


# calculating nce_loss for two randomly generated batch of embeddings (manually calculated)
@pytest.mark.parametrize(
"embedding1,embedding2",
[
(
torch.Tensor([[0.38, 0.18], [0.36, 0.66], [0.72, 0.09]]),
torch.Tensor([[0.48, 0.01], [0.54, 0.28], [0.08, 0.34]]),
)
],
)
def test_nce_loss_correctness(embedding1, embedding2):
unidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=True, temperature=1.0)
bidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=False, temperature=1.0)
assert unidirectional_loss == pytest.approx(1.1300, 0.0001)
assert bidirectional_loss == pytest.approx(2.2577, 0.0001)