Skip to content

Commit

Permalink
added l2 control strength loss, and random target shift in phase shif…
Browse files Browse the repository at this point in the history
…ting env
  • Loading branch information
1b15 committed Jun 12, 2024
1 parent 4ebd2aa commit 5d0fb52
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 12 deletions.
129 changes: 124 additions & 5 deletions examples/example-6.2-rl-phaseshifting.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
self,
duration=300,
dt=0.1,
random_target_shift=True,
target_shift=1 * np.pi,
exc_ext_baseline=2.8,
inh_ext_baseline=1.2,
Expand All @@ -24,17 +25,20 @@ def __init__(
c_inhexc=16,
c_excinh=10,
c_inhinh=1,
control_strength_loss_scale=0.005,
l1_control_strength_loss_scale=0.01,
l2_control_strength_loss_scale=0.01,
):
self.exc_ext_baseline = exc_ext_baseline
self.inh_ext_baseline = inh_ext_baseline

self.duration = duration
self.dt = dt
self.random_target_shift = random_target_shift
self.target_shift = target_shift
self.x_init = x_init
self.y_init = y_init
self.control_strength_loss_scale = control_strength_loss_scale
self.l1_control_strength_loss_scale = l1_control_strength_loss_scale
self.l2_control_strength_loss_scale = l2_control_strength_loss_scale

assert 0 < self.target_shift < 2 * np.pi

Expand Down Expand Up @@ -88,7 +92,11 @@ def get_target(self):
self.period = period

raw = np.stack((wc.exc, wc.inh), axis=1)[0]
index = np.round(self.target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
if self.random_target_shift:
target_shift = np.random.random() * 2 * np.pi
else:
target_shift = self.target_shift
index = np.round(target_shift * period / (2.0 * np.pi) / self.dt).astype(int)
target = raw[:, index : index + np.round(1 + self.duration / self.dt, 1).astype(int)]
self.target_time = wc.t[index : index + target.shape[1]]
self.target_phase = (self.target_time % self.period) / self.period * 2 * np.pi
Expand Down Expand Up @@ -131,7 +139,9 @@ def _loss(self, obs, action):
control_loss = np.sqrt(
(self.target[0, self.t_i] - obs["exc"][-1]) ** 2 + (self.target[1, self.t_i] - obs["inh"][-1]) ** 2
)
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
control_strength_loss = np.abs(action).sum() * self.l1_control_strength_loss_scale
control_strength_loss += np.sqrt(np.sum(np.square(action))) * self.l2_control_strength_loss_scale

return control_loss + control_strength_loss

def step(self, action):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
target="up",
exc_ext_baseline=2.9,
inh_ext_baseline=3.3,
control_strength_loss_scale=0.005,
l1_control_strength_loss_scale=0.005,
l2_control_strength_loss_scale=0.005,
):
self.exc_ext_baseline = exc_ext_baseline
self.inh_ext_baseline = inh_ext_baseline
Expand All @@ -26,7 +27,8 @@ def __init__(
self.duration = duration
self.dt = dt
self.target = target
self.control_strength_loss_scale = control_strength_loss_scale
self.l1_control_strength_loss_scale = l1_control_strength_loss_scale
self.l2_control_strength_loss_scale = l2_control_strength_loss_scale

assert self.target in ("up", "down")
if self.target == "up":
Expand Down Expand Up @@ -102,7 +104,8 @@ def reset(self, seed=None, options=None):

def _loss(self, obs, action):
control_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item())
control_strength_loss = np.abs(action).sum() * self.control_strength_loss_scale
control_strength_loss = np.abs(action).sum() * self.l1_control_strength_loss_scale
control_strength_loss += np.sqrt(np.sum(action**2)) * self.l2_control_strength_loss_scale
return control_loss + control_strength_loss

def step(self, action):
Expand Down

0 comments on commit 5d0fb52

Please sign in to comment.