Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New cross-entropy losses with temperature #77

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion deel/lip/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from functools import partial
import numpy as np
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy, Loss, Reduction
from tensorflow.keras.losses import (
categorical_crossentropy,
sparse_categorical_crossentropy,
Loss,
Reduction,
)
from tensorflow.keras.utils import register_keras_serializable


Expand Down Expand Up @@ -493,3 +498,66 @@ def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super(TauCategoricalCrossentropy, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "TauSparseCategoricalCrossentropy")
class TauSparseCategoricalCrossentropy(Loss):
def __init__(
self, tau, reduction=Reduction.AUTO, name="TauSparseCategoricalCrossentropy"
):
"""
Similar to original sparse categorical crossentropy, but with a settable
temperature parameter.

Args:
tau (float): temperature parameter.
reduction: reduction of the loss, passed to original loss.
name (str): name of the loss
"""
self.tau = tf.Variable(tau, dtype=tf.float32)
super().__init__(name=name, reduction=reduction)

def call(self, y_true, y_pred):
return (
sparse_categorical_crossentropy(y_true, self.tau * y_pred, from_logits=True)
/ self.tau
)

def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "TauBinaryCrossentropy")
class TauBinaryCrossentropy(Loss):
def __init__(self, tau, reduction=Reduction.AUTO, name="TauBinaryCrossentropy"):
"""
Similar to the original binary crossentropy, but with a settable temperature
parameter. y_pred must be a logits tensor (before sigmoid) and not
probabilities.

Note that `y_true` and `y_pred` must be of rank 2: (batch_size, 1). `y_true`
accepts label values in (0, 1) or (-1, 1).

Args:
tau: temperature parameter.
reduction: reduction of the loss, passed to original loss.
name: name of the loss
"""
self.tau = tf.Variable(tau, dtype=tf.float32)
super().__init__(name=name, reduction=reduction)

def call(self, y_true, y_pred):
y_true = tf.cast(y_true > 0, y_pred.dtype)
return (
tf.keras.losses.binary_crossentropy(
y_true, self.tau * y_pred, from_logits=True
)
/ self.tau
)

def get_config(self):
config = {"tau": self.tau.numpy()}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
62 changes: 60 additions & 2 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
MulticlassHKR,
MultiMargin,
TauCategoricalCrossentropy,
TauSparseCategoricalCrossentropy,
TauBinaryCrossentropy,
CategoricalHinge,
)
from deel.lip.utils import process_labels_for_multi_gpu
Expand Down Expand Up @@ -218,6 +220,43 @@ def test_tau_catcrossent(self):
)
check_serialization(1, taucatcrossent_loss)

def test_tau_sparse_catcrossent(self):
tau_sparse_catcrossent_loss = TauSparseCategoricalCrossentropy(1.0)
n_class = 10
n_items = 10000
y_true = np.random.randint(0, n_class, n_items)
y_pred = tf.random.normal((n_items, n_class))
loss_val = tau_sparse_catcrossent_loss(y_true, y_pred).numpy()
loss_val_2 = tau_sparse_catcrossent_loss(
tf.cast(y_true, dtype=tf.int32), y_pred
).numpy()
np.testing.assert_almost_equal(
loss_val_2, loss_val, 1, "test failed when y_true has dtype int32"
)
check_serialization(n_class, tau_sparse_catcrossent_loss)

def test_tau_binary_crossent(self):
loss = TauBinaryCrossentropy(2.0)
y_true = binary_tf_data([1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
y_pred = binary_tf_data([0.5, 1.5, -0.5, -0.5, -1.5, 0.5])

# Assert that loss value is equal to expected value
expected_loss_val = 0.279185
loss_val = loss(y_true, y_pred).numpy()
np.testing.assert_allclose(loss_val, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is of type int32
loss_val_2 = loss(tf.cast(y_true, dtype=tf.int32), y_pred).numpy()
np.testing.assert_allclose(loss_val_2, expected_loss_val, rtol=1e-6)

# Assert that loss value is the same when y_true is [-1, 1] instead of [0, 1]
y_true2 = tf.where(y_true == 1.0, 1.0, -1.0)
loss_val_3 = loss(y_true2, y_pred).numpy()
np.testing.assert_allclose(loss_val_3, expected_loss_val, rtol=1e-6)

# Assert that loss object is correctly serialized
check_serialization(1, loss)

def test_no_reduction_binary_losses(self):
"""
Assert binary losses without reduction. Three losses are tested on hardcoded
Expand All @@ -230,12 +269,14 @@ def test_no_reduction_binary_losses(self):
KR(reduction="none"),
HingeMargin(0.7 * 2.0, reduction="none"),
HKR(alpha=2.5, min_margin=2.0, reduction="none"),
TauBinaryCrossentropy(tau=0.5, reduction="none"),
)

expected_loss_values = (
np.array([1.0, 2.2, -0.2, 1.4, 2.6, 0.8, -0.4, 1.8]),
np.array([0.2, 0, 0.8, 0, 0, 0.3, 0.9, 0]),
np.array([0.25, -2.2, 2.95, -0.65, -2.6, 0.7, 3.4, -1.55]),
[1.15188, 0.91098, 1.43692, 1.06676, 0.84011, 1.19628, 1.48879, 0.98650],
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down Expand Up @@ -275,6 +316,7 @@ def test_no_reduction_multiclass_losses(self):
MulticlassHKR(alpha=2.5, min_margin=1.0, reduction="none"),
CategoricalHinge(1.1, reduction="none"),
TauCategoricalCrossentropy(2.0, reduction="none"),
TauSparseCategoricalCrossentropy(2.0, reduction="none"),
)

expected_loss_values = (
Expand All @@ -295,10 +337,25 @@ def test_no_reduction_multiclass_losses(self):
0.076357,
]
),
np.float64(
[
0.044275,
0.115109,
1.243572,
0.084923,
0.010887,
2.802300,
0.114224,
0.076357,
]
),
)

for loss, expected_loss_val in zip(losses, expected_loss_values):
loss_val = loss(y_true, y_pred)
if isinstance(loss, TauSparseCategoricalCrossentropy):
loss_val = loss(tf.argmax(y_true, axis=-1), y_pred)
else:
loss_val = loss(y_true, y_pred)
np.testing.assert_allclose(
loss_val,
expected_loss_val,
Expand All @@ -325,9 +382,10 @@ def test_minibatches_binary_losses(self):
KR(multi_gpu=True, reduction=reduction),
HingeMargin(0.7 * 2.0, reduction=reduction),
HKR(alpha=2.5, min_margin=2.0, multi_gpu=True, reduction=reduction),
TauBinaryCrossentropy(tau=1.5, reduction=reduction),
)

expected_loss_values = (9.2, 2.2, 0.3)
expected_loss_values = (9.2, 2.2, 0.3, 2.19262)

# Losses are tested on full batch
for loss, expected_loss_val in zip(losses, expected_loss_values):
Expand Down