Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimal control of Wong Wang model #257

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
84 changes: 71 additions & 13 deletions neurolib/control/optimal_control/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def solve_adjoint(
dxdoth,
state_vars,
output_vars,
model_name=None,
):
"""Backwards integration of the adjoint state.

Expand Down Expand Up @@ -120,9 +121,13 @@ def solve_adjoint(
if sv == ov:
fx_fullstate[:, sv_ind, :] = fx[:, ov_ind, :]

krange = range(state_dim[1])
if model_name == "wongwang":
krange = range(state_dim[1] - 1, -1, -1)

for t in range(T - 2, -1, -1): # backwards iteration including 0th index
for n in range(N): # iterate through nodes
for k in range(state_dim[1]):
for k in krange:
if dxdoth[n, k, k] == 0:
res = fx_fullstate[n, k, t + 1]
res += adjoint_input(hx_list, del_list, t, T, state_dim[1], adjoint_state, n, k)
Expand Down Expand Up @@ -367,7 +372,11 @@ def __init__(
if k in weights.keys():
defaultweights[k] = weights[k]
else:
print("Weight ", k, " not in provided weight dictionary. Use default value.")
print(
"Weight ",
k,
" not in provided weight dictionary. Use default value.",
)

self.weights = defaultweights

Expand Down Expand Up @@ -471,14 +480,20 @@ def __init__(
for v, iv in enumerate(self.model.input_vars):
control[:, v, :] = self.model.params[iv]

self.control = control.copy()
self.control = control.copy()
self.check_params()

self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control,
0.0,
np.zeros(control.shape),
self.maximum_control_strength,
)

self.model_params = self.get_model_params()
self.model_params = self.get_model_params()

def check_params(self):
"""Checks a subset of parameters and throws an error if a wrong dimension is found."""
Expand Down Expand Up @@ -539,7 +554,7 @@ def get_xs(self):

return xs

def get_xs_delay(self):
def get_xs_delayed(self):
"""Extract the complete state of the delayed dynamical system."""
maxdel = self.model.getMaxDelay()
if maxdel == 0:
Expand Down Expand Up @@ -683,6 +698,7 @@ def solve_adjoint(self):
dxdoth,
numba.typed.List(self.model.state_vars),
numba.typed.List(self.model.output_vars),
self.model.name,
)

def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient):
Expand Down Expand Up @@ -721,7 +737,13 @@ def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient)

# Inplace updating of models control bc. forward-sim relies on models parameters.
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
step,
cost_gradient,
self.maximum_control_strength,
)
self.update_input()

Expand All @@ -737,7 +759,13 @@ def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient)
# cost.
step = 0.0 # For later analysis only.
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, 0.0, np.zeros(control0.shape), self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
0.0,
np.zeros(control0.shape),
self.maximum_control_strength,
)
self.update_input()

Expand Down Expand Up @@ -782,7 +810,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient):

# Inplace updating of models control bc. forward-sim relies on models parameters
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
step,
cost_gradient,
self.maximum_control_strength,
)
self.update_input()

Expand All @@ -792,7 +826,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient):
logging.info("Increasing step encountered NAN.")
step /= factor_up # Undo the last step update by inverse operation.
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
step,
cost_gradient,
self.maximum_control_strength,
)
self.update_input()
break
Expand All @@ -807,7 +847,13 @@ def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient):
# then) and exit.
step /= factor_up # Undo the last step update by inverse operation.
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
step,
cost_gradient,
self.maximum_control_strength,
)
self.update_input()
break
Expand Down Expand Up @@ -850,7 +896,13 @@ def step_size(self, cost_gradient):
while True: # Reduce the step size, if numerical instability occurs in the forward-simulation.
# inplace updating of models control bc. forward-sim relies on models parameters
self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
self.N,
self.dim_in,
self.T,
control0,
step,
cost_gradient,
self.maximum_control_strength,
)
self.update_input()

Expand Down Expand Up @@ -909,7 +961,13 @@ def optimize(self, n_max_iterations):
self.control_interval = convert_interval(self.control_interval, self.T)

self.control = update_control_with_limit(
self.N, self.dim_in, self.T, self.control, 0.0, np.zeros(self.control.shape), self.maximum_control_strength
self.N,
self.dim_in,
self.T,
self.control,
0.0,
np.zeros(self.control.shape),
self.maximum_control_strength,
) # To avoid issues in repeated executions.
self.control = limit_control_to_interval(self.N, self.dim_in, self.T, self.control, self.control_interval)

Expand Down
6 changes: 3 additions & 3 deletions neurolib/control/optimal_control/oc_wc/oc_wc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def Duh(self):
"""

xs = self.get_xs()
xsd = self.get_xs_delay()
xsd = self.get_xs_delayed()

