Skip to content

Commit

Permalink
updates to if/lif
Browse files Browse the repository at this point in the history
  • Loading branch information
ago109 committed Jul 26, 2024
1 parent c894b8a commit ba08453
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions ngclearn/components/neurons/spiking/IFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,21 +213,21 @@ def reset(self, j, v, s, rfr, tols, surrogate):
def save(self, directory, **kwargs):
## do a protected save of constants, depending on whether they are floats or arrays
tau_m = (self.tau_m if isinstance(self.tau_m, float)
else jnp.ones([[self.tau_m]]))
else jnp.asarray([[self.tau_m * 1.]]))
thr = (self.thr if isinstance(self.thr, float)
else jnp.ones([[self.thr]]))
else jnp.asarray([[self.thr * 1.]]))
v_rest = (self.v_rest if isinstance(self.v_rest, float)
else jnp.ones([[self.v_rest]]))
else jnp.asarray([[self.v_rest * 1.]]))
v_reset = (self.v_reset if isinstance(self.v_reset, float)
else jnp.ones([[self.v_reset]]))
else jnp.asarray([[self.v_reset * 1.]]))
v_decay = (self.v_decay if isinstance(self.v_decay, float)
else jnp.ones([[self.v_decay]]))
else jnp.asarray([[self.v_decay * 1.]]))
resist_m = (self.resist_m if isinstance(self.resist_m, float)
else jnp.ones([[self.resist_m]]))
else jnp.asarray([[self.resist_m * 1.]]))
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
else jnp.ones([[self.tau_theta]]))
else jnp.asarray([[self.tau_theta * 1.]]))
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
else jnp.ones([[self.theta_plus]]))
else jnp.asarray([[self.theta_plus * 1.]]))

file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name,
Expand Down
16 changes: 8 additions & 8 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,21 +279,21 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
def save(self, directory, **kwargs):
## do a protected save of constants, depending on whether they are floats or arrays
tau_m = (self.tau_m if isinstance(self.tau_m, float)
else jnp.ones([[self.tau_m]]))
else jnp.asarray([[self.tau_m * 1.]]))
thr = (self.thr if isinstance(self.thr, float)
else jnp.ones([[self.thr]]))
else jnp.asarray([[self.thr * 1.]]))
v_rest = (self.v_rest if isinstance(self.v_rest, float)
else jnp.ones([[self.v_rest]]))
else jnp.asarray([[self.v_rest * 1.]]))
v_reset = (self.v_reset if isinstance(self.v_reset, float)
else jnp.ones([[self.v_reset]]))
else jnp.asarray([[self.v_reset * 1.]]))
v_decay = (self.v_decay if isinstance(self.v_decay, float)
else jnp.ones([[self.v_decay]]))
else jnp.asarray([[self.v_decay * 1.]]))
resist_m = (self.resist_m if isinstance(self.resist_m, float)
else jnp.ones([[self.resist_m]]))
else jnp.asarray([[self.resist_m * 1.]]))
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
else jnp.ones([[self.tau_theta]]))
else jnp.asarray([[self.tau_theta * 1.]]))
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
else jnp.ones([[self.theta_plus]]))
else jnp.asarray([[self.theta_plus * 1.]]))

file_name = directory + "/" + self.name + ".npz"
jnp.savez(file_name,
Expand Down

0 comments on commit ba08453

Please sign in to comment.