diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index a750142a..8e09f95b 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,15 +980,36 @@ def __init__( self.sharding = sharding -class JitFPLinear(Layer): +class JitLinear(Layer): def get_conn_matrix(self): - return bm.jitconn.get_conn_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + pass + + +class JitFPHomoLayer(JitLinear): + def get_conn_matrix(self): + return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPUniformLayer(JitLinear): + def get_conn_matrix(self): + return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPNormalLayer(JitLinear): + def get_conn_matrix(self): + return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) -class JitFPHomoLinear(JitFPLinear): +class JitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1067,7 +1088,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPUniformLinear(JitFPLinear): +class JitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1147,7 +1168,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPLinear): +class JitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1227,7 +1248,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPHomoLinear(JitFPLinear): +class EventJitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1306,7 +1327,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPLinear): +class EventJitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1386,7 +1407,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPNormalLinear(JitFPLinear): +class EventJitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 6cc44538..9f011cb8 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -141,6 +141,11 @@ def test_JitFPHomoLinear(self, prob, weight, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # print(conn_matrix.shape) + # self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -155,6 +160,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -169,6 +177,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -179,11 +190,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): def test_EventJitFPHomoLinear(self, prob, weight, shape): bm.random.seed() f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -195,11 +210,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): bm.random.seed() f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -211,11 +230,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.random.seed() f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 8a7ba398..296a7994 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- - - +import numbers from typing import Tuple, Optional, Union import jax import numpy as np -from jax import numpy as jnp -from jax.interpreters import ad - from brainpy._src.dependency_check import import_taichi +from brainpy._src.math import defaults from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError +from jax import numpy as jnp +from jax.interpreters import ad ti = import_taichi(error_if_not_found=False) @@ -20,7 +19,9 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_conn_matrix', + 'get_homo_weight_matrix', + 'get_uniform_weight_matrix', + 'get_normal_weight_matrix' ] @@ -258,7 +259,8 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_conn_matrix( +def get_homo_weight_matrix( + weight: float, conn_prob: float, seed: Optional[int] = None, *, @@ -288,17 +290,135 @@ def get_conn_matrix( out: Array, ndarray The connection matrix :math:`M`. """ + if isinstance(weight, numbers.Number): + weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defaults.float_)) + else: + raise ValueError(f'weight must be a number type, but get {type(weight)}') + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r *= weight + if transpose: + return r.transpose() + else: + return r + + +def get_uniform_weight_matrix( + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a uniform distribution for its value. + + Parameters + ---------- + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r + + +def get_normal_weight_matrix( + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a normal distribution for its value. + + Parameters + ---------- + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) if seed is None: with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_get_connect_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r def raw_mv_prob_homo( @@ -386,7 +506,7 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) -def raw_get_connect_matrix( +def raw_get_homo_weight_matrix( conn_len: jax.Array, seed: jax.Array, *, @@ -394,7 +514,6 @@ def raw_get_connect_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - out_shape = shape if not transpose else (shape[1], shape[0]) if outdim_parallel: prim = _get_connect_matrix_outdim_parallel_p else: @@ -402,7 +521,57 @@ def raw_get_connect_matrix( return prim(conn_len, seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)], + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.int32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_uniform_weight_matrix( + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_uniform_weight_matrix_outdim_parallel_p + else: + prim = _get_uniform_weight_matrix_p + + return prim(w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_normal_weight_matrix( + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_normal_weight_matrix_outdim_parallel_p + else: + prim = _get_normal_weight_matrix_p + + return prim(w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -1029,3 +1198,115 @@ def _get_connect_matrix_outdim_parallel( _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, gpu_kernel=_get_connect_matrix_outdim_parallel) + + + @ti.kernel + def _get_uniform_weight_matrix( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] = raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_uniform_weight_matrix_outdim_parallel( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] = raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_uniform_weight_matrix_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix, + gpu_kernel=_get_uniform_weight_matrix) + _get_uniform_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix_outdim_parallel, + gpu_kernel=_get_uniform_weight_matrix_outdim_parallel) + + + @ti.kernel + def _get_normal_weight_matrix( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] = raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_normal_weight_matrix_outdim_parallel( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] = raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_normal_weight_matrix_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix, + gpu_kernel=_get_normal_weight_matrix) + _get_normal_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix_outdim_parallel, + gpu_kernel=_get_normal_weight_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py deleted file mode 100644 index a58be6e8..00000000 --- a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py +++ /dev/null @@ -1,46 +0,0 @@ -# -*- coding: utf-8 -*- -import jax.numpy as jnp -import pytest -from absl.testing import parameterized - -import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - -import platform - -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -shapes = [(100, 200), (1000, 10)] - - -# SEED = 1234 - -class TestGetConnectMatrix(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestGetConnectMatrix, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.1], - ) - def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): - print( - f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - shape = (shape[1], shape[0]) if transpose else shape - assert conn.shape == shape - assert conn.dtype == jnp.bool_ - # sum all true values - # assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1]) - print( - f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - # print(f'conn: {conn}') diff --git a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py new file mode 100644 index 00000000..9f10505a --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +import platform + +force_test = False # turn on to force test on windows locally +# if platform.system() == 'Windows' and not force_test: +# pytest.skip('skip windows', allow_module_level=True) + +shapes = [ + (2, 2), + # (1000, 10) +] + +SEED = 1234 + + +class TestGetHomoWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetHomoWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_get_homo_weight_matrix(self, transpose, outdim_parallel, shape, prob): + homo_data = 1. + print( + f'test_get_homo_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') + conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + print(conn.shape) + assert conn.shape == shape + # assert conn.dtype == jnp.float_ + # sum all true values + print( + f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # compare with jitconn op + + print(f'conn: {conn}') + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = vector @ conn if transpose else conn @ vector + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetUniformWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetUniformWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.5], + w_low=[0.1], + w_high=[0.9], + ) + def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_low, w_high): + print( + f'test_get_uniform_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_low={w_low}, w_high={w_high}') + weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, SEED, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # compare with jitconn op + + print(f'weight: {weight}') + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + # r2 = weight @ events if transpose else events @ weight + r2 = events @ weight if transpose else weight @ events + print(f'r1: {r1}\n r2: {r2}') + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetNormalWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetNormalWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.0], + w_sigma=[1.0], + ) + def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_mu, w_sigma): + print( + f'test_get_normal_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_mu={w_mu}, w_sigma={w_sigma}') + weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, SEED, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # compare with jitconn op + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(vector, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = vector @ weight if transpose else weight @ vector + print(f'r1: {r1}\n r2: {r2}') + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index e1c4eafb..3c99b7de 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,6 +7,8 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_conn_matrix as get_conn_matrix, + get_homo_weight_matrix as get_homo_weight_matrix, + get_uniform_weight_matrix as get_uniform_weight_matrix, + get_normal_weight_matrix as get_normal_weight_matrix, )