From ba08453b09f6a90954218fa0a02272d92ce5192e Mon Sep 17 00:00:00 2001 From: ago109 Date: Fri, 26 Jul 2024 13:33:07 -0400 Subject: [PATCH] updates to if/lif --- ngclearn/components/neurons/spiking/IFCell.py | 16 ++++++++-------- ngclearn/components/neurons/spiking/LIFCell.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py index 2c9acf52..68e51bdf 100755 --- a/ngclearn/components/neurons/spiking/IFCell.py +++ b/ngclearn/components/neurons/spiking/IFCell.py @@ -213,21 +213,21 @@ def reset(self, j, v, s, rfr, tols, surrogate): def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) + else jnp.asarray([[self.tau_m * 1.]])) thr = (self.thr if isinstance(self.thr, float) - else jnp.ones([[self.thr]])) + else jnp.asarray([[self.thr * 1.]])) v_rest = (self.v_rest if isinstance(self.v_rest, float) - else jnp.ones([[self.v_rest]])) + else jnp.asarray([[self.v_rest * 1.]])) v_reset = (self.v_reset if isinstance(self.v_reset, float) - else jnp.ones([[self.v_reset]])) + else jnp.asarray([[self.v_reset * 1.]])) v_decay = (self.v_decay if isinstance(self.v_decay, float) - else jnp.ones([[self.v_decay]])) + else jnp.asarray([[self.v_decay * 1.]])) resist_m = (self.resist_m if isinstance(self.resist_m, float) - else jnp.ones([[self.resist_m]])) + else jnp.asarray([[self.resist_m * 1.]])) tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) - else jnp.ones([[self.tau_theta]])) + else jnp.asarray([[self.tau_theta * 1.]])) theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) - else jnp.ones([[self.theta_plus]])) + else jnp.asarray([[self.theta_plus * 1.]])) file_name = directory + "/" + self.name + ".npz" jnp.savez(file_name, diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py index ead51f02..d8a0d763 100644 --- a/ngclearn/components/neurons/spiking/LIFCell.py +++ b/ngclearn/components/neurons/spiking/LIFCell.py @@ -279,21 +279,21 @@ def reset(self, j, v, s, s_raw, rfr, tols, surrogate): def save(self, directory, **kwargs): ## do a protected save of constants, depending on whether they are floats or arrays tau_m = (self.tau_m if isinstance(self.tau_m, float) - else jnp.ones([[self.tau_m]])) + else jnp.asarray([[self.tau_m * 1.]])) thr = (self.thr if isinstance(self.thr, float) - else jnp.ones([[self.thr]])) + else jnp.asarray([[self.thr * 1.]])) v_rest = (self.v_rest if isinstance(self.v_rest, float) - else jnp.ones([[self.v_rest]])) + else jnp.asarray([[self.v_rest * 1.]])) v_reset = (self.v_reset if isinstance(self.v_reset, float) - else jnp.ones([[self.v_reset]])) + else jnp.asarray([[self.v_reset * 1.]])) v_decay = (self.v_decay if isinstance(self.v_decay, float) - else jnp.ones([[self.v_decay]])) + else jnp.asarray([[self.v_decay * 1.]])) resist_m = (self.resist_m if isinstance(self.resist_m, float) - else jnp.ones([[self.resist_m]])) + else jnp.asarray([[self.resist_m * 1.]])) tau_theta = (self.tau_theta if isinstance(self.tau_theta, float) - else jnp.ones([[self.tau_theta]])) + else jnp.asarray([[self.tau_theta * 1.]])) theta_plus = (self.theta_plus if isinstance(self.theta_plus, float) - else jnp.ones([[self.theta_plus]])) + else jnp.asarray([[self.theta_plus * 1.]])) file_name = directory + "/" + self.name + ".npz" jnp.savez(file_name,