-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat 22: Binary Crossentropy Loss Implementation (#51)
* feat:version0, BinaryCrossEntropy loss function implementation * feat:version1, BinaryCrossEntropy loss function implementation * refactor:BinaryCrossEntropy class documentation * refactor: binary_cross_entropy function * refactor: binary_cross_entropy function * refactor: binary-cross-entropy function * docs * refactor: tests Binary-Cross-Entropy loss * refactor: tests Binary-Cross-Entropy loss * Fixed logit-bce calculation and refactored naming * Added compatibility tests wrt tf.keras * Exported BCE class and function from losses module * Added BCE to current documentation * Improved logit-cbe for numerical stability * Refined logit testing * Added upper (1.0) clipping for non-logit bce * Added class docstring Co-authored-by: Sebastian Arango <[email protected]>
- Loading branch information
1 parent
7bcaf18
commit 9fc946d
Showing
6 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ share/python-wheels/ | |
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
.idea/ | ||
MANIFEST | ||
|
||
# PyInstaller | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# elegy.losses.BinaryCrossentropy | ||
|
||
::: elegy.losses.BinaryCrossentropy | ||
selection: | ||
inherited_members: true | ||
members: | ||
- call | ||
- __init__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
from .loss import Loss, Reduction | ||
from .categorical_crossentropy import CategoricalCrossentropy | ||
|
||
from .mean_squared_error import MeanSquaredError, mean_squared_error | ||
from .sparse_categorical_crossentropy import ( | ||
SparseCategoricalCrossentropy, | ||
sparse_categorical_crossentropy, | ||
) | ||
|
||
from .mean_absolute_error import MeanAbsoluteError, mean_absolute_error | ||
from .binary_crossentropy import BinaryCrossentropy, binary_crossentropy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from elegy import types | ||
import typing as tp | ||
import jax | ||
import jax.numpy as jnp | ||
from elegy import utils | ||
from elegy.losses.loss import Loss, Reduction | ||
|
||
|
||
def binary_crossentropy( | ||
y_true: jnp.ndarray, | ||
y_pred: jnp.ndarray, | ||
from_logits: bool = False | ||
) -> jnp.ndarray: | ||
|
||
if from_logits: | ||
return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1) | ||
|
||
y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON) | ||
return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1) | ||
|
||
|
||
|
||
class BinaryCrossentropy(Loss): | ||
""" | ||
Computes the cross-entropy loss between true labels and predicted labels. | ||
Use this cross-entropy loss when there are only two label classes (assumed to | ||
be 0 and 1). For each example, there should be a single floating-point value | ||
per prediction. | ||
In the snippet below, each of the four examples has only a single | ||
floating-pointing value, and both `y_pred` and `y_true` have the shape | ||
`[batch_size]`. | ||
Usage: | ||
```python | ||
y_true = jnp.array([[0., 1.], [0., 0.]]) | ||
y_pred = jnp.array[[0.6, 0.4], [0.4, 0.6]]) | ||
# Using 'auto'/'sum_over_batch_size' reduction type. | ||
bce = elegy.losses.BinaryCrossentropy() | ||
result = bce(y_true, y_pred) | ||
assert jnp.isclose(result, 0.815, rtol=0.01) | ||
# Calling with 'sample_weight'. | ||
bce = elegy.losses.BinaryCrossentropy() | ||
result = bce(y_true, y_pred, sample_weight=jnp.array([1, 0])) | ||
assert jnp.isclose(result, 0.458, rtol=0.01) | ||
# Using 'sum' reduction type. | ||
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM) | ||
result = bce(y_true, y_pred) | ||
assert jnp.isclose(result, 1.630, rtol=0.01) | ||
# Using 'none' reduction type. | ||
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE) | ||
result = bce(y_true, y_pred) | ||
assert jnp.all(jnp.isclose(result, [0.916, 0.713], rtol=0.01)) | ||
``` | ||
Usage with the `compile` API: | ||
```python | ||
model = elegy.Model( | ||
module_fn, | ||
loss=elegy.losses.BinaryCrossentropy(), | ||
metrics=elegy.metrics.Accuracy.defer(), | ||
optimizer=optix.adam(1e-3), | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
from_logits=False, | ||
label_smoothing: float=0, | ||
reduction: tp.Optional[Reduction] = None, | ||
name: tp.Optional[str] = None | ||
): | ||
super().__init__(reduction=reduction, name=name) | ||
self._from_logits = from_logits | ||
self._label_smoothing = label_smoothing | ||
|
||
def call( | ||
self, | ||
y_true: jnp.ndarray, | ||
y_pred: jnp.ndarray, | ||
sample_weight: tp.Optional[jnp.ndarray] = None, | ||
) -> jnp.ndarray: | ||
""" | ||
Invokes the `BinaryCrossentropy` instance. | ||
Arguments: | ||
y_true: Ground truth values. | ||
y_pred: The predicted values. | ||
sample_weight: Acts as a | ||
coefficient for the loss. If a scalar is provided, then the loss is | ||
simply scaled by the given value. If `sample_weight` is a tensor of size | ||
`[batch_size]`, then the total loss for each sample of the batch is | ||
rescaled by the corresponding element in the `sample_weight` vector. If | ||
the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be | ||
broadcasted to this shape), then each loss element of `y_pred` is scaled | ||
by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss | ||
functions reduce by 1 dimension, usually axis=-1.) | ||
Returns: | ||
Loss values per sample. | ||
""" | ||
|
||
return binary_crossentropy(y_true, y_pred, from_logits=self._from_logits) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import elegy | ||
from haiku.testing import transform_and_run | ||
import jax.numpy as jnp | ||
import tensorflow.keras as tfk | ||
|
||
|
||
@transform_and_run | ||
def test_basic(): | ||
y_true = jnp.array([[0., 1.], [0., 0.]]) | ||
y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]]) | ||
|
||
bce = elegy.losses.BinaryCrossentropy() | ||
result = bce(y_true, y_pred) | ||
assert jnp.isclose(result, 0.815, rtol=0.01) | ||
|
||
y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred) | ||
bce = elegy.losses.BinaryCrossentropy(from_logits=True) | ||
result_from_logits = bce(y_true, y_logits) | ||
assert jnp.isclose(result_from_logits, 0.815, rtol=0.01) | ||
assert jnp.isclose(result_from_logits, result, rtol=0.01) | ||
|
||
bce = elegy.losses.BinaryCrossentropy() | ||
result = bce(y_true, y_pred, sample_weight=jnp.array([1, 0])) | ||
assert jnp.isclose(result, 0.458, rtol=0.01) | ||
|
||
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM) | ||
result = bce(y_true, y_pred) | ||
assert jnp.isclose(result, 1.630, rtol=0.01) | ||
|
||
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE) | ||
result = bce(y_true, y_pred) | ||
assert jnp.all(jnp.isclose(result, [0.916, 0.713], rtol=0.01)) | ||
|
||
|
||
@transform_and_run | ||
def test_compatibility(): | ||
|
||
# Input: true (y_true) and predicted (y_pred) tensors | ||
y_true = jnp.array([[0., 1.], [0., 0.]]) | ||
y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]]) | ||
|
||
# Standard BCE, considering prediction tensor as probabilities | ||
bce_elegy = elegy.losses.BinaryCrossentropy() | ||
bce_tfk = tfk.losses.BinaryCrossentropy() | ||
assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001) | ||
|
||
# Standard BCE, considering prediction tensor as logits | ||
y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred) | ||
bce_elegy = elegy.losses.BinaryCrossentropy(from_logits=True) | ||
bce_tfk = tfk.losses.BinaryCrossentropy(from_logits=True) | ||
assert jnp.isclose(bce_elegy(y_true, y_logits), bce_tfk(y_true, y_logits), rtol=0.0001) | ||
|
||
# BCE using sample_weight | ||
bce_elegy = elegy.losses.BinaryCrossentropy() | ||
bce_tfk = tfk.losses.BinaryCrossentropy() | ||
assert jnp.isclose(bce_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])), bce_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])), rtol=0.0001) | ||
|
||
# BCE with reduction method: SUM | ||
bce_elegy = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM) | ||
bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.SUM) | ||
assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001) | ||
|
||
# BCE with reduction method: NONE | ||
bce_elegy = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE) | ||
bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.NONE) | ||
assert jnp.all(jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters