Skip to content

Commit

Permalink
Initial commit for github
Browse files Browse the repository at this point in the history
Change-Id: I3260d343b414ac9666c587542de740e92c7e0b89
  • Loading branch information
Electronic Vision(s) committed May 5, 2023
0 parents commit 89f0e78
Show file tree
Hide file tree
Showing 87 changed files with 8,279 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .ci/Jenkinsfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@Library("jenlib")

withCcache() {
wafDefaultPipeline(projects: ["jax-snn"],
container: [app: "dls-core"],
configureInstallOptions: "--build-profile=ci",
notificationChannel: "#jenkins-trashbin")
}
7 changes: 7 additions & 0 deletions .gitreview
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[gerrit]
host=gerrit.bioai.eu
port=29418
project=jax-snn
defaultbranch=main
defaultremote=review
defaultrebase=0
93 changes: 93 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# jaxsnn

`jaxsnn` is an event-based approach to machine-learning-inspired training and
simulation of SNNs, including support for neuromorphic backends (BrainScaleS-2).
We build upon [jax](https://github.com/google/jax), a Python library providing
autograd and XLA functionality for high-performance machine learning research.


## Building the Software

The software builds upon existing libraries, such as
[jax](https://github.com/google/jax),
[optax](https://github.com/deepmind/optax),
and [tree-math](https://github.com/google/tree-math).
When using the neuromorphic BrainScaleS-2 backend, the software stack of the
platform is required.

We provide a container image (based on the [Singularity format](https://sylabs.io/docs/)) including all build-time and runtime dependencies.
Feel free to download the most recent version from [here](https://openproject.bioai.eu/containers/).

For all following steps, we assume that the most recent Singularity container is located at `/containers/stable/latest`.


### Github-based Build
To build this project from public resources, adhere to the following guide:

```shell
# 1) Most of the following steps will be executed within a singularity container
# To keep the steps clutter-free, we start by defining an alias
shopt -s expand_aliases
alias c="singularity exec --app dls /containers/stable/latest"

# 2) Prepare a fresh workspace and change directory into it
mkdir workspace && cd workspace

# 3) Fetch a current copy of the symwaf2ic build tool
git clone https://github.com/electronicvisions/waf -b symwaf2ic symwaf2ic

# 4) Build symwaf2ic
c make -C symwaf2ic
ln -s symwaf2ic/waf

# 5) Setup your workspace and clone all dependencies (--clone-depth=1 to skip history)
c ./waf setup --repo-db-url=https://github.com/electronicvisions/projects --project=jaxsnn

# 6) Load PPU cross-compiler toolchain (or build https://github.com/electronicvisions/oppulance)
module load ppu-toolchain

# 7) Build the project
# Adjust -j1 to your own needs, beware that high parallelism will increase memory consumption!
c ./waf configure
c ./waf build -j1

# 8) Install the project to ./bin and ./lib
c ./waf install

# 9) If you run programs outside waf, you'll need to add ./lib and ./bin to your path specifications
export SINGULARITYENV_PREPEND_PATH=`pwd`/bin:$SINGULARITYENV_PREPEND_PATH
export SINGULARITYENV_LD_LIBRARY_PATH=`pwd`/lib:$SINGULARITYENV_LD_LIBRARY_PATH
export PYTHONPATH=`pwd`/lib:$PYTHONPATH
```


## First Steps

Check out our examples:


```
python -m jaxsnn.event.tasks.yinyang
```

```
python -m jaxsnn.event.tasks.yinyang_event_prop
```


## Acknowledgements

The software in this repository has been developed by staff and students
of Heidelberg University as part of the research carried out by the
Electronic Vision(s) group at the Kirchhoff-Institute for Physics.

This work has received funding from the EC Horizon 2020 Framework Programme
under grant agreements 785907 (HBP SGA2) and 945539 (HBP SGA3), the Deutsche
Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's
Excellence Strategy EXC 2181/1-390900948 (the Heidelberg STRUCTURES Excellence
Cluster), the German Federal Ministry of Education and Research under grant
number 16ES1127 as part of the Pilotinnovationswettbewerb Energieeffizientes
KI-System, the Helmholtz Association Initiative and Networking Fund [Advanced
Computing Architectures (ACA)] under Project SO-092, as well as from the
Manfred Stärk Foundation, and the Lautenschläger-Forschungspreis 2018 for
Karlheinz Meier.
5 changes: 5 additions & 0 deletions src/pyjaxsnn/jaxsnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import jaxsnn.base

from .functional import euler_integrate, serial
from .functional.leaky_integrate import LI, LIStep
from .functional.lif import LIF, LIFStep
1 change: 1 addition & 0 deletions src/pyjaxsnn/jaxsnn/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import jaxsnn.base.explicit
54 changes: 54 additions & 0 deletions src/pyjaxsnn/jaxsnn/base/explicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
from typing import Callable, Sequence, TypeVar
import jax
import tree_math


PyTreeState = TypeVar("PyTreeState")
TimeStepFn = Callable[[PyTreeState], PyTreeState]


class ExplicitConstrainedODE:
"""
The equation is given by:
∂u/∂t = explicit_terms(u)
0 = constraint(u)
"""

def __init__(self, explicit_terms, projection):
self.explicit_terms = explicit_terms
self.projection = projection

def explicit_terms(self, state):
"""Explicitly evaluate the ODE."""
raise NotImplementedError

def projection(self, state):
"""Enforce the constraint."""
raise NotImplementedError


class ExplicitConstrainedCDE:
"""
The equation is given by:
∂u/∂t = explicit_terms(u, x)
0 = constraint(u)
"""

def __init__(self, explicit_terms, projection, output):
self.explicit_terms = explicit_terms
self.projection = projection
self.output = output

def explicit_terms(self, state, input):
"""Explicitly evaluate the ODE."""
raise NotImplementedError

def projection(self, state, input):
"""Enforce the constraint."""
raise NotImplementedError

def output(self, state):
""""""
raise NotImplementedError
68 changes: 68 additions & 0 deletions src/pyjaxsnn/jaxsnn/base/implicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Implicit-explicit time stepping routines for ODEs."""
import dataclasses
from typing import Callable, Sequence, TypeVar
import tree_math


PyTreeState = TypeVar("PyTreeState")
ControlInput = TypeVar("ControlInput")

TimeStepFn = Callable[[PyTreeState], PyTreeState]
ControlledTimeStepFn = Callable[[PyTreeState, ControlInput], PyTreeState]


class ImplicitExplicitODE:
"""Describes a set of ODEs with implicit & explicit terms.
The equation is given by:
∂x/∂t = explicit_terms(x) + implicit_terms(x)
`explicit_terms(x)` includes terms that should use explicit time-stepping and
`implicit_terms(x)` includes terms that should be modeled implicitly.
Typically the explicit terms are non-linear and the implicit terms are linear.
This simplifies solves but isn't strictly necessary.
"""

def explicit_terms(self, state: PyTreeState) -> PyTreeState:
"""Evaluates explicit terms in the ODE."""
raise NotImplementedError

def implicit_terms(self, state: PyTreeState) -> PyTreeState:
"""Evaluates implicit terms in the ODE."""
raise NotImplementedError

def implicit_solve(
self,
state: PyTreeState,
step_size: float,
) -> PyTreeState:
"""Solves `y - step_size * implicit_terms(y) = x` for y."""
raise NotImplementedError


class ImplicitExplicitCDE:
"""Describes a set of CDEs with implicit & explicit terms.
We assume that only the explicit terms are subject to control input.
The equation is given by:
∂x/∂t = explicit_terms(x, u) + implicit_terms(x)
`explicit_terms(x, u)` includes terms that should use explicit time-stepping and are controlled
`implicit_terms(x)` includes terms that should be modeled implicitly.
Typically the explicit terms are non-linear and the implicit terms are linear.
This simplifies solves but isn't strictly necessary.
"""

def explicit_terms(self, state: PyTreeState, u: ControlInput) -> PyTreeState:
"""Evaluates explicit terms in the ODE."""
raise NotImplementedError

def implicit_terms(self, state: PyTreeState) -> PyTreeState:
"""Evaluates implicit terms in the ODE."""
raise NotImplementedError

def implicit_solve(
self,
state: PyTreeState,
step_size: float,
) -> PyTreeState:
"""Solves `y - step_size * implicit_terms(y) = x` for y."""
raise NotImplementedError
90 changes: 90 additions & 0 deletions src/pyjaxsnn/jaxsnn/base/root_solving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import jax
import jax.numpy as jnp


def linear_interpolation(f_a, f_b, a, b, x):
return (x - a) / (b - a) * f_b + (b - x) / (b - a) * f_a


def linear_interpolated_root(f_a, f_b, a, b):
return (a * f_b - b * f_a) / f_b - f_a


def newton_1d(f, x0):
initial_state = (0, x0)

def cond(state):
it, x = state
return it < 10

def body(state):
it, x = state
fx, dfx = f(x), jax.grad(f)(x)
step = fx / dfx
new_state = it + 1, x - step
return new_state

return jax.lax.while_loop(
cond,
body,
initial_state,
)[1]


def newton_nd(f, x0):

initial_state = (0, x0)

def cond(state):
it, x = state
return it < 10

def body(state):
it, x = state
fx, dfx = f(x), jax.grad(f)(x)
step = jax.numpy.linalg.solve(dfx, -fx)

new_state = it + 1, x + step
return new_state

return jax.lax.while_loop(
cond,
body,
initial_state,
)[1]


def bisection(f, x_min, x_max, tol):
"""Bisection root finding method
Based on the intermediate value theorem, which
guarantees for a continuous function that there
is a zero in the interval [x_min, x_max] as long
as sign(f(x_min)) != sign(f(x_max)).
NOTE: We do not check the precondition sign(f(x_min)) != sign(f(x_max)) here
"""
initial_state = (0, x_min, x_max) # (iteration, x)

def cond(state):
it, x_min, x_max = state
return jnp.abs(f(x_min)) > tol # it > 10

def body(state):
it, x_min, x_max = state
x = (x_min + x_max) / 2

sfxm = jnp.sign(f(x_min))
sfx = jnp.sign(f(x))

x_min = jnp.where(sfx == sfxm, x, x_min)
x_max = jnp.where(sfx == sfxm, x_max, x)

new_state = (it + 1, x_min, x_max)
return new_state

return jax.lax.while_loop(
cond,
body,
initial_state,
)[1]
Loading

0 comments on commit 89f0e78

Please sign in to comment.