Skip to content

Commit

Permalink
Merge pull request #77 from deel-ai/feat/TauSparseCategoricalCrossent…
Browse files Browse the repository at this point in the history
…ropy

New cross-entropy losses with temperature

This PR introduces two new cross-entropy losses based on Keras standard losses and with a settable temperature for softmax: TauSparseCategoricalCrossentropy equivalent to Keras SparseCategoricalCrossentropy, and TauBinaryCrossentropy equivalent to Keras BinaryCrossentropy.
  • Loading branch information
cofri committed Oct 10, 2023
2 parents f37c1e0 + f2c5da2 commit 47a4c3f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 3 deletions.
70 changes: 69 additions & 1 deletion deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from functools import partial
import numpy as np
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy, Loss, Reduction
from tensorflow.keras.losses import (
categorical_crossentropy,
sparse_categorical_crossentropy,
Loss,
Reduction,
)
from tensorflow.keras.utils import register_keras_serializable


Expand Down Expand Up @@ -493,3 +498,66 @@ def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super(TauCategoricalCrossentropy, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "TauSparseCategoricalCrossentropy")
class TauSparseCategoricalCrossentropy(Loss):
def __init__(
self, tau, reduction=Reduction.AUTO, name="TauSparseCategoricalCrossentropy"
):
"""
Similar to original sparse categorical crossentropy, but with a settable
temperature parameter.
Args:
tau (float): temperature parameter.
reduction: reduction of the loss, passed to original loss.
name (str): name of the loss
"""
self.tau = tf.Variable(tau, dtype=tf.float32)
super().__init__(name=name, reduction=reduction)

def call(self, y_true, y_pred):
return (
sparse_categorical_crossentropy(y_true, self.tau * y_pred, from_logits=True)
/ self.tau
)

def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "TauBinaryCrossentropy")
class TauBinaryCrossentropy(Loss):
def __init__(self, tau, reduction=Reduction.AUTO, name="TauBinaryCrossentropy"):
"""
Similar to the original binary crossentropy, but with a settable temperature
parameter. y_pred must be a logits tensor (before sigmoid) and not
probabilities.
Note that `y_true` and `y_pred` must be of rank 2: (batch_size, 1). `y_true`
accepts label values in (0, 1) or (-1, 1).
Args:
tau: temperature parameter.
reduction: reduction of the loss, passed to original loss.
name: name of the loss
"""
self.tau = tf.Variable(tau, dtype=tf.float32)
super().__init__(name=name, reduction=reduction)

def call(self, y_true, y_pred):
y_true = tf.cast(y_true > 0, y_pred.dtype)
return (
tf.keras.losses.binary_crossentropy(
y_true, self.tau * y_pred, from_logits=True
)
/ self.tau
)

def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
62 changes: 60 additions & 2 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
MulticlassHKR,
MultiMargin,
TauCategoricalCrossentropy,
TauSparseCategoricalCrossentropy,
TauBinaryCrossentropy,
CategoricalHinge,
)
from deel.lip.utils import process_labels_for_multi_gpu
Expand Down Expand Up @@ -218,6 +220,43 @@ def test_tau_catcrossent(self):
)
check_serialization(1, taucatcrossent_loss)

def test_tau_sparse_catcrossent(self):
tau_sparse_catcrossent_loss = TauSparseCategoricalCrossentropy(1.0)
n_class = 10
n_items = 10000
y_true = np.random.randint(0, n_class, n_items)
y_pred = tf.random.normal((n_items, n_class))
loss_val = tau_sparse_catcrossent_loss(y_true, y_pred).numpy()
loss_val_2 = tau_sparse_catcrossent_loss(
tf.cast(y_true, dtype=tf.int32), y_pred
).numpy()
np.testing.assert_almost_equal(
loss_val_2, loss_val, 1, "test failed when y_true has dtype int32"
)
check_serialization(n_class, tau_sparse_catcrossent_loss)

def test_tau_binary_crossent(self):
loss = TauBinaryCrossentropy(2.0)
y_true = binary_tf_data([1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
y_pred = binary_tf_data([0.5, 1.5, -0.5, -0.5, -1.5, 0.5])

# Assert that loss value is equal to expected value
expected_loss_val = 0.279185
loss_val = loss(y_true, y_pred).numpy()
np.testing.assert_allclose(loss_val, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is of type int32
loss_val_2 = loss(tf.cast(y_true, dtype=tf.int32), y_pred).numpy()
np.testing.assert_allclose(loss_val_2, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is [-1, 1] instead of [0, 1]
y_true2 = tf.where(y_true == 1.0, 1.0, -1.0)
loss_val_3 = loss(y_true2, y_pred).numpy()
np.testing.assert_allclose(loss_val_3, expected_loss_val, rtol=1e-6)

# Assert that loss object is correctly serialized
check_serialization(1, loss)

def test_no_reduction_binary_losses(self):
"""
Assert binary losses without reduction. Three losses are tested on hardcoded
Expand All @@ -230,12 +269,14 @@ def test_no_reduction_binary_losses(self):
KR(reduction="none"),
HingeMargin(0.7 * 2.0, reduction="none"),
HKR(alpha=2.5, min_margin=2.0, reduction="none"),
TauBinaryCrossentropy(tau=0.5, reduction="none"),
)

expected_loss_values = (
np.array([1.0, 2.2, -0.2, 1.4, 2.6, 0.8, -0.4, 1.8]),
np.array([0.2, 0, 0.8, 0, 0, 0.3, 0.9, 0]),
np.array([0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55]),
[1.15188, 0.91098, 1.43692, 1.06676, 0.84011, 1.19628, 1.48879, 0.98650],
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down Expand Up @@ -275,6 +316,7 @@ def test_no_reduction_multiclass_losses(self):
MulticlassHKR(alpha=2.5, min_margin=1.0, reduction="none"),
CategoricalHinge(1.1, reduction="none"),
TauCategoricalCrossentropy(2.0, reduction="none"),
TauSparseCategoricalCrossentropy(2.0, reduction="none"),
)

expected_loss_values = (
Expand All @@ -295,10 +337,25 @@ def test_no_reduction_multiclass_losses(self):
0.076357,
]
),
np.float64(
[
0.044275,
0.115109,
1.243572,
0.084923,
0.010887,
2.802300,
0.114224,
0.076357,
]
),
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
loss_val = loss(y_true, y_pred)
if isinstance(loss, TauSparseCategoricalCrossentropy):
loss_val = loss(tf.argmax(y_true, axis=-1), y_pred)
else:
loss_val = loss(y_true, y_pred)
np.testing.assert_allclose(
loss_val,
expected_loss_val,
Expand All @@ -325,9 +382,10 @@ def test_minibatches_binary_losses(self):
KR(multi_gpu=True, reduction=reduction),
HingeMargin(0.7 * 2.0, reduction=reduction),
HKR(alpha=2.5, min_margin=2.0, multi_gpu=True, reduction=reduction),
TauBinaryCrossentropy(tau=1.5, reduction=reduction),
)

expected_loss_values = (9.2, 2.2, 0.3)
expected_loss_values = (9.2, 2.2, 0.3, 2.19262)

# Losses are tested on full batch
for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down

0 comments on commit 47a4c3f

Please sign in to comment.