diff --git a/neurolib/control/reinforcement_learning/environments/state_switching.py b/neurolib/control/reinforcement_learning/environments/state_switching.py index b8c144ce..a438464c 100644 --- a/neurolib/control/reinforcement_learning/environments/state_switching.py +++ b/neurolib/control/reinforcement_learning/environments/state_switching.py @@ -105,7 +105,7 @@ def reset(self, seed=None, options=None): def _loss(self, obs, action): control_loss = abs(self.targetstate[0] - obs["exc"].item()) + abs(self.targetstate[1] - obs["inh"].item()) control_strength_loss = np.abs(action).sum() * self.l1_control_strength_loss_scale - control_strength_loss += np.sqrt(np.sum(action**2)) * self.l2_control_strength_loss_scale + control_strength_loss += np.sqrt(np.sum(np.square(action))) * self.l2_control_strength_loss_scale return control_loss + control_strength_loss def step(self, action):