From dd084a7a75fac9af5f3fcefd9e65d6cd77a668c1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 27 May 2024 12:59:31 +0800 Subject: [PATCH 01/11] [math] Add get JIT connect matrix methods for `brainpy.dnn.linear` --- brainpy/_src/dnn/linear.py | 36 ++++++ brainpy/_src/math/jitconn/matvec.py | 111 ++++++++++++++++++ .../jitconn/tests/test_get_connect_matrix.py | 46 ++++++++ brainpy/_src/math/op_register/base.py | 13 +- brainpy/math/jitconn.py | 2 + 5 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index c524fb0b..2570835f 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1058,6 +1058,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPUniformLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1138,6 +1144,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPNormalLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1218,6 +1230,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPHomoLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1297,6 +1315,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPUniformLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1377,6 +1401,12 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class EventJitFPNormalLinear(Layer): r"""Synaptic matrix multiplication with the just-in-time connectivity. @@ -1456,3 +1486,9 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) + + def get_connect_matrix(self): + return bm.jitconn.get_connect_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 00e5778f..ad168133 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -20,6 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', + 'get_connect_matrix', ] @@ -257,6 +258,49 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] +def get_connect_matrix( + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`. + + Parameters + ---------- + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The connection matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_get_connect_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + + def raw_mv_prob_homo( vector: jax.Array, weight: jax.Array, # vector with size 1 @@ -342,6 +386,28 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) +def raw_get_connect_matrix( + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + out_shape = shape if not transpose else (shape[1], shape[0]) + if outdim_parallel: + prim = _get_connect_matrix_outdim_parallel_p + else: + prim = _get_connect_matrix_p + + return prim(conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') @@ -918,3 +984,48 @@ def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): cpu_kernel=_mv_prob_normal_cpu, gpu_kernel=_mv_prob_normal_gpu ) + + + @ti.kernel + def _get_connect_matrix( + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row, i_col] = 1 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_connect_matrix_outdim_parallel( + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + out[i_row, i_col] = 1 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) + _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, + gpu_kernel=_get_connect_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py new file mode 100644 index 00000000..4b948fc2 --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +import jax.numpy as jnp +import pytest +from absl.testing import parameterized + +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +import platform + +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + +shapes = [(100, 200), (1000, 10)] + + +# SEED = 1234 + +class TestGetConnectMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetConnectMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob): + print( + f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') + conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert conn.shape == shape + assert conn.dtype == jnp.bool_ + # sum all true values + # assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1]) + print( + f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + # print(f'conn: {conn}') diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 5af5a7e3..20a48778 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -13,14 +13,16 @@ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + from .cupy_based import ( + register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) else: from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + from .cupy_based import ( + register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -116,7 +118,8 @@ def __init__( register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: - raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') + raise ValueError( + f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index a87d27d5..441b19f2 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -6,5 +6,7 @@ mv_prob_homo as mv_prob_homo, mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, + + get_connect_matrix as get_connect_matrix, ) From 11a6624d2c7b68cce6a4ece389190649ff8fa025 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 28 May 2024 11:05:28 +0800 Subject: [PATCH 02/11] Update --- brainpy/_src/dnn/linear.py | 56 +++++-------------- brainpy/_src/math/jitconn/matvec.py | 4 +- ...nect_matrix.py => test_get_conn_matrix.py} | 4 +- brainpy/math/jitconn.py | 2 +- 4 files changed, 19 insertions(+), 47 deletions(-) rename brainpy/_src/math/jitconn/tests/{test_get_connect_matrix.py => test_get_conn_matrix.py} (87%) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 2570835f..a1d31e08 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,7 +980,15 @@ def __init__( self.sharding = sharding -class JitFPHomoLinear(Layer): +class JitFPLinear(Layer): + def get_connect_matrix(self): + return bm.jitconn.get_conn_matrix(self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1058,14 +1066,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPUniformLinear(Layer): +class JitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1144,14 +1146,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class JitFPNormalLinear(Layer): +class JitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1230,14 +1226,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPHomoLinear(Layer): +class EventJitFPHomoLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1315,14 +1305,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPUniformLinear(Layer): +class EventJitFPUniformLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1401,14 +1385,8 @@ def _batch_mv(self, x): transpose=self.transpose, outdim_parallel=not self.atomic) - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - -class EventJitFPNormalLinear(Layer): +class EventJitFPNormalLinear(JitFPLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1486,9 +1464,3 @@ def _batch_mv(self, x): shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) - - def get_connect_matrix(self): - return bm.jitconn.get_connect_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index ad168133..8a7ba398 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -20,7 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_connect_matrix', + 'get_conn_matrix', ] @@ -258,7 +258,7 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_connect_matrix( +def get_conn_matrix( conn_prob: float, seed: Optional[int] = None, *, diff --git a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py similarity index 87% rename from brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py rename to brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py index 4b948fc2..a58be6e8 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_connect_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py @@ -32,10 +32,10 @@ def __init__(self, *args, platform='cpu', **kwargs): shape=shapes, prob=[0.1], ) - def test_get_connect_matrix(self, transpose, outdim_parallel, shape, prob): + def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_connect_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape assert conn.shape == shape assert conn.dtype == jnp.bool_ diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 441b19f2..e1c4eafb 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,6 +7,6 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_connect_matrix as get_connect_matrix, + get_conn_matrix as get_conn_matrix, ) From 3f46507a3f6bfd1ca7fb23418210ea1fff8d4b76 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 28 May 2024 11:23:43 +0800 Subject: [PATCH 03/11] Update linear.py --- brainpy/_src/dnn/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index a1d31e08..a750142a 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -981,7 +981,7 @@ def __init__( class JitFPLinear(Layer): - def get_connect_matrix(self): + def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, From 0e6dbc2210064a2651ddc67bd0c9e23aa498da5f Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 29 May 2024 17:05:55 +0800 Subject: [PATCH 04/11] [math] Add get JIT weight matrix methods(Uniform & Normal) for `brainpy.dnn.linear` --- brainpy/_src/dnn/linear.py | 27 +- brainpy/_src/math/jitconn/matvec.py | 284 +++++++++++++++++- .../jitconn/tests/test_get_conn_matrix.py | 131 +++++++- brainpy/math/jitconn.py | 2 + 4 files changed, 426 insertions(+), 18 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index a750142a..03398766 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,15 +980,28 @@ def __init__( self.sharding = sharding -class JitFPLinear(Layer): +class JitFPHomoLinear(Layer): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) +class JitFPUniformLinear(Layer): + def get_conn_matrix(self): + return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + +class JitFPNormalLinear(Layer): + def get_conn_matrix(self): + return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) -class JitFPHomoLinear(JitFPLinear): +class JitFPHomoLinear(JitFPHomoLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1067,7 +1080,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPUniformLinear(JitFPLinear): +class JitFPUniformLinear(JitFPUniformLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1147,7 +1160,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPLinear): +class JitFPNormalLinear(JitFPNormalLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1227,7 +1240,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPHomoLinear(JitFPLinear): +class EventJitFPHomoLinear(JitFPHomoLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1306,7 +1319,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPLinear): +class EventJitFPUniformLinear(JitFPUniformLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1386,7 +1399,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPNormalLinear(JitFPLinear): +class EventJitFPNormalLinear(JitFPNormalLinear): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 8a7ba398..e2ae2322 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -21,6 +21,8 @@ 'mv_prob_uniform', 'mv_prob_normal', 'get_conn_matrix', + 'get_uniform_weight_matrix', + 'get_normal_weight_matrix' ] @@ -297,8 +299,121 @@ def get_conn_matrix( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_get_connect_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r = raw_get_connect_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + if transpose: + return r.transpose() + else: + return r + + +def get_uniform_weight_matrix( + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a uniform distribution for its value. + + Parameters + ---------- + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r + + +def get_normal_weight_matrix( + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Get the weight matrix :math:`M` with a normal distribution for its value. + + Parameters + ---------- + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The weight matrix :math:`M`. + """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + if transpose: + return r.transpose() + else: + return r def raw_mv_prob_homo( @@ -394,7 +509,6 @@ def raw_get_connect_matrix( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - out_shape = shape if not transpose else (shape[1], shape[0]) if outdim_parallel: prim = _get_connect_matrix_outdim_parallel_p else: @@ -402,7 +516,57 @@ def raw_get_connect_matrix( return prim(conn_len, seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=jnp.int32)], + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.int32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_uniform_weight_matrix( + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_uniform_weight_matrix_outdim_parallel_p + else: + prim = _get_uniform_weight_matrix_p + + return prim(w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_get_normal_weight_matrix( + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + if outdim_parallel: + prim = _get_normal_weight_matrix_outdim_parallel_p + else: + prim = _get_normal_weight_matrix_p + + return prim(w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -1029,3 +1193,115 @@ def _get_connect_matrix_outdim_parallel( _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, gpu_kernel=_get_connect_matrix_outdim_parallel) + + + @ti.kernel + def _get_uniform_weight_matrix( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_uniform_weight_matrix_outdim_parallel( + w_low: ti.types.ndarray(), + w_high: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_low0 = w_low[0] + w_high0 = w_high[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_low0, w_high0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_uniform_weight_matrix_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix, + gpu_kernel=_get_uniform_weight_matrix) + _get_uniform_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix_outdim_parallel, + gpu_kernel=_get_uniform_weight_matrix_outdim_parallel) + + + @ti.kernel + def _get_normal_weight_matrix( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _get_normal_weight_matrix_outdim_parallel( + w_mu: ti.types.ndarray(), + w_sigma: ti.types.ndarray(), + clen: ti.types.ndarray(), + seed: ti.types.ndarray(), + out: ti.types.ndarray(), + ): + num_row = out.shape[0] + num_col = out.shape[1] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row, i_col] += raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + + + _get_normal_weight_matrix_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix, + gpu_kernel=_get_normal_weight_matrix) + _get_normal_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix_outdim_parallel, + gpu_kernel=_get_normal_weight_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py index a58be6e8..e671199d 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py @@ -12,13 +12,13 @@ import platform force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) +# if platform.system() == 'Windows' and not force_test: +# pytest.skip('skip windows', allow_module_level=True) -shapes = [(100, 200), (1000, 10)] +shapes = [(10, 20), (1000, 10)] +SEED = 1234 -# SEED = 1234 class TestGetConnectMatrix(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -35,12 +35,129 @@ def __init__(self, *args, platform='cpu', **kwargs): def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_conn_matrix(prob, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_conn_matrix(prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape + print(conn.shape) assert conn.shape == shape assert conn.dtype == jnp.bool_ # sum all true values - # assert jnp.sum(conn) == jnp.round(prob * shape[0] * shape[1]) print( f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - # print(f'conn: {conn}') + + # compare with jitconn op + homo_data = 1. + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = vector @ conn if transpose else conn @ vector + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetUniformWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetUniformWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[0.1], + w_high=[0.9], + ) + def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_low, w_high): + print( + f'test_get_uniform_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_low={w_low}, w_high={w_high}') + weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED + # compare with jitconn op + + # rng = bm.random.RandomState() + # events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + # + # r1 = bm.jitconn.mv_prob_uniform(events, + # w_low=w_low, + # w_high=w_high, + # conn_prob=prob, + # shape=shape, + # seed=SEED, + # outdim_parallel=outdim_parallel, + # transpose=transpose) + # + # r2 = events @ weight if transpose else weight @ events + # print(f'r1: {r1}\n r2: {r2}') + # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + +class TestGetNormalWeightMatrix(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(TestGetNormalWeightMatrix, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.0], + w_sigma=[1.0], + ) + def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_mu, w_sigma): + print( + f'test_get_normal_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_mu={w_mu}, w_sigma={w_sigma}') + weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + shape = (shape[1], shape[0]) if transpose else shape + assert weight.shape == shape + assert weight.dtype == jnp.float32 + + weight_true = weight > 0. + + print( + f'jnnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') + + # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED + # compare with jitconn op + + # rng = bm.random.RandomState() + # vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + # + # r1 = bm.jitconn.mv_prob_normal(vector, + # w_mu=w_mu, + # w_sigma=w_sigma, + # conn_prob=prob, + # shape=shape, + # seed=SEED, + # outdim_parallel=outdim_parallel, + # transpose=transpose) + # + # r2 = vector @ weight if transpose else weight @ vector + # print(f'r1: {r1}\n r2: {r2}') + # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index e1c4eafb..79a9bc87 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -8,5 +8,7 @@ mv_prob_normal as mv_prob_normal, get_conn_matrix as get_conn_matrix, + get_uniform_weight_matrix as get_uniform_weight_matrix, + get_normal_weight_matrix as get_normal_weight_matrix, ) From 452b2a3a9ad1bca6530f5102c717addd6d3f62d4 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 30 May 2024 12:40:45 +0800 Subject: [PATCH 05/11] Update linear.py --- brainpy/_src/dnn/linear.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 03398766..729fdc70 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -980,28 +980,28 @@ def __init__( self.sharding = sharding -class JitFPHomoLinear(Layer): +class JitFPHomoLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPUniformLinear(Layer): +class JitFPUniformLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPNormalLinear(Layer): +class JitFPNormalLayer(Layer): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPHomoLinear(JitFPHomoLinear): +class JitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1080,7 +1080,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPUniformLinear(JitFPUniformLinear): +class JitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1160,7 +1160,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class JitFPNormalLinear(JitFPNormalLinear): +class JitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1240,7 +1240,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPHomoLinear(JitFPHomoLinear): +class EventJitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1319,7 +1319,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPUniformLinear(JitFPUniformLinear): +class EventJitFPUniformLinear(JitFPUniformLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: @@ -1399,7 +1399,7 @@ def _batch_mv(self, x): outdim_parallel=not self.atomic) -class EventJitFPNormalLinear(JitFPNormalLinear): +class EventJitFPNormalLinear(JitFPNormalLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. It performs the computation of: From 8defda9e59391aa47fd607be19ca34c83f8f6bee Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 18:21:19 +0800 Subject: [PATCH 06/11] Update --- brainpy/_src/dnn/linear.py | 9 ++++++--- brainpy/_src/math/jitconn/matvec.py | 20 ++++++++++++------- ...nn_matrix.py => test_get_weight_matrix.py} | 7 ++++--- brainpy/math/jitconn.py | 2 +- 4 files changed, 24 insertions(+), 14 deletions(-) rename brainpy/_src/math/jitconn/tests/{test_get_conn_matrix.py => test_get_weight_matrix.py} (97%) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 729fdc70..d1dbd308 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -979,22 +979,25 @@ def __init__( self.conn = conn self.sharding = sharding +class JitLinear(Layer): + def get_conn_matrix(self): + pass -class JitFPHomoLayer(Layer): +class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_conn_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPUniformLayer(Layer): +class JitFPUniformLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) -class JitFPNormalLayer(Layer): +class JitFPNormalLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, shape=(self.num_out, self.num_in), diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index e2ae2322..08310506 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- - - +import numbers from typing import Tuple, Optional, Union import jax @@ -9,6 +8,7 @@ from jax.interpreters import ad from brainpy._src.dependency_check import import_taichi +from brainpy._src.math.defaults import float_ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp @@ -20,7 +20,7 @@ 'mv_prob_homo', 'mv_prob_uniform', 'mv_prob_normal', - 'get_conn_matrix', + 'get_homo_weight_matrix', 'get_uniform_weight_matrix', 'get_normal_weight_matrix' ] @@ -260,7 +260,8 @@ def mv_prob_normal( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def get_conn_matrix( +def get_homo_weight_matrix( + weight: float, conn_prob: float, seed: Optional[int] = None, *, @@ -290,6 +291,10 @@ def get_conn_matrix( out: Array, ndarray The connection matrix :math:`M`. """ + if isinstance(weight, numbers.Number): + weight = jnp.atleast_1d(jnp.asarray(weight, dtype=float_)) + else: + raise ValueError(f'weight must be a number type, but get {type(weight)}') if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') @@ -299,8 +304,9 @@ def get_conn_matrix( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - r = raw_get_connect_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) + r *= weight if transpose: return r.transpose() else: @@ -501,7 +507,7 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) -def raw_get_connect_matrix( +def raw_get_homo_weight_matrix( conn_len: jax.Array, seed: jax.Array, *, diff --git a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py similarity index 97% rename from brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py rename to brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py index bf94065e..fca88d6d 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_conn_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py @@ -33,19 +33,20 @@ def __init__(self, *args, platform='cpu', **kwargs): prob=[0.1], ) def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): + homo_data = 1. print( f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') - conn = bm.jitconn.get_conn_matrix(prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape print(conn.shape) assert conn.shape == shape - assert conn.dtype == jnp.bool_ + # assert conn.dtype == jnp.float_ # sum all true values print( f'jnp.sum(conn): {jnp.sum(conn)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') # compare with jitconn op - homo_data = 1. + rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 79a9bc87..3c99b7de 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -7,7 +7,7 @@ mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - get_conn_matrix as get_conn_matrix, + get_homo_weight_matrix as get_homo_weight_matrix, get_uniform_weight_matrix as get_uniform_weight_matrix, get_normal_weight_matrix as get_normal_weight_matrix, ) From 2f5802910bd46214dcc3e3df935fee4d690ff024 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 19:06:11 +0800 Subject: [PATCH 07/11] Update linear.py --- brainpy/_src/dnn/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index d1dbd308..3d4d61e3 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -985,7 +985,7 @@ def get_conn_matrix(self): class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): - return bm.jitconn.get_conn_matrix(self.prob, self.seed, + return bm.jitconn.get_uniform_weight_matrix(self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) From dc57c69a608371aa18afc31d356314a12d02db3b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 19:16:07 +0800 Subject: [PATCH 08/11] Add test for get_conn_matrix` at `brainpy.dnn.linear` module --- brainpy/_src/dnn/linear.py | 25 +++++++++++++++---------- brainpy/_src/dnn/tests/test_linear.py | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 3d4d61e3..1fdbf43c 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -979,30 +979,35 @@ def __init__( self.conn = conn self.sharding = sharding + class JitLinear(Layer): def get_conn_matrix(self): pass + class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): - return bm.jitconn.get_uniform_weight_matrix(self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + return bm.jitconn.get_uniform_weight_matrix(self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPUniformLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_uniform_weight_matrix(self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPNormalLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_normal_weight_matrix(self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + class JitFPHomoLinear(JitFPHomoLayer): r"""Synaptic matrix multiplication with the just-in-time connectivity. diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 6cc44538..43defbfa 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -141,6 +141,10 @@ def test_JitFPHomoLinear(self, prob, weight, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + print(conn_matrix.shape) + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -155,6 +159,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -169,6 +176,9 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): x = bm.random.random(shape + (100,)) y = f(x) self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -184,6 +194,9 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -200,6 +213,9 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -216,6 +232,9 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() From 939f03024c60c349f24df413fdc0b1ffee022ac1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 22:06:13 +0800 Subject: [PATCH 09/11] Update linear.py --- brainpy/_src/dnn/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 1fdbf43c..331efaed 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -987,7 +987,7 @@ def get_conn_matrix(self): class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): - return bm.jitconn.get_uniform_weight_matrix(self.weight, self.prob, self.seed, + return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, shape=(self.num_out, self.num_in), transpose=self.transpose, outdim_parallel=not self.atomic) From 3908afe480bc536281e59bd638f4cc169931e4c0 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 22:13:54 +0800 Subject: [PATCH 10/11] Update --- brainpy/_src/dnn/linear.py | 6 +++--- brainpy/_src/math/jitconn/matvec.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 331efaed..8e09f95b 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -988,9 +988,9 @@ def get_conn_matrix(self): class JitFPHomoLayer(JitLinear): def get_conn_matrix(self): return bm.jitconn.get_homo_weight_matrix(self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) class JitFPUniformLayer(JitLinear): diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 08310506..b9ed789c 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -8,11 +8,11 @@ from jax.interpreters import ad from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.defaults import float_ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError +from brainpy._src.math import defauts ti = import_taichi(error_if_not_found=False) @@ -292,7 +292,7 @@ def get_homo_weight_matrix( The connection matrix :math:`M`. """ if isinstance(weight, numbers.Number): - weight = jnp.atleast_1d(jnp.asarray(weight, dtype=float_)) + weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defauts.float_)) else: raise ValueError(f'weight must be a number type, but get {type(weight)}') if ti is None: @@ -1221,7 +1221,7 @@ def _get_uniform_weight_matrix( key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) while i_row < num_row: key, raw_v = lfsr88_uniform(key, w_low0, w_high0) - out[i_row, i_col] += raw_v + out[i_row, i_col] = raw_v key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc @@ -1246,7 +1246,7 @@ def _get_uniform_weight_matrix_outdim_parallel( key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) while i_col < num_col: key, raw_v = lfsr88_uniform(key, w_low0, w_high0) - out[i_row, i_col] += raw_v + out[i_row, i_col] = raw_v key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc @@ -1277,7 +1277,7 @@ def _get_normal_weight_matrix( key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) while i_row < num_row: key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row, i_col] += raw_v + out[i_row, i_col] = raw_v key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc @@ -1302,7 +1302,7 @@ def _get_normal_weight_matrix_outdim_parallel( key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) while i_col < num_col: key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row, i_col] += raw_v + out[i_row, i_col] = raw_v key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc From fae9b2acf6e606d6079803934fcafa88d49382a1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 1 Jun 2024 22:45:38 +0800 Subject: [PATCH 11/11] Fix bugs --- brainpy/_src/dnn/tests/test_linear.py | 24 +++--- brainpy/_src/math/jitconn/matvec.py | 9 +- .../jitconn/tests/test_get_weight_matrix.py | 85 ++++++++++--------- 3 files changed, 63 insertions(+), 55 deletions(-) diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 43defbfa..9f011cb8 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -143,8 +143,9 @@ def test_JitFPHomoLinear(self, prob, weight, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - print(conn_matrix.shape) - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # print(conn_matrix.shape) + # self.assertTrue(conn_matrix.shape == (200, 100)) bm.clear_buffer_memory() @parameterized.product( @@ -161,7 +162,7 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -178,7 +179,7 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -189,14 +190,15 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): def test_EventJitFPHomoLinear(self, prob, weight, shape): bm.random.seed() f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -208,14 +210,15 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): bm.random.seed() f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() @parameterized.product( @@ -227,14 +230,15 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.random.seed() f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) self.assertTrue(y.shape == shape + (200,)) y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) self.assertTrue(y2.shape == shape + (200,)) conn_matrix = f.get_conn_matrix() - self.assertTrue(conn_matrix.shape == (200, 100)) + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index b9ed789c..296a7994 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -4,15 +4,14 @@ import jax import numpy as np -from jax import numpy as jnp -from jax.interpreters import ad - from brainpy._src.dependency_check import import_taichi +from brainpy._src.math import defaults from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -from brainpy._src.math import defauts +from jax import numpy as jnp +from jax.interpreters import ad ti = import_taichi(error_if_not_found=False) @@ -292,7 +291,7 @@ def get_homo_weight_matrix( The connection matrix :math:`M`. """ if isinstance(weight, numbers.Number): - weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defauts.float_)) + weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defaults.float_)) else: raise ValueError(f'weight must be a number type, but get {type(weight)}') if ti is None: diff --git a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py index fca88d6d..9f10505a 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py @@ -15,14 +15,17 @@ # if platform.system() == 'Windows' and not force_test: # pytest.skip('skip windows', allow_module_level=True) -shapes = [(100, 200), (1000, 10)] +shapes = [ + (2, 2), + # (1000, 10) +] SEED = 1234 -class TestGetConnectMatrix(parameterized.TestCase): +class TestGetHomoWeightMatrix(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): - super(TestGetConnectMatrix, self).__init__(*args, **kwargs) + super(TestGetHomoWeightMatrix, self).__init__(*args, **kwargs) bm.set_platform(platform) print() @@ -32,10 +35,10 @@ def __init__(self, *args, platform='cpu', **kwargs): shape=shapes, prob=[0.1], ) - def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): + def test_get_homo_weight_matrix(self, transpose, outdim_parallel, shape, prob): homo_data = 1. print( - f'test_get_connect_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') + f'test_get_homo_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}') conn = bm.jitconn.get_homo_weight_matrix(homo_data, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape print(conn.shape) @@ -47,6 +50,7 @@ def test_get_conn_matrix(self, transpose, outdim_parallel, shape, prob): # compare with jitconn op + print(f'conn: {conn}') rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) @@ -74,14 +78,14 @@ def __init__(self, *args, platform='cpu', **kwargs): transpose=[True, False], outdim_parallel=[True, False], shape=shapes, - prob=[0.1], + prob=[0.5], w_low=[0.1], w_high=[0.9], ) def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_low, w_high): print( f'test_get_uniform_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_low={w_low}, w_high={w_high}') - weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, shape=shape, transpose=transpose, + weight = bm.jitconn.get_uniform_weight_matrix(w_low, w_high, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape assert weight.shape == shape @@ -92,24 +96,26 @@ def test_get_uniform_weight_matrix(self, transpose, outdim_parallel, shape, prob print( f'jnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED # compare with jitconn op - # rng = bm.random.RandomState() - # events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - # - # r1 = bm.jitconn.mv_prob_uniform(events, - # w_low=w_low, - # w_high=w_high, - # conn_prob=prob, - # shape=shape, - # seed=SEED, - # outdim_parallel=outdim_parallel, - # transpose=transpose) - # - # r2 = events @ weight if transpose else weight @ events - # print(f'r1: {r1}\n r2: {r2}') - # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + print(f'weight: {weight}') + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + # r2 = weight @ events if transpose else events @ weight + r2 = events @ weight if transpose else weight @ events + print(f'r1: {r1}\n r2: {r2}') + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) bm.clear_buffer_memory() @@ -131,7 +137,7 @@ def __init__(self, *args, platform='cpu', **kwargs): def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, w_mu, w_sigma): print( f'test_get_normal_weight_matrix: transpose={transpose}, outdim_parallel={outdim_parallel}, shape={shape}, prob={prob}, w_mu={w_mu}, w_sigma={w_sigma}') - weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, shape=shape, transpose=transpose, + weight = bm.jitconn.get_normal_weight_matrix(w_mu, w_sigma, prob, SEED, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) shape = (shape[1], shape[0]) if transpose else shape assert weight.shape == shape @@ -142,23 +148,22 @@ def test_get_normal_weight_matrix(self, transpose, outdim_parallel, shape, prob, print( f'jnnp.sum(conn): {jnp.sum(weight_true)}, jnp.round(prob * shape[0] * shape[1]): {jnp.round(prob * shape[0] * shape[1])}') - # CANNOT BE TESTED IN THIS WAY, BECAUSE UNIFORM JITCONN OP HAS BEEN OPTIMIZED # compare with jitconn op - # rng = bm.random.RandomState() - # vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - # - # r1 = bm.jitconn.mv_prob_normal(vector, - # w_mu=w_mu, - # w_sigma=w_sigma, - # conn_prob=prob, - # shape=shape, - # seed=SEED, - # outdim_parallel=outdim_parallel, - # transpose=transpose) - # - # r2 = vector @ weight if transpose else weight @ vector - # print(f'r1: {r1}\n r2: {r2}') - # self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(vector, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=SEED, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = vector @ weight if transpose else weight @ vector + print(f'r1: {r1}\n r2: {r2}') + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) bm.clear_buffer_memory()