Skip to content

Commit

Permalink
Feat 22: Binary Crossentropy Loss Implementation (#51)
Browse files Browse the repository at this point in the history
* feat:version0, BinaryCrossEntropy loss function implementation

* feat:version1, BinaryCrossEntropy loss function implementation

* refactor:BinaryCrossEntropy class documentation

* refactor: binary_cross_entropy function

* refactor: binary_cross_entropy function

* refactor: binary-cross-entropy function

* docs

* refactor: tests Binary-Cross-Entropy loss

* refactor: tests Binary-Cross-Entropy loss

* Fixed logit-bce calculation and refactored naming

* Added compatibility tests wrt tf.keras

* Exported BCE class and function from losses module

* Added BCE to current documentation

* Improved logit-cbe for numerical stability

* Refined logit testing

* Added upper (1.0) clipping for non-logit bce

* Added class docstring

Co-authored-by: Sebastian Arango <[email protected]>
  • Loading branch information
haruiz and sebasarango1180 authored Jul 21, 2020
1 parent 7bcaf18 commit 9fc946d
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
.idea/
MANIFEST

# PyInstaller
Expand Down
8 changes: 8 additions & 0 deletions docs/api/losses/BinaryCrossentropy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# elegy.losses.BinaryCrossentropy

::: elegy.losses.BinaryCrossentropy
selection:
inherited_members: true
members:
- call
- __init__
2 changes: 1 addition & 1 deletion elegy/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .loss import Loss, Reduction
from .categorical_crossentropy import CategoricalCrossentropy

from .mean_squared_error import MeanSquaredError, mean_squared_error
from .sparse_categorical_crossentropy import (
SparseCategoricalCrossentropy,
sparse_categorical_crossentropy,
)

from .mean_absolute_error import MeanAbsoluteError, mean_absolute_error
from .binary_crossentropy import BinaryCrossentropy, binary_crossentropy
107 changes: 107 additions & 0 deletions elegy/losses/binary_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from elegy import types
import typing as tp
import jax
import jax.numpy as jnp
from elegy import utils
from elegy.losses.loss import Loss, Reduction


def binary_crossentropy(
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
from_logits: bool = False
) -> jnp.ndarray:

if from_logits:
return -jnp.mean(y_true * y_pred - jnp.logaddexp(0.0, y_pred), axis=-1)

y_pred = jnp.clip(y_pred, utils.EPSILON, 1.0 - utils.EPSILON)
return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred), axis=-1)



class BinaryCrossentropy(Loss):
"""
Computes the cross-entropy loss between true labels and predicted labels.
Use this cross-entropy loss when there are only two label classes (assumed to
be 0 and 1). For each example, there should be a single floating-point value
per prediction.
In the snippet below, each of the four examples has only a single
floating-pointing value, and both `y_pred` and `y_true` have the shape
`[batch_size]`.
Usage:
```python
y_true = jnp.array([[0., 1.], [0., 0.]])
y_pred = jnp.array[[0.6, 0.4], [0.4, 0.6]])
# Using 'auto'/'sum_over_batch_size' reduction type.
bce = elegy.losses.BinaryCrossentropy()
result = bce(y_true, y_pred)
assert jnp.isclose(result, 0.815, rtol=0.01)
# Calling with 'sample_weight'.
bce = elegy.losses.BinaryCrossentropy()
result = bce(y_true, y_pred, sample_weight=jnp.array([1, 0]))
assert jnp.isclose(result, 0.458, rtol=0.01)
# Using 'sum' reduction type.
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM)
result = bce(y_true, y_pred)
assert jnp.isclose(result, 1.630, rtol=0.01)
# Using 'none' reduction type.
bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE)
result = bce(y_true, y_pred)
assert jnp.all(jnp.isclose(result, [0.916, 0.713], rtol=0.01))
```
Usage with the `compile` API:
```python
model = elegy.Model(
module_fn,
loss=elegy.losses.BinaryCrossentropy(),
metrics=elegy.metrics.Accuracy.defer(),
optimizer=optix.adam(1e-3),
)
```
"""

def __init__(
self,
from_logits=False,
label_smoothing: float=0,
reduction: tp.Optional[Reduction] = None,
name: tp.Optional[str] = None
):
super().__init__(reduction=reduction, name=name)
self._from_logits = from_logits
self._label_smoothing = label_smoothing

