From 288ca67adaacebd7c19517ec05fc64b4c91679f3 Mon Sep 17 00:00:00 2001 From: Georg Reich <39133020+1b15@users.noreply.github.com> Date: Wed, 19 Jun 2024 17:31:27 +0200 Subject: [PATCH] prepare heun wc numerical integration --- neurolib/models/jax/wc/timeIntegration.py | 59 +++++++++++++++++++++-- neurolib/models/wc/loadDefaultParams.py | 2 + tests/test_jax.py | 3 ++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/neurolib/models/jax/wc/timeIntegration.py b/neurolib/models/jax/wc/timeIntegration.py index 42b8e89f..02cc9362 100644 --- a/neurolib/models/jax/wc/timeIntegration.py +++ b/neurolib/models/jax/wc/timeIntegration.py @@ -102,6 +102,8 @@ def timeIntegration_args(params): # ------------------------------------------------------------------------ + integration_method = params['integration_method'] + return ( startind, t, @@ -134,6 +136,7 @@ def timeIntegration_args(params): tau_ou, sigma_ou, key, + integration_method ) @@ -170,6 +173,7 @@ def timeIntegration_elementwise( tau_ou, sigma_ou, key, + integration_method ): update_step = get_update_step( @@ -204,6 +208,7 @@ def timeIntegration_elementwise( tau_ou, sigma_ou, key, + integration_method ) # Iterating through time steps @@ -222,7 +227,6 @@ def timeIntegration_elementwise( inh_ou, ) - def get_update_step( startind, t, @@ -255,6 +259,7 @@ def get_update_step( tau_ou, sigma_ou, key, + integration_method ): key, subkey_exc = random.split(key) noise_exc = random.normal(subkey_exc, (N, len(t))) @@ -269,7 +274,7 @@ def S_E(x): def S_I(x): return 1.0 / (1.0 + jnp.exp(-a_inh * (x - mu_inh))) - def update_step(state, _): + def step_rhs(state): exc_history, inh_history, exc_ou, inh_ou, i = state # Vectorized calculation of delayed excitatory input @@ -307,6 +312,15 @@ def update_step(state, _): + inh_ou # ou noise ) ) + + exc_ou_rhs = (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind] + inh_ou_rhs = (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind] + + return exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs + + def euler(state): + exc_rhs, inh_rhs, exc_ou_rhs, inh_ou_rhs = step_rhs(state) + exc_history, inh_history, exc_ou, inh_ou, i = state # Euler integration # make sure e and i variables do not exceed 1 (can only happen with noise) exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1) @@ -314,11 +328,48 @@ def update_step(state, _): # Update Ornstein-Uhlenbeck process for noise exc_ou = ( - exc_ou + (exc_ou_mean - exc_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_exc[:, i - startind] + exc_ou + exc_ou_rhs ) # mV/ms inh_ou = ( - inh_ou + (inh_ou_mean - inh_ou) * dt / tau_ou + sigma_ou * sqrt_dt * noise_inh[:, i - startind] + inh_ou + inh_ou_rhs + ) # mV/ms + + return exc_new, inh_new, exc_ou, inh_ou + + def heun(state): + # TODO + exc_k1, inh_k1, exc_ou_rhs, inh_ou_rhs = step_rhs(state) + + # Update Ornstein-Uhlenbeck process for noise + exc_ou = ( + exc_ou + exc_ou_rhs ) # mV/ms + inh_ou = ( + inh_ou + inh_ou_rhs + ) # mV/ms + + # make sure e and i variables do not exceed 1 (can only happen with noise) + exc_new = jnp.clip(exc_history[:, -1] + dt * exc_rhs, 0, 1) + inh_new = jnp.clip(inh_history[:, -1] + dt * inh_rhs, 0, 1) + + exc_k1_history = jnp.concatenate((exc_history[:, 1:], jnp.expand_dims(exc_new, axis=1)), axis=1) + inh_k1_history = jnp.concatenate((inh_history[:, 1:], jnp.expand_dims(inh_new, axis=1)), axis=1) + + new_state = exc_k1_history, inh_k1_history, exc_ou, inh_ou + exc_k2, inh_k2, _, _ = step_rhs(new_state) + exc_new = ... + inh_new = ... + return exc_new, inh_new, exc_ou, inh_ou + + def update_step(state, _): + exc_history, inh_history, exc_ou, inh_ou, i = state + if integration_method == 'euler': + integration_f = euler + else if integration_method == 'heun': + integration_f = heun + else: + raise Exception(f'Integration method {integration_method} not implemented.') + exc_new, inh_new, exc_ou, inh_ou = integration_f(state) return ( ( diff --git a/neurolib/models/wc/loadDefaultParams.py b/neurolib/models/wc/loadDefaultParams.py index d6d87bff..e2ded712 100644 --- a/neurolib/models/wc/loadDefaultParams.py +++ b/neurolib/models/wc/loadDefaultParams.py @@ -81,4 +81,6 @@ def loadDefaultParams(Cmat=None, Dmat=None, seed=None): params.exc_ou = np.zeros((params.N,)) params.inh_ou = np.zeros((params.N,)) + params.integration_method = 'euler' + return params diff --git a/tests/test_jax.py b/tests/test_jax.py index 7d0dadd8..94ae6ffb 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -26,6 +26,7 @@ def test_single_node_deterministic(self): model_jax = WCModel_jax(seed=0) model_jax.params["duration"] = 1.0 * 1000 model_jax.params["sigma_ou"] = 0.0 + model_jax.params['integration_method'] = 'euler' model_jax.run() @@ -48,6 +49,7 @@ def test_single_node_dist(self): model_jax = WCModel_jax() model_jax.params["duration"] = 5.0 * 1000 model_jax.params["sigma_ou"] = 0.01 + model_jax.params['integration_method'] = 'euler' model_jax.run() @@ -86,6 +88,7 @@ def test_network(self): model.params["duration"] = 10 * 1000 model.params["sigma_ou"] = 0.0 model.params["K_gl"] = 0.6 + model_jax.params['integration_method'] = 'euler' # local node input parameter model.params["exc_ext"] = 0.72