Skip to content

Commit

Permalink
[WIP] Refactor: reference preserving hooks. (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae authored Aug 17, 2020
1 parent ae58afe commit 9a0b5b5
Show file tree
Hide file tree
Showing 176 changed files with 6,139 additions and 2,164 deletions.
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*

!elegy
!tests
!pyproject.toml
!poetry.lock
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,5 @@ cython_debug/
/test.*
/summaries
/runs
/docs/guides/model
/docs/model
/models
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## [0.2.0] - 2020-08-17
* Big refactor. Elegy has its own Module system independent of Haiku and its now incompatible with it. #85

## [0.1.5] - 2020-07-28
* Mean Absolute Percentage Error Implementation @Ciroye
* Adds `elegy.nn.Linear`, `elegy.nn.Conv2D`, `elegy.nn.Flatten`, `elegy.nn.Sequential` @cgarciae
Expand All @@ -11,7 +14,7 @@
* Adds `elegy.metrics.BinaryCrossentropy` @sebasarango1180
* Adds `elegy.nn.Dropout` and `elegy.nn.BatchNormalization` @cgarciae
* Improves documentation
* Fixes bug that cause error when using `is_training` via dependency injection on `Model.predict`.
* Fixes bug that cause error when using `training` via dependency injection on `Model.predict`.

## [0.1.3] - 2020-07-22
* Initial release
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ For this you can follow these guidelines:
* You must provide documentation for the following:
* The class definition.
* The `__init__` method.
* The `__apply__` method.
* The `call` method.
* Try to port the documentation + signature from its Keras counter part.
* If so you must give credits to the original source file.
* You must include tests.
Expand All @@ -43,7 +43,7 @@ We use `mkdocs`. If you create a new object that requires documentation please d
selection:
inherited_members: true
members:
- __apply__
- call
- __init__
```
* Add and entry to `mkdocs.yml` inside `nav` pointing to this file. Checkout `mkdocs.yml`.
Expand Down
64 changes: 43 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

-----------------

_Elegy is a Neural Networks framework based on Jax and Haiku._
_Elegy is a Neural Networks framework based on Jax inspired by Keras and Haiku._

Elegy implements the Keras API but makes changes to play better with Jax & Haiku and gives more flexibility around losses and metrics (more on this soon). Elegy is still in a very early stage, feel free to test it and send us your feedback!
Elegy implements the Keras API but makes changes to play better with Jax and gives more flexibility around [losses and metrics](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/), it also ports Haiku's excellent [module system](https://poets-ai.github.io/elegy/guides/module-system/) and makes it easier to use. Elegy is in an early stage, feel free to send us your feedback!

#### Main Features

* **Familiar**: Elegy should feel very familiar to Keras users.
* **Flexible**: Elegy improves upon the basic Keras API by letting users optionally take more control over the definition of losses and metrics.
* **Easy-to-use**: Elegy maintains all the simplicity and ease of use that Keras brings with it.
* **Compatible**: Elegy strives to be compatible with the rest of the Jax and Haiku ecosystem.
* **Compatible**: Elegy strives to be compatible with the rest of the Jax ecosystem.

For more information take a look at the [Documentation](https://poets-ai.github.io/elegy).

Expand All @@ -33,29 +33,28 @@ pip install elegy
For Windows users we recommend the Windows subsystem for linux 2 [WSL2](https://docs.microsoft.com/es-es/windows/wsl/install-win10?redirectedfrom=MSDN) since [jax](https://github.com/google/jax/issues/438) does not support it yet.

## Quick Start
Elegy greatly simplifies the training of Deep Learning models compared to pure Jax / Haiku where, due to Jax functional nature, users have to do a lot of book keeping around the state of the model. In Elegy you just have to follow 3 basic steps:
Elegy greatly simplifies the training of Deep Learning models compared to pure Jax where, due to Jax's functional nature, users have to do a lot of book keeping around the state of the model. In Elegy you just have to follow 3 basic steps:

**1.** Define the architecture inside an `elegy.Module`:
```python
class MLP(elegy.Module):
def __apply__(self, image: jnp.ndarray) -> jnp.ndarray:
mlp = elegy.nn.Sequential([
elegy.nn.Flatten(),
elegy.nn.Linear(300),
jax.nn.relu,
elegy.nn.Linear(10),
])
return mlp(image)
def call(self, x: jnp.ndarray) -> jnp.ndarray:
x = elegy.nn.Linear(300)(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(10)(x)
return x
```
Note that we can define sub-modules on-the-fly directly in the `call` (forward) method.

**2.** Create a `Model` from this module and specify additional things like losses, metrics, and optimizers:
```python
model = elegy.Model(
module=MLP.defer(),
module=MLP(),
loss=[
elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
elegy.regularizers.GlobalL2(l=1e-5),
],
metrics=elegy.metrics.SparseCategoricalAccuracy.defer(),
metrics=elegy.metrics.SparseCategoricalAccuracy(),
optimizer=optix.rmsprop(1e-3),
)
```
Expand All @@ -77,22 +76,45 @@ And you are done! For more information check out:


* Our [Getting Started](https://poets-ai.github.io/elegy/getting-started/) tutorial.
* Couple of examples in [examples](https://github.com/poets-ai/elegy/tree/master/examples) directory.
* Haiku's [User Manual](https://github.com/deepmind/dm-haiku#user-manual) and [Documentation](https://dm-haiku.readthedocs.io/en/latest/)
* Elegy's [Documentation](https://poets-ai.github.io/elegy).
* The [examples](https://github.com/poets-ai/elegy/tree/master/examples) directory.
* [What is Jax?](https://github.com/google/jax#what-is-jax)

## Why Jax + Haiku?
## Why Jax & Elegy?

Given all the well-stablished Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.

#### Why Jax?

**Jax** is a linear algebra library with the perfect recipe:
* Numpy's familiar API
* The speed and hardware support of XLA
* Automatic Differentiation

The awesome thing about Jax is that Deep Learning is just a usecase that it happens to excel at but you can use it for most task you would use NumPy for.
The awesome thing about Jax is that Deep Learning is just a use-case that it happens to excel at but you can use it for most task you would use NumPy for. Jax is so compatible with Numpy that is array type actually inherits from `np.ndarray`.

In a sense, Jax takes the best of both TensorFlow and Pytorch in a principled manner: while both TF and Pytorch historically converged to the same set of features, their APIs still contain quirks they have to keep for compatibility.

#### Why Elegy?

We believe that **Elegy** can offer the best experience for coding Deep Learning applications by leveraging the power and familiarity of Jax API, an easy-to-use and succinct Module system, and packaging everything on top of a convenient Keras-like API. Elegy improves upon other Deep Learning frameworks in the following ways:

1. Its hook-based [Module System](https://poets-ai.github.io/elegy/guides/module-system/) makes it easier (less verbose) to write model code compared to Keras & Pytorch since it lets you declare sub-modules, parameters, and states directly on your `call` (forward) method. Thanks to this you get shape inference for free so there is no need for a `build` method (Keras) or propagating shape information all over the place (Pytorch). A naive implementation of `Linear` could be as simple as:

```python
class Linear(elegy.Module):
def __init__(self, units):
super().__init__()
self.units = units

On the other hand, **Haiku** is a Neural Networks library built on top of Jax that implements a `Module` system, common Neural Network layers, and even some full architectures. Compared to other Jax-based libraries like Trax or Flax, Haiku is very minimal, polished, well documented, and makes it super easy / clean to implement Deep Learning code!
def call(self, x):
w = self.get_parameter("w", [x.shape[-1], self.units], initializer=jnp.ones)
b = self.get_parameter("b", [self.units], initializer=jnp.ones)

We believe that **Elegy** can offer the best experience for coding Deep Learning applications by leveraging the power and familiarity of Jax API, the ease-of-use of Haiku's Module system, and packaging everything on top of a convenient Keras-like API.
return jnp.dot(x, w) + b
```
2. It has a very flexible system for defining the inputs for [losses and metrics](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/) based on _dependency injection_ in opposition to Keras rigid requirement to have matching (output, label) pairs, and being unable to use additional information like inputs, parameters, and states in the definition of losses and metrics.
3. Its hook system preserve's [reference information](https://poets-ai.github.io/elegy/guides/module-system/) from a module to its sub-modules, parameters, and states while maintaining a functional API. This is crucial since most Jax-based frameworks like Flax and Haiku tend to loose this information which makes it very tricky to perform tasks like transfer learning where you need to mix a pre-trained models into a new model (easier to do if you keep references).

## Features
* `Model` estimator class
Expand Down Expand Up @@ -124,7 +146,7 @@ To cite this project:
author = {PoetsAI},
title = {Elegy: A Keras-like deep learning framework based on Jax & Haiku},
url = {https://github.com/poets-ai/elegy},
version = {0.1.5},
version = {0.2.0},
year = {2020},
}
```
Expand Down
18 changes: 18 additions & 0 deletions docs/api/Loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# elegy.Loss

::: elegy.Loss
selection:
inherited_members: true
members:
- __init__
- call
- init
- apply
- reset
- get_parameters
- set_parameters
- get_states
- set_states
- submodules

18 changes: 18 additions & 0 deletions docs/api/Metric.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# elegy.Metric

::: elegy.Metric
selection:
inherited_members: true
members:
- __init__
- call
- init
- apply
- reset
- get_parameters
- set_parameters
- get_states
- set_states
- submodules

23 changes: 23 additions & 0 deletions docs/api/Model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

# elegy.Model

::: elegy.Model
selection:
inherited_members: true
members:
- evaluate
- fit
- load
- predict
- predict_on_batch
- reset
- reset_metrics
- save
- summary
- test_on_batch
- train_on_batch
- full_state
- parameters
- seed
- states

18 changes: 18 additions & 0 deletions docs/api/Module.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

# elegy.Module

::: elegy.Module
selection:
inherited_members: true
members:
- __init__
- call
- init
- apply
- reset
- get_parameters
- set_parameters
- get_states
- set_states
- submodules

9 changes: 9 additions & 0 deletions docs/api/add_loss.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# elegy.add_loss

::: elegy.add_loss
selection:
inherited_members: true
members:
- __NONE__

9 changes: 9 additions & 0 deletions docs/api/add_metric.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# elegy.add_metric

::: elegy.add_metric
selection:
inherited_members: true
members:
- __NONE__

9 changes: 9 additions & 0 deletions docs/api/add_summary.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# elegy.add_summary

::: elegy.add_summary
selection:
inherited_members: true
members:
- __NONE__

19 changes: 18 additions & 1 deletion docs/api/callbacks/CSVLogger.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@

# elegy.callbacks.CSVLogger

::: elegy.callbacks.CSVLogger
selection:
inherited_members: true
members:
- __init__
- on_epoch_begin
- on_epoch_end
- on_predict_batch_begin
- on_predict_batch_end
- on_predict_begin
- on_predict_end
- on_test_batch_begin
- on_test_batch_end
- on_test_begin
- on_test_end
- on_train_batch_begin
- on_train_batch_end
- on_train_begin
- on_train_end
- set_model
- set_params

21 changes: 21 additions & 0 deletions docs/api/callbacks/Callback.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@

# elegy.callbacks.Callback

::: elegy.callbacks.Callback
selection:
inherited_members: true
members:
- on_epoch_begin
- on_epoch_end
- on_predict_batch_begin
- on_predict_batch_end
- on_predict_begin
- on_predict_end
- on_test_batch_begin
- on_test_batch_end
- on_test_begin
- on_test_end
- on_train_batch_begin
- on_train_batch_end
- on_train_begin
- on_train_end
- set_model
- set_params

9 changes: 9 additions & 0 deletions docs/api/callbacks/CallbackList.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# elegy.callbacks.CallbackList

::: elegy.callbacks.CallbackList
selection:
inherited_members: true
members:
- __NONE__

19 changes: 18 additions & 1 deletion docs/api/callbacks/EarlyStopping.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@

# elegy.callbacks.EarlyStopping

::: elegy.callbacks.EarlyStopping
selection:
inherited_members: true
members:
- __init__
- on_epoch_begin
- on_epoch_end
- on_predict_batch_begin
- on_predict_batch_end
- on_predict_begin
- on_predict_end
- on_test_batch_begin
- on_test_batch_end
- on_test_begin
- on_test_end
- on_train_batch_begin
- on_train_batch_end
- on_train_begin
- on_train_end
- set_model
- set_params

Loading

0 comments on commit 9a0b5b5

Please sign in to comment.