Skip to content

Commit

Permalink
Improve guides
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 23, 2020
1 parent 93d9d35 commit 7ea8749
Showing 1 changed file with 32 additions and 25 deletions.
57 changes: 32 additions & 25 deletions docs/guides/modules-losses-metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This guide goes into depth on how modules, losses and metrics work in Elegy and how to create your own. One of our goals with Elegy was to solve Keras restrictions around the type of losses and metrics you can define.

When creating a complex model with multiple outputs with Keras, say `output_a` and `output_b`, you are forced to define losses and metrics per-output only:
When creating a complex model with multiple outputs in Keras, say `output_a` and `output_b`, you are forced to define losses and metrics per-output only:

```python
model.compile(
Expand Down Expand Up @@ -42,13 +42,13 @@ Elegy solves the previous problems by introducing a _dependency injection_ mecha


!!! Note
The content of `x` is technically passed to the model's `Module` but the parameter name _x_ will bare no special meaning there.
The content of `x` is technically passed to the model's `Module` but the parameter name _x_ will bare no special meaning in that context.


## Modules
Modules define the architecture of the network, their primary task (Elegy terms) is transforming `x` into `y_pred`. To make it easy to consume the content of `x`, Elegy has some special but very simple rules on how the signature of any `Module` can be structured:
Modules define the architecture of the network, their primary task (in Elegy terms) is transforming the inputs `x` into outputs `y_pred`. To make it easy to consume the content of `x`, Elegy has some special but very simple rules on how the signature of any `Module` can be structured:

If `x` is a `tuple`, then `x` will be expanded positional arguments a.k.a. `*args`, this means that the module will have define EXACTLY as many arguments as there are inputs. For example:
**1.** If `x` is a `tuple`, then `x` will be expanded positional arguments a.k.a. `*args`, this means that the module will have define **exactly** as many arguments as there are inputs. For example:

```python hl_lines="2 10"
class SomeModule(elegy.Module):
Expand All @@ -66,7 +66,9 @@ model.fit(
```
In this case `a` is passed as `m` and `b` is passed as `n`.

On the other hand, if `x` is a `dict`, then `x` will be expanded as keyword arguments a.k.a. `**kwargs`, in this case the module can optionally request any key defined in `x` as an argument. For example:
**2.** If `x` is a single array it will be converted internally into a `tuple` containing that array so the module can expect it as a positional argument.

**3.** If `x` is a `dict`, then `x` will be expanded as keyword arguments a.k.a. `**kwargs`, in this case the module can optionally request any key defined in `x` as an argument. For example:

```python hl_lines="2 10"
class SomeModule(elegy.Module):
Expand All @@ -82,7 +84,9 @@ model.fit(
...
)
```
Here `n` is requested by name and you get as input its value `b`, and `m` is safely ignored.
Here `n` is requested by name and you get as input its value `b`, and `m` with the content of `a` is safely ignored.



## Losses
Losses can request all the available parameters that Elegy provides for dependency injection. A typical loss will request the `y_true` and `y_pred` values (as its common / enforced in Keras). The Mean Squared Error loss for example is easily defined in these terms:
Expand All @@ -102,7 +106,7 @@ model.fit(
...
)
```
Here the input `y` is passed as `y_true` as stated previously. However, if for example you want to build an autoencoder you don't actually need `y` since you just want to reconstruct `x`, therefore it makes perfect sense to just request `x` and Elegy lets you do exactly that:
Here the input `y` is passed as `y_true` to `MSE`. However, if you for example want to build an autoencoder then, according to the math, you actually don't need `y` because you are actually trying to reconstruct `x`. It makes perfect sense for this lossto be defined in terms of `x` and Elegy lets you do exactly that:

```python hl_lines="2"
class AutoEncoderLoss(elegy.Loss):
Expand All @@ -118,21 +122,21 @@ model.fit(
...
)
```
Here we only used `x` instead of `y_true` to define the loss as the math usually tells use, therefore no `y` was required on `fit`.

Notice thanks to this we didn't have to define `y` on the `fit` method.

!!! Note
In this case you could have easily just passed `y=X_train` and reused the previous `MSE` definition. However, avoiding the creation of redundant labels is good in general and being explicit about e.g. the function using `x` might even self-document its behaviour.
An alternative here is to just use the previous definition of `MSE` and define `y=X_train`. However, avoiding the creation of redundant information is good in general and being explicit about dependencies might help documenting the behaviour of the model in general.

### Partitioning a loss
If you have a complex loss function that is just a sum of different subparts but that you probably have to compute together, e.g. to reuse some computation, you might define something like this:
If you have a complex loss function that is just a sum of different parts that have to be compute together you might define something like this:
```python
class SomeComplexFunction(elegy.Loss):
def call(self, x, y_true, y_pred, params, ...):
...

return a + b + c
```
Purely for logging purposes you can instead return a `dict` of these losses:
Elegy lets you return a `dict` specifying the name of each part:

```python
class SomeComplexFunction(elegy.Loss):
Expand All @@ -150,21 +154,24 @@ Elegy will use this information to show you each loss separate in the logs / Ten
* `some_complex_function_loss/b`
* `some_complex_function_loss/c`

Each individual loss will still be subject to the `sample_weight` and `reduction` behavior as specified to `SomeComplexFunction`.

### Multiple Outputs + Labels
The models constructor `loss` argument can accept a single `Loss`, a `list` or `dict` of losses, and even nested structures of the previous, yet the form of `loss` is not strictly related to structure or numbers of labels and outputs of the model. This is very different to Keras where each loss has to be matched with exactly 1 label and 1 output. Elegy's method of dealing with multiple outputs and labels is super simple:
The `Model`'s constructor `loss` argument can accept a single `Loss`, a `list` or `dict` of losses, and even nested structures of the previous, yet in Elegy the form of `loss` is not strictly related to structure of input labels and outputs of the model. This is very different to Keras where each loss has to be matched with exactly one (label, output) pair. Elegy's method of dealing with multiple outputs and labels is super simple:

!!! Quote
`y_true` and `y_pred` will **always** be passed to each loss exactly as they are defined
- `y_true` will contain the **entire** structure passed to `y`.
- `y_pred` will contain the **entire** structure output by the `Module`.

This means there are no restrictions on how you structure the loss function. According to this rule, for simple cases where there is only 1 output and 1 label, Keras and Elegy will behave the same because there is no structure:
This means there are no restrictions on how you structure the loss function. According to this rule Keras and Elegy behave the same when there is only one output and one label because there is no structure. Both framework will allow you to define something like:

```python
model = Model(
...
loss=elegy.losses.CategoricalCrossentropy(from_logits=True)
)
```
But if you have many outputs and many labels Elegy will just pass them to you and you can just define your loss function by using indexing their structures:
However, if you have many outputs and many labels, Elegy will just pass their structures to your loss and you will be able to do whatever you want by e.g. indexing these structures:

```python
class MyLoss(Elegy.Loss):
Expand All @@ -178,10 +185,10 @@ model = Model(
loss=elegy.losses.MyLoss()
)
```
This example assumes they are dictionaries but they can also be tuples. This gives you maximal flexibility but come at the additional cost of having to implement a custom loss function.
This example assumes the `y_true` and `y_pred` are dictionaries but they can also be tuples or nested structures. This strategy gives you maximal flexibility but come with the additional cost of having to implement your own loss function.

### Keras-like behavior
While these examples show Elegy's flexibility, there is an inbetween scenario that Keras covers really well: what if you really just need 1 loss per (label, output) pair? For example the equivalent of:
While having this flexibility available is good, there is a common scenario that Keras covers really well: what if you really just need one loss per (label, output) pair? In other words, how can we achieve equivalent of the following Keras code in Elege?

```python
class MyModel(keras.Model):
Expand All @@ -206,7 +213,7 @@ model.compile(
},
)
```
Elegy recovers this behavior by letting each `Loss` filter (or rather index) `y_true` and `y_pred` based on a string key in the case of dictionaries or int key in the case of tuples using the constructor's `on` parameter:
To recover this behavior Elegy lets each `Loss` optionally filter / index the `y_true` and `y_pred` arguments based on a string key (for `dict`s) or integer key (for `tuple`s) in the constructor's `on` parameter:

```python
class MyModule(elegy.Module):
Expand All @@ -227,7 +234,7 @@ model = elegy.Model(
]
)
```
This is almost exactly how Keras behaves except each loss is explicitly aware of which part of the output / label its supposed to attend to. The previous is roughly equivalent to manually indexing `y_true` and `y_pred` and passing the resulting value to the loss like this:
This is almost exactly how Keras behaves except each loss is explicitly aware of which part of the output / label its supposed to attend to. The previous is roughly equivalent to manually indexing `y_true` and `y_pred` and passing the resulting value to the loss in question like this:

```python
model = elegy.Model(
Expand All @@ -246,15 +253,15 @@ model = elegy.Model(
)
```
!!! Note
Elegy doesn't support `loss_weights` like in `keras.compile`. Nonetheless, you just can add the weight directly in the constructor of each loss like the above example.
For the same reasons Elegy doesn't support the `loss_weights` parameter as defined in `keras.compile`. Nonetheless, each loss accepts a `weight` argument directly, as seen in the examples above, which you can use to recover this behavior.

## Metrics
Metrics behave very similar to losses, everything said about losses previously holds for metrics except for one thing: metrics can hold state. As in Keras, Elegy metrics are cumulative metrics which update their internal state on every step. From an user perspective this means a couple of things:
Metrics behave exactly like losses except for one thing: Metrics can hold state. As in Keras, Elegy metrics are cumulative metrics which update their internal state on every step. From an user's perspective this means a couple of things:

1. Metrics are implemented using Haiku `Module`s, this means that you can't instantiate them normally outside of Haiku, hence the `lambda` / `defer` trick.
1. Metrics are implemented using Haiku `Module`, this means that you can't instantiate them normally outside of Haiku, hence the `lambda` / `defer` trick.
2. You can use `hk.get_state` and `hk.set_state` when implementing your own metrics.

Here is an example of a simple implementation of Accuracy:
Here is an example of a simple implementation of `Accuracy` which uses this cumulative behavior:

```python
class Accuracy(elegy.Metric):
Expand All @@ -274,4 +281,4 @@ class Accuracy(elegy.Metric):


## A little secret
We think users should use the base classes provided by Elegy (Module, Loss, Metric) for convenience, but the fact is that Elegy also accepts ordinary callables. Being true to Haiku and Jax in general, you can just use functions, however you can run into trouble with Haiku due to not scoping you computation inside Modules.
We think users should use the base classes provided by Elegy (Module, Loss, Metric) for convenience, being true to Haiku and Jax in general Elegy also lets you use plain functions. Be cautious when doing this since you can easily run into trouble with Haiku's scoping rules.

0 comments on commit 7ea8749

Please sign in to comment.