def call(
self,
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Invokes the `BinaryCrossentropy` instance.
Arguments:
y_true: Ground truth values.
y_pred: The predicted values.
sample_weight: Acts as a
coefficient for the loss. If a scalar is provided, then the loss is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the total loss for each sample of the batch is
rescaled by the corresponding element in the `sample_weight` vector. If
the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
broadcasted to this shape), then each loss element of `y_pred` is scaled
by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
functions reduce by 1 dimension, usually axis=-1.)
Returns:
Loss values per sample.
"""

return binary_crossentropy(y_true, y_pred, from_logits=self._from_logits)
66 changes: 66 additions & 0 deletions elegy/losses/binary_crossentropy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import elegy
from haiku.testing import transform_and_run
import jax.numpy as jnp
import tensorflow.keras as tfk


@transform_and_run
def test_basic():
y_true = jnp.array([[0., 1.], [0., 0.]])
y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]])

bce = elegy.losses.BinaryCrossentropy()
result = bce(y_true, y_pred)
assert jnp.isclose(result, 0.815, rtol=0.01)

y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred)
bce = elegy.losses.BinaryCrossentropy(from_logits=True)
result_from_logits = bce(y_true, y_logits)
assert jnp.isclose(result_from_logits, 0.815, rtol=0.01)
assert jnp.isclose(result_from_logits, result, rtol=0.01)

bce = elegy.losses.BinaryCrossentropy()
result = bce(y_true, y_pred, sample_weight=jnp.array([1, 0]))
assert jnp.isclose(result, 0.458, rtol=0.01)

bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM)
result = bce(y_true, y_pred)
assert jnp.isclose(result, 1.630, rtol=0.01)

bce = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE)
result = bce(y_true, y_pred)
assert jnp.all(jnp.isclose(result, [0.916, 0.713], rtol=0.01))


@transform_and_run
def test_compatibility():

# Input: true (y_true) and predicted (y_pred) tensors
y_true = jnp.array([[0., 1.], [0., 0.]])
y_pred = jnp.array([[0.6, 0.4], [0.4, 0.6]])

# Standard BCE, considering prediction tensor as probabilities
bce_elegy = elegy.losses.BinaryCrossentropy()
bce_tfk = tfk.losses.BinaryCrossentropy()
assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001)

# Standard BCE, considering prediction tensor as logits
y_logits = jnp.log(y_pred) - jnp.log(1 - y_pred)
bce_elegy = elegy.losses.BinaryCrossentropy(from_logits=True)
bce_tfk = tfk.losses.BinaryCrossentropy(from_logits=True)
assert jnp.isclose(bce_elegy(y_true, y_logits), bce_tfk(y_true, y_logits), rtol=0.0001)

# BCE using sample_weight
bce_elegy = elegy.losses.BinaryCrossentropy()
bce_tfk = tfk.losses.BinaryCrossentropy()
assert jnp.isclose(bce_elegy(y_true, y_pred, sample_weight=jnp.array([1, 0])), bce_tfk(y_true, y_pred, sample_weight=jnp.array([1, 0])), rtol=0.0001)

# BCE with reduction method: SUM
bce_elegy = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.SUM)
bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.SUM)
assert jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001)

# BCE with reduction method: NONE
bce_elegy = elegy.losses.BinaryCrossentropy(reduction=elegy.losses.Reduction.NONE)
bce_tfk = tfk.losses.BinaryCrossentropy(reduction=tfk.losses.Reduction.NONE)
assert jnp.all(jnp.isclose(bce_elegy(y_true, y_pred), bce_tfk(y_true, y_pred), rtol=0.0001))
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ nav:
MeanSquaredError: api/metrics/MeanSquaredError.md
losses:
Loss: api/losses/Loss.md
BinaryCrossentropy: api/losses/BinaryCrossentropy.md
CategoricalCrossentropy: api/losses/CategoricalCrossentropy.md
SparseCategoricalCrossentropy: api/losses/SparseCategoricalCrossentropy.md
MeanAbsoluteError: api/losses/MeanAbsoluteError.md
Expand Down

0 comments on commit 9fc946d

Please sign in to comment.