Skip to content

Commit

Permalink
RL state switching env and example
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed May 13, 2024
1 parent 6be0d37 commit e02094a
Show file tree
Hide file tree
Showing 4 changed files with 462 additions and 0 deletions.
332 changes: 332 additions & 0 deletions examples/example-6.1-rl-stateswitching.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions neurolib/control/reinforcement_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from gymnasium.envs.registration import register

register(
id="StateSwitching-v0",
entry_point="neurolib.control.reinforcement_learning.environments.state_switching:StateSwitchingEnv",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from neurolib.control.reinforcement_learning.environments.state_switching import StateSwitchingEnv
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from neurolib.utils.stimulus import ZeroInput

import numpy as np

import gymnasium as gym
from gymnasium import spaces

from neurolib.models.wc import WCModel


class StateSwitchingEnv(gym.Env):

def __init__(
self,
duration=200,
dt=0.1,
target="up",
exc_ext_baseline=2.9,
inh_ext_baseline=3.3,
control_strength_loss_scale=0.005,
):
self.exc_ext_baseline = exc_ext_baseline
self.inh_ext_baseline = inh_ext_baseline
self.compute_up_and_down_states()

self.duration = duration
self.dt = dt
self.target = target
self.control_strength_loss_scale = control_strength_loss_scale

assert self.target in ("up", "down")
if self.target == "up":
self.targetstate = self.upstate
self.initstate = self.downstate
elif self.target == "down":
self.targetstate = self.downstate
self.initstate = self.upstate

self.model = WCModel()
self.model.params["dt"] = self.dt
self.model.params["duration"] = self.dt # one step at a time
self.model.params["exc_init"] = np.array([[self.initstate[0]]])
self.model.params["inh_init"] = np.array([[self.initstate[1]]])
self.model.params["exc_ext_baseline"] = self.exc_ext_baseline
self.model.params["inh_ext_baseline"] = self.inh_ext_baseline

self.n_steps = round(self.duration / self.dt)

self.observation_space = spaces.Dict(
{
"exc": spaces.Box(0, 1, shape=(1,), dtype=float),
"inh": spaces.Box(0, 1, shape=(1,), dtype=float),
}
)

self.action_space = spaces.Tuple(
(
spaces.Box(-5, 5, shape=(1,), dtype=float), # exc
spaces.Box(-5, 5, shape=(1,), dtype=float), # inh
)
)

def compute_up_and_down_states(self):
model = WCModel()

dt = model.params["dt"]
duration = 500
model.params["duration"] = duration

zero_input = ZeroInput().generate_input(duration=duration + dt, dt=dt)
bi_control = zero_input.copy()
bi_control[0, :500] = -5.0
bi_control[0, 2500:3000] = +5.0

model.params["exc_ext_baseline"] = self.exc_ext_baseline
model.params["inh_ext_baseline"] = self.inh_ext_baseline
model.params["exc_ext"] = bi_control
model.params["inh_ext"] = zero_input
model.run()
self.downstate = [model.exc[0, 2000], model.inh[0, 2000]]
self.upstate = [model.exc[0, -1], model.inh[0, -1]]

def _get_obs(self):
return {"exc": self.model.exc[0], "inh": self.model.inh[0]}

def _get_info(self):
return {"t": self.t_i * self.dt}

def reset(self, seed=None, options=None):
super().reset(seed=seed, options=options)
self.t_i = 0
self.model.clearModelState()

self.model.params["exc_init"] = np.array([[self.initstate[0]]])
self.model.params["inh_init"] = np.array([[self.initstate[1]]])
self.model.exc = np.array([[self.initstate[0]]])
self.model.inh = np.array([[self.initstate[1]]])

observation = self._get_obs()
info = self._get_info()
return observation, info

def _loss(self, obs, action):
accuracy_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
# exc_ext, inh_ext = action
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
return accuracy_loss + control_strength_loss

def step(self, action):
exc, inh = action
self.model.params["exc_ext"] = np.array([exc])
self.model.params["inh_ext"] = np.array([inh])
self.model.run(continue_run=True)

observation = self._get_obs()

reward = -self._loss(observation, action)

self.t_i += 1
terminated = self.t_i >= self.n_steps
info = self._get_info()

return observation, reward, terminated, False, info

0 comments on commit e02094a

Please sign in to comment.