Skip to content

Commit

Permalink
initial untested autodiff of wc and aln
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed May 3, 2024
1 parent 0a682c0 commit 7333c4c
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 18 deletions.
6 changes: 5 additions & 1 deletion neurolib/models/jax/aln/timeIntegration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@


def timeIntegration(params):
return timeIntegration_elementwise(*timeIntegration_args(params))


def timeIntegration_args(params):
"""Sets up the parameters for time integration
Return:
Expand Down Expand Up @@ -208,7 +212,7 @@ def timeIntegration(params):

# ------------------------------------------------------------------------

return timeIntegration_elementwise(
return (
dt,
duration,
filter_sigma,
Expand Down
108 changes: 91 additions & 17 deletions neurolib/models/jax/wc/timeIntegration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@


def timeIntegration(params):
return timeIntegration_elementwise(*timeIntegration_args(params))


def timeIntegration_args(params):
"""Sets up the parameters for time integration in a JAX-compatible manner.
:param params: Parameter dictionary of the model
Expand Down Expand Up @@ -98,7 +102,7 @@ def timeIntegration(params):

# ------------------------------------------------------------------------

return timeIntegration_elementwise(
return (
startind,
t,
dt,
Expand Down Expand Up @@ -167,6 +171,91 @@ def timeIntegration_elementwise(
sigma_ou,
key,
):

update_step = get_update_step(
startind,
t,
dt,
sqrt_dt,
N,
Cmat,
K_gl,
Dmat_ndt,
exc_init,
inh_init,
exc_ext_baseline,
inh_ext_baseline,
exc_ext,
inh_ext,
tau_exc,
tau_inh,
a_exc,
a_inh,
mu_exc,
mu_inh,
c_excexc,
c_excinh,
c_inhexc,
c_inhinh,
exc_ou_init,
inh_ou_init,
exc_ou_mean,
inh_ou_mean,
tau_ou,
sigma_ou,
key,
)

# Iterating through time steps
(exc_history, inh_history, exc_ou, inh_ou, i), (excs_new, inhs_new) = jax.lax.scan(
update_step,
(exc_init, inh_init, exc_ou_init, inh_ou_init, startind),
xs=None,
length=len(t),
)

return (
t,
jnp.concatenate((exc_init, excs_new.T), axis=1),
jnp.concatenate((inh_init, inhs_new.T), axis=1),
exc_ou,
inh_ou,
)


def get_update_step(
startind,
t,
dt,
sqrt_dt,
N,
Cmat,
K_gl,
Dmat_ndt,
exc_init,
inh_init,
exc_ext_baseline,
inh_ext_baseline,
exc_ext,
inh_ext,
tau_exc,
tau_inh,
a_exc,
a_inh,
mu_exc,
mu_inh,
c_excexc,
c_excinh,
c_inhexc,
c_inhinh,
exc_ou_init,
inh_ou_init,
exc_ou_mean,
inh_ou_mean,
tau_ou,
sigma_ou,
key,
):
key, subkey_exc = random.split(key)
noise_exc = random.normal(subkey_exc, (N, len(t)))
key, subkey_inh = random.split(key)
Expand All @@ -180,7 +269,6 @@ def S_E(x):
def S_I(x):
return 1.0 / (1.0 + jnp.exp(-a_inh * (x - mu_inh)))

### integrate ODE system:
def update_step(state, _):
exc_history, inh_history, exc_ou, inh_ou, i = state

Expand Down Expand Up @@ -243,18 +331,4 @@ def update_step(state, _):
(exc_new, inh_new),
)

# Iterating through time steps
(exc_history, inh_history, exc_ou, inh_ou, i), (excs_new, inhs_new) = jax.lax.scan(
update_step,
(exc_init, inh_init, exc_ou_init, inh_ou_init, startind),
xs=None,
length=len(t),
)

return (
t,
jnp.concatenate((exc_init, excs_new.T), axis=1),
jnp.concatenate((inh_init, inhs_new.T), axis=1),
exc_ou,
inh_ou,
)
return update_step
98 changes: 98 additions & 0 deletions neurolib/optimize/autodiff/aln_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from jax import jit
from neurolib.models.jax.aln.timeIntegration import timeIntegration_args, timeIntegration_elementwise

args_names = [
"dt",
"duration",
"filter_sigma",
"Cmat",
"Dmat",
"c_gl",
"Ke_gl",
"tau_ou",
"sigma_ou",
"mue_ext_mean",
"mui_ext_mean",
"sigmae_ext",
"sigmai_ext",
"Ke",
"Ki",
"de",
"di",
"tau_se",
"tau_si",
"tau_de",
"tau_di",
"cee",
"cie",
"cii",
"cei",
"Jee_max",
"Jei_max",
"Jie_max",
"Jii_max",
"a",
"b",
"EA",
"tauA",
"C",
"gL",
"EL",
"DeltaT",
"VT",
"Vr",
"Vs",
"Tref",
"taum",
"mufe",
"mufi",
"IA_init",
"seem",
"seim",
"seev",
"seiv",
"siim",
"siem",
"siiv",
"siev",
"precalc_r",
"precalc_V",
"precalc_tau_mu",
"precalc_tau_sigma",
"dI",
"ds",
"sigmarange",
"Irange",
"N",
"Dmat_ndt",
"t",
"rates_exc_init",
"rates_inh_init",
"rd_exc",
"rd_inh",
"sqrt_dt",
"startind",
"ndt_de",
"ndt_di",
"mue_ou",
"mui_ou",
"ext_exc_rate",
"ext_inh_rate",
"ext_exc_current",
"ext_inh_current",
"key",
]


def get_loss(model_params, loss_f, opt_params):
args_values = timeIntegration_args(model_params)
args = dict(zip(args_names, args_values))

@jit
def loss(x):
args_local = args.copy()
args_local.update(dict(zip(opt_params, x)))
simulation_outputs = timeIntegration_elementwise(**args_local)
return loss_f(*simulation_outputs)

return loss
19 changes: 19 additions & 0 deletions neurolib/optimize/autodiff/loss_functions/aln.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def excitation_l1(
t,
exc,
inh,
mufe,
mufi,
IA,
seem,
seim,
siem,
siim,
seev,
seiv,
siev,
siiv,
mue_ou,
mui_ou,
):
return -exc.mean()
55 changes: 55 additions & 0 deletions neurolib/optimize/autodiff/wc_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from jax import jit
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise

args_names = [
"startind",
"t",
"dt",
"sqrt_dt",
"N",
"Cmat",
"K_gl",
"Dmat_ndt",
"exc_init",
"inh_init",
"exc_ext_baseline",
"inh_ext_baseline",
"exc_ext",
"inh_ext",
"tau_exc",
"tau_inh",
"a_exc",
"a_inh",
"mu_exc",
"mu_inh",
"c_excexc",
"c_excinh",
"c_inhexc",
"c_inhinh",
"exc_ou_init",
"inh_ou_init",
"exc_ou_mean",
"inh_ou_mean",
"tau_ou",
"sigma_ou",
"key",
]


# example usage:
# model = WCModel()
# wc_loss = get_loss(model.params, loss_f, ['exc_ext'])
# grad_wc_loss = jax.jit(jax.grad(wc_loss))
# grad_wc_loss([exc_ext])
def get_loss(model_params, loss_f, opt_params):
args_values = timeIntegration_args(model_params)
args = dict(zip(args_names, args_values))

@jit
def loss(x):
args_local = args.copy()
args_local.update(dict(zip(opt_params, x)))
simulation_outputs = timeIntegration_elementwise(**args_local)
return loss_f(*simulation_outputs)

return loss

0 comments on commit 7333c4c

Please sign in to comment.