return Duh(
self.model_params,
Expand Down Expand Up @@ -116,7 +116,7 @@ def compute_hx(self):
self.dim_vars,
self.T,
self.get_xs(),
self.get_xs_delay(),
self.get_xs_delayed(),
self.control,
self.state_vars_dict,
)
Expand All @@ -140,7 +140,7 @@ def compute_hx_nw(self):
self.T,
xs[:, self.state_vars_dict["exc"], :],
xs[:, self.state_vars_dict["inh"], :],
self.get_xs_delay()[:, self.state_vars_dict["exc"], :],
self.get_xs_delayed()[:, self.state_vars_dict["exc"], :],
self.control[:, self.state_vars_dict["exc"], :],
self.state_vars_dict,
)
1 change: 1 addition & 0 deletions neurolib/control/optimal_control/oc_ww/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .oc_ww import OcWw
173 changes: 173 additions & 0 deletions neurolib/control/optimal_control/oc_ww/oc_ww.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import numba

from neurolib.control.optimal_control.oc import OC
from neurolib.models.ww.timeIntegration import (
compute_hx,
compute_hx_min1,
compute_hx_nw,
Duh,
Dxdoth,
)


class OcWw(OC):
"""Class for optimal control specific to neurolib's implementation of the two-population Wong-Wang model
("WWmodel").

:param model: Instance of Wong-Wang model (can describe a single Wong-Wang node or a network of coupled
Wong-Wang nodes.
:type model: neurolib.models.ww.model.WWModel
"""

def __init__(
self,
model,
target,
weights=None,
print_array=[],
cost_interval=(None, None),
cost_matrix=None,
control_matrix=None,
M=1,
M_validation=0,
validate_per_step=False,
):
super().__init__(
model,
target,
weights=weights,
print_array=print_array,
cost_interval=cost_interval,
cost_matrix=cost_matrix,
control_matrix=control_matrix,
M=M,
M_validation=M_validation,
validate_per_step=validate_per_step,
)

assert self.model.name == "wongwang"

def compute_dxdoth(self):
"""Derivative of systems dynamics wrt. change of systems variables."""
return Dxdoth(self.N, self.dim_vars)

def get_model_params(self):
"""Model params as an ordered tuple.

:rtype: tuple
"""
return (
self.model.params.a_exc,
self.model.params.b_exc,
self.model.params.d_exc,
self.model.params.tau_exc,
self.model.params.gamma_exc,
self.model.params.w_exc,
self.model.params.exc_current_baseline,
self.model.params.a_inh,
self.model.params.b_inh,
self.model.params.d_inh,
self.model.params.tau_inh,
self.model.params.w_inh,
self.model.params.inh_current_baseline,
self.model.params.J_NMDA,
self.model.params.J_I,
self.model.params.w_ee,
)

def Duh(self):
"""Jacobian of systems dynamics wrt. external control input.

:return: N x 4 x 4 x T Jacobians.
:rtype: np.ndarray
"""

xs = self.get_xs()
xsd = self.get_xs_delayed()

return Duh(
self.model_params,
self.N,
self.dim_in,
self.dim_vars,
self.T,
self.control[:, self.state_vars_dict["r_exc"], :],
self.control[:, self.state_vars_dict["r_inh"], :],
xs[:, self.state_vars_dict["se"], :],
xs[:, self.state_vars_dict["si"], :],
self.model.params.K_gl,
self.model.params.Cmat,
self.Dmat_ndt,
xsd[:, self.state_vars_dict["se"], :],
self.state_vars_dict,
)

def compute_hx_list(self):
"""List of Jacobians without and with time delays (e.g. in the ALN model) and list of respective time step delays as integers (0 for undelayed)

:return: List of Jacobian matrices, list of time step delays
: rtype: List of np.ndarray, List of integers
"""
hx = self.compute_hx()
hx_min1 = self.compute_hx_min1()
return numba.typed.List([hx, hx_min1]), numba.typed.List([0, -1])

def compute_hx(self):
"""Jacobians of WwModel wrt. all variables.

:return: N x T x 6 x 6 Jacobians.
:rtype: np.ndarray
"""
return compute_hx(
self.model_params,
self.model.params.K_gl,
self.model.Cmat,
self.Dmat_ndt,
self.N,
self.dim_vars,
self.T,
self.get_xs(),
self.get_xs_delayed(),
self.control,
self.state_vars_dict,
)

def compute_hx_min1(self):
"""Jacobians of WWModel dse/dre and dsi/dri.
Dependency is in same time step, so shift by -1 in time is required for OC computation.

:return: N x T x 6 x 6 Jacobians.
:rtype: np.ndarray
"""
return compute_hx_min1(
self.model_params,
self.N,
self.dim_vars,
self.T,
self.get_xs(),
self.state_vars_dict,
)

def compute_hx_nw(self):
"""Jacobians for each time step for the network coupling.

:return: N x N x T x (4x4) array
:rtype: np.ndarray
"""

xs = self.get_xs()

return compute_hx_nw(
self.model_params,
self.model.params.K_gl,
self.model.Cmat,
self.Dmat_ndt,
self.N,
self.dim_vars,
self.T,
xs[:, self.state_vars_dict["se"], :],
xs[:, self.state_vars_dict["se"], :],
self.get_xs_delayed()[:, self.state_vars_dict["se"], :],
self.control[:, self.state_vars_dict["r_exc"], :],
self.state_vars_dict,
)
Loading
Loading