Skip to content

Commit

Permalink
fixed phaseshifting env observation space
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed May 29, 2024
1 parent 6109825 commit 4ebd2aa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 47 deletions.
59 changes: 16 additions & 43 deletions examples/example-6.2-rl-phaseshifting.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def __init__(

self.observation_space = spaces.Dict(
{
"exc": spaces.Box(0, 1, shape=(1,), dtype=float),
"inh": spaces.Box(0, 1, shape=(1,), dtype=float),
"exc": spaces.Box(0, 1, shape=(self.period_n,), dtype=float),
"inh": spaces.Box(0, 1, shape=(self.period_n,), dtype=float),
"target_phase": spaces.Box(0, 2 * np.pi, shape=(1,), dtype=float),
}
)

Expand All @@ -80,17 +81,26 @@ def get_target(self):
p_list = []
for i in range(3, len(peaks)):
p_list.append(peaks[i] - peaks[i - 1])

self.period_n = np.ceil(np.mean(p_list)).astype(int)

period = np.mean(p_list) * self.dt
self.period = period

raw = np.stack((wc.exc, wc.inh), axis=1)[0]
index = np.round(self.target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
self.target_time = wc.t[index : index + target.shape[1]]
self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi

return target

def _get_obs(self):
return {"exc": self.model.exc[0], "inh": self.model.inh[0]}
return {
"exc": self.exc_history,
"inh": self.inh_history,
"target_phase": np.array([self.target_phase[self.t_i]]),
}

def _get_info(self):
return {"t": self.t_i * self.dt}
Expand All @@ -101,16 +111,25 @@ def reset(self, seed=None, options=None):
self.model.clearModelState()

self.model.params = self.params.copy()

# init history window
self.model.params["duration"] = self.period_n * self.dt
self.model.exc = np.array([[self.x_init]])
self.model.inh = np.array([[self.y_init]])
self.model.run()
self.exc_history = self.model.exc[0]
self.inh_history = self.model.inh[0]

# reset duration parameter
self.model.params = self.params.copy()

observation = self._get_obs()
info = self._get_info()
return observation, info

def _loss(self, obs, action):
control_loss = np.sqrt(
(self.target[0, self.t_i] - obs["exc"].item()) ** 2 + (self.target[1, self.t_i] - obs["inh"].item()) ** 2
(self.target[0, self.t_i] - obs["exc"][-1]) ** 2 + (self.target[1, self.t_i] - obs["inh"][-1]) ** 2
)
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
return control_loss + control_strength_loss
Expand All @@ -122,6 +141,10 @@ def step(self, action):
self.model.params["inh_ext"] = np.array([inh])
self.model.run(continue_run=True)

# shift observation window
self.exc_history = np.concatenate((self.exc_history[-self.period_n + 1 :], self.model.exc[0]))
self.inh_history = np.concatenate((self.inh_history[-self.period_n + 1 :], self.model.inh[0]))

observation = self._get_obs()

reward = -self._loss(observation, action)
Expand Down

0 comments on commit 4ebd2aa

Please sign in to comment.