diff --git a/neurolib/control/optimal_control/oc.py b/neurolib/control/optimal_control/oc.py index 82fcc145..0af3d9b0 100644 --- a/neurolib/control/optimal_control/oc.py +++ b/neurolib/control/optimal_control/oc.py @@ -80,6 +80,7 @@ def solve_adjoint( dxdoth, state_vars, output_vars, + model_name=None, ): """Backwards integration of the adjoint state. @@ -120,9 +121,13 @@ def solve_adjoint( if sv == ov: fx_fullstate[:, sv_ind, :] = fx[:, ov_ind, :] + krange = range(state_dim[1]) + if model_name == "wongwang": + krange = range(state_dim[1] - 1, -1, -1) + for t in range(T - 2, -1, -1): # backwards iteration including 0th index for n in range(N): # iterate through nodes - for k in range(state_dim[1]): + for k in krange: if dxdoth[n, k, k] == 0: res = fx_fullstate[n, k, t + 1] res += adjoint_input(hx_list, del_list, t, T, state_dim[1], adjoint_state, n, k) @@ -367,7 +372,11 @@ def __init__( if k in weights.keys(): defaultweights[k] = weights[k] else: - print("Weight ", k, " not in provided weight dictionary. Use default value.") + print( + "Weight ", + k, + " not in provided weight dictionary. Use default value.", + ) self.weights = defaultweights @@ -471,14 +480,20 @@ def __init__( for v, iv in enumerate(self.model.input_vars): control[:, v, :] = self.model.params[iv] - self.control = control.copy() + self.control = control.copy() self.check_params() self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control, + 0.0, + np.zeros(control.shape), + self.maximum_control_strength, ) - self.model_params = self.get_model_params() + self.model_params = self.get_model_params() def check_params(self): """Checks a subset of parameters and throws an error if a wrong dimension is found.""" @@ -539,7 +554,7 @@ def get_xs(self): return xs - def get_xs_delay(self): + def get_xs_delayed(self): """Extract the complete state of the delayed dynamical system.""" maxdel = self.model.getMaxDelay() if maxdel == 0: @@ -683,6 +698,7 @@ def solve_adjoint(self): dxdoth, numba.typed.List(self.model.state_vars), numba.typed.List(self.model.output_vars), + self.model.name, ) def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient): @@ -721,7 +737,13 @@ def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient) # Inplace updating of models control bc. forward-sim relies on models parameters. self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + step, + cost_gradient, + self.maximum_control_strength, ) self.update_input() @@ -737,7 +759,13 @@ def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient) # cost. step = 0.0 # For later analysis only. self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, 0.0, np.zeros(control0.shape), self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + 0.0, + np.zeros(control0.shape), + self.maximum_control_strength, ) self.update_input() @@ -782,7 +810,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient): # Inplace updating of models control bc. forward-sim relies on models parameters self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + step, + cost_gradient, + self.maximum_control_strength, ) self.update_input() @@ -792,7 +826,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient): logging.info("Increasing step encountered NAN.") step /= factor_up # Undo the last step update by inverse operation. self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + step, + cost_gradient, + self.maximum_control_strength, ) self.update_input() break @@ -807,7 +847,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient): # then) and exit. step /= factor_up # Undo the last step update by inverse operation. self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + step, + cost_gradient, + self.maximum_control_strength, ) self.update_input() break @@ -850,7 +896,13 @@ def step_size(self, cost_gradient): while True: # Reduce the step size, if numerical instability occurs in the forward-simulation. # inplace updating of models control bc. forward-sim relies on models parameters self.control = update_control_with_limit( - self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength + self.N, + self.dim_in, + self.T, + control0, + step, + cost_gradient, + self.maximum_control_strength, ) self.update_input() @@ -909,7 +961,13 @@ def optimize(self, n_max_iterations): self.control_interval = convert_interval(self.control_interval, self.T) self.control = update_control_with_limit( - self.N, self.dim_in, self.T, self.control, 0.0, np.zeros(self.control.shape), self.maximum_control_strength + self.N, + self.dim_in, + self.T, + self.control, + 0.0, + np.zeros(self.control.shape), + self.maximum_control_strength, ) # To avoid issues in repeated executions. self.control = limit_control_to_interval(self.N, self.dim_in, self.T, self.control, self.control_interval) diff --git a/neurolib/control/optimal_control/oc_wc/oc_wc.py b/neurolib/control/optimal_control/oc_wc/oc_wc.py index 1329228a..bb9cb586 100644 --- a/neurolib/control/optimal_control/oc_wc/oc_wc.py +++ b/neurolib/control/optimal_control/oc_wc/oc_wc.py @@ -73,7 +73,7 @@ def Duh(self): """ xs = self.get_xs() - xsd = self.get_xs_delay() + xsd = self.get_xs_delayed() return Duh( self.model_params, @@ -116,7 +116,7 @@ def compute_hx(self): self.dim_vars, self.T, self.get_xs(), - self.get_xs_delay(), + self.get_xs_delayed(), self.control, self.state_vars_dict, ) @@ -140,7 +140,7 @@ def compute_hx_nw(self): self.T, xs[:, self.state_vars_dict["exc"], :], xs[:, self.state_vars_dict["inh"], :], - self.get_xs_delay()[:, self.state_vars_dict["exc"], :], + self.get_xs_delayed()[:, self.state_vars_dict["exc"], :], self.control[:, self.state_vars_dict["exc"], :], self.state_vars_dict, ) diff --git a/neurolib/control/optimal_control/oc_ww/__init__.py b/neurolib/control/optimal_control/oc_ww/__init__.py new file mode 100644 index 00000000..22b96d13 --- /dev/null +++ b/neurolib/control/optimal_control/oc_ww/__init__.py @@ -0,0 +1 @@ +from .oc_ww import OcWw diff --git a/neurolib/control/optimal_control/oc_ww/oc_ww.py b/neurolib/control/optimal_control/oc_ww/oc_ww.py new file mode 100644 index 00000000..96e0daf4 --- /dev/null +++ b/neurolib/control/optimal_control/oc_ww/oc_ww.py @@ -0,0 +1,173 @@ +import numba + +from neurolib.control.optimal_control.oc import OC +from neurolib.models.ww.timeIntegration import ( + compute_hx, + compute_hx_min1, + compute_hx_nw, + Duh, + Dxdoth, +) + + +class OcWw(OC): + """Class for optimal control specific to neurolib's implementation of the two-population Wong-Wang model + ("WWmodel"). + + :param model: Instance of Wong-Wang model (can describe a single Wong-Wang node or a network of coupled + Wong-Wang nodes. + :type model: neurolib.models.ww.model.WWModel + """ + + def __init__( + self, + model, + target, + weights=None, + print_array=[], + cost_interval=(None, None), + cost_matrix=None, + control_matrix=None, + M=1, + M_validation=0, + validate_per_step=False, + ): + super().__init__( + model, + target, + weights=weights, + print_array=print_array, + cost_interval=cost_interval, + cost_matrix=cost_matrix, + control_matrix=control_matrix, + M=M, + M_validation=M_validation, + validate_per_step=validate_per_step, + ) + + assert self.model.name == "wongwang" + + def compute_dxdoth(self): + """Derivative of systems dynamics wrt. change of systems variables.""" + return Dxdoth(self.N, self.dim_vars) + + def get_model_params(self): + """Model params as an ordered tuple. + + :rtype: tuple + """ + return ( + self.model.params.a_exc, + self.model.params.b_exc, + self.model.params.d_exc, + self.model.params.tau_exc, + self.model.params.gamma_exc, + self.model.params.w_exc, + self.model.params.exc_current_baseline, + self.model.params.a_inh, + self.model.params.b_inh, + self.model.params.d_inh, + self.model.params.tau_inh, + self.model.params.w_inh, + self.model.params.inh_current_baseline, + self.model.params.J_NMDA, + self.model.params.J_I, + self.model.params.w_ee, + ) + + def Duh(self): + """Jacobian of systems dynamics wrt. external control input. + + :return: N x 4 x 4 x T Jacobians. + :rtype: np.ndarray + """ + + xs = self.get_xs() + xsd = self.get_xs_delayed() + + return Duh( + self.model_params, + self.N, + self.dim_in, + self.dim_vars, + self.T, + self.control[:, self.state_vars_dict["r_exc"], :], + self.control[:, self.state_vars_dict["r_inh"], :], + xs[:, self.state_vars_dict["se"], :], + xs[:, self.state_vars_dict["si"], :], + self.model.params.K_gl, + self.model.params.Cmat, + self.Dmat_ndt, + xsd[:, self.state_vars_dict["se"], :], + self.state_vars_dict, + ) + + def compute_hx_list(self): + """List of Jacobians without and with time delays (e.g. in the ALN model) and list of respective time step delays as integers (0 for undelayed) + + :return: List of Jacobian matrices, list of time step delays + : rtype: List of np.ndarray, List of integers + """ + hx = self.compute_hx() + hx_min1 = self.compute_hx_min1() + return numba.typed.List([hx, hx_min1]), numba.typed.List([0, -1]) + + def compute_hx(self): + """Jacobians of WwModel wrt. all variables. + + :return: N x T x 6 x 6 Jacobians. + :rtype: np.ndarray + """ + return compute_hx( + self.model_params, + self.model.params.K_gl, + self.model.Cmat, + self.Dmat_ndt, + self.N, + self.dim_vars, + self.T, + self.get_xs(), + self.get_xs_delayed(), + self.control, + self.state_vars_dict, + ) + + def compute_hx_min1(self): + """Jacobians of WWModel dse/dre and dsi/dri. + Dependency is in same time step, so shift by -1 in time is required for OC computation. + + :return: N x T x 6 x 6 Jacobians. + :rtype: np.ndarray + """ + return compute_hx_min1( + self.model_params, + self.N, + self.dim_vars, + self.T, + self.get_xs(), + self.state_vars_dict, + ) + + def compute_hx_nw(self): + """Jacobians for each time step for the network coupling. + + :return: N x N x T x (4x4) array + :rtype: np.ndarray + """ + + xs = self.get_xs() + + return compute_hx_nw( + self.model_params, + self.model.params.K_gl, + self.model.Cmat, + self.Dmat_ndt, + self.N, + self.dim_vars, + self.T, + xs[:, self.state_vars_dict["se"], :], + xs[:, self.state_vars_dict["se"], :], + self.get_xs_delayed()[:, self.state_vars_dict["se"], :], + self.control[:, self.state_vars_dict["r_exc"], :], + self.state_vars_dict, + ) diff --git a/neurolib/models/aln/timeIntegration.py b/neurolib/models/aln/timeIntegration.py index fd953540..0b6a8283 100644 --- a/neurolib/models/aln/timeIntegration.py +++ b/neurolib/models/aln/timeIntegration.py @@ -732,8 +732,8 @@ def jacobian_aln( uri, nw_input, nw_input_sq, - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ): """Jacobian of the ALN dynamical system. @@ -754,10 +754,10 @@ def jacobian_aln( :type nw_input: float :param nw_input_sq: sum of all network inputs into current node at current time with squared prefactors :type nw_input_sq: float - :param re_del: E rate delayed by de - :type re_del: float - :param ri_del: I rate delayed by di - :type ri_del: float + :param re_delayed: E rate delayed by de + :type re_delayed: float + :param ri_delayed: I rate delayed by di + :type ri_delayed: float :param sv: dictionary of state vars and respective indices :type sv: dict @@ -806,15 +806,15 @@ def jacobian_aln( jacobian = np.zeros((V, V)) - z1ee = z1ee_f * re_del + nw_input + c_gl * Ke_gl * ure - z2ee = z2ee_f * re_del + nw_input_sq + c_gl**2 * Ke_gl * ure - z1ei = z1ei_f * ri_del - z2ei = z2ei_f * ri_del + z1ee = z1ee_f * re_delayed + nw_input + c_gl * Ke_gl * ure + z2ee = z2ee_f * re_delayed + nw_input_sq + c_gl**2 * Ke_gl * ure + z1ei = z1ei_f * ri_delayed + z2ei = z2ei_f * ri_delayed - z1ie = z1ie_f * re_del + c_gl * Ke_gl * uri - z2ie = z2ie_f * re_del + c_gl**2 * Ke_gl * uri - z1ii = z1ii_f * ri_del - z2ii = z2ii_f * ri_del + z1ie = z1ie_f * re_delayed + c_gl * Ke_gl * uri + z2ie = z2ie_f * re_delayed + c_gl**2 * Ke_gl * uri + z1ii = z1ii_f * ri_delayed + z2ii = z2ii_f * ri_delayed sig_ee_den = (1 + z1ee) * taum + tau_se sig_ei_den = (1 + z1ei) * taum + tau_si @@ -1004,7 +1004,7 @@ def compute_hx( ui = control[n, sv["rates_inh"], t] ure = control[n, sv["mufe"], t] uri = control[n, sv["mufi"], t] - re_del, ri_del = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] + re_delayed, ri_delayed = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] hx[n, t, :, :] = jacobian_aln( model_params, precomp_factors, @@ -1016,8 +1016,8 @@ def compute_hx( uri, nw_input[n, t], nw_input_sq[n, t], - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ) @@ -1056,8 +1056,8 @@ def jacobian_de( ure, uri, nw_input, - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ): """Jacobian of the ALN dynamical system wrt relations with delay de @@ -1076,10 +1076,10 @@ def jacobian_de( :type ui: float :param nw_input: sum of all network inputs into current node at current time :type nw_input: float - :param re_del: E rate delayed by de - :type re_del: float - :param ri_del: I rate delayed by di - :type ri_del: float + :param re_delayed: E rate delayed by de + :type re_delayed: float + :param ri_delayed: I rate delayed by di + :type ri_delayed: float :param sv: dictionary of state vars and respective indices :type sv: dict @@ -1127,10 +1127,10 @@ def jacobian_de( jacobian = np.zeros((V, V)) - z1ee = z1ee_f * re_del + nw_input + c_gl * Ke_gl * ure # factors 1e-3 are in z1ee_f and in ne_input - z1ei = z1ei_f * ri_del - z1ie = z1ie_f * re_del + c_gl * Ke_gl * uri - z1ii = z1ii_f * ri_del + z1ee = z1ee_f * re_delayed + nw_input + c_gl * Ke_gl * ure # factors 1e-3 are in z1ee_f and in ne_input + z1ei = z1ei_f * ri_delayed + z1ie = z1ie_f * re_delayed + c_gl * Ke_gl * uri + z1ii = z1ii_f * ri_delayed sig_ee_den = (1 + z1ee) * taum + tau_se sig_ei_den = (1 + z1ei) * taum + tau_si @@ -1283,7 +1283,7 @@ def compute_hx_de( ui = control[n, sv["rates_inh"], t] ure = control[n, sv["mufe"], t] uri = control[n, sv["mufi"], t] - re_del, ri_del = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] + re_delayed, ri_delayed = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] hx[n, t, :, :] = jacobian_de( model_params, precomp_factors, @@ -1294,8 +1294,8 @@ def compute_hx_de( ure, uri, nw_input[n, t], - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ) @@ -1313,8 +1313,8 @@ def jacobian_di( ure, uri, nw_input, - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ): """Jacobian of the ALN dynamical system wrt relations with delay di @@ -1332,10 +1332,10 @@ def jacobian_di( :type ui: float :param nw_input: sum of all network inputs into current node at current time :type nw_input: float - :param re_del: E rate delayed by de - :type re_del: float - :param ri_del: I rate delayed by di - :type ri_del: float + :param re_delayed: E rate delayed by de + :type re_delayed: float + :param ri_delayed: I rate delayed by di + :type ri_delayed: float :param sv: dictionary of state vars and respective indices :type sv: dict @@ -1383,10 +1383,10 @@ def jacobian_di( jacobian = np.zeros((V, V)) - z1ee = z1ee_f * re_del + nw_input + c_gl * Ke_gl * ure - z1ei = z1ei_f * ri_del - z1ie = z1ie_f * re_del + c_gl * Ke_gl * uri - z1ii = z1ii_f * ri_del + z1ee = z1ee_f * re_delayed + nw_input + c_gl * Ke_gl * ure + z1ei = z1ei_f * ri_delayed + z1ie = z1ie_f * re_delayed + c_gl * Ke_gl * uri + z1ii = z1ii_f * ri_delayed sig_ee_den = (1 + z1ee) * taum + tau_se sig_ei_den = (1 + z1ei) * taum + tau_si @@ -1541,7 +1541,7 @@ def compute_hx_di( ui = control[n, sv["rates_inh"], t] ure = control[n, sv["mufe"], t] uri = control[n, sv["mufi"], t] - re_del, ri_del = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] + re_delayed, ri_delayed = dyn_vars[n, sv["rates_exc"], t - ndt_de], dyn_vars[n, sv["rates_inh"], t - ndt_di] hx[n, t, :, :] = jacobian_di( model_params, precomp_factors, @@ -1552,8 +1552,8 @@ def compute_hx_di( ure, uri, nw_input[n, t], - re_del, - ri_del, + re_delayed, + ri_delayed, sv, ) @@ -1617,7 +1617,7 @@ def compute_hx_nw( if cmat[n1, n2] == 0.0: continue for t in range(T): - re_del, ri_del = dyn_vars[n1, sv["rates_exc"], t - ndt_de], dyn_vars[n1, sv["rates_inh"], t - ndt_di] + re_delayed, ri_delayed = dyn_vars[n1, sv["rates_exc"], t - ndt_de], dyn_vars[n1, sv["rates_inh"], t - ndt_di] ue = control[n1, sv["rates_exc"], t] ure = control[n1, sv["mufe"], t] hx_nw[n1, n2, t, :, :] = jacobian_nw( @@ -1625,8 +1625,8 @@ def compute_hx_nw( precomp_factors, V, dyn_vars[n1, :, t], - re_del, - ri_del, + re_delayed, + ri_delayed, nw_input[n1, t], cmat[n1, n2], ue, @@ -1643,8 +1643,8 @@ def jacobian_nw( precomp_factors, V, fullstate, - re_del, - ri_del, + re_delayed, + ri_delayed, nw_input, cmat_entry, ue, @@ -1661,10 +1661,10 @@ def jacobian_nw( :type V: int :param fullstate: Value of all V=16 dynamical variables at given time :type fullstate: np.ndarray - :param re_del: E rate delayed by de - :type re_del: float - :param ri_del: I rate delayed by di - :type ri_del: float + :param re_delayed: E rate delayed by de + :type re_delayed: float + :param ri_delayed: I rate delayed by di + :type ri_delayed: float :param nw_input: sum of all network inputs into current node at current time :type nw_input: float :param cmat_entry: Entry of the connectivity matrix at n1, n2 @@ -1719,8 +1719,8 @@ def jacobian_nw( jac_nw = np.zeros((V, V)) - z1ee = z1ee_f * re_del + nw_input + c_gl * Ke_gl * ure - z1ei = z1ei_f * ri_del + z1ee = z1ee_f * re_delayed + nw_input + c_gl * Ke_gl * ure + z1ei = z1ei_f * ri_delayed sig_ee_den = (1 + z1ee) * taum + tau_se sig_ei_den = (1 + z1ei) * taum + tau_si diff --git a/neurolib/models/wc/timeIntegration.py b/neurolib/models/wc/timeIntegration.py index b5a8d3ec..be57bb32 100644 --- a/neurolib/models/wc/timeIntegration.py +++ b/neurolib/models/wc/timeIntegration.py @@ -313,20 +313,20 @@ def jacobian_wc( :param model_params: Tuple of parameters in the WC Model in order :type model_params: tuple of float - :param nw_e: N x T input of network into each node's 'exc' - :type nw_e: np.ndarray - :param e: Value of the E-variable at specific time. - :type e: float - :param i: Value of the I-variable at specific time. - :type i: float - :param ue: N x T combined input of 'background' and 'control' into 'exc'. - :type ue: np.ndarray - :param ui: N x T combined input of 'background' and 'control' into 'inh'. - :type ui: np.ndarray - :param V: Number of system variables. - :type V: int - :param sv: dictionary of state vars and respective indices - :type sv: dict + :param nw_e: N x T input of network into each node's 'exc' + :type nw_e: np.ndarray + :param e: Value of the E-variable at specific time. + :type e: float + :param i: Value of the I-variable at specific time. + :type i: float + :param ue: Value of control input to into 'exc' at specific time. + :type ue: float + :param ui: Value of control input to into 'ihn' at specific time. + :type ui: float + :param V: Number of system variables. + :type V: int + :param sv: dictionary of state vars and respective indices + :type sv: dict :return: 4 x 4 Jacobian matrix. :rtype: np.ndarray @@ -372,7 +372,7 @@ def compute_hx( V, T, dyn_vars, - dyn_vars_delay, + dyn_vars_delayed, control, sv, ): @@ -394,8 +394,8 @@ def compute_hx( :type T: int :param dyn_vars: N x V x T array containing all values of 'exc' and 'inh'. :type dyn_vars: np.ndarray - :param dyn_vars_delay: - :type dyn_vars_delay: np.ndarray + :param dyn_vars_delayed: N x V x T array containing all values of delayed 'exc' and 'inh'. + :type dyn_vars_delayed: np.ndarray :param control: N x 2 x T control inputs to 'exc' and 'inh'. :type control: np.ndarray :param sv: dictionary of state vars and respective indices @@ -405,7 +405,7 @@ def compute_hx( :rtype: np.ndarray """ hx = np.zeros((N, T, V, V)) - nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, dyn_vars_delay[:, sv["exc"], :]) + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, dyn_vars_delayed[:, sv["exc"], :]) for n in range(N): for t in range(T): @@ -465,7 +465,7 @@ def compute_hx_nw( T, e, i, - e_delay, + e_delayed, ue, sv, ): @@ -485,14 +485,16 @@ def compute_hx_nw( :type V: int :param T: Length of simulation (time dimension). :type T: int - :param e: Value of the E-variable at specific time. - :type e: float - :param i: Value of the I-variable at specific time. - :type i: float - :param ue: N x T array of the total input received by 'exc' population in every node at any time. - :type ue: np.ndarray - :param sv: dictionary of state vars and respective indices - :type sv: dict + :param e: Value of the E-variable at specific time. + :type e: float + :param i: Value of the I-variable at specific time. + :type i: float + :param e_delayed: Value of the delayed E-variable at specific time. + :type e_delayed: float + :param ue: N x T array of the total input received by 'exc' population in every node at any time. + :type ue: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict :return: Jacobians for network connectivity in all time steps. :rtype: np.ndarray of shape N x N x T x 4 x 4 @@ -513,17 +515,14 @@ def compute_hx_nw( ) = model_params hx_nw = np.zeros((N, N, T, V, V)) - nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, e_delay) + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, e_delayed) exc_input = c_excexc * e - c_inhexc * i + nw_e + exc_ext_baseline + ue for n1 in range(N): for n2 in range(N): for t in range(T - 1): hx_nw[n1, n2, t, sv["exc"], sv["exc"]] = ( - (1.0 - e[n1, t]) - * logistic_der(exc_input[n1, t], a_exc, mu_exc) - * K_gl - * cmat[n1, n2] + (1.0 - e[n1, t]) * logistic_der(exc_input[n1, t], a_exc, mu_exc) * K_gl * cmat[n1, n2] ) / tau_exc return -hx_nw @@ -543,7 +542,7 @@ def Duh( K_gl, cmat, dmat_ndt, - exc_values, + exc_delayed, sv, ): """Jacobian of systems dynamics wrt. external inputs (control signals). @@ -574,8 +573,8 @@ def Duh( :type cmat: np.ndarray :param dmat_ndt: delay index matrix :type dmat_ndt: np.ndarray - :param exc_values: N x T array containing values of 'exc' of all nodes through time. - :type exc_values: np.ndarray + :param exc_delayed: N x T array containing values of 'exc' of all nodes through time incl. delay + :type exc_delayed: np.ndarray :param sv: dictionary of state vars and respective indices :type sv: dict @@ -597,7 +596,7 @@ def Duh( inh_ext_baseline, ) = model_params - nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, exc_values) + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, exc_delayed) duh = np.zeros((N, V_vars, V_in, T)) for t in range(T): diff --git a/neurolib/models/ww/loadDefaultParams.py b/neurolib/models/ww/loadDefaultParams.py index 974e3ab3..3f13dc51 100644 --- a/neurolib/models/ww/loadDefaultParams.py +++ b/neurolib/models/ww/loadDefaultParams.py @@ -82,8 +82,10 @@ def loadDefaultParams(Cmat=None, Dmat=None, seed=None): # ------------------------------------------------------------------------ - params.ses_init = 0.05 * np.random.uniform(0, 1, (params.N, 1)) - params.sis_init = 0.05 * np.random.uniform(0, 1, (params.N, 1)) + params.r_exc_init = 0.1 * np.random.uniform(0, 1, (params.N, 1)) + params.r_inh_init = 0.1 * np.random.uniform(0, 1, (params.N, 1)) + params.se_init = 0.05 * np.random.uniform(0, 1, (params.N, 1)) + params.si_init = 0.05 * np.random.uniform(0, 1, (params.N, 1)) # Ornstein-Uhlenbeck noise state variables params.exc_ou = np.zeros((params.N,)) diff --git a/neurolib/models/ww/model.py b/neurolib/models/ww/model.py index 12d72ebc..df6967b4 100644 --- a/neurolib/models/ww/model.py +++ b/neurolib/models/ww/model.py @@ -26,13 +26,14 @@ class WWModel(Model): name = "wongwang" description = "Wong-Wang neural mass model" - init_vars = ["r_exc", "r_inh", "ses_init", "sis_init", "exc_ou", "inh_ou"] + init_vars = ["r_exc_init", "r_inh_init", "se_init", "si_init", "exc_ou", "inh_ou"] state_vars = ["r_exc", "r_inh", "se", "si", "exc_ou", "inh_ou"] output_vars = ["r_exc", "r_inh", "se", "si"] default_output = "r_exc" + input_vars = ["exc_current", "inh_current"] + default_input = "exc_current" def __init__(self, params=None, Cmat=None, Dmat=None, seed=None): - self.Cmat = Cmat self.Dmat = Dmat self.seed = seed diff --git a/neurolib/models/ww/timeIntegration.py b/neurolib/models/ww/timeIntegration.py index ae213ac1..bf4f395a 100644 --- a/neurolib/models/ww/timeIntegration.py +++ b/neurolib/models/ww/timeIntegration.py @@ -108,13 +108,13 @@ def timeIntegration(params): # ------------------------------------------------------------------------ # Set initial values # if initial values are just a Nx1 array - if np.shape(params["ses_init"])[1] == 1: - ses_init = np.dot(params["ses_init"], np.ones((1, startind))) - sis_init = np.dot(params["sis_init"], np.ones((1, startind))) + if np.shape(params["se_init"])[1] == 1: + ses_init = np.dot(params["se_init"], np.ones((1, startind))) + sis_init = np.dot(params["si_init"], np.ones((1, startind))) # if initial values are a Nxt array else: - ses_init = params["ses_init"][:, -startind:] - sis_init = params["sis_init"][:, -startind:] + ses_init = params["se_init"][:, -startind:] + sis_init = params["si_init"][:, -startind:] # xsd = np.zeros((N,N)) # delayed activity ses_input_d = np.zeros(N) # delayed input to x @@ -290,3 +290,478 @@ def r(I, a, b, d): ) # mV/ms return t, r_exc, r_inh, ses, sis, exc_ou, inh_ou + + +@numba.njit +def logistic(x, a, b, d): + """Logistic function evaluated at point 'x'. + + :type x: float + :param a: Parameter of logistic function. + :type a: float + :param b: Parameter of logistic function. + :type b: float + :param d: Parameter of logistic function. + :type d: float + + :rtype: float + """ + return (a * x - b) / (1.0 - np.exp(-d * (a * x - b))) + + +@numba.njit +def logistic_der(x, a, b, d): + """Derivative of logistic function, evaluated at point 'x'. + + :type x: float + :param a: Parameter of logistic function. + :type a: float + :param b: Parameter of logistic function. + :type b: float + :param d: Parameter of logistic function. + :type d: float + + :rtype: float + """ + exp = np.exp(-d * (a * x - b)) + return (a * (1.0 - exp) - (a * x - b) * d * a * exp) / (1.0 - exp) ** 2 + + +@numba.njit +def jacobian_ww( + model_params, + nw_se, + re, + se, + si, + ue, + ui, + V, + sv, +): + """Jacobian of the WW dynamical system. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param nw_se: N x T input of network into each node's 'exc' + :type nw_se: np.ndarray + :param re: Value of the r_exc-variable at specific time. + :type re: float + :param se: Value of the se-variable at specific time. + :type se: float + :param si: Value of the si-variable at specific time. + :type si: float + :param ue: Value of control input to into 'exc' at specific time. + :type ue: float + :param ui: Value of control input to into 'ihn' at specific time. + :type ui: float + :param V: Number of system variables. + :type V: int + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: 4 x 4 Jacobian matrix. + :rtype: np.ndarray + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + jacobian = np.zeros((V, V)) + IE = w_exc * (exc_current_baseline + ue) + w_ee * J_NMDA * se - J_I * si + J_NMDA * nw_se + jacobian[sv["r_exc"], sv["se"]] = -logistic_der(IE, a_exc, b_exc, d_exc) * w_ee * J_NMDA + jacobian[sv["r_exc"], sv["si"]] = logistic_der(IE, a_exc, b_exc, d_exc) * J_I + II = w_inh * (inh_current_baseline + ui) + J_NMDA * se - si + jacobian[sv["r_inh"], sv["se"]] = -logistic_der(II, a_inh, b_inh, d_inh) * J_NMDA + jacobian[sv["r_inh"], sv["si"]] = logistic_der(II, a_inh, b_inh, d_inh) + + # jacobian[sv["se"], sv["r_exc"]] = -(1.0 - se) * gamma_exc + jacobian[sv["se"], sv["se"]] = 1.0 / tau_exc + gamma_exc * re + + # jacobian[sv["si"], sv["r_inh"]] = -1.0 + jacobian[sv["si"], sv["si"]] = 1.0 / tau_inh + return jacobian + + +@numba.njit +def compute_hx( + wc_model_params, + K_gl, + cmat, + dmat_ndt, + N, + V, + T, + dyn_vars, + dyn_vars_delayed, + control, + sv, +): + """Jacobians of WWModel wrt. the all variables for each time step. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param dyn_vars: N x V x T array containing all values of 'exc' and 'inh'. + :type dyn_vars: np.ndarray + :param dyn_vars_delayed: N x V x T array containing all delayed values of 'exc' and 'inh'. + :type dyn_vars_delayed: np.ndarray + :param control: N x 2 x T control inputs to 'exc' and 'inh'. + :type control: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: N x T x 4 x 4 Jacobians. + :rtype: np.ndarray + """ + hx = np.zeros((N, T, V, V)) + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, dyn_vars_delayed[:, sv["se"], :]) + + for n in range(N): + for t in range(T): + re = dyn_vars[n, sv["r_exc"], t] + se = dyn_vars[n, sv["se"], t] + si = dyn_vars[n, sv["si"], t] + ue = control[n, sv["r_exc"], t] + ui = control[n, sv["r_inh"], t] + hx[n, t, :, :] = jacobian_ww( + wc_model_params, + nw_e[n, t], + re, + se, + si, + ue, + ui, + V, + sv, + ) + return hx + + +@numba.njit +def jacobian_ww_min1( + model_params, + se, + V, + sv, +): + """Jacobian of the WW dynamical system. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param se: Value of the se-variable at specific time. + :type se: float + :param V: Number of system variables. + :type V: int + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: 4 x 4 Jacobian matrix. + :rtype: np.ndarray + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + jacobian = np.zeros((V, V)) + + jacobian[sv["se"], sv["r_exc"]] = -(1.0 - se) * gamma_exc + jacobian[sv["si"], sv["r_inh"]] = -1.0 + return jacobian + + +@numba.njit +def compute_hx_min1( + wc_model_params, + N, + V, + T, + dyn_vars, + sv, +): + """Jacobians of WWModel wrt. the all variables for each time step. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param dyn_vars: N x V x T array containing all values of 'exc' and 'inh'. + :type dyn_vars: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: N x T x 4 x 4 Jacobians. + :rtype: np.ndarray + """ + hx = np.zeros((N, T, V, V)) + + for n in range(N): + for t in range(T): + se = dyn_vars[n, sv["se"], t] + hx[n, t, :, :] = jacobian_ww_min1( + wc_model_params, + se, + V, + sv, + ) + return hx + + +@numba.njit +def compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se): + """Compute input by other nodes of network into each node's 'exc' population at every timestep. + + :param N: Number of nodes in the network. + :type N: int + :param T: Length of simulation (time dimension). + :type T: int + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param se: N x T array containing values of 'exc' of all nodes through time. + :type se: np.ndarray + :return: N x T network inputs. + :rytpe: np.ndarray + """ + nw_input = np.zeros((N, T)) + + for t in range(1, T): + for n in range(N): + for l in range(N): + nw_input[n, t] += K_gl * cmat[n, l] * (se[l, t - dmat_ndt[n, l] - 1]) + return nw_input + + +@numba.njit +def compute_hx_nw( + model_params, + K_gl, + cmat, + dmat_ndt, + N, + V, + T, + se, + si, + se_delayed, + ue, + sv, +): + """Jacobians for network connectivity in all time steps. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param se: Array of the se-variable. + :type se: np.ndarray + :param si: Array of the se-variable. + :type si: np.ndarray + :param se_delayed: Value of delayed se-variable. + :type se_delayed: np.ndarray + :param ue: N x T array of the total input received by 'exc' population in every node at any time. + :type ue: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: Jacobians for network connectivity in all time steps. + :rtype: np.ndarray of shape N x N x T x 4 x 4 + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + hx_nw = np.zeros((N, N, T, V, V)) + + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se_delayed) + IE = w_exc * (exc_current_baseline + ue) + w_ee * J_NMDA * se - J_I * si + J_NMDA * nw_e + + for n1 in range(N): + for n2 in range(N): + for t in range(T - 1): + hx_nw[n1, n2, t, sv["r_exc"], sv["se"]] = ( + logistic_der(IE[n1, t], a_exc, b_exc, d_exc) * J_NMDA * K_gl * cmat[n1, n2] + ) + + return -hx_nw + + +@numba.njit +def Duh( + model_params, + N, + V_in, + V_vars, + T, + ue, + ui, + se, + si, + K_gl, + cmat, + dmat_ndt, + se_delayed, + sv, +): + """Jacobian of systems dynamics wrt. external inputs (control signals). + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param N: Number of nodes in the network. + :type N: int + :param V_in: Number of input variables. + :type V_in: int + :param V_vars: Number of system variables. + :type V_vars: int + :param T: Length of simulation (time dimension). + :type T: int + :param nw_e: N x T input of network into each node's 'exc' + :type nw_e: np.ndarray + :param ue: N x T array of the total input received by 'exc' population in every node at any time. + :type ue: np.ndarray + :param ui: N x T array of the total input received by 'inh' population in every node at any time. + :type ui: np.ndarray + :param se: Value of the se-variable for each node and timepoint + :type se: np.ndarray + :param si: Value of the si-variable for each node and timepoint + :type si: np.ndarray + :param K_gl: global coupling strength + :type K_gl float + :param cmat: coupling matrix + :type cmat: np.ndarray + :param dmat_ndt: delay index matrix + :type dmat_ndt: np.ndarray + :param se_delayed: N x T array containing values of 'exc' of all nodes through time. + :type se_delayed: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :rtype: np.ndarray of shape N x V x V x T + """ + + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se_delayed) + + duh = np.zeros((N, V_vars, V_in, T)) + for t in range(T): + for n in range(N): + IE = ( + w_exc * (exc_current_baseline + ue[n, t]) + + w_ee * J_NMDA * se[n, t] + - J_I * si[n, t] + + J_NMDA * nw_e[n, t] + ) + duh[n, sv["r_exc"], sv["r_exc"], t] = -logistic_der(IE, a_exc, b_exc, d_exc) * w_exc + II = w_inh * (inh_current_baseline + ui[n, t]) + J_NMDA * se[n, t] - si[n, t] + duh[n, sv["r_inh"], sv["r_inh"], t] = -logistic_der(II, a_inh, b_inh, d_inh) * w_inh + return duh + + +@numba.njit +def Dxdoth(N, V): + """Derivative of system dynamics wrt x dot + + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + + :return: N x V x V matrix. + :rtype: np.ndarray + """ + dxdoth = np.zeros((N, V, V)) + for n in range(N): + for v in range(2, V): + dxdoth[n, v, v] = 1 + + return dxdoth diff --git a/tests/control/optimal_control/test_oc_utils.py b/tests/control/optimal_control/test_oc_utils.py index f880908c..fc78d0e5 100644 --- a/tests/control/optimal_control/test_oc_utils.py +++ b/tests/control/optimal_control/test_oc_utils.py @@ -120,20 +120,50 @@ def gettarget_1n(model): return np.concatenate( ( - np.concatenate((model.params[model.init_vars[0]], model.params[model.init_vars[0]]), axis=1)[ - :, :, np.newaxis - ], + np.concatenate( + (model.params[model.init_vars[0]], model.params[model.init_vars[0]]), + axis=1, + )[:, :, np.newaxis], np.stack((model[model.state_vars[0]], model[model.state_vars[1]]), axis=1), ), axis=2, ) +def gettarget_1n_ww(model): + return np.concatenate( + ( + np.concatenate( + ( + model.params[model.init_vars[0]], + model.params[model.init_vars[0]], + model.params[model.init_vars[2]], + model.params[model.init_vars[3]], + ), + axis=1, + )[:, :, np.newaxis], + np.stack( + ( + model[model.state_vars[0]], + model[model.state_vars[1]], + model[model.state_vars[2]], + model[model.state_vars[3]], + ), + axis=1, + ), + ), + axis=2, + ) + + def gettarget_2n(model): return np.concatenate( ( np.stack( - (model.params[model.init_vars[0]][:, -1], model.params[model.init_vars[1]][:, -1]), + ( + model.params[model.init_vars[0]][:, -1], + model.params[model.init_vars[1]][:, -1], + ), axis=1, )[:, :, np.newaxis], np.stack((model[model.state_vars[0]], model[model.state_vars[1]]), axis=1), @@ -142,6 +172,32 @@ def gettarget_2n(model): ) +def gettarget_2n_ww(model): + return np.concatenate( + ( + np.stack( + ( + model.params[model.init_vars[0]][:, -1], + model.params[model.init_vars[1]][:, -1], + model.params[model.init_vars[2]][:, -1], + model.params[model.init_vars[3]][:, -1], + ), + axis=1, + )[:, :, np.newaxis], + np.stack( + ( + model[model.state_vars[0]], + model[model.state_vars[1]], + model[model.state_vars[2]], + model[model.state_vars[3]], + ), + axis=1, + ), + ), + axis=2, + ) + + def setinitzero_1n(model): for init_var in model.init_vars: if "ou" in init_var: diff --git a/tests/control/optimal_control/test_oc_wc.py b/tests/control/optimal_control/test_oc_wc.py index 4b7e21e5..ec4ab4c3 100644 --- a/tests/control/optimal_control/test_oc_wc.py +++ b/tests/control/optimal_control/test_oc_wc.py @@ -84,7 +84,6 @@ def test_2n(self): control_mat[0, 0] = 1.0 cost_mat[1, 0] = 1.0 - model.params.coupling = "additive" # test additive in undelayed network, diffusive in delayed network model.params["exc_ext"] = p.TEST_INPUT_2N_6 model.params["inh_ext"] = p.ZERO_INPUT_2N_6 model.run() @@ -101,7 +100,8 @@ def test_2n(self): model_controlled.maximum_control_strength = 2.0 model_controlled.control = np.concatenate( - [p.INIT_INPUT_2N_6[:, np.newaxis, :], p.ZERO_INPUT_2N_6[:, np.newaxis, :]], axis=1 + [p.INIT_INPUT_2N_6[:, np.newaxis, :], p.ZERO_INPUT_2N_6[:, np.newaxis, :]], + axis=1, ) model_controlled.update_input() @@ -137,8 +137,6 @@ def test_2n_delay(self): control_mat[0, 0] = 1.0 cost_mat[1, 0] = 1.0 - model.params.coupling = "diffusive" # test additive in undelayed network, diffusive in delayed network - model.params["exc_ext"] = p.TEST_INPUT_2N_8 model.params["inh_ext"] = p.ZERO_INPUT_2N_8 @@ -156,7 +154,8 @@ def test_2n_delay(self): model_controlled.maximum_control_strength = 2.0 model_controlled.control = np.concatenate( - [p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], axis=1 + [p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], + axis=1, ) model_controlled.update_input() diff --git a/tests/control/optimal_control/test_oc_ww.py b/tests/control/optimal_control/test_oc_ww.py new file mode 100644 index 00000000..1005e508 --- /dev/null +++ b/tests/control/optimal_control/test_oc_ww.py @@ -0,0 +1,213 @@ +import unittest +import numpy as np + +from neurolib.models.ww import WWModel +from neurolib.control.optimal_control import oc_ww + +import test_oc_utils as test_oc_utils + +p = test_oc_utils.params + + +class TestWW(unittest.TestCase): + """ + Test ww in neurolib/optimal_control/ + """ + + # tests if the control from OC computation coincides with a random input used for target forward-simulation + # single-node case + def test_1n(self): + print("Test OC in single-node system") + model = WWModel() + test_oc_utils.setinitzero_1n(model) + model.params["duration"] = p.TEST_DURATION_6 + # decrease time scale of sigmoidal function + # model.params["d_exc"] = 1.0 + # model.params["d_inh"] = 1.0 + + for input_channel in [0, 1]: + for measure_channel in range(4): + print("input_channel, measure_channel = ", input_channel, measure_channel) + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, input_channel] = 1.0 # only allow inputs to input_channel + cost_mat[0, measure_channel] = 1.0 # only measure other channel + + test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6) + model.params[model.input_vars[input_channel]] = p.TEST_INPUT_1N_6 + model.run() + target = test_oc_utils.gettarget_1n_ww(model) + + test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6) + + model_controlled = oc_ww.OcWw(model, target) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [ + control_mat[0, 0] * p.INIT_INPUT_1N_6[:, np.newaxis, :], + control_mat[0, 1] * p.INIT_INPUT_1N_6[:, np.newaxis, :], + ], + axis=1, + ) + + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + + c_diff = np.abs(model_controlled.control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]) + + if np.amax(c_diff) < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + def test_2n(self): + print("Test OC in 2-node network") + ### communication between E and I is validated in test_onenode_oc. Test only E-E communication + ### Because of symmetry, test only inputs to 0 node, precision measuement in 1 node + + dmat = np.array([[0.0, 0.0], [0.0, 0.0]]) # no delay + cmat = np.array([[0.0, 1.0], [1.0, 0.0]]) + + model = WWModel(Cmat=cmat, Dmat=dmat) + test_oc_utils.setinitzero_2n(model) + model.params.duration = p.TEST_DURATION_10 + + # decrease time scale of sigmoidal function + # model.params["d_exc"] = 1.0 + # model.params["d_inh"] = 1.0 + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, 0] = 1.0 + cost_mat[1, 0] = 1.0 + + model.params["exc_current"] = p.TEST_INPUT_2N_10 + model.params["inh_current"] = p.ZERO_INPUT_2N_10 + model.run() + + target = test_oc_utils.gettarget_2n_ww(model) + model.params["exc_current"] = p.ZERO_INPUT_2N_10 + + model_controlled = oc_ww.OcWw( + model, + target, + control_matrix=control_mat, + cost_matrix=cost_mat, + ) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [ + p.INIT_INPUT_2N_10[:, np.newaxis, :], + p.ZERO_INPUT_2N_10[:, np.newaxis, :], + ], + axis=1, + ) + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + c_diff = np.abs(model_controlled.control[0, 0, :] - p.TEST_INPUT_2N_10[0, :]) + if np.amax(c_diff) < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + # tests if the control from OC computation coincides with a random input used for target forward-simulation + # delayed network case + def test_2n_delay(self): + print("Test OC in delayed 2-node network") + + cmat = np.array([[0.0, 0.0], [1.0, 0.0]]) + dmat = np.array([[0.0, 0.0], [p.TEST_DELAY, 0.0]]) + + model = WWModel(Cmat=cmat, Dmat=dmat) + test_oc_utils.setinitzero_2n(model) + model.params.duration = p.TEST_DURATION_8 + model.params.signalV = 1.0 + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, 0] = 1.0 + cost_mat[1, 0] = 1.0 + + model.params["exc_current"] = p.TEST_INPUT_2N_8 + model.params["inh_current"] = p.ZERO_INPUT_2N_8 + + model.run() + + target = test_oc_utils.gettarget_2n_ww(model) + model.params["exc_current"] = p.ZERO_INPUT_2N_8 + + model_controlled = oc_ww.OcWw( + model, + target, + control_matrix=control_mat, + cost_matrix=cost_mat, + ) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], + axis=1, + ) + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + + # last entries of adjoint_state[0,0,:] are zero + self.assertTrue(np.amax(np.abs(model_controlled.adjoint_state[0, 0, -model.getMaxDelay() :])) == 0.0) + + c_diff_max = np.amax(np.abs(model_controlled.control[0, 0, :] - p.TEST_INPUT_2N_8[0, :])) + if c_diff_max < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + # Arbitrary network and control setting, get_xs() returns correct array shape (despite initial values array longer than 1) + def test_get_xs(self): + print("Test state shape agrees with target shape") + + cmat = np.array([[0.0, 1.0], [1.0, 0.0]]) + dmat = np.array([[0.0, 0.0], [0.0, 0.0]]) # no delay + model = WWModel(Cmat=cmat, Dmat=dmat) + model.params.duration = p.TEST_DURATION_6 + test_oc_utils.set_input(model, p.TEST_INPUT_2N_6) + + target = np.ones((2, len(model.output_vars), p.TEST_INPUT_2N_6.shape[1])) + + model_controlled = oc_ww.OcWw( + model, + target, + ) + + model_controlled.optimize(1) + xs = model_controlled.get_xs() + self.assertTrue(xs.shape == target.shape) + + +if __name__ == "__main__": + unittest.main()