diff --git a/neurolib/models/jax/aln/timeIntegration.py b/neurolib/models/jax/aln/timeIntegration.py index 33bb9f45..bdd18b2b 100644 --- a/neurolib/models/jax/aln/timeIntegration.py +++ b/neurolib/models/jax/aln/timeIntegration.py @@ -8,6 +8,10 @@ def timeIntegration(params): + return timeIntegration_elementwise(*timeIntegration_args(params)) + + +def timeIntegration_args(params): """Sets up the parameters for time integration Return: @@ -208,7 +212,7 @@ def timeIntegration(params): # ------------------------------------------------------------------------ - return timeIntegration_elementwise( + return ( dt, duration, filter_sigma, diff --git a/neurolib/models/jax/wc/timeIntegration.py b/neurolib/models/jax/wc/timeIntegration.py index 4ade7c03..42b8e89f 100644 --- a/neurolib/models/jax/wc/timeIntegration.py +++ b/neurolib/models/jax/wc/timeIntegration.py @@ -9,6 +9,10 @@ def timeIntegration(params): + return timeIntegration_elementwise(*timeIntegration_args(params)) + + +def timeIntegration_args(params): """Sets up the parameters for time integration in a JAX-compatible manner. :param params: Parameter dictionary of the model @@ -98,7 +102,7 @@ def timeIntegration(params): # ------------------------------------------------------------------------ - return timeIntegration_elementwise( + return ( startind, t, dt, @@ -167,6 +171,91 @@ def timeIntegration_elementwise( sigma_ou, key, ): + + update_step = get_update_step( + startind, + t, + dt, + sqrt_dt, + N, + Cmat, + K_gl, + Dmat_ndt, + exc_init, + inh_init, + exc_ext_baseline, + inh_ext_baseline, + exc_ext, + inh_ext, + tau_exc, + tau_inh, + a_exc, + a_inh, + mu_exc, + mu_inh, + c_excexc, + c_excinh, + c_inhexc, + c_inhinh, + exc_ou_init, + inh_ou_init, + exc_ou_mean, + inh_ou_mean, + tau_ou, + sigma_ou, + key, + ) + + # Iterating through time steps + (exc_history, inh_history, exc_ou, inh_ou, i), (excs_new, inhs_new) = jax.lax.scan( + update_step, + (exc_init, inh_init, exc_ou_init, inh_ou_init, startind), + xs=None, + length=len(t), + ) + + return ( + t, + jnp.concatenate((exc_init, excs_new.T), axis=1), + jnp.concatenate((inh_init, inhs_new.T), axis=1), + exc_ou, + inh_ou, + ) + + +def get_update_step( + startind, + t, + dt, + sqrt_dt, + N, + Cmat, + K_gl, + Dmat_ndt, + exc_init, + inh_init, + exc_ext_baseline, + inh_ext_baseline, + exc_ext, + inh_ext, + tau_exc, + tau_inh, + a_exc, + a_inh, + mu_exc, + mu_inh, + c_excexc, + c_excinh, + c_inhexc, + c_inhinh, + exc_ou_init, + inh_ou_init, + exc_ou_mean, + inh_ou_mean, + tau_ou, + sigma_ou, + key, +): key, subkey_exc = random.split(key) noise_exc = random.normal(subkey_exc, (N, len(t))) key, subkey_inh = random.split(key) @@ -180,7 +269,6 @@ def S_E(x): def S_I(x): return 1.0 / (1.0 + jnp.exp(-a_inh * (x - mu_inh))) - ### integrate ODE system: def update_step(state, _): exc_history, inh_history, exc_ou, inh_ou, i = state @@ -243,18 +331,4 @@ def update_step(state, _): (exc_new, inh_new), ) - # Iterating through time steps - (exc_history, inh_history, exc_ou, inh_ou, i), (excs_new, inhs_new) = jax.lax.scan( - update_step, - (exc_init, inh_init, exc_ou_init, inh_ou_init, startind), - xs=None, - length=len(t), - ) - - return ( - t, - jnp.concatenate((exc_init, excs_new.T), axis=1), - jnp.concatenate((inh_init, inhs_new.T), axis=1), - exc_ou, - inh_ou, - ) + return update_step diff --git a/neurolib/optimize/autodiff/aln_optimizer.py b/neurolib/optimize/autodiff/aln_optimizer.py new file mode 100644 index 00000000..123f8119 --- /dev/null +++ b/neurolib/optimize/autodiff/aln_optimizer.py @@ -0,0 +1,98 @@ +from jax import jit +from neurolib.models.jax.aln.timeIntegration import timeIntegration_args, timeIntegration_elementwise + +args_names = [ + "dt", + "duration", + "filter_sigma", + "Cmat", + "Dmat", + "c_gl", + "Ke_gl", + "tau_ou", + "sigma_ou", + "mue_ext_mean", + "mui_ext_mean", + "sigmae_ext", + "sigmai_ext", + "Ke", + "Ki", + "de", + "di", + "tau_se", + "tau_si", + "tau_de", + "tau_di", + "cee", + "cie", + "cii", + "cei", + "Jee_max", + "Jei_max", + "Jie_max", + "Jii_max", + "a", + "b", + "EA", + "tauA", + "C", + "gL", + "EL", + "DeltaT", + "VT", + "Vr", + "Vs", + "Tref", + "taum", + "mufe", + "mufi", + "IA_init", + "seem", + "seim", + "seev", + "seiv", + "siim", + "siem", + "siiv", + "siev", + "precalc_r", + "precalc_V", + "precalc_tau_mu", + "precalc_tau_sigma", + "dI", + "ds", + "sigmarange", + "Irange", + "N", + "Dmat_ndt", + "t", + "rates_exc_init", + "rates_inh_init", + "rd_exc", + "rd_inh", + "sqrt_dt", + "startind", + "ndt_de", + "ndt_di", + "mue_ou", + "mui_ou", + "ext_exc_rate", + "ext_inh_rate", + "ext_exc_current", + "ext_inh_current", + "key", +] + + +def get_loss(model_params, loss_f, opt_params): + args_values = timeIntegration_args(model_params) + args = dict(zip(args_names, args_values)) + + @jit + def loss(x): + args_local = args.copy() + args_local.update(dict(zip(opt_params, x))) + simulation_outputs = timeIntegration_elementwise(**args_local) + return loss_f(*simulation_outputs) + + return loss diff --git a/neurolib/optimize/autodiff/loss_functions/aln.py b/neurolib/optimize/autodiff/loss_functions/aln.py new file mode 100644 index 00000000..5605758b --- /dev/null +++ b/neurolib/optimize/autodiff/loss_functions/aln.py @@ -0,0 +1,19 @@ +def excitation_l1( + t, + exc, + inh, + mufe, + mufi, + IA, + seem, + seim, + siem, + siim, + seev, + seiv, + siev, + siiv, + mue_ou, + mui_ou, +): + return -exc.mean() diff --git a/neurolib/optimize/autodiff/wc_optimizer.py b/neurolib/optimize/autodiff/wc_optimizer.py new file mode 100644 index 00000000..1b3aa2b2 --- /dev/null +++ b/neurolib/optimize/autodiff/wc_optimizer.py @@ -0,0 +1,55 @@ +from jax import jit +from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise + +args_names = [ + "startind", + "t", + "dt", + "sqrt_dt", + "N", + "Cmat", + "K_gl", + "Dmat_ndt", + "exc_init", + "inh_init", + "exc_ext_baseline", + "inh_ext_baseline", + "exc_ext", + "inh_ext", + "tau_exc", + "tau_inh", + "a_exc", + "a_inh", + "mu_exc", + "mu_inh", + "c_excexc", + "c_excinh", + "c_inhexc", + "c_inhinh", + "exc_ou_init", + "inh_ou_init", + "exc_ou_mean", + "inh_ou_mean", + "tau_ou", + "sigma_ou", + "key", +] + + +# example usage: +# model = WCModel() +# wc_loss = get_loss(model.params, loss_f, ['exc_ext']) +# grad_wc_loss = jax.jit(jax.grad(wc_loss)) +# grad_wc_loss([exc_ext]) +def get_loss(model_params, loss_f, opt_params): + args_values = timeIntegration_args(model_params) + args = dict(zip(args_names, args_values)) + + @jit + def loss(x): + args_local = args.copy() + args_local.update(dict(zip(opt_params, x))) + simulation_outputs = timeIntegration_elementwise(**args_local) + return loss_f(*simulation_outputs) + + return loss