Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Add get JIT weight matrix methods(Uniform & Normal) for brainpy.dnn.linear #673

Merged
merged 12 commits into from
Jun 1, 2024
43 changes: 32 additions & 11 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,15 +980,36 @@ def __init__(
self.sharding = sharding


class JitFPLinear(Layer):
class JitLinear(Layer):
def get_conn_matrix(self):
return bm.jitconn.get_conn_matrix(self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)
pass


class JitFPHomoLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPUniformLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPNormalLayer(JitLinear):
def get_conn_matrix(self):
return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed,
shape=(self.num_out, self.num_in),
transpose=self.transpose,
outdim_parallel=not self.atomic)


class JitFPHomoLinear(JitFPLinear):
class JitFPHomoLinear(JitFPHomoLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1067,7 +1088,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPUniformLinear(JitFPLinear):
class JitFPUniformLinear(JitFPUniformLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1147,7 +1168,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class JitFPNormalLinear(JitFPLinear):
class JitFPNormalLinear(JitFPNormalLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1227,7 +1248,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPHomoLinear(JitFPLinear):
class EventJitFPHomoLinear(JitFPHomoLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1306,7 +1327,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPUniformLinear(JitFPLinear):
class EventJitFPUniformLinear(JitFPUniformLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down Expand Up @@ -1386,7 +1407,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)


class EventJitFPNormalLinear(JitFPLinear):
class EventJitFPNormalLinear(JitFPNormalLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.

It performs the computation of:
Expand Down
29 changes: 26 additions & 3 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
# print(conn_matrix.shape)
# self.assertTrue(conn_matrix.shape == (200, 100))
bm.clear_buffer_memory()

@parameterized.product(
Expand All @@ -155,6 +160,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()

@parameterized.product(
Expand All @@ -169,6 +177,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()

@parameterized.product(
Expand All @@ -179,11 +190,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
def test_EventJitFPHomoLinear(self, prob, weight, shape):
bm.random.seed()
f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
x = bm.random.random(shape + (100,)) < 0.1
y = f(x)
self.assertTrue(y.shape == shape + (200,))

y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()

@parameterized.product(
Expand All @@ -195,11 +210,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
bm.random.seed()
f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
x = bm.random.random(shape + (100,)) < 0.1
y = f(x)
self.assertTrue(y.shape == shape + (200,))

y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()

@parameterized.product(
Expand All @@ -211,11 +230,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
bm.random.seed()
f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
x = bm.random.random(shape + (100,)) < 0.1
y = f(x)
self.assertTrue(y.shape == shape + (200,))

y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))

conn_matrix = f.get_conn_matrix()
self.assertTrue(bm.allclose(y, x @ conn_matrix.T))
bm.clear_buffer_memory()


Expand Down
Loading
Loading