Skip to content

Commit

Permalink
added phase shifting env
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed May 23, 2024
1 parent 899616c commit 6109825
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 3 deletions.
277 changes: 277 additions & 0 deletions examples/example-6.2-rl-phaseshifting.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions neurolib/control/reinforcement_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
id="StateSwitching-v0",
entry_point="neurolib.control.reinforcement_learning.environments.state_switching:StateSwitchingEnv",
)

register(
id="PhaseShifting-v0",
entry_point="neurolib.control.reinforcement_learning.environments.phase_shifting:PhaseShiftingEnv",
)
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from neurolib.control.reinforcement_learning.environments.state_switching import StateSwitchingEnv
from neurolib.control.reinforcement_learning.environments.phase_shifting import PhaseShiftingEnv
133 changes: 133 additions & 0 deletions neurolib/control/reinforcement_learning/environments/phase_shifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from neurolib.utils.stimulus import ZeroInput

import numpy as np
import scipy

import gymnasium as gym
from gymnasium import spaces

from neurolib.models.wc import WCModel


class PhaseShiftingEnv(gym.Env):

def __init__(
self,
duration=300,
dt=0.1,
target_shift=1 * np.pi,
exc_ext_baseline=2.8,
inh_ext_baseline=1.2,
x_init=0.04201540010391125,
y_init=0.1354067401509556,
sigma_ou=0.0,
c_inhexc=16,
c_excinh=10,
c_inhinh=1,
control_strength_loss_scale=0.005,
):
self.exc_ext_baseline = exc_ext_baseline
self.inh_ext_baseline = inh_ext_baseline

self.duration = duration
self.dt = dt
self.target_shift = target_shift
self.x_init = x_init
self.y_init = y_init
self.control_strength_loss_scale = control_strength_loss_scale

assert 0 < self.target_shift < 2 * np.pi

self.model = WCModel()
self.model.params["dt"] = self.dt
self.model.params["sigma_ou"] = sigma_ou
self.model.params["duration"] = self.dt # one step at a time
self.model.params["exc_init"] = np.array([[x_init]])
self.model.params["inh_init"] = np.array([[y_init]])
self.model.params["exc_ext_baseline"] = self.exc_ext_baseline
self.model.params["inh_ext_baseline"] = self.inh_ext_baseline

self.model.params["c_inhexc"] = c_inhexc
self.model.params["c_excinh"] = c_excinh
self.model.params["c_inhinh"] = c_inhinh
self.params = self.model.params.copy()

self.n_steps = round(self.duration / self.dt)

self.target = self.get_target()

self.observation_space = spaces.Dict(
{
"exc": spaces.Box(0, 1, shape=(1,), dtype=float),
"inh": spaces.Box(0, 1, shape=(1,), dtype=float),
}
)

self.action_space = spaces.Tuple(
(
spaces.Box(-5, 5, shape=(1,), dtype=float), # exc
spaces.Box(-5, 5, shape=(1,), dtype=float), # inh
)
)

def get_target(self):
wc = WCModel()
wc.params = self.model.params.copy()
wc.params["duration"] = self.duration + 100.0
wc.run()

peaks = scipy.signal.find_peaks(wc.exc[0, :])[0]
p_list = []
for i in range(3, len(peaks)):
p_list.append(peaks[i] - peaks[i - 1])
period = np.mean(p_list) * self.dt
self.period = period

raw = np.stack((wc.exc, wc.inh), axis=1)[0]
index = np.round(self.target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]

return target

def _get_obs(self):
return {"exc": self.model.exc[0], "inh": self.model.inh[0]}

def _get_info(self):
return {"t": self.t_i * self.dt}

def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.t_i = 0
self.model.clearModelState()

self.model.params = self.params.copy()
self.model.exc = np.array([[self.x_init]])
self.model.inh = np.array([[self.y_init]])

observation = self._get_obs()
info = self._get_info()
return observation, info

def _loss(self, obs, action):
control_loss = np.sqrt(
(self.target[0, self.t_i] - obs["exc"].item()) ** 2 + (self.target[1, self.t_i] - obs["inh"].item()) ** 2
)
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
return control_loss + control_strength_loss

def step(self, action):
assert self.action_space.contains(action)
exc, inh = action
self.model.params["exc_ext"] = np.array([exc])
self.model.params["inh_ext"] = np.array([inh])
self.model.run(continue_run=True)

observation = self._get_obs()

reward = -self._loss(observation, action)

self.t_i += 1
terminated = self.t_i >= self.n_steps
info = self._get_info()

return observation, reward, terminated, False, info
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def reset(self, seed=None, options=None):
return observation, info

def _loss(self, obs, action):
accuracy_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
# exc_ext, inh_ext = action
control_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
return accuracy_loss + control_strength_loss
return control_loss + control_strength_loss

def step(self, action):
assert self.action_space.contains(action)
exc, inh = action
self.model.params["exc_ext"] = np.array([exc])
self.model.params["inh_ext"] = np.array([inh])
Expand Down

0 comments on commit 6109825

Please sign in to comment.