-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change-Id: I3260d343b414ac9666c587542de740e92c7e0b89
- Loading branch information
Electronic Vision(s)
committed
May 5, 2023
0 parents
commit 89f0e78
Showing
87 changed files
with
8,279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
@Library("jenlib") | ||
|
||
withCcache() { | ||
wafDefaultPipeline(projects: ["jax-snn"], | ||
container: [app: "dls-core"], | ||
configureInstallOptions: "--build-profile=ci", | ||
notificationChannel: "#jenkins-trashbin") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[gerrit] | ||
host=gerrit.bioai.eu | ||
port=29418 | ||
project=jax-snn | ||
defaultbranch=main | ||
defaultremote=review | ||
defaultrebase=0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# jaxsnn | ||
|
||
`jaxsnn` is an event-based approach to machine-learning-inspired training and | ||
simulation of SNNs, including support for neuromorphic backends (BrainScaleS-2). | ||
We build upon [jax](https://github.com/google/jax), a Python library providing | ||
autograd and XLA functionality for high-performance machine learning research. | ||
|
||
|
||
## Building the Software | ||
|
||
The software builds upon existing libraries, such as | ||
[jax](https://github.com/google/jax), | ||
[optax](https://github.com/deepmind/optax), | ||
and [tree-math](https://github.com/google/tree-math). | ||
When using the neuromorphic BrainScaleS-2 backend, the software stack of the | ||
platform is required. | ||
|
||
We provide a container image (based on the [Singularity format](https://sylabs.io/docs/)) including all build-time and runtime dependencies. | ||
Feel free to download the most recent version from [here](https://openproject.bioai.eu/containers/). | ||
|
||
For all following steps, we assume that the most recent Singularity container is located at `/containers/stable/latest`. | ||
|
||
|
||
### Github-based Build | ||
To build this project from public resources, adhere to the following guide: | ||
|
||
```shell | ||
# 1) Most of the following steps will be executed within a singularity container | ||
# To keep the steps clutter-free, we start by defining an alias | ||
shopt -s expand_aliases | ||
alias c="singularity exec --app dls /containers/stable/latest" | ||
|
||
# 2) Prepare a fresh workspace and change directory into it | ||
mkdir workspace && cd workspace | ||
|
||
# 3) Fetch a current copy of the symwaf2ic build tool | ||
git clone https://github.com/electronicvisions/waf -b symwaf2ic symwaf2ic | ||
|
||
# 4) Build symwaf2ic | ||
c make -C symwaf2ic | ||
ln -s symwaf2ic/waf | ||
|
||
# 5) Setup your workspace and clone all dependencies (--clone-depth=1 to skip history) | ||
c ./waf setup --repo-db-url=https://github.com/electronicvisions/projects --project=jaxsnn | ||
|
||
# 6) Load PPU cross-compiler toolchain (or build https://github.com/electronicvisions/oppulance) | ||
module load ppu-toolchain | ||
|
||
# 7) Build the project | ||
# Adjust -j1 to your own needs, beware that high parallelism will increase memory consumption! | ||
c ./waf configure | ||
c ./waf build -j1 | ||
|
||
# 8) Install the project to ./bin and ./lib | ||
c ./waf install | ||
|
||
# 9) If you run programs outside waf, you'll need to add ./lib and ./bin to your path specifications | ||
export SINGULARITYENV_PREPEND_PATH=`pwd`/bin:$SINGULARITYENV_PREPEND_PATH | ||
export SINGULARITYENV_LD_LIBRARY_PATH=`pwd`/lib:$SINGULARITYENV_LD_LIBRARY_PATH | ||
export PYTHONPATH=`pwd`/lib:$PYTHONPATH | ||
``` | ||
|
||
|
||
## First Steps | ||
|
||
Check out our examples: | ||
|
||
|
||
``` | ||
python -m jaxsnn.event.tasks.yinyang | ||
``` | ||
|
||
``` | ||
python -m jaxsnn.event.tasks.yinyang_event_prop | ||
``` | ||
|
||
|
||
## Acknowledgements | ||
|
||
The software in this repository has been developed by staff and students | ||
of Heidelberg University as part of the research carried out by the | ||
Electronic Vision(s) group at the Kirchhoff-Institute for Physics. | ||
|
||
This work has received funding from the EC Horizon 2020 Framework Programme | ||
under grant agreements 785907 (HBP SGA2) and 945539 (HBP SGA3), the Deutsche | ||
Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's | ||
Excellence Strategy EXC 2181/1-390900948 (the Heidelberg STRUCTURES Excellence | ||
Cluster), the German Federal Ministry of Education and Research under grant | ||
number 16ES1127 as part of the Pilotinnovationswettbewerb Energieeffizientes | ||
KI-System, the Helmholtz Association Initiative and Networking Fund [Advanced | ||
Computing Architectures (ACA)] under Project SO-092, as well as from the | ||
Manfred Stärk Foundation, and the Lautenschläger-Forschungspreis 2018 for | ||
Karlheinz Meier. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import jaxsnn.base | ||
|
||
from .functional import euler_integrate, serial | ||
from .functional.leaky_integrate import LI, LIStep | ||
from .functional.lif import LIF, LIFStep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import jaxsnn.base.explicit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import dataclasses | ||
from typing import Callable, Sequence, TypeVar | ||
import jax | ||
import tree_math | ||
|
||
|
||
PyTreeState = TypeVar("PyTreeState") | ||
TimeStepFn = Callable[[PyTreeState], PyTreeState] | ||
|
||
|
||
class ExplicitConstrainedODE: | ||
""" | ||
The equation is given by: | ||
∂u/∂t = explicit_terms(u) | ||
0 = constraint(u) | ||
""" | ||
|
||
def __init__(self, explicit_terms, projection): | ||
self.explicit_terms = explicit_terms | ||
self.projection = projection | ||
|
||
def explicit_terms(self, state): | ||
"""Explicitly evaluate the ODE.""" | ||
raise NotImplementedError | ||
|
||
def projection(self, state): | ||
"""Enforce the constraint.""" | ||
raise NotImplementedError | ||
|
||
|
||
class ExplicitConstrainedCDE: | ||
""" | ||
The equation is given by: | ||
∂u/∂t = explicit_terms(u, x) | ||
0 = constraint(u) | ||
""" | ||
|
||
def __init__(self, explicit_terms, projection, output): | ||
self.explicit_terms = explicit_terms | ||
self.projection = projection | ||
self.output = output | ||
|
||
def explicit_terms(self, state, input): | ||
"""Explicitly evaluate the ODE.""" | ||
raise NotImplementedError | ||
|
||
def projection(self, state, input): | ||
"""Enforce the constraint.""" | ||
raise NotImplementedError | ||
|
||
def output(self, state): | ||
"""""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Implicit-explicit time stepping routines for ODEs.""" | ||
import dataclasses | ||
from typing import Callable, Sequence, TypeVar | ||
import tree_math | ||
|
||
|
||
PyTreeState = TypeVar("PyTreeState") | ||
ControlInput = TypeVar("ControlInput") | ||
|
||
TimeStepFn = Callable[[PyTreeState], PyTreeState] | ||
ControlledTimeStepFn = Callable[[PyTreeState, ControlInput], PyTreeState] | ||
|
||
|
||
class ImplicitExplicitODE: | ||
"""Describes a set of ODEs with implicit & explicit terms. | ||
The equation is given by: | ||
∂x/∂t = explicit_terms(x) + implicit_terms(x) | ||
`explicit_terms(x)` includes terms that should use explicit time-stepping and | ||
`implicit_terms(x)` includes terms that should be modeled implicitly. | ||
Typically the explicit terms are non-linear and the implicit terms are linear. | ||
This simplifies solves but isn't strictly necessary. | ||
""" | ||
|
||
def explicit_terms(self, state: PyTreeState) -> PyTreeState: | ||
"""Evaluates explicit terms in the ODE.""" | ||
raise NotImplementedError | ||
|
||
def implicit_terms(self, state: PyTreeState) -> PyTreeState: | ||
"""Evaluates implicit terms in the ODE.""" | ||
raise NotImplementedError | ||
|
||
def implicit_solve( | ||
self, | ||
state: PyTreeState, | ||
step_size: float, | ||
) -> PyTreeState: | ||
"""Solves `y - step_size * implicit_terms(y) = x` for y.""" | ||
raise NotImplementedError | ||
|
||
|
||
class ImplicitExplicitCDE: | ||
"""Describes a set of CDEs with implicit & explicit terms. | ||
We assume that only the explicit terms are subject to control input. | ||
The equation is given by: | ||
∂x/∂t = explicit_terms(x, u) + implicit_terms(x) | ||
`explicit_terms(x, u)` includes terms that should use explicit time-stepping and are controlled | ||
`implicit_terms(x)` includes terms that should be modeled implicitly. | ||
Typically the explicit terms are non-linear and the implicit terms are linear. | ||
This simplifies solves but isn't strictly necessary. | ||
""" | ||
|
||
def explicit_terms(self, state: PyTreeState, u: ControlInput) -> PyTreeState: | ||
"""Evaluates explicit terms in the ODE.""" | ||
raise NotImplementedError | ||
|
||
def implicit_terms(self, state: PyTreeState) -> PyTreeState: | ||
"""Evaluates implicit terms in the ODE.""" | ||
raise NotImplementedError | ||
|
||
def implicit_solve( | ||
self, | ||
state: PyTreeState, | ||
step_size: float, | ||
) -> PyTreeState: | ||
"""Solves `y - step_size * implicit_terms(y) = x` for y.""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def linear_interpolation(f_a, f_b, a, b, x): | ||
return (x - a) / (b - a) * f_b + (b - x) / (b - a) * f_a | ||
|
||
|
||
def linear_interpolated_root(f_a, f_b, a, b): | ||
return (a * f_b - b * f_a) / f_b - f_a | ||
|
||
|
||
def newton_1d(f, x0): | ||
initial_state = (0, x0) | ||
|
||
def cond(state): | ||
it, x = state | ||
return it < 10 | ||
|
||
def body(state): | ||
it, x = state | ||
fx, dfx = f(x), jax.grad(f)(x) | ||
step = fx / dfx | ||
new_state = it + 1, x - step | ||
return new_state | ||
|
||
return jax.lax.while_loop( | ||
cond, | ||
body, | ||
initial_state, | ||
)[1] | ||
|
||
|
||
def newton_nd(f, x0): | ||
|
||
initial_state = (0, x0) | ||
|
||
def cond(state): | ||
it, x = state | ||
return it < 10 | ||
|
||
def body(state): | ||
it, x = state | ||
fx, dfx = f(x), jax.grad(f)(x) | ||
step = jax.numpy.linalg.solve(dfx, -fx) | ||
|
||
new_state = it + 1, x + step | ||
return new_state | ||
|
||
return jax.lax.while_loop( | ||
cond, | ||
body, | ||
initial_state, | ||
)[1] | ||
|
||
|
||
def bisection(f, x_min, x_max, tol): | ||
"""Bisection root finding method | ||
Based on the intermediate value theorem, which | ||
guarantees for a continuous function that there | ||
is a zero in the interval [x_min, x_max] as long | ||
as sign(f(x_min)) != sign(f(x_max)). | ||
NOTE: We do not check the precondition sign(f(x_min)) != sign(f(x_max)) here | ||
""" | ||
initial_state = (0, x_min, x_max) # (iteration, x) | ||
|
||
def cond(state): | ||
it, x_min, x_max = state | ||
return jnp.abs(f(x_min)) > tol # it > 10 | ||
|
||
def body(state): | ||
it, x_min, x_max = state | ||
x = (x_min + x_max) / 2 | ||
|
||
sfxm = jnp.sign(f(x_min)) | ||
sfx = jnp.sign(f(x)) | ||
|
||
x_min = jnp.where(sfx == sfxm, x, x_min) | ||
x_max = jnp.where(sfx == sfxm, x_max, x) | ||
|
||
new_state = (it + 1, x_min, x_max) | ||
return new_state | ||
|
||
return jax.lax.while_loop( | ||
cond, | ||
body, | ||
initial_state, | ||
)[1] |
Oops, something went wrong.