diff --git a/brian2/codegen/runtime/cython_rt/templates/ratemonitor.pyx b/brian2/codegen/runtime/cython_rt/templates/ratemonitor.pyx index a96582d04..7b0cfab8d 100644 --- a/brian2/codegen/runtime/cython_rt/templates/ratemonitor.pyx +++ b/brian2/codegen/runtime/cython_rt/templates/ratemonitor.pyx @@ -1,31 +1,44 @@ -{# USES_VARIABLES { N, t, rate, _clock_t, _clock_dt, _spikespace, - _num_source_neurons, _source_start, _source_stop } #} +{# USES_VARIABLES { N, t, rate, _clock_t, _clock_dt, _spikespace} #} {% extends 'common_group.pyx' %} {% block maincode %} - cdef size_t _num_spikes = {{_spikespace}}[_num{{_spikespace}}-1] - - # For subgroups, we do not want to record all spikes + + {% if subgroup %} + cdef int32_t _filtered_spikes = 0 + cdef size_t _source_index_counter = 0 + {% if contiguous %}{# contiguous subgroup #} # We assume that spikes are ordered - cdef int _start_idx = -1 - cdef int _end_idx = -1 - cdef size_t _j + _start_idx = _num_spikes + _end_idx = _num_spikes for _j in range(_num_spikes): _idx = {{_spikespace}}[_j] if _idx >= _source_start: _start_idx = _j break - if _start_idx == -1: - _start_idx = _num_spikes - for _j in range(_start_idx, _num_spikes): + for _j in range(_num_spikes-1, _start_idx-1, -1): _idx = {{_spikespace}}[_j] - if _idx >= _source_stop: - _end_idx = _j + if _idx < _source_stop: + break + _end_idx = _j + _filtered_spikes = _end_idx - _start_idx + {% else %}{# non-contiguous subgroup #} + for _j in range(_num_spikes): + _idx = {{_spikespace}}[_j] + if _idx < {{_source_indices}}[_source_index_counter]: + continue + while {{_source_indices}}[_source_index_counter] < _idx: + _source_index_counter += 1 + if (_source_index_counter < {{source_N}} and + _idx == {{_source_indices}}[_source_index_counter]): + _source_index_counter += 1 + _filtered_spikes += 1 + + if _source_index_counter == {{source_N}}: break - if _end_idx == -1: - _end_idx =_num_spikes - _num_spikes = _end_idx - _start_idx + {% endif %} + _num_spikes = _filtered_spikes + {% endif %} # Calculate the new length for the arrays cdef size_t _new_len = {{_dynamic_t}}.shape[0] + 1 @@ -36,6 +49,6 @@ # Set the new values {{_dynamic_t}}.data[_new_len-1] = {{_clock_t}} - {{_dynamic_rate}}.data[_new_len-1] = _num_spikes/{{_clock_dt}}/_num_source_neurons + {{_dynamic_rate}}.data[_new_len-1] = _num_spikes/{{_clock_dt}}/{{source_N}} {% endblock %} diff --git a/brian2/codegen/runtime/cython_rt/templates/spikemonitor.pyx b/brian2/codegen/runtime/cython_rt/templates/spikemonitor.pyx index b5abb20e0..1017a671f 100644 --- a/brian2/codegen/runtime/cython_rt/templates/spikemonitor.pyx +++ b/brian2/codegen/runtime/cython_rt/templates/spikemonitor.pyx @@ -1,5 +1,4 @@ -{# USES_VARIABLES { N, _clock_t, count, - _source_start, _source_stop} #} +{# USES_VARIABLES { N, _clock_t, count} #} {% extends 'common_group.pyx' %} {% block maincode %} @@ -9,11 +8,20 @@ cdef size_t _num_events = {{_eventspace}}[_num{{_eventspace}}-1] cdef size_t _start_idx, _end_idx, _curlen, _newlen, _j + {% if subgroup and not contiguous %} + # We use the same data structure as for the eventspace to store the + # "filtered" events, i.e. the events that are indexed in the subgroup + cdef int[{{source_N}} + 1] _filtered_events + cdef size_t _source_index_counter = 0 + _filtered_events[{{source_N}}] = 0 + {% endif %} {% for varname, var in record_variables | dictsort %} cdef {{cpp_dtype(var.dtype)}}[:] _{{varname}}_view {% endfor %} if _num_events > 0: + {% if subgroup %} # For subgroups, we do not want to record all spikes + {% if contiguous %} # We assume that spikes are ordered _start_idx = _num_events _end_idx = _num_events @@ -28,6 +36,25 @@ break _end_idx = _j _num_events = _end_idx - _start_idx + {% else %} + for _j in range(_num_events): + _idx = {{_eventspace}}[_j] + if _idx < {{_source_indices}}[_source_index_counter]: + continue + if _idx > {{_source_indices}}[{{source_N}}-1]: + break + while {{_source_indices}}[_source_index_counter] < _idx: + _source_index_counter += 1 + if (_source_index_counter < {{source_N}} and + _idx == {{_source_indices}}[_source_index_counter]): + _source_index_counter += 1 + _filtered_events[_filtered_events[{{source_N}}]] = _idx + _filtered_events[{{source_N}}] += 1 + if _source_index_counter == {{source_N}}: + break + _num_events = _filtered_events[{{source_N}}] + {% endif %} + {% endif %} if _num_events > 0: # scalar code _vectorisation_idx = 1 @@ -41,6 +68,8 @@ _{{varname}}_view = {{get_array_name(var, access_data=False)}}.data {% endfor %} # Copy the values across + {% if subgroup %} + {% if contiguous %} for _j in range(_start_idx, _end_idx): _idx = {{_eventspace}}[_j] _vectorisation_idx = _idx @@ -49,4 +78,24 @@ _{{varname}}_view [_curlen + _j - _start_idx] = _to_record_{{varname}} {% endfor %} {{count}}[_idx - _source_start] += 1 + {% else %} + for _j in range(_num_events): + _idx = _filtered_events[_j] + _vectorisation_idx = _idx + {{ vector_code|autoindent }} + {% for varname in record_variables | sort %} + _{{varname}}_view [_curlen + _j] = _to_record_{{varname}} + {% endfor %} + {{count}}[_to_record_i] += 1 + {% endif %} + {% else %} + for _j in range(_num_events): + _idx = {{_eventspace}}[_j] + _vectorisation_idx = _idx + {{ vector_code|autoindent }} + {% for varname in record_variables | sort %} + _{{varname}}_view [_curlen + _j] = _to_record_{{varname}} + {% endfor %} + {{count}}[_idx] += 1 + {% endif %} {% endblock %} diff --git a/brian2/codegen/runtime/cython_rt/templates/summed_variable.pyx b/brian2/codegen/runtime/cython_rt/templates/summed_variable.pyx index 4d2b06827..0911b14a1 100644 --- a/brian2/codegen/runtime/cython_rt/templates/summed_variable.pyx +++ b/brian2/codegen/runtime/cython_rt/templates/summed_variable.pyx @@ -8,7 +8,11 @@ # Set all the target variable values to zero for _target_idx in range({{_target_size_name}}): + {% if _target_contiguous %} {{_target_var_array}}[_target_idx + {{_target_start}}] = 0 + {% else %} + {{_target_var_array}}[{{_target_indices}}[_target_idx]] = 0 + {% endif %} # scalar code _vectorisation_idx = 1 diff --git a/brian2/codegen/runtime/cython_rt/templates/synapses_create_generator.pyx b/brian2/codegen/runtime/cython_rt/templates/synapses_create_generator.pyx index eb9200616..ce313e090 100644 --- a/brian2/codegen/runtime/cython_rt/templates/synapses_create_generator.pyx +++ b/brian2/codegen/runtime/cython_rt/templates/synapses_create_generator.pyx @@ -68,8 +68,11 @@ cdef void _flush_buffer(buf, dynarr, int buf_len): {{scalar_code['update']|autoindent}} for _{{outer_index}} in range({{outer_index_size}}): + {% if outer_contiguous %} _raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}} - + {% else %} + _raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}] + {% endif %} {% if not result_index_condition %} {{vector_code['create_cond']|autoindent}} if not _cond: @@ -162,8 +165,11 @@ cdef void _flush_buffer(buf, dynarr, int buf_len): {% endif %} {{vector_code['generator_expr']|autoindent}} + {% if result_contiguous %} _raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}} - + {% else %} + _raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}] + {% endif %} {% if result_index_condition %} {% if result_index_used %} {# The condition could index outside of array range #} diff --git a/brian2/codegen/runtime/numpy_rt/numpy_rt.py b/brian2/codegen/runtime/numpy_rt/numpy_rt.py index d56e8b855..e94211ef8 100644 --- a/brian2/codegen/runtime/numpy_rt/numpy_rt.py +++ b/brian2/codegen/runtime/numpy_rt/numpy_rt.py @@ -127,7 +127,7 @@ def __iter__(self): return iter(self.indices) # Allow conversion to a proper array with np.array(...) - def __array__(self, dtype=None, copy=None): + def __array__(self, dtype=np.int32, copy=None): if copy is False: raise ValueError("LazyArange does not support copy=False") if self.indices is None: diff --git a/brian2/codegen/runtime/numpy_rt/templates/ratemonitor.py_ b/brian2/codegen/runtime/numpy_rt/templates/ratemonitor.py_ index 46da9356f..aa4bb1149 100644 --- a/brian2/codegen/runtime/numpy_rt/templates/ratemonitor.py_ +++ b/brian2/codegen/runtime/numpy_rt/templates/ratemonitor.py_ @@ -1,14 +1,19 @@ -{# USES_VARIABLES { rate, t, _spikespace, _num_source_neurons, - _clock_t, _clock_dt, _source_start, _source_stop, N } #} +{# USES_VARIABLES { rate, t, _spikespace, _clock_t, _clock_dt, N } #} {% extends 'common_group.py_' %} {% block maincode %} _spikes = {{_spikespace}}[:{{_spikespace}}[-1]] +{% if subgroup %} # Take subgroups into account +{% if contiguous %} _spikes = _spikes[(_spikes >= _source_start) & (_spikes < _source_stop)] +{% else %} +_spikes = _numpy.intersect1d(_spikes, {{_source_indices}}, assume_unique=True) +{% endif %} +{% endif %} _new_len = {{N}} + 1 _owner.resize(_new_len) {{N}} = _new_len {{_dynamic_t}}[-1] = {{_clock_t}} -{{_dynamic_rate}}[-1] = 1.0 * len(_spikes) / {{_clock_dt}} / _num_source_neurons +{{_dynamic_rate}}[-1] = 1.0 * len(_spikes) / {{_clock_dt}} / {{source_N}} {% endblock %} \ No newline at end of file diff --git a/brian2/codegen/runtime/numpy_rt/templates/spikemonitor.py_ b/brian2/codegen/runtime/numpy_rt/templates/spikemonitor.py_ index 639328f41..60274dcf1 100644 --- a/brian2/codegen/runtime/numpy_rt/templates/spikemonitor.py_ +++ b/brian2/codegen/runtime/numpy_rt/templates/spikemonitor.py_ @@ -1,4 +1,4 @@ -{# USES_VARIABLES {N, count, _clock_t, _source_start, _source_stop, _source_N} #} +{# USES_VARIABLES {N, count, _clock_t} #} {% extends 'common_group.py_' %} {% block maincode %} @@ -9,11 +9,15 @@ _n_events = {{_eventspace}}[-1] if _n_events > 0: _events = {{_eventspace}}[:_n_events] + {% if subgroup %} # Take subgroups into account - if _source_start != 0 or _source_stop != _source_N: - _events = _events[(_events >= _source_start) & (_events < _source_stop)] - _n_events = len(_events) - + {% if contiguous %} + _events = _events[(_events >= _source_start) & (_events < _source_stop)] + {% else %} + _events = _numpy.intersect1d(_events, {{_source_indices}}, assume_unique=True) + {% endif %} + _n_events = len(_events) + {% endif %} if _n_events > 0: _vectorisation_idx = 1 {{scalar_code|autoindent}} @@ -28,5 +32,13 @@ if _n_events > 0: {% set dynamic_varname = get_array_name(var, access_data=False) %} {{dynamic_varname}}[_curlen:_newlen] = _to_record_{{varname}} {% endfor %} + {% if not subgroup %} + {{count}}[_events] += 1 + {% else %} + {% if contiguous %} {{count}}[_events - _source_start] += 1 + {% else %} + {{count}}[_to_record_i] += 1 + {% endif %} + {% endif %} {% endblock %} diff --git a/brian2/codegen/runtime/numpy_rt/templates/summed_variable.py_ b/brian2/codegen/runtime/numpy_rt/templates/summed_variable.py_ index b925f66f3..734497a39 100644 --- a/brian2/codegen/runtime/numpy_rt/templates/summed_variable.py_ +++ b/brian2/codegen/runtime/numpy_rt/templates/summed_variable.py_ @@ -17,6 +17,9 @@ _vectorisation_idx = LazyArange(N) # We write to the array, using the name provided as a keyword argument to the # template # Note that for subgroups, we do not want to overwrite the full array +{% if not _target_contiguous %} +{{_target_var_array}}[{{_target_indices}}] = _numpy.broadcast_to(_synaptic_var, (N, )) +{% else %} {% if _target_start > 0 %} _indices = {{_index_array}} - {{_target_start}} {% else %} @@ -32,4 +35,5 @@ _length = _target_stop - {{_target_start}} {{_target_var_array}}[{{_target_start}}:_target_stop] = _numpy.bincount(_indices, minlength=_length, weights=_numpy.broadcast_to(_synaptic_var, (N, ))) +{% endif %} {% endblock %} diff --git a/brian2/codegen/runtime/numpy_rt/templates/synapses_create_generator.py_ b/brian2/codegen/runtime/numpy_rt/templates/synapses_create_generator.py_ index 0e201765c..86d1a1cd6 100644 --- a/brian2/codegen/runtime/numpy_rt/templates/synapses_create_generator.py_ +++ b/brian2/codegen/runtime/numpy_rt/templates/synapses_create_generator.py_ @@ -106,7 +106,11 @@ _vectorisation_idx = 1 "k" is called the inner variable #} for _{{outer_index}} in range({{outer_index_size}}): + {% if outer_contiguous %} _raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}} + {% else %} + _raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}] + {% endif %} {% if not result_index_condition %} {{vector_code['create_cond']|autoindent}} if not _cond: @@ -126,7 +130,11 @@ for _{{outer_index}} in range({{outer_index_size}}): _vectorisation_idx = {{inner_variable}} {{vector_code['generator_expr']|autoindent}} _vectorisation_idx = _{{result_index}} - _raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}} + {% if result_contiguous %} + _raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}}; + {% else %} + _raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}] + {% endif %} {% if result_index_condition %} {% if result_index_used %} {# The condition could index outside of array range #} @@ -184,7 +192,11 @@ for _{{outer_index}} in range({{outer_index_size}}): {% endif %} _vectorisation_idx = _{{result_index}} + {% if result_contiguous %} _raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}} + {% else %} + _raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}] + {% endif %} {{vector_code['update']|autoindent}} if not _numpy.isscalar(_n): diff --git a/brian2/core/network.py b/brian2/core/network.py index 4f2a44424..dfaee8d2f 100644 --- a/brian2/core/network.py +++ b/brian2/core/network.py @@ -16,6 +16,8 @@ from collections import Counter, defaultdict, namedtuple from collections.abc import Mapping, Sequence +import numpy as np + from brian2.core.base import BrianObject, BrianObjectException from brian2.core.clocks import Clock, defaultclock from brian2.core.names import Nameable @@ -321,18 +323,26 @@ def _check_multiple_summed_updaters(objects): "the target group instead." ) raise NotImplementedError(msg) - elif ( - obj.target.start < other_target.stop - and other_target.start < obj.target.stop - ): - # Overlapping subgroups - msg = ( - "Multiple 'summed variables' target the " - f"variable '{obj.target_var.name}' in overlapping " - f"groups '{other_target.name}' and '{obj.target.name}'. " - "Use separate variables in the target groups instead." - ) - raise NotImplementedError(msg) + else: + if getattr(obj.target, "contiguous", True): + target_indices = np.arange(obj.target.start, obj.target.stop) + else: + target_indices = obj.target.indices[:] + if getattr(other_target, "contiguous", True): + other_indices = np.arange(other_target.start, other_target.stop) + else: + other_indices = other_target.indices[:] + if np.intersect1d( + target_indices, other_indices, assume_unique=True + ).size: + # Overlapping subgroups + msg = ( + "Multiple 'summed variables' target the " + f"variable '{obj.target_var.name}' in overlapping " + f"groups '{other_target.name}' and '{obj.target.name}'. " + "Use separate variables in the target groups instead." + ) + raise NotImplementedError(msg) summed_targets[obj.target_var] = obj.target diff --git a/brian2/core/variables.py b/brian2/core/variables.py index f5e5e7f91..ec26494e6 100644 --- a/brian2/core/variables.py +++ b/brian2/core/variables.py @@ -196,7 +196,7 @@ def __setstate__(self, state): @property def is_boolean(self): - return np.issubdtype(self.dtype, np.bool_) + return np.issubdtype(self.dtype, bool) @property def is_integer(self): diff --git a/brian2/devices/cpp_standalone/codeobject.py b/brian2/devices/cpp_standalone/codeobject.py index 50fdec9cb..eae731411 100644 --- a/brian2/devices/cpp_standalone/codeobject.py +++ b/brian2/devices/cpp_standalone/codeobject.py @@ -108,8 +108,12 @@ def __init__(self, *args, **kwds): super().__init__(*args, **kwds) #: Store whether this code object defines before/after blocks self.before_after_blocks = [] + #: Store variables that are updated by this code object and therefore invalidate the cache + self.invalidate_cache_variables = set() def __call__(self, **kwds): + for var in self.invalidate_cache_variables: + get_device().array_cache[var] = None return self.run() def compile_block(self, block): diff --git a/brian2/devices/cpp_standalone/device.py b/brian2/devices/cpp_standalone/device.py index 0ceed3c15..e4e6281ff 100644 --- a/brian2/devices/cpp_standalone/device.py +++ b/brian2/devices/cpp_standalone/device.py @@ -19,9 +19,10 @@ import numpy as np -from brian2.codegen.codeobject import check_compiler_kwds +from brian2.codegen.codeobject import check_compiler_kwds, create_runner_codeobj from brian2.codegen.cpp_prefs import get_compiler_and_args, get_msvc_env from brian2.codegen.generators.cpp_generator import c_data_type +from brian2.codegen.runtime.numpy_rt import NumpyCodeObject from brian2.core.functions import Function from brian2.core.namespace import get_local_namespace from brian2.core.preferences import BrianPreference, prefs @@ -30,6 +31,7 @@ Constant, DynamicArrayVariable, Variable, + Variables, VariableView, ) from brian2.devices.device import Device, all_devices, reset_device, set_device @@ -480,8 +482,8 @@ def resize(self, var, new_size): self.main_queue.append(("resize_array", (array_name, new_size))) def variableview_set_with_index_array(self, variableview, item, value, check_units): - if isinstance(item, slice) and item == slice(None): - item = "True" + if item == "True": + item = slice(None) value = Quantity(value) if ( @@ -496,7 +498,7 @@ def variableview_set_with_index_array(self, variableview, item, value, check_uni ("set_by_single_value", (array_name, item, value_str)) ) # Simple case where we don't have to do any indexing - elif item == "True" and variableview.index_var in ("_idx", "0"): + elif item == slice(None) and variableview.index_var in ("_idx", "0"): self.fill_with_array(variableview.variable, value) else: # We have to calculate indices. This will not work for synaptic @@ -528,7 +530,10 @@ def variableview_set_with_index_array(self, variableview, item, value, check_uni staticarrayname_index = self.static_array(f"_index_{arrayname}", indices) staticarrayname_value = self.static_array(f"_value_{arrayname}", value) - self.array_cache[variableview.variable] = None + # Put values into the cache + cache_variable = self.array_cache[variableview.variable] + if cache_variable is not None: + cache_variable[indices] = value self.main_queue.append( ( "set_array_by_array", @@ -543,7 +548,14 @@ def get_value(self, var, access_data=True): # however (values that have been set with explicit values and not # changed in code objects) if self.array_cache.get(var, None) is not None: - return self.array_cache[var] + values = self.array_cache[var] + if isinstance(var, DynamicArrayVariable): + # Make sure that size information is up-to-date as well + if var.ndim == 2: + var.size = values.shape + else: + var.size = values.size + return values else: # After the network has been run, we can retrieve the values from # disk @@ -552,27 +564,16 @@ def get_value(self, var, access_data=True): fname = os.path.join(self.results_dir, self.get_array_filename(var)) with open(fname, "rb") as f: data = np.fromfile(f, dtype=dtype) - # This is a bit of an heuristic, but our 2d dynamic arrays are - # only expanding in one dimension, we assume here that the - # other dimension has size 0 at the beginning - if isinstance(var.size, tuple) and len(var.size) == 2: - if var.size[0] * var.size[1] == len(data): - size = var.size - elif var.size[0] == 0: - size = (len(data) // var.size[1], var.size[1]) - elif var.size[1] == 0: - size = (var.size[0], len(data) // var.size[0]) - else: - raise IndexError( - "Do not now how to deal with 2d " - f"array of size {var.size!s}, the array on " - f"disk has length {len(data)}." - ) + if ( + isinstance(var.size, tuple) and len(var.size) == 2 + ): # 2D dynamic array + values = data.reshape(var.size) + else: + values = data + # assert (np.atleast_1d(values.shape) == np.atleast_1d(var.size)).all(), f"{values.shape} ≠ {var.size} ({var.name})" + self.array_cache[var] = values + return values - var.size = size - return data.reshape(var.size) - var.size = len(data) - return data raise NotImplementedError( "Cannot retrieve the values of state " "variables in standalone code before the " @@ -604,6 +605,39 @@ def variableview_get_with_expression(self, variableview, code, run_namespace=Non "standalone scripts." ) + def index_wrapper_get_item(self, index_wrapper, item, level): + if isinstance(item, str): + variables = Variables(None) + variables.add_auxiliary_variable("_indices", dtype=np.int32) + variables.add_auxiliary_variable("_cond", dtype=bool) + + abstract_code = "_cond = " + item + namespace = get_local_namespace(level=level + 2) + try: + codeobj = create_runner_codeobj( + index_wrapper.group, + abstract_code, + "group_get_indices", + run_namespace=namespace, + additional_variables=variables, + codeobj_class=NumpyCodeObject, + ) + except NotImplementedError: + raise NotImplementedError( + "Cannot calculate indices with string " + "expressions in standalone mode if " + "the expression refers to variable " + "with values not known before running " + "the simulation." + ) + indices = codeobj() + # Delete code object from device to avoid trying to build it later + del self.code_objects[codeobj.name] + # Handle subgroups correctly + return index_wrapper.indices(indices) + else: + return index_wrapper.indices(item) + def code_object_class(self, codeobj_class=None, fallback_pref=None): """ Return `CodeObject` class (either `CPPStandaloneCodeObject` class or input) @@ -699,6 +733,9 @@ def code_object( cache.get(var, np.empty(0, dtype=int)), value ) do_not_invalidate.add(var) + # Also update the size attribute + variables["_synaptic_pre"].size += variables["sources"].size + variables["_synaptic_post"].size += variables["targets"].size codeobj = super().code_object( owner, @@ -746,7 +783,7 @@ def code_object( and var not in do_not_invalidate and (not var.read_only or var in written_readonly_vars) ): - self.array_cache[var] = None + codeobj.invalidate_cache_variables.add(var) return codeobj @@ -1246,7 +1283,6 @@ def run( list_rep.append(f"{name}={string_value}") run_args = list_rep - # Invalidate array cache for all variables set on the command line for arg in run_args: s = arg.split("=") @@ -1257,6 +1293,12 @@ def run( and var.owner.name + "." + var.name == s[0] ): self.array_cache[var] = None + + # Invalidate array cache for all variables written by code objects + for codeobj in self.code_objects.values(): + for var in codeobj.invalidate_cache_variables: + self.array_cache[var] = None + run_args = ["--results_dir", self.results_dir] + run_args # Invalidate the cached end time of the clock and network, to deal with stopped simulations for clock in self.clocks: @@ -1330,12 +1372,28 @@ def run( except ReferenceError: pass + # Make sure the size information is up-to-date for dynamic variables + # Note that the actual data will only be loaded on demand + with in_directory(directory): + for var in self.dynamic_arrays: + fname = self.get_array_filename(var) + "_size" + new_size = int(np.loadtxt(fname, delimiter=" ", dtype=np.int32)) + var.size = new_size + for var in self.dynamic_arrays_2d: + fname = self.get_array_filename(var) + "_size" + new_size = tuple(np.loadtxt(fname, delimiter=" ", dtype=np.int32)) + var.size = new_size + def _prepare_variableview_run_arg(self, key, value): fail_for_dimension_mismatch(key.dim, value) # TODO: Give name of variable value_ar = np.asarray(value, dtype=key.dtype) if value_ar.ndim == 0 or value_ar.size == 1: # single value, give value directly on command line - string_value = repr(value_ar.item()) + value = value_ar.item() + if value is True or value is None: # translate True/False to 1/0 + string_value = "1" if value else "0" + else: + string_value = repr(value) value_name = None else: if value_ar.ndim != 1 or ( @@ -1358,7 +1416,11 @@ def _prepare_timed_array_run_arg(self, key, value): value_ar = np.asarray(value, dtype=key.values.dtype) if value_ar.ndim == 0 or value_ar.size == 1: # single value, give value directly on command line - string_value = repr(value_ar.item()) + value = value_ar.item() + if value is True or value is None: # translate True/False to 1/0 + string_value = "1" if value else "0" + else: + string_value = repr(value) value_name = None elif value_ar.shape == key.values.shape: value_name = f"init_{key.name}_values_{md5(value_ar.data).hexdigest()}.dat" @@ -1919,6 +1981,11 @@ def network_run( net.after_run() + # Invalidate array cache for all variables written by code objects + for _, codeobj in code_objects: + for var in codeobj.invalidate_cache_variables: + self.array_cache[var] = None + # Manually set the cache for the clocks, simulation scripts might # want to access the time (which has been set in code and is therefore # not accessible by the normal means until the code has been built and diff --git a/brian2/devices/cpp_standalone/templates/objects.cpp b/brian2/devices/cpp_standalone/templates/objects.cpp index fab6bf3a0..411a29500 100644 --- a/brian2/devices/cpp_standalone/templates/objects.cpp +++ b/brian2/devices/cpp_standalone/templates/objects.cpp @@ -277,15 +277,21 @@ void _write_arrays() outfile_{{varname}}.open(results_dir + "{{get_array_filename(var)}}", ios::binary | ios::out); if(outfile_{{varname}}.is_open()) { - if (! {{varname}}.empty() ) - { - outfile_{{varname}}.write(reinterpret_cast(&{{varname}}[0]), {{varname}}.size()*sizeof({{varname}}[0])); - outfile_{{varname}}.close(); - } + outfile_{{varname}}.write(reinterpret_cast(&{{varname}}[0]), {{varname}}.size()*sizeof({{varname}}[0])); + outfile_{{varname}}.close(); } else { std::cout << "Error writing output file for {{varname}}." << endl; } + outfile_{{varname}}.open("{{get_array_filename(var) | replace('\\', '\\\\')}}_size", ios::out); + if (outfile_{{varname}}.is_open()) + { + outfile_{{varname}} << {{varname}}.size(); + outfile_{{varname}}.close(); + } else + { + std::cout << "Error writing size file for {{varname}}." << endl; + } {% endfor %} {% for var, varname in dynamic_array_2d_specs | dictsort(by='value') %} @@ -305,6 +311,15 @@ void _write_arrays() { std::cout << "Error writing output file for {{varname}}." << endl; } + outfile_{{varname}}.open("{{get_array_filename(var) | replace('\\', '\\\\')}}_size", ios::out); + if (outfile_{{varname}}.is_open()) { + outfile_{{varname}} << {{varname}}.n << " " << {{varname}}.m; + outfile_{{varname}}.close(); + } else + { + std::cout << "Error writing size file for {{varname}}." << endl; + } + {% endfor %} {% if profiled_codeobjects is defined and profiled_codeobjects %} // Write profiling info to disk diff --git a/brian2/devices/cpp_standalone/templates/ratemonitor.cpp b/brian2/devices/cpp_standalone/templates/ratemonitor.cpp index c06362543..9e22eaf12 100644 --- a/brian2/devices/cpp_standalone/templates/ratemonitor.cpp +++ b/brian2/devices/cpp_standalone/templates/ratemonitor.cpp @@ -1,36 +1,61 @@ -{# USES_VARIABLES { N, rate, t, _spikespace, _clock_t, _clock_dt, - _num_source_neurons, _source_start, _source_stop } #} +{# USES_VARIABLES { N, rate, t, _spikespace, _clock_t, _clock_dt} #} {# WRITES_TO_READ_ONLY_VARIABLES { N } #} {% extends 'common_group.cpp' %} {% block maincode %} size_t _num_spikes = {{_spikespace}}[_num_spikespace-1]; - // For subgroups, we do not want to record all spikes - // We assume that spikes are ordered - int _start_idx = -1; - int _end_idx = -1; - for(size_t _j=0; _j<_num_spikes; _j++) + {% if subgroup %} + int32_t _filtered_spikes = 0; + size_t _source_index_counter = 0; + size_t _start_idx = _num_spikes; + size_t _end_idx = _num_spikes; + if (_num_spikes > 0) { - const size_t _idx = {{_spikespace}}[_j]; - if (_idx >= _source_start) { - _start_idx = _j; - break; + {% if contiguous %} {# contiguous subgroup #} + // We filter the spikes, making use of the fact that they are sorted + for(size_t _j=0; _j<_num_spikes; _j++) + { + const int _idx = {{_spikespace}}[_j]; + if (_idx >= _source_start) { + _start_idx = _j; + break; + } } - } - if (_start_idx == -1) - _start_idx = _num_spikes; - for(size_t _j=_start_idx; _j<_num_spikes; _j++) - { - const size_t _idx = {{_spikespace}}[_j]; - if (_idx >= _source_stop) { + for(size_t _j=_num_spikes-1; _j>=_start_idx; _j--) + { + const int _idx = {{_spikespace}}[_j]; + if (_idx < _source_stop) { + break; + } _end_idx = _j; - break; } + _filtered_spikes = _end_idx - _start_idx; + {% else %} {# non-contiguous subgroup #} + for (size_t _j=0; _j<_num_spikes; _j++) + { + const size_t _idx = {{_spikespace}}[_j]; + if (_idx < {{_source_indices}}[_source_index_counter]) + continue; + while ({{_source_indices}}[_source_index_counter] < _idx) + { + _source_index_counter++; + } + if (_source_index_counter < {{source_N}} && + _idx == {{_source_indices}}[_source_index_counter]) + { + _source_index_counter += 1; + _filtered_spikes += 1; + if (_source_index_counter == {{source_N}}) + break; + } + if (_source_index_counter == {{source_N}}) + break; + } + {% endif %} + _num_spikes = _filtered_spikes; } - if (_end_idx == -1) - _end_idx =_num_spikes; - _num_spikes = _end_idx - _start_idx; - {{_dynamic_rate}}.push_back(1.0*_num_spikes/{{_clock_dt}}/_num_source_neurons); + {% endif %} + {{_dynamic_rate}}.push_back(1.0*_num_spikes/{{_clock_dt}}/{{source_N}}); {{_dynamic_t}}.push_back({{_clock_t}}); {{N}}++; {% endblock %} diff --git a/brian2/devices/cpp_standalone/templates/spikemonitor.cpp b/brian2/devices/cpp_standalone/templates/spikemonitor.cpp index ca62a5fe8..96ec3e41d 100644 --- a/brian2/devices/cpp_standalone/templates/spikemonitor.cpp +++ b/brian2/devices/cpp_standalone/templates/spikemonitor.cpp @@ -1,6 +1,5 @@ -{# USES_VARIABLES { N, _clock_t, count, - _source_start, _source_stop} #} - {# WRITES_TO_READ_ONLY_VARIABLES { N, count } #} +{# USES_VARIABLES { N, _clock_t, count} #} +{# WRITES_TO_READ_ONLY_VARIABLES { N, count } #} {% extends 'common_group.cpp' %} {% block maincode %} @@ -9,11 +8,20 @@ {% set _eventspace = get_array_name(eventspace_variable) %} int32_t _num_events = {{_eventspace}}[_num{{eventspace_variable.name}}-1]; - + {% if subgroup and not contiguous %} + // We use the same data structure as for the eventspace to store the + // "filtered" events, i.e. the events that are indexed in the subgroup + int32_t _filtered_events[{{source_N}} + 1]; + _filtered_events[{{source_N}}] = 0; + size_t _source_index_counter = 0; + {% endif %} + {% if subgroup %} + // For subgroups, we do not want to record all spikes + size_t _start_idx = _num_events; + size_t _end_idx = _num_events; if (_num_events > 0) { - size_t _start_idx = _num_events; - size_t _end_idx = _num_events; + {% if contiguous %} for(size_t _j=0; _j<_num_events; _j++) { const int _idx = {{_eventspace}}[_j]; @@ -31,21 +39,74 @@ _end_idx = _j; } _num_events = _end_idx - _start_idx; - if (_num_events > 0) { - const size_t _vectorisation_idx = 1; - {{scalar_code|autoindent}} - for(size_t _j=_start_idx; _j<_end_idx; _j++) + {% else %} + const size_t _max_source_index = {{_source_indices}}[{{source_N}}-1]; + for (size_t _j=0; _j<_num_events; _j++) + { + const size_t _idx = {{_eventspace}}[_j]; + if (_idx < {{_source_indices}}[_source_index_counter]) + continue; + if (_idx > _max_source_index) + break; + while ({{_source_indices}}[_source_index_counter] < _idx) { - const size_t _idx = {{_eventspace}}[_j]; - const size_t _vectorisation_idx = _idx; - {{vector_code|autoindent}} - {% for varname, var in record_variables | dictsort %} - {{get_array_name(var, access_data=False)}}.push_back(_to_record_{{varname}}); - {% endfor %} - {{count}}[_idx-_source_start]++; + _source_index_counter++; } - {{N}} += _num_events; + if (_source_index_counter < {{source_N}} && + _idx == {{_source_indices}}[_source_index_counter]) + { + _source_index_counter += 1; + _filtered_events[_filtered_events[{{source_N}}]++] = _idx; + if (_source_index_counter == {{source_N}}) + break; + } + if (_source_index_counter == {{source_N}}) + break; + } + _num_events = _filtered_events[{{source_N}}]; + {% endif %} + } + {% endif %} + if (_num_events > 0) { + const size_t _vectorisation_idx = 1; + {{scalar_code|autoindent}} + {% if subgroup %} + {% if contiguous %} + for(size_t _j=_start_idx; _j<_end_idx; _j++) + { + const size_t _idx = {{_eventspace}}[_j]; + const size_t _vectorisation_idx = _idx; + {{vector_code|autoindent}} + {% for varname, var in record_variables | dictsort %} + {{get_array_name(var, access_data=False)}}.push_back(_to_record_{{varname}}); + {% endfor %} + {{count}}[_idx-_source_start]++; + } + {% else %} + for(size_t _j=0; _j < _num_events; _j++) + { + const size_t _idx = _filtered_events[_j]; + const size_t _vectorisation_idx = _idx; + {{vector_code|autoindent}} + {% for varname, var in record_variables | dictsort %} + {{get_array_name(var, access_data=False)}}.push_back(_to_record_{{varname}}); + {% endfor %} + {{count}}[_to_record_i]++; + } + {% endif %} + {% else %} + for (size_t _j=0; _j < _num_events; _j++) + { + const size_t _idx = {{_eventspace}}[_j]; + const size_t _vectorisation_idx = _idx; + {{ vector_code|autoindent }} + {% for varname, var in record_variables | dictsort %} + {{get_array_name(var, access_data=False)}}.push_back(_to_record_{{varname}}); + {% endfor %} + {{count}}[_idx]++; } + {% endif %} + {{N}} += _num_events; } {% endblock %} diff --git a/brian2/devices/cpp_standalone/templates/summed_variable.cpp b/brian2/devices/cpp_standalone/templates/summed_variable.cpp index 9967d0fda..1a9f4d273 100644 --- a/brian2/devices/cpp_standalone/templates/summed_variable.cpp +++ b/brian2/devices/cpp_standalone/templates/summed_variable.cpp @@ -7,12 +7,15 @@ //// MAIN CODE //////////// {# This enables summed variables for connections to a synapse #} const int _target_size = {{constant_or_scalar(_target_size_name, variables[_target_size_name])}}; - // Set all the target variable values to zero {{ openmp_pragma('parallel-static') }} for (int _target_idx=0; _target_idx<_target_size; _target_idx++) { + {% if _target_contiguous %} {{_target_var_array}}[_target_idx + {{_target_start}}] = 0; + {% else %} + {{_target_var_array}}[{{_target_indices}}[_target_idx]] = 0; + {% endif %} } // scalar code diff --git a/brian2/devices/cpp_standalone/templates/synapses_create_array.cpp b/brian2/devices/cpp_standalone/templates/synapses_create_array.cpp index 2d652bc70..b13881ab5 100644 --- a/brian2/devices/cpp_standalone/templates/synapses_create_array.cpp +++ b/brian2/devices/cpp_standalone/templates/synapses_create_array.cpp @@ -16,8 +16,16 @@ const size_t _new_num_synapses = _old_num_synapses + _numsources; constants or scalar arrays#} const size_t _N_pre = {{constant_or_scalar('N_pre', variables['N_pre'])}}; const size_t _N_post = {{constant_or_scalar('N_post', variables['N_post'])}}; +{% if "_target_sub_idx" in variables %} +{{_dynamic_N_incoming}}.resize({{get_array_name(variables['_target_sub_idx'])}}[_num_target_sub_idx - 1] + 1); +{% else %} {{_dynamic_N_incoming}}.resize(_N_post + _target_offset); +{% endif %} +{% if "_source_sub_idx" in variables %} +{{_dynamic_N_outgoing}}.resize({{get_array_name(variables['_source_sub_idx'])}}[_num_source_sub_idx - 1] + 1); +{% else %} {{_dynamic_N_outgoing}}.resize(_N_pre + _source_offset); +{% endif %} for (size_t _idx=0; _idx<_numsources; _idx++) { {# After this code has been executed, the arrays _real_sources and diff --git a/brian2/devices/cpp_standalone/templates/synapses_create_generator.cpp b/brian2/devices/cpp_standalone/templates/synapses_create_generator.cpp index 04cfe1c17..bd2fab5a8 100644 --- a/brian2/devices/cpp_standalone/templates/synapses_create_generator.cpp +++ b/brian2/devices/cpp_standalone/templates/synapses_create_generator.cpp @@ -18,8 +18,16 @@ constants or scalar arrays#} const size_t _N_pre = {{constant_or_scalar('N_pre', variables['N_pre'])}}; const size_t _N_post = {{constant_or_scalar('N_post', variables['N_post'])}}; + {% if "_target_sub_idx" in variables %} + {{_dynamic_N_incoming}}.resize({{get_array_name(variables['_target_sub_idx'])}}[_num_target_sub_idx - 1] + 1); + {% else %} {{_dynamic_N_incoming}}.resize(_N_post + _target_offset); + {% endif %} + {% if "_source_sub_idx" in variables %} + {{_dynamic_N_outgoing}}.resize({{get_array_name(variables['_source_sub_idx'])}}[_num_source_sub_idx - 1] + 1); + {% else %} {{_dynamic_N_outgoing}}.resize(_N_pre + _source_offset); + {% endif %} size_t _raw_pre_idx, _raw_post_idx; {# For a connect call j='k+i for k in range(0, N_post, 2) if k+i < N_post' "j" is called the "result index" (and "_post_idx" the "result index array", etc.) @@ -35,7 +43,11 @@ for(size_t _{{outer_index}}=0; _{{outer_index}}<_{{outer_index_size}}; _{{outer_index}}++) { bool __cond, _cond; + {% if outer_contiguous %} _raw{{outer_index_array}} = _{{outer_index}} + {{outer_index_offset}}; + {% else %} + _raw{{outer_index_array}} = {{get_array_name(variables[outer_sub_idx])}}[_{{outer_index}}]; + {% endif %} {% if not result_index_condition %} { {{vector_code['create_cond']|autoindent}} @@ -181,7 +193,11 @@ } _{{result_index}} = __{{result_index}}; // make the previously locally scoped var available {{outer_index_array}} = _{{outer_index_array}}; + {% if result_contiguous %} _raw{{result_index_array}} = _{{result_index}} + {{result_index_offset}}; + {% else %} + _raw{{result_index_array}} = {{get_array_name(variables[result_sub_idx])}}[_{{result_index}}]; + {% endif %} {% if result_index_condition %} { {% if result_index_used %} diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 9a867e471..86c1d048f 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -15,7 +15,7 @@ import numpy as np from brian2.codegen.codeobject import create_runner_codeobj -from brian2.core.base import BrianObject, weakproxy_with_fallback +from brian2.core.base import BrianObject, device_override, weakproxy_with_fallback from brian2.core.functions import Function from brian2.core.names import Nameable, find_name from brian2.core.namespace import ( @@ -219,7 +219,8 @@ class Indexing: specific indices. Stores strong references to the necessary variables so that basic indexing (i.e. slicing, integer arrays/values, ...) works even when the respective `VariableOwner` no longer exists. Note that this object - does not handle string indexing. + does not handle string indexing (handled by `IndexWrapper` and + `VariableView.get_with_expression`). """ def __init__(self, group, default_index="_idx"): @@ -257,41 +258,82 @@ def __call__(self, item=slice(None), index_var=None): # noqa: B008 raise IndexError( f"Can only interpret 1-d indices, got {len(item)} dimensions." ) - else: - if isinstance(item, str) and item == "True": - item = slice(None) - if isinstance(item, slice): - if index_var == "0": - return 0 - if index_var == "_idx": - start, stop, step = item.indices(self.N.item()) + + index_array = self._to_index_array(item, index_var) + + if index_array.size == 0: + return index_array + + if index_var not in ("_idx", "0"): + try: + index_array = index_var.get_value()[index_array] + except IndexError as ex: + # We try to emulate numpy's indexing semantics here: + # slices never lead to IndexErrors, instead they return an + # empty array if they don't match anything + if isinstance(item, slice): + return np.array([], dtype=np.int32) else: - start, stop, step = item.indices(index_var.size) - index_array = np.arange(start, stop, step, dtype=np.int32) - else: - index_array = np.asarray(item) - if index_array.dtype == bool: - index_array = np.nonzero(index_array)[0] - elif not np.issubdtype(index_array.dtype, np.signedinteger): - raise TypeError( - "Indexing is only supported for integer " - "and boolean arrays, not for type " - f"{index_array.dtype}" + raise ex + else: + N = self.N.item() + if np.min(index_array) < -N: + raise IndexError( + "Illegal index {} for a group of size {}".format( + np.min(index_array), N + ) + ) + if np.max(index_array) >= N: + raise IndexError( + "Illegal index {} for a group of size {}".format( + np.max(index_array), N ) + ) - if index_var != "_idx": - try: - return index_var.get_value()[index_array] - except IndexError as ex: - # We try to emulate numpy's indexing semantics here: - # slices never lead to IndexErrors, instead they return an - # empty array if they don't match anything - if isinstance(item, slice): - return np.array([], dtype=np.int32) - else: - raise ex + index_array = index_array % N # interpret negative indices + + return index_array + + def _to_index_array(self, item, index_var): + """ + Convert slices, integer/boolean arrays to an integer array of indices. + + Parameters + ---------- + item: slice, array, int + The indices to translate. + index_var : `ArrayVariable`, str + The index variable. + Returns + ------- + indices : `numpy.ndarray` + The flat indices corresponding to the indices given in `item`. + """ + if isinstance(item, slice): + if index_var == "0": + index_array = np.array(0) else: - return index_array + index_size = ( + int(self.N.get_value()) if index_var == "_idx" else index_var.size + ) + start, stop, step = item.indices(index_size) + index_array = np.arange(start, stop, step) + else: # array, sequence, or single value + index_array = np.asarray(item) + if index_array.dtype == bool: + if not index_array.shape == (self.N.get_value(),): + raise IndexError( + "Boolean index did not match shape of indexed array;shape is" + f" {(self.N.get_value(), )}, but index is {index_array.shape}" + ) + index_array = np.flatnonzero(index_array) + elif not np.issubdtype(index_array.dtype, np.signedinteger): + raise TypeError( + "Indexing is only supported for integer " + "and boolean arrays, not for type " + f"{index_array.dtype}" + ) + return index_array class IndexWrapper: @@ -307,13 +349,17 @@ def __init__(self, group): self.indices = group._indices def __getitem__(self, item): + return self.get_item(item, level=1) + + @device_override("index_wrapper_get_item") + def get_item(self, item, level): if isinstance(item, str): variables = Variables(None) variables.add_auxiliary_variable("_indices", dtype=np.int32) variables.add_auxiliary_variable("_cond", dtype=bool) abstract_code = f"_cond = {item}" - namespace = get_local_namespace(level=1) + namespace = get_local_namespace(level=level + 2) # decorated function from brian2.devices.device import get_device device = get_device() @@ -327,7 +373,9 @@ def __getitem__(self, item): fallback_pref="codegen.string_expression_target" ), ) - return codeobj() + indices = codeobj() + # Handle subgroups correctly + return self.indices(indices) else: return self.indices(item) diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 4e3b27491..36558da57 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -2,9 +2,8 @@ This model defines the `NeuronGroup`, the core of most simulations. """ -import numbers import string -from collections.abc import MutableMapping, Sequence +from collections.abc import MutableMapping import numpy as np import sympy @@ -44,7 +43,7 @@ from brian2.utils.stringtools import get_identifiers from .group import CodeRunner, Group, get_dtype -from .subgroup import Subgroup +from .subgroup import Subgroup, to_start_stop_or_index __all__ = ["NeuronGroup"] @@ -109,82 +108,6 @@ def check_identifier_pre_post(identifier): ) -def to_start_stop(item, N): - """ - Helper function to transform a single number, a slice or an array of - contiguous indices to a start and stop value. This is used to allow for - some flexibility in the syntax of specifying subgroups in `.NeuronGroup` - and `.SpatialNeuron`. - - Parameters - ---------- - item : slice, int or sequence - The slice, index, or sequence of indices to use. Note that a sequence - of indices has to be a sorted ascending sequence of subsequent integers. - N : int - The total number of elements in the group. - - Returns - ------- - start : int - The start value of the slice. - stop : int - The stop value of the slice. - - Examples - -------- - >>> from brian2.groups.neurongroup import to_start_stop - >>> to_start_stop(slice(3, 6), 10) - (3, 6) - >>> to_start_stop(slice(3, None), 10) - (3, 10) - >>> to_start_stop(5, 10) - (5, 6) - >>> to_start_stop([3, 4, 5], 10) - (3, 6) - >>> to_start_stop([3, 5, 7], 10) - Traceback (most recent call last): - ... - IndexError: Subgroups can only be constructed using contiguous indices. - - """ - if isinstance(item, slice): - start, stop, step = item.indices(N) - elif isinstance(item, numbers.Integral): - start = item - stop = item + 1 - step = 1 - elif isinstance(item, (Sequence, np.ndarray)) and not isinstance(item, str): - if not (len(item) > 0 and np.all(np.diff(item) == 1)): - raise IndexError( - "Subgroups can only be constructed using contiguous indices." - ) - if not np.issubdtype(np.asarray(item).dtype, np.integer): - raise TypeError("Subgroups can only be constructed using integer values.") - start = int(item[0]) - stop = int(item[-1]) + 1 - step = 1 - else: - raise TypeError( - "Subgroups can only be constructed using slicing " - "syntax, a single index, or an array of contiguous " - "indices." - ) - if step != 1: - raise IndexError("Subgroups have to be contiguous") - if start >= stop: - raise IndexError( - f"Illegal start/end values for subgroup, {int(start)}>={int(stop)}" - ) - if start >= N: - raise IndexError(f"Illegal start value for subgroup, {int(start)}>={int(N)}") - if stop > N: - raise IndexError(f"Illegal stop value for subgroup, {int(stop)}>{int(N)}") - if start < 0: - raise IndexError("Indices have to be positive.") - return start, stop - - class StateUpdater(CodeRunner): """ The `CodeRunner` that updates the state variables of a `NeuronGroup` @@ -902,9 +825,8 @@ def __setattr__(self, key, value): Group.__setattr__(self, key, value, level=1) def __getitem__(self, item): - start, stop = to_start_stop(item, self._N) - - return Subgroup(self, start, stop) + start, stop, indices = to_start_stop_or_index(item, self, level=1) + return Subgroup(self, start, stop, indices) def _create_variables(self, user_dtype, events): """ diff --git a/brian2/groups/subgroup.py b/brian2/groups/subgroup.py index 34a2a73e7..688e4f160 100644 --- a/brian2/groups/subgroup.py +++ b/brian2/groups/subgroup.py @@ -1,11 +1,73 @@ +import numpy as np + from brian2.core.base import weakproxy_with_fallback from brian2.core.spikesource import SpikeSource from brian2.core.variables import Variables +from brian2.utils.logger import get_logger from .group import Group, Indexing __all__ = ["Subgroup"] +logger = get_logger(__name__) + + +def to_start_stop_or_index(item, group, level=0): + """ + Helper function to transform a single number, a slice or an array of + indices to a start and stop value (if possible), or to an index of positive + indices (interpreting negative indices correctly). This is used to allow for + some flexibility in the syntax of specifying subgroups in `.NeuronGroup` + and `.SpatialNeuron`. + + Parameters + ---------- + item : slice, int, str, or sequence + The slice, index, or sequence of indices to use, or a boolean string + expression that can be evaluated in the context of the group. + group : `Group` + The group providing the context for the interpretation. + Returns + ------- + start : int or None + The start value of the slice. + stop : int or None + The stop value of the slice. + indices : `np.ndarray` or None + The indices. + + Examples + -------- + >>> from brian2.groups.neurongroup import NeuronGroup, to_start_stop_or_index + >>> group = NeuronGroup(10, '') + >>> to_start_stop_or_index(slice(3, 6), group) + (3, 6, None) + >>> to_start_stop_or_index(slice(3, None), group) + (3, 10, None) + >>> to_start_stop_or_index(5, group) + (5, 6, None) + >>> to_start_stop_or_index(slice(None, None, 2), group) # doctest: +ELLIPSIS + (None, None, array([0, 2, 4, 6, 8]...)) + >>> to_start_stop_or_index([3, 4, 5], group) + (3, 6, None) + >>> to_start_stop_or_index([3, 5, 7], group) # doctest: +ELLIPSIS + (None, None, array([3, 5, 7]...)) + >>> to_start_stop_or_index([-3, -2, -1], group) + (7, 10, None) + """ + start = stop = None + indices = group.indices.get_item(item, level=level + 1) + # For convenience, allow subgroups with a single value instead of x:x+1 slice + if indices.shape == (): + indices = np.array([indices]) + + if np.all(np.diff(indices) == 1): + start = int(indices[0]) + stop = int(indices[-1]) + 1 + indices = None + + return start, stop, indices + class Subgroup(Group, SpikeSource): """ @@ -15,13 +77,53 @@ class Subgroup(Group, SpikeSource): ---------- source : SpikeSource The source object to subgroup. - start, stop : int - Select only spikes with indices from ``start`` to ``stop-1``. + start, stop : int, optional + Select only spikes with indices from ``start`` to ``stop-1``. Cannot + be specified at the same time as ``indices``. + indices : `np.ndarray`, optional + The indices of the subgroup. Note that subgroups with non-contiguous + indices cannot be used everywhere. Cannot be specified at the same time + as ``start`` and ``stop``. name : str, optional A unique name for the group, or use ``source.name+'_subgroup_0'``, etc. """ - def __init__(self, source, start, stop, name=None): + def __init__(self, source, start=None, stop=None, indices=None, name=None): + if start is stop is indices is None: + raise TypeError("Need to specify either start and stop or indices.") + if start != stop and (start is None or stop is None): + raise TypeError("start and stop have to be specified together.") + if indices is not None and (start is not None): + raise TypeError("Cannot specify both indices and start and stop.") + if start is not None: + self.contiguous = True + if start < 0: + raise IndexError("Start index cannot be negative.") + if stop <= start: + raise IndexError("Stop index has to be bigger than start.") + if stop > len(source): + raise IndexError( + "Stop index cannot be > the size of the group " + f"({stop} > {len(source)})." + ) + else: + self.contiguous = False + if not len(indices): + raise IndexError("Cannot create an empty subgroup.") + min_index = np.min(indices) + max_index = np.max(indices) + if min_index < 0: + raise IndexError("Indices cannot contain negative values.") + if max_index >= len(source): + raise IndexError( + "Indices cannot be ≥ the size of the group " + f"({max_index} ≥ {len(source)})." + ) + if not np.all(np.diff(indices) > 0): + raise IndexError( + "indices need to be sorted and cannot contain repeated values." + ) + # A Subgroup should never be constructed from another Subgroup # Instead, use Subgroup(source.source, # start + source.start, stop + source.start) @@ -46,9 +148,13 @@ def __init__(self, source, start, stop, name=None): order=source.order + 1, name=name, ) - self._N = stop - start + if self.contiguous: + self._N = stop - start + else: + self._N = len(indices) self.start = start self.stop = stop + self.sub_indices = indices self.events = self.source.events @@ -57,7 +163,7 @@ def __init__(self, source, start, stop, name=None): self.variables = Variables(self, default_index="_sub_idx") # overwrite the meaning of N and i - if self.start > 0: + if self.contiguous and self.start > 0: self.variables.add_constant("_offset", value=self.start) self.variables.add_reference("_source_i", source, "i") self.variables.add_subexpression( @@ -66,9 +172,23 @@ def __init__(self, source, start, stop, name=None): expr="_source_i - _offset", index="_idx", ) - else: + elif self.contiguous: # no need to calculate anything if this is a subgroup starting at 0 self.variables.add_reference("i", source) + else: + # We need an array to invert the indexing, i.e. an array where you + # can use the indices and get back 0, 1, 2, ... + inv_idx = np.zeros(np.max(indices) + 1) + inv_idx[indices] = np.arange(len(indices)) + self.variables.add_array( + "i", + size=len(inv_idx), + dtype=source.variables["i"].dtype, + values=inv_idx, + constant=True, + read_only=True, + unique=True, + ) self.variables.add_constant("N", value=self._N) self.variables.add_constant("_source_N", value=len(source)) @@ -77,14 +197,26 @@ def __init__(self, source, start, stop, name=None): # Only the variable _sub_idx itself is stored in the subgroup # and needs the normal index for this group - self.variables.add_arange( - "_sub_idx", size=self._N, start=self.start, index="_idx" - ) + if self.contiguous: + self.variables.add_arange( + "_sub_idx", size=self._N, start=self.start, index="_idx" + ) + else: + self.variables.add_array( + "_sub_idx", + size=self._N, + dtype=np.int32, + values=indices, + index="_idx", + constant=True, + read_only=True, + unique=True, + ) # special indexing for subgroups self._indices = Indexing(self, self.variables["_sub_idx"]) - # Deal with special indices + # Deal with special sub_indices for key, value in self.source.variables.indices.items(): if value == "0": self.variables.indices[key] = "0" @@ -104,22 +236,25 @@ def __init__(self, source, start, stop, name=None): spikes = property(lambda self: self.source.spikes) def __getitem__(self, item): - if not isinstance(item, slice): - raise TypeError("Subgroups can only be constructed using slicing syntax") - start, stop, step = item.indices(self._N) - if step != 1: - raise IndexError("Subgroups have to be contiguous") - if start >= stop: - raise IndexError( - f"Illegal start/end values for subgroup, {int(start)}>={int(stop)}" - ) - return Subgroup(self.source, self.start + start, self.start + stop) + start, stop, indices = to_start_stop_or_index(item, self, level=1) + return Subgroup(self.source, start, stop, indices) def __repr__(self): - classname = self.__class__.__name__ - return ( - f"<{classname} {self.name!r} of {self.source.name!r} " - f"from {self.start} to {self.stop}>" + if self.contiguous: + description = "<{classname} {name} of {source} from {start} to {end}>" + str_indices = None + else: + description = "<{classname} {name} of {source} with indices {indices}>" + str_indices = np.array2string( + self.sub_indices, threshold=10, separator=", " + ) + return description.format( + classname=self.__class__.__name__, + name=repr(self.name), + source=repr(self.source.name), + start=self.start, + indices=str_indices, + end=self.stop, ) def __del__(self): diff --git a/brian2/input/poissongroup.py b/brian2/input/poissongroup.py index 50e48acfb..70394f614 100644 --- a/brian2/input/poissongroup.py +++ b/brian2/input/poissongroup.py @@ -8,7 +8,7 @@ from brian2.core.variables import Subexpression, Variables from brian2.groups.group import Group from brian2.groups.neurongroup import Thresholder -from brian2.groups.subgroup import Subgroup +from brian2.groups.subgroup import Subgroup, to_start_stop_or_index from brian2.parsing.expressions import parse_expression_dimensions from brian2.units.fundamentalunits import check_units, fail_for_dimension_mismatch from brian2.units.stdunits import Hz @@ -111,17 +111,8 @@ def __init__( self.rates = rates def __getitem__(self, item): - if not isinstance(item, slice): - raise TypeError("Subgroups can only be constructed using slicing syntax") - start, stop, step = item.indices(self._N) - if step != 1: - raise IndexError("Subgroups have to be contiguous") - if start >= stop: - raise IndexError( - f"Illegal start/end values for subgroup, {int(start)}>={int(stop)}" - ) - - return Subgroup(self, start, stop) + start, stop, indices = to_start_stop_or_index(item, self, level=1) + return Subgroup(self, start, stop, indices) def before_run(self, run_namespace=None): rates_var = self.variables["rates"] diff --git a/brian2/input/timedarray.py b/brian2/input/timedarray.py index 03cf473de..f1e03da84 100644 --- a/brian2/input/timedarray.py +++ b/brian2/input/timedarray.py @@ -4,6 +4,7 @@ import numpy as np +from brian2.codegen.generators import c_data_type from brian2.core.clocks import defaultclock from brian2.core.functions import Function from brian2.core.names import Nameable @@ -44,7 +45,7 @@ def cpp_impl(owner): K = _find_K(owner.clock.dt_, dt) code = ( """ - static inline double %NAME%(const double t) + static inline %TYPE% %NAME%(const double t) { const double epsilon = %DT% / %K%; int i = (int)((t/epsilon + 0.5)/%K%); @@ -60,6 +61,7 @@ def cpp_impl(owner): .replace("%DT%", f"{dt:.18f}") .replace("%K%", str(K)) .replace("%NUM_VALUES%", str(len(values))) + .replace("%TYPE%", c_data_type(values.dtype)) ) return code @@ -71,7 +73,7 @@ def _generate_cpp_code_2d(values, dt, name): def cpp_impl(owner): K = _find_K(owner.clock.dt_, dt) support_code = """ - static inline double %NAME%(const double t, const int i) + static inline %TYPE% %NAME%(const double t, const int i) { const double epsilon = %DT% / %K%; if (i < 0 || i >= %COLS%) @@ -92,6 +94,7 @@ def cpp_impl(owner): "%K%": str(K), "%COLS%": str(values.shape[1]), "%ROWS%": str(values.shape[0]), + "%TYPE%": c_data_type(values.dtype), }, ) return code @@ -104,7 +107,7 @@ def cython_impl(owner): K = _find_K(owner.clock.dt_, dt) code = ( """ - cdef double %NAME%(const double t): + cdef %TYPE% %NAME%(const double t): global _namespace%NAME%_values cdef double epsilon = %DT% / %K% cdef int i = (int)((t/epsilon + 0.5)/%K%) @@ -119,6 +122,7 @@ def cython_impl(owner): .replace("%DT%", f"{dt:.18f}") .replace("%K%", str(K)) .replace("%NUM_VALUES%", str(len(values))) + .replace("%TYPE%", c_data_type(values.dtype)) ) return code @@ -130,7 +134,7 @@ def _generate_cython_code_2d(values, dt, name): def cython_impl(owner): K = _find_K(owner.clock.dt_, dt) code = """ - cdef double %NAME%(const double t, const int i): + cdef %TYPE% %NAME%(const double t, const int i): global _namespace%NAME%_values cdef double epsilon = %DT% / %K% if i < 0 or i >= %COLS%: @@ -150,6 +154,7 @@ def cython_impl(owner): "%K%": str(K), "%COLS%": str(values.shape[1]), "%ROWS%": str(values.shape[0]), + "%TYPE%": c_data_type(values.dtype), }, ) return code @@ -227,12 +232,21 @@ class TimedArray(Function, Nameable, CacheKey): @check_units(dt=second) def __init__(self, values, dt, name=None): + from brian2.core.preferences import prefs + if name is None: name = "_timedarray*" Nameable.__init__(self, name) dimensions = get_dimensions(values) self.dim = dimensions - values = np.asarray(values, dtype=np.float64) + values = np.asarray(values) # infer dtype + if values.dtype == object: + raise TypeError("TimedArray does not support arrays with dtype 'object'") + elif ( + values.dtype == np.float64 and prefs.core.default_float_dtype != np.float64 + ): + # Reduce the precision of the values array to the default scalar type + values = values.astype(prefs.core.default_float_dtype) self.values = values dt = float(dt) self.dt = dt @@ -337,7 +351,9 @@ def unitless_timed_array_func(t, i): self.implementations.add_dynamic_implementation( "numpy", create_numpy_implementation ) - values_flat = self.values.astype(np.double, order="C", copy=False).ravel() + values_flat = self.values.astype( + self.values.dtype, order="C", copy=False + ).ravel() namespace = lambda owner: {f"{self.name}_values": values_flat} for target, (_, func_2d) in TimedArray.implementations.items(): diff --git a/brian2/monitors/ratemonitor.py b/brian2/monitors/ratemonitor.py index 3a0a74576..d52cc2731 100644 --- a/brian2/monitors/ratemonitor.py +++ b/brian2/monitors/ratemonitor.py @@ -4,8 +4,10 @@ import numpy as np +from brian2.core.names import Nameable from brian2.core.variables import Variables from brian2.groups.group import CodeRunner, Group +from brian2.groups.subgroup import Subgroup from brian2.units.allunits import hertz, second from brian2.units.fundamentalunits import Quantity, check_units from brian2.utils.logger import get_logger @@ -46,10 +48,32 @@ class PopulationRateMonitor(Group, CodeRunner): def __init__( self, source, name="ratemonitor*", codeobj_class=None, dtype=np.float64 ): + Nameable.__init__(self, name=name) #: The group we are recording from self.source = source self.codeobj_class = codeobj_class + + self.variables = Variables(self) + + # Handle subgroups correctly + subgroup = isinstance(source, Subgroup) + contiguous = not subgroup or source.contiguous + self.variables.add_arange("_source_idx", size=len(source)) + needed_variables = {} + if subgroup: + if contiguous: + self.variables.add_constant("_source_start", source.start) + self.variables.add_constant("_source_stop", source.stop) + needed_variables = {"_source_start", "_source_stop"} + else: + self.variables.add_reference( + "_source_indices", source, "_sub_idx", index="_source_idx" + ) + needed_variables = {"_source_indices"} + self.variables.add_dynamic_array( + "rate", size=0, dimensions=hertz.dim, read_only=True, dtype=dtype + ) CodeRunner.__init__( self, group=self, @@ -59,20 +83,14 @@ def __init__( when="end", order=0, name=name, + needed_variables=needed_variables, + template_kwds={ + "subgroup": subgroup, + "contiguous": contiguous, + "source_N": source.N, + }, ) - - self.add_dependency(source) - - self.variables = Variables(self) - # Handle subgroups correctly - start = getattr(source, "start", 0) - stop = getattr(source, "stop", len(source)) - self.variables.add_constant("_source_start", start) - self.variables.add_constant("_source_stop", stop) - self.variables.add_reference("_spikespace", source) - self.variables.add_dynamic_array( - "rate", size=0, dimensions=hertz.dim, read_only=True, dtype=dtype - ) + self.variables.create_clock_variables(self._clock, prefix="_clock_") self.variables.add_dynamic_array( "t", size=0, @@ -80,11 +98,12 @@ def __init__( read_only=True, dtype=self._clock.variables["t"].dtype, ) - self.variables.add_reference("_num_source_neurons", source, "N") self.variables.add_array( "N", dtype=np.int32, size=1, scalar=True, read_only=True ) - self.variables.create_clock_variables(self._clock, prefix="_clock_") + self.variables.add_reference("_spikespace", source) + + self.add_dependency(source) self._enable_group_attributes() def resize(self, new_size): diff --git a/brian2/monitors/spikemonitor.py b/brian2/monitors/spikemonitor.py index 5c899d4d5..0c5f0a96e 100644 --- a/brian2/monitors/spikemonitor.py +++ b/brian2/monitors/spikemonitor.py @@ -8,6 +8,7 @@ from brian2.core.spikesource import SpikeSource from brian2.core.variables import Variables from brian2.groups.group import CodeRunner, Group +from brian2.groups.subgroup import Subgroup from brian2.units.fundamentalunits import Quantity __all__ = ["EventMonitor", "SpikeMonitor"] @@ -161,6 +162,9 @@ def __init__( dtype=source_var.dtype, read_only=True, ) + needed_variables = {eventspace_name} | self.record_variables + subgroup = isinstance(source, Subgroup) + contiguous = not subgroup or source.contiguous self.variables.add_arange("_source_idx", size=len(source)) self.variables.add_array( "count", @@ -169,12 +173,20 @@ def __init__( read_only=True, index="_source_idx", ) - self.variables.add_constant("_source_start", start) - self.variables.add_constant("_source_stop", stop) - self.variables.add_constant("_source_N", source_N) self.variables.add_array( "N", size=1, dtype=np.int32, read_only=True, scalar=True ) + if subgroup: + if contiguous: + self.variables.add_constant("_source_start", start) + self.variables.add_constant("_source_stop", stop) + self.variables.add_constant("_source_N", source_N) + needed_variables |= {"_source_start", "_source_stop", "_source_N"} + else: + self.variables.add_reference( + "_source_indices", source, "_sub_idx", index="_source_idx" + ) + needed_variables |= {"_source_indices"} record_variables = { varname: self.variables[varname] for varname in self.record_variables @@ -183,8 +195,10 @@ def __init__( "eventspace_variable": source.variables[eventspace_name], "record_variables": record_variables, "record": self.record, + "subgroup": subgroup, + "contiguous": contiguous, + "source_N": source.N, } - needed_variables = {eventspace_name} | self.record_variables CodeRunner.__init__( self, group=self, diff --git a/brian2/spatialneuron/spatialneuron.py b/brian2/spatialneuron/spatialneuron.py index f23840629..855b81744 100644 --- a/brian2/spatialneuron/spatialneuron.py +++ b/brian2/spatialneuron/spatialneuron.py @@ -21,7 +21,11 @@ extract_constant_subexpressions, ) from brian2.groups.group import CodeRunner, Group -from brian2.groups.neurongroup import NeuronGroup, SubexpressionUpdater, to_start_stop +from brian2.groups.neurongroup import ( + NeuronGroup, + SubexpressionUpdater, + to_start_stop_or_index, +) from brian2.groups.subgroup import Subgroup from brian2.parsing.sympytools import str_to_sympy, sympy_to_str from brian2.units.allunits import amp, meter, ohm, siemens, volt @@ -527,9 +531,11 @@ def spatialneuron_attribute(neuron, name): if name == "main": # Main section, without the subtrees indices = neuron.morphology.indices[:] start, stop = indices[0], indices[-1] - return SpatialSubgroup( - neuron, start, stop + 1, morphology=neuron.morphology - ) + morpho = neuron.morphology + if isinstance(neuron, SpatialSubgroup): + # For subtrees, make the new Subgroup a child of the original neuron + neuron = neuron.source + return SpatialSubgroup(neuron, start, stop + 1, morphology=morpho) elif (name != "morphology") and ( (name in getattr(neuron.morphology, "children", [])) or all([c in "LR123456789" for c in name]) @@ -537,6 +543,9 @@ def spatialneuron_attribute(neuron, name): morpho = neuron.morphology[name] start = morpho.indices[0] stop = SpatialNeuron._find_subtree_end(morpho) + if isinstance(neuron, SpatialSubgroup): + neuron = neuron.source + return SpatialSubgroup(neuron, start, stop + 1, morphology=morpho) else: return Group.__getattr__(neuron, name) @@ -561,25 +570,19 @@ def spatialneuron_segment(neuron, item): "Start and stop should have units of meter", start, stop ) # Convert to integers (compartment numbers) - indices = neuron.morphology.indices[item] - start, stop = indices[0], indices[-1] + 1 + compartment_indices = neuron.morphology.indices[item] + start, stop = compartment_indices[0], compartment_indices[-1] + 1 + indices = None elif not isinstance(item, slice) and hasattr(item, "indices"): - start, stop = to_start_stop(item.indices[:], neuron._N) + start, stop, indices = to_start_stop_or_index(item.indices[:], neuron) else: - start, stop = to_start_stop(item, neuron._N) - if isinstance(neuron, SpatialSubgroup): - start += neuron.start - stop += neuron.start + start, stop, indices = to_start_stop_or_index(item, neuron) - if start >= stop: - raise IndexError( - f"Illegal start/end values for subgroup, {int(start)}>={int(stop)}" - ) if isinstance(neuron, SpatialSubgroup): - # Note that the start/stop values calculated above are always - # absolute values, even for subgroups + # For subtrees, make the new Subgroup a child of the original neuron neuron = neuron.source - return Subgroup(neuron, start, stop) + + return Subgroup(neuron, start, stop, indices) class SpatialSubgroup(Subgroup): @@ -601,11 +604,7 @@ class SpatialSubgroup(Subgroup): def __init__(self, source, start, stop, morphology, name=None): self.morphology = morphology - if isinstance(source, SpatialSubgroup): - source = source.source - start += source.start - stop += source.start - Subgroup.__init__(self, source, start, stop, name) + Subgroup.__init__(self, source, start, stop, name=name) def __getattr__(self, name): return SpatialNeuron.spatialneuron_attribute(self, name) diff --git a/brian2/synapses/synapses.py b/brian2/synapses/synapses.py index 0ff46b8bb..030178a6f 100644 --- a/brian2/synapses/synapses.py +++ b/brian2/synapses/synapses.py @@ -26,12 +26,13 @@ Equations, check_subexpressions, ) -from brian2.groups.group import CodeRunner, Group, get_dtype +from brian2.groups.group import CodeRunner, Group, Indexing, get_dtype from brian2.groups.neurongroup import ( SubexpressionUpdater, check_identifier_pre_post, extract_constant_subexpressions, ) +from brian2.groups.subgroup import Subgroup from brian2.parsing.bast import brian_ast from brian2.parsing.expressions import ( is_boolean_expression, @@ -122,7 +123,7 @@ class SummedVariableUpdater(CodeRunner): def __init__( self, expression, target_varname, synapses, target, target_size_name, index_var ): - # Handling sumped variables using the standard mechanisms is not + # Handling summed variables using the standard mechanisms is not # possible, we therefore also directly give the names of the arrays # to the template. @@ -139,14 +140,21 @@ def __init__( "_index_var": synapses.variables[index_var], "_target_start": getattr(target, "start", 0), "_target_stop": getattr(target, "stop", -1), + "_target_contiguous": True, } + needed_variables = [target_varname, target_size_name, index_var] + self.variables = Variables(synapses) + if not getattr(target, "contiguous", True): + self.variables.add_reference("_target_indices", target, "_sub_idx") + needed_variables.append("_target_indices") + template_kwds["_target_contiguous"] = False CodeRunner.__init__( self, group=synapses, template="summed_variable", code=code, - needed_variables=[target_varname, target_size_name, index_var], + needed_variables=needed_variables, # We want to update the summed variable before # the target group gets updated clock=target.clock, @@ -473,7 +481,7 @@ def slice_to_test(x): pass if isinstance(x, slice): - if isinstance(x, slice) and x == slice(None): + if x == slice(None): # No need for testing return lambda y: np.repeat(True, len(y)) start, stop, step = x.start, x.stop, x.step @@ -528,9 +536,9 @@ def find_synapses(index, synaptic_neuron): return synapses -class SynapticSubgroup: +class SynapticSubgroup(Group): """ - A simple subgroup of `Synapses` that can be used for indexing. + A subgroup of `Synapses` that can be used for indexing and for accessing variables. Parameters ---------- @@ -541,21 +549,41 @@ class SynapticSubgroup: when new synapses where added after creating this object. """ - def __init__(self, synapses, indices): + def __init__(self, synapses, indices, name=None): + indices = np.atleast_1d(indices) # Deal with scalar indices self.synapses = weakproxy_with_fallback(synapses) + self.source = weakproxy_with_fallback(synapses.source) + self.target = weakproxy_with_fallback(synapses.target) + self.multisynaptic_index = self.synapses.multisynaptic_index self._stored_indices = indices + self._N = len(indices) self._synaptic_pre = synapses.variables["_synaptic_pre"] self._source_N = self._synaptic_pre.size # total number of synapses - def _indices(self, index_var="_idx"): - if index_var != "_idx": - raise AssertionError(f"Did not expect index {index_var} here.") - if len(self._synaptic_pre.get_value()) != self._source_N: - raise RuntimeError( - "Synapses have been added/removed since this " - "synaptic subgroup has been created" - ) - return self._stored_indices + if name is None: + name = f"{self.synapses.name}_subgroup*" + + Group.__init__( + self, + name=name, + ) + self.variables = Variables(self, default_index="_sub_idx") + self.variables.add_constant("N", self._N) + self.variables.add_references(synapses, list(synapses.variables.keys())) + self.variables.add_array( + "_sub_idx", + size=self._N, + dtype=np.int32, + values=indices, + index="_idx", + constant=True, + read_only=True, + unique=True, + ) + + self._indices = SynapticIndexing(self, self.variables["_sub_idx"]) + + self._enable_group_attributes() def __len__(self): return len(self._stored_indices) @@ -567,11 +595,11 @@ def __repr__(self): ) -class SynapticIndexing: - def __init__(self, synapses): - self.synapses = weakref.proxy(synapses) - self.source = weakproxy_with_fallback(self.synapses.source) - self.target = weakproxy_with_fallback(self.synapses.target) +class SynapticIndexing(Indexing): + def __init__(self, synapses, default_idx="_idx"): + super().__init__(synapses, default_idx) + self.source = weakproxy_with_fallback(synapses.source) + self.target = weakproxy_with_fallback(synapses.target) self.synaptic_pre = synapses.variables["_synaptic_pre"] self.synaptic_post = synapses.variables["_synaptic_post"] if synapses.multisynaptic_index is not None: @@ -579,90 +607,101 @@ def __init__(self, synapses): else: self.synapse_number = None - def __call__(self, index=None, index_var="_idx"): + def __call__(self, item=slice(None), index_var=None): # noqa: B008 """ Returns synaptic indices for `index`, which can be a tuple of indices (including arrays and slices), a single index or a string. """ - if index is None or (isinstance(index, str) and index == "True"): - index = slice(None) - - if not isinstance(index, (tuple, str)) and ( - isinstance(index, (numbers.Integral, np.ndarray, slice, Sequence)) - or hasattr(index, "_indices") - ): - if hasattr(index, "_indices"): - final_indices = index._indices(index_var=index_var).astype(np.int32) - elif isinstance(index, slice): - start, stop, step = index.indices(len(self.synaptic_pre.get_value())) - final_indices = np.arange(start, stop, step, dtype=np.int32) - else: - final_indices = np.asarray(index) - elif isinstance(index, tuple): - if len(index) == 2: # two indices (pre- and postsynaptic cell) - index = (index[0], index[1], slice(None)) - elif len(index) > 3: - raise IndexError(f"Need 1, 2 or 3 indices, got {len(index)}.") - - i_indices, j_indices, k_indices = index - # Convert to absolute indices (e.g. for subgroups) - # Allow the indexing to fail, we'll later return an empty array in - # that case - try: - if hasattr( - i_indices, "_indices" - ): # will return absolute indices already - i_indices = i_indices._indices() - else: - i_indices = self.source._indices(i_indices) - pre_synapses = find_synapses(i_indices, self.synaptic_pre.get_value()) - except IndexError: - pre_synapses = np.array([], dtype=np.int32) - try: - if hasattr(j_indices, "_indices"): - j_indices = j_indices._indices() - else: - j_indices = self.target._indices(j_indices) - post_synapses = find_synapses(j_indices, self.synaptic_post.get_value()) - except IndexError: - post_synapses = np.array([], dtype=np.int32) + if index_var is None: + index_var = self.default_index + + if not isinstance(item, tuple): + # 1d indexing = synaptic indices + if hasattr(item, "_indices"): + item = item._indices() + final_indices = self._to_index_array(item, index_var) + if final_indices.size > 0: + if np.min(final_indices) < 0: + raise IndexError("Negative indices are not allowed.") + try: + N = self.N.get_value() + if np.max(final_indices) >= N: + raise IndexError( + f"Index {np.max(final_indices)} is out of bounds for " + f"{N} synapses." + ) + except NotImplementedError: + logger.warn("Cannot check synaptic indices for correctness") + else: + # 2d or 3d indexing = pre-/post-synaptic indices and (optionally) synapse number + if len(item) == 2: # two indices (pre- and postsynaptic cell) + item = (item[0], item[1], slice(None)) + elif len(item) > 3: + raise IndexError(f"Need 1, 2 or 3 indices, got {len(item)}.") + + i_indices, j_indices, k_indices = item + + if k_indices != slice(None) and self.synapse_number is None: + raise IndexError( + "To index by the third dimension you need " + "to switch on the calculation of the " + "'multisynaptic_index' when you create " + "the Synapses object." + ) - matching_synapses = np.intersect1d( - pre_synapses, post_synapses, assume_unique=True + final_indices = self._from_3d_indices_to_index_array( + i_indices, j_indices, k_indices ) - if isinstance(k_indices, slice) and k_indices == slice(None): - final_indices = matching_synapses - else: - if self.synapse_number is None: - raise IndexError( - "To index by the third dimension you need " - "to switch on the calculation of the " - "'multisynaptic_index' when you create " - "the Synapses object." - ) - if isinstance(k_indices, (numbers.Integral, slice)): - test_k = slice_to_test(k_indices) + if index_var not in ("_idx", "0"): + try: + return index_var.get_value()[final_indices.astype(np.int32)] + except IndexError as ex: + # We try to emulate numpy's indexing semantics here: + # slices never lead to IndexErrors, instead they return an + # empty array if they don't match anything + if isinstance(item, slice): + return np.array([], dtype=np.int32) else: - raise NotImplementedError( - "Indexing synapses with arrays notimplemented yet" - ) + raise ex + else: + return final_indices.astype(np.int32) - # We want to access the raw arrays here, not go through the Variable - synapse_numbers = self.synapse_number.get_value()[matching_synapses] - final_indices = np.intersect1d( - matching_synapses, - np.flatnonzero(test_k(synapse_numbers)), - assume_unique=True, - ) + def _from_3d_indices_to_index_array(self, i_indices, j_indices, k_indices): + # Convert to absolute indices (e.g. for subgroups) + if hasattr(i_indices, "_indices"): # will return absolute indices already + i_indices = i_indices._indices() else: - raise IndexError(f"Unsupported index type {type(index)}") + i_indices = self.source._indices(i_indices) + pre_synapses = find_synapses(i_indices, self.synaptic_pre.get_value()) - if index_var not in ("_idx", "0"): - return index_var.get_value()[final_indices.astype(np.int32)] + if hasattr(j_indices, "_indices"): + j_indices = j_indices._indices() else: - return final_indices.astype(np.int32) + j_indices = self.target._indices(j_indices) + post_synapses = find_synapses(j_indices, self.synaptic_post.get_value()) + + matching_synapses = np.intersect1d( + pre_synapses, post_synapses, assume_unique=True + ) + + if k_indices == slice(None): + final_indices = matching_synapses + else: + # We want to access the raw arrays here, not go through the Variable + synapse_numbers = self.synapse_number.get_value()[matching_synapses] + if isinstance(k_indices, (numbers.Integral, slice)): + test_k = slice_to_test(k_indices) + final_indices = matching_synapses[test_k(synapse_numbers)] + else: + k_indices = np.asarray(k_indices) + if k_indices.dtype == bool: + raise NotImplementedError( + "Boolean indexing not supported for synapse number" + ) + final_indices = matching_synapses[np.in1d(synapse_numbers, k_indices)] + return final_indices class Synapses(Group): @@ -1299,10 +1338,14 @@ def _create_variables(self, equations, user_dtype=None): self.variables.add_reference("_target_offset", self.target, "_offset") else: self.variables.add_constant("_target_offset", value=0) + if "_sub_idx" in self.source.variables: + self.variables.add_reference("_source_sub_idx", self.source, "_sub_idx") if "_offset" in self.source.variables: self.variables.add_reference("_source_offset", self.source, "_offset") else: self.variables.add_constant("_source_offset", value=0) + if "_sub_idx" in self.target.variables: + self.variables.add_reference("_target_sub_idx", self.target, "_sub_idx") # To cope with connections to/from other synapses, N_incoming/N_outgoing # will be resized when synapses are created self.variables.add_dynamic_array( @@ -1331,35 +1374,55 @@ def _create_variables(self, equations, user_dtype=None): self.variables.add_reference("_presynaptic_idx", self, "_synaptic_pre") self.variables.add_reference("_postsynaptic_idx", self, "_synaptic_post") - # Except for subgroups (which potentially add an offset), the "i" and - # "j" variables are simply equivalent to `_synaptic_pre` and - # `_synaptic_post` - if getattr(self.source, "start", 0) == 0: - self.variables.add_reference("i", self, "_synaptic_pre") - else: + # Except for subgroups, the "i" and "j" variables are simply equivalent to + # `_synaptic_pre` and `_synaptic_post` + if ( + isinstance(self.source, Subgroup) + and not getattr(self.source, "start", -1) == 0 + ): self.variables.add_reference( "_source_i", self.source.source, "i", index="_presynaptic_idx" ) - self.variables.add_reference("_source_offset", self.source, "_offset") - self.variables.add_subexpression( - "i", - dtype=self.source.source.variables["i"].dtype, - expr="_source_i - _source_offset", - index="_presynaptic_idx", - ) - if getattr(self.target, "start", 0) == 0: - self.variables.add_reference("j", self, "_synaptic_post") + if getattr(self.source, "contiguous", True): + # Contiguous subgroup simply shift the indices + self.variables.add_subexpression( + "i", + dtype=self.source.source.variables["i"].dtype, + expr="_source_i - _source_offset", + index="_presynaptic_idx", + ) + else: + # Non-contiguous subgroups need a full translation + self.variables.add_reference( + "i", self.source, "i", index="_presynaptic_idx" + ) else: + # No subgroup or subgroup starting at 0 + self.variables.add_reference("i", self, "_synaptic_pre") + + if ( + isinstance(self.target, Subgroup) + and not getattr(self.target, "start", -1) == 0 + ): self.variables.add_reference( "_target_j", self.target.source, "i", index="_postsynaptic_idx" ) - self.variables.add_reference("_target_offset", self.target, "_offset") - self.variables.add_subexpression( - "j", - dtype=self.target.source.variables["i"].dtype, - expr="_target_j - _target_offset", - index="_postsynaptic_idx", - ) + if getattr(self.target, "contiguous", True): + # Contiguous subgroup simply shift the indices + self.variables.add_subexpression( + "j", + dtype=self.target.source.variables["i"].dtype, + expr="_target_j - _target_offset", + index="_postsynaptic_idx", + ) + else: + # Non-contiguous subgroups need a full translation + self.variables.add_reference( + "j", self.target, "i", index="_postsynaptic_idx" + ) + else: + # No subgroup or subgroup starting at 0 + self.variables.add_reference("j", self, "_synaptic_post") # Add the standard variables self.variables.add_array( @@ -1770,13 +1833,19 @@ def _resize(self, number): self.variables["N"].set_value(number) def _update_synapse_numbers(self, old_num_synapses): - source_offset = self.variables["_source_offset"].get_value() - target_offset = self.variables["_target_offset"].get_value() - # This resizing is only necessary if we are connecting to/from synapses - post_with_offset = self.variables["N_post"].item() + target_offset - pre_with_offset = self.variables["N_pre"].item() + source_offset + if "_source_sub_idx" in self.variables: + pre_with_offset = self.variables["_source_sub_idx"].get_value()[-1] + 1 + else: + source_offset = self.variables["_source_offset"].get_value() + pre_with_offset = self.variables["N_pre"].item() + source_offset + if "_target_sub_idx" in self.variables: + post_with_offset = self.variables["_target_sub_idx"].get_value()[-1] + 1 + else: + target_offset = self.variables["_target_offset"].get_value() + post_with_offset = self.variables["N_post"].item() + target_offset self.variables["N_incoming"].resize(post_with_offset) self.variables["N_outgoing"].resize(pre_with_offset) + N_outgoing = self.variables["N_outgoing"].get_value() N_incoming = self.variables["N_incoming"].get_value() synaptic_pre = self.variables["_synaptic_pre"].get_value() @@ -1891,11 +1960,15 @@ def _add_synapses_from_arrays(self, sources, targets, n, p, namespace=None): if "_offset" in self.source.variables: variables.add_reference("_source_offset", self.source, "_offset") abstract_code += "_real_sources = sources + _source_offset\n" + elif not getattr(self.source, "contiguous", True): + abstract_code += "_real_sources = _source_sub_idx\n" else: abstract_code += "_real_sources = sources\n" if "_offset" in self.target.variables: variables.add_reference("_target_offset", self.target, "_offset") abstract_code += "_real_targets = targets + _target_offset\n" + elif not getattr(self.target, "contiguous", True): + abstract_code += "_real_targets = _target_sub_idx\n" else: abstract_code += "_real_targets = targets" logger.debug( @@ -1979,11 +2052,25 @@ def _add_synapses_generator( outer_index_size = "N_pre" if over_presynaptic else "N_post" outer_index_array = "_pre_idx" if over_presynaptic else "_post_idx" outer_index_offset = "_source_offset" if over_presynaptic else "_target_offset" + outer_sub_idx = "_source_sub_idx" if over_presynaptic else "_target_sub_idx" + outer_contiguous = ( + getattr(self.source, "contiguous", True) + if over_presynaptic + else getattr(self.target, "contiguous", True) + ) + result_index = "j" if over_presynaptic else "i" result_index_size = "N_post" if over_presynaptic else "N_pre" target_idx = "_postsynaptic_idx" if over_presynaptic else "_presynaptic_idx" result_index_array = "_post_idx" if over_presynaptic else "_pre_idx" result_index_offset = "_target_offset" if over_presynaptic else "_source_offset" + result_sub_idx = "_target_sub_idx" if over_presynaptic else "_source_sub_idx" + result_contiguous = ( + getattr(self.target, "contiguous", True) + if over_presynaptic + else getattr(self.source, "contiguous", True) + ) + result_index_name = "postsynaptic" if over_presynaptic else "presynaptic" template_kwds.update( { @@ -1991,11 +2078,15 @@ def _add_synapses_generator( "outer_index_size": outer_index_size, "outer_index_array": outer_index_array, "outer_index_offset": outer_index_offset, + "outer_sub_idx": outer_sub_idx, + "outer_contiguous": outer_contiguous, "result_index": result_index, "result_index_size": result_index_size, "result_index_name": result_index_name, "result_index_array": result_index_array, "result_index_offset": result_index_offset, + "result_sub_idx": result_sub_idx, + "result_contiguous": result_contiguous, } ) abstract_code = { @@ -2083,11 +2174,17 @@ def _add_synapses_generator( else: variables.add_constant("_source_offset", value=0) + if not getattr(self.source, "contiguous", True): + needed_variables.append("_source_sub_idx") + if "_offset" in self.target.variables: variables.add_reference("_target_offset", self.target, "_offset") else: variables.add_constant("_target_offset", value=0) + if not getattr(self.target, "contiguous", True): + needed_variables.append("_target_sub_idx") + variables.add_auxiliary_variable("_raw_pre_idx", dtype=np.int32) variables.add_auxiliary_variable("_raw_post_idx", dtype=np.int32) diff --git a/brian2/tests/test_cpp_standalone.py b/brian2/tests/test_cpp_standalone.py index 7a23c3c26..4130521b2 100644 --- a/brian2/tests/test_cpp_standalone.py +++ b/brian2/tests/test_cpp_standalone.py @@ -845,6 +845,7 @@ def test_change_parameter_without_recompile_dict_syntax(): G.n: ar, G.b: np.arange(10) % 2 != 0, stim: ar2, + on_off: [False, True, False], } ) assert array_equal(G.x, ar) @@ -855,9 +856,9 @@ def test_change_parameter_without_recompile_dict_syntax(): mon.s.T / nA, np.array( [ - [0, 2, 4, 6, 8, 10, 12, 14, 16, 18], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # on_off(t) == False - [40, 42, 44, 46, 48, 50, 52, 54, 56, 58], + [20, 22, 24, 26, 28, 30, 32, 34, 36, 38], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # on_off(t) == False ] ), ) diff --git a/brian2/tests/test_monitor.py b/brian2/tests/test_monitor.py index efb2259b4..3e7c18a0d 100644 --- a/brian2/tests/test_monitor.py +++ b/brian2/tests/test_monitor.py @@ -16,10 +16,8 @@ def test_spike_monitor(): G_without_threshold = NeuronGroup(5, "x : 1") G = NeuronGroup( 3, - """ - dv/dt = rate : 1 - rate: Hz - """, + """dv/dt = rate : 1 + rate: Hz""", threshold="v>1", reset="v=0", ) @@ -91,11 +89,9 @@ def test_spike_monitor_indexing(): def test_spike_monitor_variables(): G = NeuronGroup( 3, - """ - dv/dt = rate : 1 - rate : Hz - prev_spikes : integer - """, + """dv/dt = rate : 1 + rate : Hz + prev_spikes : integer""", threshold="v>1", reset="v=0; prev_spikes += 1", ) @@ -127,8 +123,8 @@ def test_spike_monitor_get_states(): G = NeuronGroup( 3, """dv/dt = rate : 1 - rate : Hz - prev_spikes : integer""", + rate : Hz + prev_spikes : integer""", threshold="v>1", reset="v=0; prev_spikes += 1", ) @@ -155,7 +151,11 @@ def test_spike_monitor_subgroups(): spikes_1 = SpikeMonitor(G[:2]) spikes_2 = SpikeMonitor(G[2:4]) spikes_3 = SpikeMonitor(G[4:]) + spikes_indexed = SpikeMonitor(G[::2]) + with pytest.raises(IndexError): + SpikeMonitor(G[[4, 0, 2]]) # unsorted run(defaultclock.dt) + # Spikes assert_allclose(spikes_all.i, [0, 4, 5]) assert_allclose(spikes_all.t, [0, 0, 0] * ms) assert_allclose(spikes_1.i, [0]) @@ -164,6 +164,14 @@ def test_spike_monitor_subgroups(): assert len(spikes_2.t) == 0 assert_allclose(spikes_3.i, [0, 1]) # recorded spike indices are relative assert_allclose(spikes_3.t, [0, 0] * ms) + assert_allclose(spikes_indexed.i, [0, 2]) + assert_allclose(spikes_indexed.t, [0, 0] * ms) + # Spike count + assert_allclose(spikes_all.count, [1, 0, 0, 0, 1, 1]) + assert_allclose(spikes_1.count, [1, 0]) + assert_allclose(spikes_2.count, [0, 0]) + assert_allclose(spikes_3.count, [1, 1]) + assert_allclose(spikes_indexed.count, [1, 0, 1]) def test_spike_monitor_bug_824(): @@ -184,10 +192,8 @@ def test_spike_monitor_bug_824(): def test_event_monitor(): G = NeuronGroup( 3, - """ - dv/dt = rate : 1 - rate: Hz - """, + """dv/dt = rate : 1 + rate: Hz""", events={"my_event": "v>1"}, ) G.run_on_event("my_event", "v=0") @@ -233,10 +239,8 @@ def test_event_monitor_no_record(): # Check that you can switch off recording spike times/indices G = NeuronGroup( 3, - """ - dv/dt = rate : 1 - rate: Hz - """, + """dv/dt = rate : 1 + rate: Hz""", events={"my_event": "v>1"}, threshold="v>1", reset="v=0", @@ -316,11 +320,9 @@ def test_state_monitor(): # Check that all kinds of variables can be recorded G = NeuronGroup( 2, - """ - dv/dt = -v / (10*ms) : 1 - f = clip(v, 0.1, 0.9) : 1 - rate: Hz - """, + """dv/dt = -v / (10*ms) : 1 + f = clip(v, 0.1, 0.9) : 1 + rate: Hz""", threshold="v>1", reset="v=0", refractory=2 * ms, @@ -391,6 +393,43 @@ def test_state_monitor(): ) +@pytest.mark.standalone_compatible +def test_state_monitor_subgroups(): + G = NeuronGroup( + 10, + """v : volt + step_size : volt (constant)""", + ) + G.run_regularly("v += step_size") + G.step_size = "i*mV" + + SG1 = G[3:8] + SG2 = G[::2] + + state_mon_full = StateMonitor(G, "v", record=True) + + # monitor subgroups and record from all + state_mon1 = StateMonitor(SG1, "v", record=True) + state_mon2 = StateMonitor(SG2, "v", record=True) + + # monitor subgroup and use (relative) indices + state_mon3 = StateMonitor(SG1, "v", record=[0, 3]) + state_mon4 = StateMonitor(SG2, "v", record=[0, 3]) + + # monitor full group and use subgroup as indices + state_mon5 = StateMonitor(G, "v", record=SG1) + state_mon6 = StateMonitor(G, "v", record=SG2) + + run(2 * defaultclock.dt) + + assert_allclose(state_mon1.v, state_mon_full.v[3:8]) + assert_allclose(state_mon2.v, state_mon_full.v[::2]) + assert_allclose(state_mon3.v, state_mon_full.v[3:8][[0, 3]]) + assert_allclose(state_mon4.v, state_mon_full.v[::2][[0, 3]]) + assert_allclose(state_mon5.v, state_mon_full.v[3:8]) + assert_allclose(state_mon6.v, state_mon_full.v[::2]) + + @pytest.mark.standalone_compatible @pytest.mark.multiple_runs def test_state_monitor_record_single_timestep(): @@ -457,11 +496,9 @@ def test_state_monitor_indexing(): def test_state_monitor_get_states(): G = NeuronGroup( 2, - """ - dv/dt = -v / (10*ms) : 1 - f = clip(v, 0.1, 0.9) : 1 - rate: Hz - """, + """dv/dt = -v / (10*ms) : 1 + f = clip(v, 0.1, 0.9) : 1 + rate: Hz""", threshold="v>1", reset="v=0", refractory=2 * ms, @@ -650,10 +687,8 @@ def test_rate_monitor_subgroups(): defaultclock.dt = 0.01 * ms G = NeuronGroup( 4, - """ - dv/dt = rate : 1 - rate : Hz - """, + """dv/dt = rate : 1 + rate : Hz""", threshold="v>0.999", reset="v=0", ) @@ -677,11 +712,13 @@ def test_rate_monitor_subgroups_2(): rate_1 = PopulationRateMonitor(G[:2]) rate_2 = PopulationRateMonitor(G[2:4]) rate_3 = PopulationRateMonitor(G[4:]) + rate_indexed = PopulationRateMonitor(G[::2]) run(2 * defaultclock.dt) assert_allclose(rate_all.rate, 0.5 / defaultclock.dt) assert_allclose(rate_1.rate, 0.5 / defaultclock.dt) assert_allclose(rate_2.rate, 0 * Hz) assert_allclose(rate_3.rate, 1 / defaultclock.dt) + assert_allclose(rate_indexed.rate, 2 / 3 * (1 / defaultclock.dt)) # 2 out of 3 @pytest.mark.codegen_independent diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index 28b094f11..30bd82aea 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -1201,19 +1201,26 @@ def test_state_variable_access(): assert_allclose(np.asarray(G.v[:]), np.arange(10)) assert have_same_dimensions(G.v[:], volt) assert_allclose(np.asarray(G.v[:]), G.v_[:]) - # Accessing single elements, slices and arrays + # Accessing single elements, slices and integer and boolean arrays assert G.v[5] == 5 * volt assert G.v_[5] == 5 assert_allclose(G.v[:5], np.arange(5) * volt) assert_allclose(G.v_[:5], np.arange(5)) assert_allclose(G.v[[0, 5]], [0, 5] * volt) assert_allclose(G.v_[[0, 5]], np.array([0, 5])) + assert_allclose(G.v[[True, False] * 5], np.arange(10)[::2] * volt) # Illegal indexing with pytest.raises(IndexError): G.v[0, 0] with pytest.raises(IndexError): G.v_[0, 0] + with pytest.raises(IndexError): + G.v[[0, 10]] # out of range array indices + with pytest.raises(IndexError): + G.v[[True, False]] # too few boolean indices + with pytest.raises(IndexError): + G.v[[True] * 11] # too many boolean indices with pytest.raises(TypeError): G.v[object()] with pytest.raises(TypeError): diff --git a/brian2/tests/test_spatialneuron.py b/brian2/tests/test_spatialneuron.py index 67906ed73..605e2e7ad 100644 --- a/brian2/tests/test_spatialneuron.py +++ b/brian2/tests/test_spatialneuron.py @@ -625,42 +625,36 @@ def user_fun(voltage): allowed_eqs = [ "Im = gL*(EL-v) : amp/meter**2", + """Im = gl * (El-v) + gNa * m**3 * h * (ENa-v) : amp/meter**2 + dm/dt = alpham * (1-m) - betam * m : 1 + dh/dt = alphah * (1-h) - betah * h : 1 + alpham = (0.1/mV) * (-v+25*mV) / (exp((-v+25*mV) / (10*mV)) - 1)/ms : Hz + betam = 4 * exp(-v/(18*mV))/ms : Hz + alphah = 0.07 * exp(-v/(20*mV))/ms : Hz + betah = 1/(exp((-v+30*mV) / (10*mV)) + 1)/ms : Hz""", + """Im = gl * (El-v) : amp/meter**2 + I_ext = 1*nA + sin(2*pi*100*Hz*t)*nA : amp (point current)""", + """Im = I_leak + I_spike : amp/meter**2 + I_leak = gL*(EL - v) : amp/meter**2 + I_spike = gL*DeltaT*exp((v - VT)/DeltaT): amp/meter**2 (constant over dt) + """, """ - Im = gl * (El-v) + gNa * m**3 * h * (ENa-v) : amp/meter**2 - dm/dt = alpham * (1-m) - betam * m : 1 - dh/dt = alphah * (1-h) - betah * h : 1 - alpham = (0.1/mV) * (-v+25*mV) / (exp((-v+25*mV) / (10*mV)) - 1)/ms : Hz - betam = 4 * exp(-v/(18*mV))/ms : Hz - alphah = 0.07 * exp(-v/(20*mV))/ms : Hz - betah = 1/(exp((-v+30*mV) / (10*mV)) + 1)/ms : Hz - """, - """ - Im = gl * (El-v) : amp/meter**2 - I_ext = 1*nA + sin(2*pi*100*Hz*t)*nA : amp (point current) - """, - """ - Im = I_leak + I_spike : amp/meter**2 - I_leak = gL*(EL - v) : amp/meter**2 - I_spike = gL*DeltaT*exp((v - VT)/DeltaT): amp/meter**2 (constant over dt) - """, - """ - Im = gL*(EL-v) : amp/meter**2 - I_NMDA = gNMDA*(ENMDA-v)*Mgblock : amp (point current) - gNMDA : siemens - Mgblock = 1./(1. + exp(-0.062*v/mV)/3.57) : 1 (constant over dt) - """, + Im = gL*(EL-v) : amp/meter**2 + I_NMDA = gNMDA*(ENMDA-v)*Mgblock : amp (point current) + gNMDA : siemens + Mgblock = 1./(1. + exp(-0.062*v/mV)/3.57) : 1 (constant over dt) + """, "Im = gL*(EL - v) + gL*DeltaT*exp((v - VT)/DeltaT) : amp/meter**2", + """Im = I_leak + I_spike : amp/meter**2 + I_leak = gL*(EL - v) : amp/meter**2 + I_spike = gL*DeltaT*exp((v - VT)/DeltaT): amp/meter**2 + """, """ - Im = I_leak + I_spike : amp/meter**2 - I_leak = gL*(EL - v) : amp/meter**2 - I_spike = gL*DeltaT*exp((v - VT)/DeltaT): amp/meter**2 - """, - """ - Im = gL*(EL-v) : amp/meter**2 - I_NMDA = gNMDA*(ENMDA-v)*Mgblock : amp (point current) - gNMDA : siemens - Mgblock = 1./(1. + exp(-0.062*v/mV)/3.57) : 1 - """, + Im = gL*(EL-v) : amp/meter**2 + I_NMDA = gNMDA*(ENMDA-v)*Mgblock : amp (point current) + gNMDA : siemens + Mgblock = 1./(1. + exp(-0.062*v/mV)/3.57) : 1 + """, ] forbidden_eqs = [ """Im = gl * (El-v + user_fun(v)) : amp/meter**2""", @@ -718,6 +712,15 @@ def test_spatialneuron_indexing(): assert len(neuron[0:1].indices[:]) == 1 assert len(neuron[sec.sec2.indices[:]]) == 16 assert len(neuron[sec.sec2]) == 16 + assert len(neuron[:16:2].indices[:]) == 8 + assert len(neuron["i < 16 and i % 2 == 0"].indices[:]) == 8 + assert len(neuron[:8][4:].indices[:]) == 4 + assert len(neuron[:8][::2].indices[:]) == 4 + assert len(neuron[:8]["i % 2 == 0"].indices[:]) == 4 + assert len(neuron[::2][:8].indices[:]) == 8 + assert len(neuron["i % 2 == 0"][:8].indices[:]) == 8 + assert len(neuron["distance >= 100*um"].indices[:]) == 44 + assert len(neuron["distance >= 100*um"][:8].indices[:]) == 8 assert_equal(neuron.sec1.sec11.v, [3, 4, 5, 6] * volt) assert_equal(neuron.sec1.sec11[1].v, neuron.sec1.sec11.v[1]) assert_equal(neuron.sec1.sec11[1:3].v, neuron.sec1.sec11.v[1:3]) diff --git a/brian2/tests/test_subgroup.py b/brian2/tests/test_subgroup.py index c3e99664b..a141fa916 100644 --- a/brian2/tests/test_subgroup.py +++ b/brian2/tests/test_subgroup.py @@ -16,9 +16,51 @@ def test_str_repr(): """ G = NeuronGroup(10, "v:1") SG = G[5:8] + SGi = G[[3, 6, 9]] # very basic test, only make sure no error is raised assert len(str(SG)) assert len(repr(SG)) + assert len(str(SGi)) + assert len(repr(SGi)) + + +@pytest.mark.codegen_independent +def test_creation(): + G = NeuronGroup(10, "") + SG = G[5:8] + SGi = G[[3, 6, 9]] + SGi2 = G[[3, 4, 5]] + assert len(SG) == 3 + assert SG.contiguous + assert len(SGi) == 3 + assert not SGi.contiguous + assert len(SGi2) == 3 + assert SGi2.contiguous + + with pytest.raises(TypeError): + Subgroup(G) + with pytest.raises(TypeError): + Subgroup(G, start=3) + with pytest.raises(TypeError): + Subgroup(G, stop=3) + with pytest.raises(IndexError): + Subgroup(G, start=-1, stop=5) + with pytest.raises(IndexError): + Subgroup(G, start=1, stop=1) + with pytest.raises(IndexError): + Subgroup(G, start=1, stop=11) + with pytest.raises(TypeError): + Subgroup(G, start=1, stop=11, indices=[2, 4, 6]) + with pytest.raises(IndexError): + Subgroup(G, indices=[]) + with pytest.raises(IndexError): + Subgroup(G, indices=[1, 1, 2]) + with pytest.raises(IndexError): + Subgroup(G, indices=[-1, 1, 2]) + with pytest.raises(IndexError): + Subgroup(G, indices=[1, 2, 10]) + with pytest.raises(IndexError): + Subgroup(G, indices=[3, 2, 1]) def test_state_variables(): @@ -66,12 +108,11 @@ def test_state_variables(): def test_state_variables_simple(): G = NeuronGroup( 10, - """ - a : 1 - b : 1 - c : 1 - d : 1 - """, + """a : 1 + b : 1 + c : 1 + d : 1 + """, ) SG = G[3:7] SG.a = 1 @@ -88,6 +129,38 @@ def test_state_variables_simple(): assert_equal(G.d[:], [0, 0, 0, 0, 4, 1, 2, 0, 0, 0]) +@pytest.mark.standalone_compatible +def test_state_variables_simple_indexed(): + G = NeuronGroup( + 10, + """a : 1 + b : 1 + c : 1 + d : 1 + """, + ) + # Illegal indices: + with pytest.raises(IndexError): + G[[3, 5, 5, 7, 9]] # duplicate indices + with pytest.raises(IndexError): + G[[9, 7, 5, 3]] # unsorted + with pytest.raises(IndexError): + G[[8, 10]] # out of range + SG = G[[3, 5, 7, 9]] + SG.a = 1 + SG.a["i == 0"] = 2 + SG.b = "i" + SG.b["i == 3"] = "i * 2" + SG.c = np.arange(3, 7) + SG.d[1:2] = 4 + SG.d[2:4] = [1, 2] + run(0 * ms) + assert_equal(G.a[:], [0, 0, 0, 2, 0, 1, 0, 1, 0, 1]) + assert_equal(G.b[:], [0, 0, 0, 0, 0, 1, 0, 2, 0, 6]) + assert_equal(G.c[:], [0, 0, 0, 3, 0, 4, 0, 5, 0, 6]) + assert_equal(G.d[:], [0, 0, 0, 0, 0, 4, 0, 1, 0, 2]) + + def test_state_variables_string_indices(): """ Test accessing subgroups with string indices. @@ -133,6 +206,20 @@ def test_state_variables_group_as_index_problematic(): ) +@pytest.mark.standalone_compatible +def test_state_variables_string_group(): + G = NeuronGroup(10, "v : 1") + G.v = "i" + c = 3 + SG1 = G[ + "i > 5" + ] # indexing with constant expressions should even work in standalone + SG1.v = "v * 2" + run(0 * ms) # for standalone + assert_equal(G.v[:], [0, 1, 2, 3, 4, 5, 12, 14, 16, 18]) + assert_equal(SG1.v[:], [12, 14, 16, 18]) + + @pytest.mark.standalone_compatible def test_variableview_calculations(): # Check that you can directly calculate with "variable views" @@ -222,16 +309,43 @@ def test_state_monitor(): G = NeuronGroup(10, "v : volt") G.v = np.arange(10) * volt SG = G[5:] + SG_indices = G[[2, 4, 6]] mon_all = StateMonitor(SG, "v", record=True) mon_0 = StateMonitor(SG, "v", record=0) + mon_all_indices = StateMonitor(SG_indices, "v", record=True) + mon_0_indices = StateMonitor(SG_indices, "v", record=0) run(defaultclock.dt) assert_allclose(mon_0[0].v, mon_all[0].v) assert_allclose(mon_0[0].v, np.array([5]) * volt) assert_allclose(mon_all.v.flatten(), np.arange(5, 10) * volt) + assert_allclose(mon_0_indices[0].v, mon_all_indices[0].v) + assert_allclose(mon_0_indices[0].v, np.array([2]) * volt) + assert_allclose(mon_all_indices.v.flatten(), np.array([2, 4, 6]) * volt) + with pytest.raises(IndexError): mon_all[5] + with pytest.raises(IndexError): + mon_all_indices[3] + + +@pytest.mark.standalone_compatible +def test_rate_monitor(): + G = NeuronGroup(10, "", threshold="i%2 == 0") + SG = G[5:] + SG_indices = G[[2, 4, 6]] + SG_indices2 = G[[3, 5, 7]] + pop_mon = PopulationRateMonitor(SG) + pop_mon_indices = PopulationRateMonitor(SG_indices) + pop_mon_indices2 = PopulationRateMonitor(SG_indices2) + run(defaultclock.dt) + r = 1 / defaultclock.dt + assert_allclose( + pop_mon.rate[0], (2 * r + 3 * 0) / 5 + ) # 2 out of 3 neurons are firing + assert_allclose(pop_mon_indices.rate[0], r) # all neurons firing + assert_allclose(pop_mon_indices2.rate[0], 0 * Hz) # no neurons firing def test_shared_variable(): @@ -349,6 +463,54 @@ def test_synapse_creation_generator(): assert all(S5.v_pre[:] < 25) +@pytest.mark.standalone_compatible +def test_synapse_creation_generator_non_contiguous(): + G1 = NeuronGroup(10, "v:1") + G2 = NeuronGroup(20, "v:1") + G1.v = "i" + G2.v = "10 + i" + SG1 = G1[[0, 2, 4, 6, 8]] + SG2 = G2[1::2] + S = Synapses(SG1, SG2, "w:1") + S.connect(j="i*2 + k for k in range(2)") # diverging connections + + # connect based on pre-/postsynaptic state variables + S2 = Synapses(SG1, SG2, "w:1") + S2.connect(j="k for k in range(N_post) if v_pre > 2") + + S3 = Synapses(SG1, SG2, "w:1") + S3.connect(j="k for k in range(N_post) if v_post < 25") + + S4 = Synapses(SG2, SG1, "w:1") + S4.connect(j="k for k in range(N_post) if v_post > 2") + + S5 = Synapses(SG2, SG1, "w:1") + S5.connect(j="k for k in range(N_post) if v_pre < 25") + + run(0 * ms) # for standalone + + # Internally, the "real" neuron indices should be used + assert_equal(S._synaptic_pre[:], np.arange(0, 10, 2).repeat(2)) + assert_equal(S._synaptic_post[:], np.arange(1, 20, 2)) + # For the user, the subgroup-relative indices should be presented + assert_equal(S.i[:], np.arange(5).repeat(2)) + assert_equal(S.j[:], np.arange(10)) + + # N_incoming and N_outgoing should also be correct + assert all(S.N_outgoing[:] == 2) + assert all(S.N_incoming[:] == 1) + + assert len(S2) == 3 * len(SG2), str(len(S2)) + assert all(S2.v_pre[:] > 2) + assert len(S3) == 7 * len(SG1), f"{len(S3)} != {7 * len(SG1)} " + assert all(S3.v_post[:] < 25) + + assert len(S4) == 3 * len(SG2), str(len(S4)) + assert all(S4.v_post[:] > 2) + assert len(S5) == 7 * len(SG1), f"{len(S5)} != {7 * len(SG1)} " + assert all(S5.v_pre[:] < 25) + + @pytest.mark.standalone_compatible def test_synapse_creation_generator_multiple_synapses(): G1 = NeuronGroup(10, "v:1") @@ -487,7 +649,6 @@ def test_synapse_access(): S.w = "2*j" assert all(S.w[:, 1] == 2) - assert len(S.w[:, 10]) == 0 assert len(S.w["j==10"]) == 0 # Test referencing pre- and postsynaptic variables @@ -502,7 +663,6 @@ def test_synapse_access(): assert len(S) == len(S.w[SG1, SG2]) assert_equal(S.w[SG1, 1], S.w[:, 1]) assert_equal(S.w[1, SG2], S.w[1, :]) - assert len(S.w[SG1, 10]) == 0 def test_synapses_access_subgroups(): @@ -569,20 +729,67 @@ def test_synapses_access_subgroups_problematic(): @pytest.mark.standalone_compatible def test_subgroup_summed_variable(): # Check in particular that only neurons targeted are reset to 0 (see github issue #925) - source = NeuronGroup(1, "") - target = NeuronGroup(5, "Iin : 1") + source = NeuronGroup(1, "x : 1") + target = NeuronGroup( + 7, + """Iin : 1 + x : 1""", + ) + source.x = 5 target.Iin = 10 + target.x = "i" target1 = target[1:2] - target2 = target[3:] - - syn1 = Synapses(source, target1, "Iin_post = 5 : 1 (summed)") + target2 = target[3:5] + target3 = target[[0, 6]] + syn1 = Synapses(source, target1, "Iin_post = x_pre + x_post : 1 (summed)") syn1.connect(True) - syn2 = Synapses(source, target2, "Iin_post = 1 : 1 (summed)") + syn2 = Synapses(source, target2, "Iin_post = x_pre + x_post : 1 (summed)") syn2.connect(True) + syn3 = Synapses(source, target3, "Iin_post = x_pre + x_post : 1 (summed)") + syn3.connect(True) run(2 * defaultclock.dt) - assert_array_equal(target.Iin, [10, 5, 10, 1, 1]) + assert_array_equal(target.Iin, [5, 6, 10, 8, 9, 10, 11]) + + +@pytest.mark.codegen_independent +def test_subgroup_summed_variable_overlap(): + # Check that overlapping subgroups raise an error + source = NeuronGroup(1, "") + target = NeuronGroup(10, "Iin : 1") + target1 = target[1:3] + target2 = target[2:5] + target3 = target[[1, 6]] + target4 = target[[4, 6]] + + syn1 = Synapses(source, target1, "Iin_post = 1 : 1 (summed)") + syn1.connect(True) + + syn2 = Synapses(source, target2, "Iin_post = 2 : 1 (summed)") + syn2.connect(True) + + syn3 = Synapses(source, target3, "Iin_post = 3 : 1 (summed)") + syn3.connect(True) + + syn4 = Synapses(source, target4, "Iin_post = 4 : 1 (summed)") + syn4.connect(True) + + net1 = Network(source, target, syn1, syn2) # overlap between contiguous subgroups + with pytest.raises(NotImplementedError): + net1.run(0 * ms) + + net2 = Network( + source, target, syn1, syn3 + ) # overlap between contiguous and non-contiguous subgroups + with pytest.raises(NotImplementedError): + net2.run(0 * ms) + + net3 = Network( + source, target, syn3, syn4 + ) # overlap between non-contiguous subgroups + with pytest.raises(NotImplementedError): + net3.run(0 * ms) def test_subexpression_references(): @@ -591,10 +798,8 @@ def test_subexpression_references(): """ G = NeuronGroup( 10, - """ - v : 1 - v2 = 2*v : 1 - """, + """v : 1 + v2 = 2*v : 1""", ) G.v = np.arange(10) SG1 = G[:5] @@ -603,11 +808,9 @@ def test_subexpression_references(): S1 = Synapses( SG1, SG2, - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S1.connect("i==(5-1-j)") assert_equal(S1.i[:], np.arange(5)) @@ -618,11 +821,9 @@ def test_subexpression_references(): S2 = Synapses( G, SG2, - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S2.connect("i==(5-1-j)") assert_equal(S2.i[:], np.arange(5)) @@ -633,11 +834,9 @@ def test_subexpression_references(): S3 = Synapses( SG1, G, - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S3.connect("i==(10-1-j)") assert_equal(S3.i[:], np.arange(5)) @@ -653,10 +852,8 @@ def test_subexpression_no_references(): """ G = NeuronGroup( 10, - """ - v : 1 - v2 = 2*v : 1 - """, + """v : 1 + v2 = 2*v : 1""", ) G.v = np.arange(10) @@ -665,11 +862,9 @@ def test_subexpression_no_references(): S1 = Synapses( G[:5], G[5:], - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S1.connect("i==(5-1-j)") assert_equal(S1.i[:], np.arange(5)) @@ -680,11 +875,9 @@ def test_subexpression_no_references(): S2 = Synapses( G, G[5:], - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S2.connect("i==(5-1-j)") assert_equal(S2.i[:], np.arange(5)) @@ -695,11 +888,9 @@ def test_subexpression_no_references(): S3 = Synapses( G[:5], G, - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S3.connect("i==(10-1-j)") assert_equal(S3.i[:], np.arange(5)) @@ -761,14 +952,14 @@ def test_run_regularly(): @pytest.mark.standalone_compatible def test_spike_monitor(): G = NeuronGroup(10, "v:1", threshold="v>1", reset="v=0") - G.v[0] = 1.1 - G.v[2] = 1.1 - G.v[5] = 1.1 + G.v[[0, 2, 5]] = 1.1 SG = G[3:] SG2 = G[:3] + SGi = G[[3, 5, 7]] s_mon = SpikeMonitor(G) sub_s_mon = SpikeMonitor(SG) sub_s_mon2 = SpikeMonitor(SG2) + sub_s_moni = SpikeMonitor(SGi) run(defaultclock.dt) assert_equal(s_mon.i, np.array([0, 2, 5])) assert_equal(s_mon.t_, np.zeros(3)) @@ -776,6 +967,8 @@ def test_spike_monitor(): assert_equal(sub_s_mon.t_, np.zeros(1)) assert_equal(sub_s_mon2.i, np.array([0, 2])) assert_equal(sub_s_mon2.t_, np.zeros(2)) + assert_equal(sub_s_moni.i, np.array([1])) + assert_equal(sub_s_moni.t, np.zeros(1)) expected = np.zeros(10, dtype=int) expected[[0, 2, 5]] = 1 assert_equal(s_mon.count, expected) @@ -783,45 +976,45 @@ def test_spike_monitor(): expected[[2]] = 1 assert_equal(sub_s_mon.count, expected) assert_equal(sub_s_mon2.count, np.array([1, 0, 1])) + assert_equal(sub_s_moni.count, np.array([0, 1, 0])) @pytest.mark.codegen_independent -def test_wrong_indexing(): +@pytest.mark.parametrize( + "item", + [ + slice(10, None), + slice(3, 2), + [9, 10], + [10, 11], + [2.5, 3.5, 4.5], + [5, 5, 5], + [], + [5, 4, 3], + ], +) +def test_wrong_indexing(item): G = NeuronGroup(10, "v:1") - with pytest.raises(TypeError): - G["string"] - - with pytest.raises(IndexError): - G[10] - with pytest.raises(IndexError): - G[10:] - with pytest.raises(IndexError): - G[::2] - with pytest.raises(IndexError): - G[3:2] - with pytest.raises(IndexError): - G[[5, 4, 3]] - with pytest.raises(IndexError): - G[[2, 4, 6]] - with pytest.raises(IndexError): - G[[-1, 0, 1]] - with pytest.raises(IndexError): - G[[9, 10, 11]] - with pytest.raises(IndexError): - G[[9, 10]] - with pytest.raises(IndexError): - G[[10, 11]] - with pytest.raises(TypeError): - G[[2.5, 3.5, 4.5]] + with pytest.raises((TypeError, IndexError)): + G[item] @pytest.mark.codegen_independent -def test_alternative_indexing(): +@pytest.mark.parametrize( + "item,expected", + [ + (slice(-3, None), np.array([7, 8, 9])), + (slice(None, None, 2), np.array([0, 2, 4, 6, 8])), + (3, np.array([3])), + ([3, 4, 5], np.array([3, 4, 5])), + ([3, 5, 7], np.array([3, 5, 7])), + ([3, -1], np.array([3, 9])), + ], +) +def test_alternative_indexing(item, expected): G = NeuronGroup(10, "v : integer") G.v = "i" - assert_equal(G[-3:].v, np.array([7, 8, 9])) - assert_equal(G[3].v, np.array([3])) - assert_equal(G[[3, 4, 5]].v, np.array([3, 4, 5])) + assert_equal(G[item].v, expected) def test_no_reference_1(): @@ -887,10 +1080,20 @@ def test_recursive_subgroup(): G = NeuronGroup(10, "v : 1") G.v = "i" SG = G[3:8] + SGi = G[[3, 5, 7, 9]] SG2 = SG[2:4] + SGi2 = SGi[2:4] + SGii = SGi[[1, 3]] + SG2i = SG[[1, 3]] assert_equal(SG2.v[:], np.array([5, 6])) assert_equal(SG2.v[:], SG.v[2:4]) + assert_equal(SGi2.v[:], np.array([7, 9])) + assert_equal(SGii.v[:], np.array([5, 9])) + assert_equal(SG2i.v[:], np.array([4, 6])) assert SG2.source.name == G.name + assert SGi2.source.name == G.name + assert SGii.source.name == G.name + assert SG2i.source.name == G.name if __name__ == "__main__": diff --git a/brian2/tests/test_synapses.py b/brian2/tests/test_synapses.py index 7e00063f9..f7c8b698d 100644 --- a/brian2/tests/test_synapses.py +++ b/brian2/tests/test_synapses.py @@ -85,9 +85,9 @@ def test_creation_errors(): # Check that using pre and on_pre (resp. post/on_post) at the same time # raises an error with pytest.raises(TypeError): - Synapses(G, G, "w:1", pre="v+=w", on_pre="v+=w", connect=True) + Synapses(G, G, "w:1", pre="v+=w", on_pre="v+=w") with pytest.raises(TypeError): - Synapses(G, G, "w:1", post="v+=w", on_post="v+=w", connect=True) + Synapses(G, G, "w:1", post="v+=w", on_post="v+=w") @pytest.mark.codegen_independent @@ -307,10 +307,9 @@ def test_connection_string_deterministic_full_one_to_one(): G, G, """ - sub_1 = v_pre : 1 - sub_2 = v_post : 1 - w:1 - """, + sub_1 = v_pre : 1 + sub_2 = v_post : 1 + w:1""", ) S3.connect("sub_1 == sub_2") @@ -532,18 +531,14 @@ def test_connection_random_with_indices(): def test_connection_random_without_condition(): G = NeuronGroup( 4, - """ - v: 1 - x : integer - """, + """v: 1 + x : integer""", ) G.x = "i" G2 = NeuronGroup( 7, - """ - v: 1 - y : 1 - """, + """v: 1 + y : 1""", ) G2.y = "1.0*i/N" @@ -720,9 +715,23 @@ def test_state_variable_indexing(): assert len(S.w[:]) == len(S.w[np.arange(len(G1) * len(G2) * 2)]) assert S.w[3] == S.w[np.int32(3)] == S.w[np.int64(3)] # See issue #888 - # Array-indexing (not yet supported for synapse index) + # Array-indexing assert_equal(S.w[:, 0:3], S.w[:, [0, 1, 2]]) assert_equal(S.w[:, 0:3], S.w[np.arange(len(G1)), [0, 1, 2]]) + assert_equal(S.w[:, 0:3], S.w[:, [0, 1, 2], [0, 1]]) + assert_equal(S.w[:, 0:3, 0], S.w[:, [0, 1, 2], [0]]) + + # Array indexing with boolean arrays + assert_equal( + S.w[:, 0:3], S.w[:, np.array([True, True, True, False, False, False, False])] + ) + assert_equal( + S.w[:, 0:3], + S.w[ + [True, True, True, True, True], + [True, True, True, False, False, False, False], + ], + ) # string-based indexing assert_equal(S.w[0:3, :], S.w["i<3"]) @@ -735,11 +744,81 @@ def test_state_variable_indexing(): with pytest.raises(IndexError): S.w.__getitem__((1, 2, 3, 4)) with pytest.raises(IndexError): - S.w.__getitem__(object()) + S.w[[0, 5], :] # out-of-range array index + with pytest.raises(IndexError): + S.w[[True, False], :] # too few boolean indices with pytest.raises(IndexError): + S.w[[True, False, True, False, True, False], :] # too many boolean indices + with pytest.raises(TypeError): + S.w.__getitem__(object()) + with pytest.raises(TypeError): S.w.__getitem__(1.5) +def test_state_variable_indexing_with_subgroups(): + G1 = NeuronGroup(5, "v:volt") + G1.v = "i*mV" + G2 = NeuronGroup(7, "v:volt") + G2.v = "10*mV + i*mV" + S = Synapses(G1, G2, "w:1", multisynaptic_index="k") + S.connect(True, n=2) + S.w[:, :, 0] = "5*i + j" + S.w[:, :, 1] = "35 + 5*i + j" + + # Slicing + assert len(S[:].w) == len(S[:, :].w) == len(S[:, :, :].w) == len(G1) * len(G2) * 2 + assert len(S[0:, 0:].w) == len(S[0:, 0:, 0:].w) == len(G1) * len(G2) * 2 + assert len(S[0::2, 0:].w) == 3 * len(G2) * 2 + assert len(S[0, :].w) == len(S[0, :, :].w) == len(G2) * 2 + assert len(S[0:2, :].w) == len(S[0:2, :, :].w) == 2 * len(G2) * 2 + assert len(S[:2, :].w) == len(S[:2, :, :].w) == 2 * len(G2) * 2 + assert len(S[0:4:2, :].w) == len(S[0:4:2, :, :].w) == 2 * len(G2) * 2 + assert len(S[:4:2, :].w) == len(S[:4:2, :, :].w) == 2 * len(G2) * 2 + assert len(S[:, 0].w) == len(S[:, 0, :].w) == len(G1) * 2 + assert len(S[:, 0:2].w) == len(S[:, 0:2, :].w) == 2 * len(G1) * 2 + assert len(S[:, :2].w) == len(S[:, :2, :].w) == 2 * len(G1) * 2 + assert len(S[:, 0:4:2].w) == len(S[:, 0:4:2, :].w) == 2 * len(G1) * 2 + assert len(S[:, :4:2].w) == len(S[:, :4:2, :].w) == 2 * len(G1) * 2 + assert len(S[:, :, 0].w) == len(G1) * len(G2) + assert len(S[:, :, 0:2].w) == len(G1) * len(G2) * 2 + assert len(S[:, :, :2].w) == len(G1) * len(G2) * 2 + assert len(S[:, :, 0:2:2].w) == len(G1) * len(G2) + assert len(S[:, :, :2:2].w) == len(G1) * len(G2) + + # 1d indexing is directly indexing synapses and should be equivalent to numpy syntax! + assert len(S[:].w) == len(S[0:].w) == len(S.w[:]) + assert len(S[[0, 1]].w) == len(S[3:5].w) == len(S.w[:][[0, 1]]) == 2 + assert ( + len(S[:].w) + == len(S[np.arange(len(G1) * len(G2) * 2)].w) + == len(S.w[:][np.arange(len(G1) * len(G2) * 2)]) + ) + assert S[3].w == S[np.int32(3)].w == S[np.int64(3)].w # See issue #888 + + # Array-indexing + assert_equal(S[:, 0:3].w[:], S[:, [0, 1, 2]].w[:]) + assert_equal(S[:, 0:3].w[:], S[np.arange(len(G1)), [0, 1, 2]].w[:]) + assert_equal(S[:, 0:3].w[:], S[:, [0, 1, 2], [0, 1]].w[:]) + assert_equal(S[:, 0:3, 0].w[:], S[:, [0, 1, 2], [0]].w[:]) + + # string-based indexing + assert_equal(S[0:3, :].w[:], S["i<3"].w[:]) + assert_equal(S[:, 0:3].w[:], S["j<3"].w[:]) + assert_equal(S[:, :, 0].w[:], S["k == 0"].w[:]) + assert_equal(S[0:3, :].w[:], S["v_pre < 2.5*mV"].w[:]) + assert_equal(S[:, 0:3].w[:], S["v_post < 12.5*mV"].w[:]) + + # invalid indices + with pytest.raises(IndexError): + S.__getitem__((1, 2, 3, 4)) + with pytest.raises(TypeError): + S.__getitem__(object()) + with pytest.raises(TypeError): + S.__getitem__(1.5) + with pytest.raises(IndexError): + S.__getitem__(np.arange(len(G1) * len(G2) * 2 + 1)) + + def test_indices(): G = NeuronGroup(10, "v : 1") S = Synapses(G, G, "") @@ -758,20 +837,16 @@ def test_subexpression_references(): """ G = NeuronGroup( 10, - """ - v : 1 - v2 = 2*v : 1 - """, + """v : 1 + v2 = 2*v : 1""", ) G.v = np.arange(10) S = Synapses( G, G, - """ - w : 1 - u = v2_post + 1 : 1 - x = v2_pre + 1 : 1 - """, + """w : 1 + u = v2_post + 1 : 1 + x = v2_pre + 1 : 1""", ) S.connect("i==(10-1-j)") assert_equal(S.u[:], np.arange(10)[::-1] * 2 + 1) @@ -784,15 +859,13 @@ def test_constant_variable_subexpression_in_synapses(): S = Synapses( G, G, - """ - dv1/dt = -v1**2 / (10*ms) : 1 (clock-driven) - dv2/dt = -v_const**2 / (10*ms) : 1 (clock-driven) - dv3/dt = -v_var**2 / (10*ms) : 1 (clock-driven) - dv4/dt = -v_noflag**2 / (10*ms) : 1 (clock-driven) - v_const = v2 : 1 (constant over dt) - v_var = v3 : 1 - v_noflag = v4 : 1 - """, + """ dv1/dt = -v1**2 / (10*ms) : 1 (clock-driven) + dv2/dt = -v_const**2 / (10*ms) : 1 (clock-driven) + dv3/dt = -v_var**2 / (10*ms) : 1 (clock-driven) + dv4/dt = -v_noflag**2 / (10*ms) : 1 (clock-driven) + v_const = v2 : 1 (constant over dt) + v_var = v3 : 1 + v_noflag = v4 : 1""", method="rk2", ) S.connect(j="i") @@ -818,11 +891,9 @@ def test_nested_subexpression_references(): """ G = NeuronGroup( 10, - """ - v : 1 - v2 = 2*v : 1 - v3 = 1.5*v2 : 1 - """, + """v : 1 + v2 = 2*v : 1 + v3 = 1.5*v2 : 1""", threshold="v>=5", ) G2 = NeuronGroup(10, "v : 1") @@ -840,10 +911,8 @@ def test_equations_unit_check(): syn = Synapses( group, group, - """ - sub1 = 3 : 1 - sub2 = sub1 + 1*mV : volt - """, + """sub1 = 3 : 1 + sub2 = sub1 + 1*mV : volt""", on_pre="v += sub2", ) syn.connect() @@ -979,10 +1048,8 @@ def test_pre_before_post(): # The pre pathway should be executed before the post pathway G = NeuronGroup( 1, - """ - x : 1 - y : 1 - """, + """x : 1 + y : 1""", threshold="True", ) S = Synapses(G, G, "", on_pre="x=1; y=1", on_post="x=2") @@ -1003,10 +1070,8 @@ def test_pre_post_simple(): S = Synapses( G1, G2, - """ - pre_value : 1 - post_value : 1 - """, + """pre_value : 1 + post_value : 1""", pre="pre_value +=1", post="post_value +=2", ) @@ -1122,10 +1187,8 @@ def test_transmission(): # Make sure that the Synapses class actually propagates spikes :) source = NeuronGroup( 4, - """ - dv/dt = rate : 1 - rate : Hz - """, + """dv/dt = rate : 1 + rate : Hz""", threshold="v>1", reset="v=0", ) @@ -1236,8 +1299,8 @@ def test_transmission_scalar_delay_different_clocks(): run(2 * ms) assert len(l) == 1, "expected a warning, got %d" % len(l) assert l[0][1].endswith("synapses_dt_mismatch") - - run(0 * ms) + else: + run(2 * ms) assert_allclose(mon[0].v[mon.t < 0.5 * ms], 0) assert_allclose(mon[0].v[mon.t >= 0.5 * ms], 1) assert_allclose(mon[1].v[mon.t < 1.5 * ms], 0) @@ -1370,11 +1433,9 @@ def test_summed_variable(): S = Synapses( source, target, - """ - w : volt - x : volt - v_post = 2*x : volt (summed) - """, + """w : volt + x : volt + v_post = 2*x : volt (summed)""", on_pre="x+=w", multisynaptic_index="k", ) @@ -1392,33 +1453,28 @@ def test_summed_variable(): def test_summed_variable_pre_and_post(): G1 = NeuronGroup( 4, - """ - neuron_var : 1 - syn_sum : 1 - neuron_sum : 1 - """, + """neuron_var : 1 + syn_sum : 1 + neuron_sum : 1""", ) G1.neuron_var = "i" G2 = NeuronGroup( 4, - """ - neuron_var : 1 - syn_sum : 1 - neuron_sum : 1 - """, + """neuron_var : 1 + syn_sum : 1 + neuron_sum : 1""", ) G2.neuron_var = "i+4" synapses = Synapses( G1, G2, - """ - syn_var : 1 - neuron_sum_pre = neuron_var_post : 1 (summed) - syn_sum_pre = syn_var : 1 (summed) - neuron_sum_post = neuron_var_pre : 1 (summed) - syn_sum_post = syn_var : 1 (summed) - """, + """syn_var : 1 + neuron_sum_pre = neuron_var_post : 1 (summed) + syn_sum_pre = syn_var : 1 (summed) + neuron_sum_post = neuron_var_pre : 1 (summed) + syn_sum_post = syn_var : 1 (summed) + """, ) # The first three cells in G1 connect to the first cell in G2 # The remaining three cells of G2 all connect to the last cell of G1 @@ -1441,10 +1497,8 @@ def test_summed_variable_differing_group_size(): syn1 = Synapses( G1, G2, - """ - syn_var : 1 - var_pre = syn_var + var_post : 1 (summed) - """, + """syn_var : 1 + var_pre = syn_var + var_post : 1 (summed)""", ) syn1.connect(i=0, j=[0, 1, 2, 3, 4]) syn1.connect(i=1, j=[5, 6, 7, 8, 9]) @@ -1457,10 +1511,8 @@ def test_summed_variable_differing_group_size(): syn2 = Synapses( G3, G4, - """ - syn_var : 1 - var_post = syn_var + var_pre : 1 (summed) - """, + """syn_var : 1 + var_post = syn_var + var_pre : 1 (summed)""", ) syn2.connect(i=[0, 1, 2, 3, 4], j=0) syn2.connect(i=[5, 6, 7, 8, 9], j=1) @@ -1478,11 +1530,9 @@ def test_summed_variable_differing_group_size(): def test_summed_variable_errors(): G = NeuronGroup( 10, - """ - dv/dt = -v / (10*ms) : volt - sub = 2*v : volt - p : volt - """, + """dv/dt = -v / (10*ms) : volt + sub = 2*v : volt + p : volt""", threshold="False", reset="", ) @@ -1515,10 +1565,8 @@ def test_summed_variable_errors(): Synapses( G, G, - """ - p_post = 3*volt : volt (summed) - p_pre = 3*volt : volt (summed) - """, + """p_post = 3*volt : volt (summed) + p_pre = 3*volt : volt (summed)""", ) # Summed variable referring to an event-driven variable @@ -1526,10 +1574,8 @@ def test_summed_variable_errors(): Synapses( G, G, - """ - ds/dt = -s/(3*ms) : volt (event-driven) - p_post = s : volt (summed) - """, + """ds/dt = -s/(3*ms) : volt (event-driven) + p_post = s : volt (summed)""", on_pre="s += 1*mV", ) assert "'p_post'" in str(ex.value) and "'s'" in str(ex.value) @@ -1539,12 +1585,10 @@ def test_summed_variable_errors(): Synapses( G, G, - """ - ds/dt = -s/(3*ms) : volt (event-driven) - x = s : volt - y = x : volt - p_post = y : 1 (summed) - """, + """ds/dt = -s/(3*ms) : volt (event-driven) + x = s : volt + y = x : volt + p_post = y : 1 (summed)""", on_pre="s += 1*mV", ) assert "'p_post'" in str(ex.value) and "'s'" in str(ex.value) @@ -1554,10 +1598,8 @@ def test_summed_variable_errors(): S = Synapses( G, G, - """ - y : siemens - p_post = y : volt (summed) - """, + """y : siemens + p_post = y : volt (summed)""", ) run(0 * ms) @@ -1629,20 +1671,16 @@ def test_summed_variables_linked_variables(): def test_scalar_parameter_access(): G = NeuronGroup( 10, - """ - v : 1 - scalar : Hz (shared) - """, + """v : 1 + scalar : Hz (shared)""", threshold="False", ) S = Synapses( G, G, - """ - w : 1 - s : Hz (shared) - number : 1 (shared) - """, + """w : 1 + s : Hz (shared) + number : 1 (shared)""", on_pre="v+=w*number", ) S.connect() @@ -1689,19 +1727,15 @@ def test_scalar_parameter_access(): def test_scalar_subexpression(): G = NeuronGroup( 10, - """ - v : 1 - number : 1 (shared) - """, + """v : 1 + number : 1 (shared)""", threshold="False", ) S = Synapses( G, G, - """ - s : 1 (shared) - sub = number_post + s : 1 (shared) - """, + """s : 1 (shared) + sub = number_post + s : 1 (shared)""", on_pre="v+=s", ) S.connect() @@ -1713,10 +1747,8 @@ def test_scalar_subexpression(): Synapses( G, G, - """ - s : 1 (shared) - sub = v_post + s : 1 (shared) - """, + """s : 1 (shared) + sub = v_post + s : 1 (shared)""", on_pre="v+=s", ) @@ -1728,10 +1760,8 @@ def test_sim_with_scalar_variable(): syn = Synapses( inp, out, - """ - w : 1 - s : 1 (shared) - """, + """w : 1 + s : 1 (shared)""", on_pre="v += s + w", ) syn.connect(j="i") @@ -1748,10 +1778,8 @@ def test_sim_with_scalar_subexpression(): syn = Synapses( inp, out, - """ - w : 1 - s = 5 : 1 (shared) - """, + """w : 1 + s = 5 : 1 (shared)""", on_pre="v += s + w", ) syn.connect(j="i") @@ -1767,10 +1795,8 @@ def test_sim_with_constant_subexpression(): syn = Synapses( inp, out, - """ - w : 1 - s = 5 : 1 (constant over dt) - """, + """w : 1 + s = 5 : 1 (constant over dt)""", on_pre="v += s + w", ) syn.connect(j="i") @@ -1799,20 +1825,16 @@ def test_event_driven(): # trains with different rates pre = NeuronGroup( 2, - """ - dv/dt = rate : 1 - rate : Hz - """, + """dv/dt = rate : 1 + rate : Hz""", threshold="v>1", reset="v=0", ) pre.rate = [1000, 1500] * Hz post = NeuronGroup( 2, - """ - dv/dt = rate : 1 - rate : Hz - """, + """dv/dt = rate : 1 + rate : Hz""", threshold="v>1", reset="v=0", ) @@ -1829,43 +1851,31 @@ def test_event_driven(): S1 = Synapses( pre, post, - """ - w : 1 - dApre/dt = -Apre/taupre : 1 (event-driven) - dApost/dt = -Apost/taupost : 1 (event-driven) - """, - on_pre=""" - Apre += dApre - w = clip(w+Apost, 0, gmax) - """, - on_post=""" - Apost += dApost - w = clip(w+Apre, 0, gmax) - """, + """w : 1 + dApre/dt = -Apre/taupre : 1 (event-driven) + dApost/dt = -Apost/taupost : 1 (event-driven)""", + on_pre="""Apre += dApre + w = clip(w+Apost, 0, gmax)""", + on_post="""Apost += dApost + w = clip(w+Apre, 0, gmax)""", ) S1.connect(j="i") # not event-driven S2 = Synapses( pre, post, - """ - w : 1 - Apre : 1 - Apost : 1 - lastupdate : second - """, - on_pre=""" - Apre=Apre*exp((lastupdate-t)/taupre)+dApre - Apost=Apost*exp((lastupdate-t)/taupost) - w = clip(w+Apost, 0, gmax) - lastupdate = t - """, - on_post=""" - Apre=Apre*exp((lastupdate-t)/taupre) - Apost=Apost*exp((lastupdate-t)/taupost) +dApost - w = clip(w+Apre, 0, gmax) - lastupdate = t - """, + """w : 1 + Apre : 1 + Apost : 1 + lastupdate : second""", + on_pre="""Apre=Apre*exp((lastupdate-t)/taupre)+dApre + Apost=Apost*exp((lastupdate-t)/taupost) + w = clip(w+Apost, 0, gmax) + lastupdate = t""", + on_post="""Apre=Apre*exp((lastupdate-t)/taupre) + Apost=Apost*exp((lastupdate-t)/taupost) +dApost + w = clip(w+Apre, 0, gmax) + lastupdate = t""", ) S2.connect(j="i") S1.w = 0.5 * gmax @@ -1884,8 +1894,8 @@ def test_event_driven_dependency_checks(): dummy, dummy, """ - da/dt = (a-b) / (5*ms): 1 (event-driven) - b : 1""", + da/dt = (a-b) / (5*ms): 1 (event-driven) + b : 1""", on_pre="b+=1", ) syn.connect() @@ -1895,9 +1905,9 @@ def test_event_driven_dependency_checks(): dummy, dummy, """ - da/dt = (a-b) / (5*ms): 1 (event-driven) - b = c : 1 - c : 1""", + da/dt = (a-b) / (5*ms): 1 (event-driven) + b = c : 1 + c : 1""", on_pre="c+=1", ) syn2.connect() @@ -1911,9 +1921,9 @@ def test_event_driven_dependency_error(): stim, stim, """ - da/dt = -a / (5*ms) : 1 (event-driven) - db/dt = -b / (5*ms) : 1 (event-driven) - dc/dt = a*b / (5*ms) : 1 (event-driven)""", + da/dt = -a / (5*ms) : 1 (event-driven) + db/dt = -b / (5*ms) : 1 (event-driven) + dc/dt = a*b / (5*ms) : 1 (event-driven)""", on_pre="a+=1", ) syn.connect() @@ -1932,10 +1942,9 @@ def test_event_driven_dependency_error2(): stim, stim, """ - da/dt = -a / (5*ms) : 1 (clock-driven) - db/dt = -b / (5*ms) : 1 (clock-driven) - dc/dt = a*b / (5*ms) : 1 (event-driven) - """, + da/dt = -a / (5*ms) : 1 (clock-driven) + db/dt = -b / (5*ms) : 1 (clock-driven) + dc/dt = a*b / (5*ms) : 1 (event-driven)""", on_pre="a+=1", ) assert "'c'" in str(exc.value) and ( @@ -1948,10 +1957,9 @@ def test_event_driven_dependency_error2(): stim, stim, """ - da/dt = -a / (5*ms) : 1 (clock-driven) - b = a : 1 - dc/dt = b / (5*ms) : 1 (event-driven) - """, + da/dt = -a / (5*ms) : 1 (clock-driven) + b = a : 1 + dc/dt = b / (5*ms) : 1 (event-driven)""", on_pre="a+=1", ) assert ( @@ -1966,10 +1974,9 @@ def test_event_driven_dependency_error3(): Synapses( P, P, - """ - ds/dt = -s/(3*ms) : 1 (event-driven) - df/dt = f*s/(5*ms) : 1 (clock-driven) - """, + """ds/dt = -s/(3*ms) : 1 (event-driven) + df/dt = f*s/(5*ms) : 1 (clock-driven) + """, on_pre="s += 1", ) assert "'s'" in str(ex.value) and "'f'" in str(ex.value) @@ -1979,12 +1986,11 @@ def test_event_driven_dependency_error3(): Synapses( P, P, - """ - ds/dt = -s/(3*ms) : 1 (event-driven) - x = s : 1 - y = x : 1 - df/dt = f*y/(5*ms) : 1 (clock-driven) - """, + """ds/dt = -s/(3*ms) : 1 (event-driven) + x = s : 1 + y = x : 1 + df/dt = f*y/(5*ms) : 1 (clock-driven) + """, on_pre="s += 1", ) assert "'s'" in str(ex.value) and "'f'" in str(ex.value) @@ -1997,19 +2003,13 @@ def test_repr(): S = Synapses( G, G, - """ - w : 1 - dApre/dt = -Apre/taupre : 1 (event-driven) - dApost/dt = -Apost/taupost : 1 (event-driven) - """, - on_pre=""" - Apre += dApre - w = clip(w+Apost, 0, gmax) - """, - on_post=""" - Apost += dApost - w = clip(w+Apre, 0, gmax) - """, + """w : 1 + dApre/dt = -Apre/taupre : 1 (event-driven) + dApost/dt = -Apost/taupost : 1 (event-driven)""", + on_pre="""Apre += dApre + w = clip(w+Apost, 0, gmax)""", + on_post="""Apost += dApost + w = clip(w+Apre, 0, gmax)""", ) # Test that string/LaTeX representations do not raise errors for func in [str, repr, sympy.latex]: @@ -2021,10 +2021,8 @@ def test_pre_post_variables(): G = NeuronGroup(10, "v : 1", threshold="False") G2 = NeuronGroup( 10, - """ - v : 1 - w : 1 - """, + """v : 1 + w : 1""", threshold="False", ) S = Synapses(G, G2, "x : 1") @@ -2063,10 +2061,8 @@ def test_variables_by_owner(): G = NeuronGroup(10, "v : 1") G2 = NeuronGroup( 10, - """ - v : 1 - w : 1 - """, + """v : 1 + w : 1""", ) S = Synapses(G, G2, "x : 1") @@ -2363,19 +2359,16 @@ def test_vectorisation(): source = NeuronGroup(10, "v : 1", threshold="v>1") target = NeuronGroup( 10, - """ - x : 1 - y : 1 - """, + """x : 1 + y : 1""", ) syn = Synapses( source, target, "w_syn : 1", - on_pre=""" - v_pre += w_syn - x_post = y_post - """, + on_pre="""v_pre += w_syn + x_post = y_post + """, ) syn.connect() syn.w_syn = 1 @@ -2394,12 +2387,10 @@ def test_vectorisation_STDP_like(): w_max = 10 neurons = NeuronGroup( 6, - """ - dv/dt = rate : 1 - ge : 1 - rate : Hz - dA/dt = -A/(1*ms) : 1 - """, + """dv/dt = rate : 1 + ge : 1 + rate : Hz + dA/dt = -A/(1*ms) : 1""", threshold="v>1", reset="v=0", ) @@ -2412,19 +2403,15 @@ def test_vectorisation_STDP_like(): syn = Synapses( neurons[:3], neurons[3:], - """ - w_dep : 1 - w_fac : 1 - """, - on_pre=""" - ge_post += w_dep - w_fac - A_pre += 1 - w_dep = clip(w_dep + A_post, 0, w_max) - """, - on_post=""" - A_post += 1 - w_fac = clip(w_fac + A_pre, 0, w_max) - """, + """w_dep : 1 + w_fac : 1""", + on_pre="""ge_post += w_dep - w_fac + A_pre += 1 + w_dep = clip(w_dep + A_post, 0, w_max) + """, + on_post="""A_post += 1 + w_fac = clip(w_fac + A_pre, 0, w_max) + """, ) syn.connect() neurons.rate = 1000 * Hz @@ -2716,10 +2703,8 @@ def test_synapses_to_synapses_summed_variable(): summed_conn = Synapses( source, conn, - """ - w_post = x : integer (summed) - x : integer - """, + """w_post = x : integer (summed) + x : integer""", ) summed_conn.connect("i>=j") summed_conn.x = "i" @@ -3625,17 +3610,47 @@ def test_synaptic_subgroups(): from_3 = syn[3, :] assert len(from_3) == 3 assert all(syn.i[from_3] == 3) + assert all(from_3.i == 3) assert_array_equal(syn.j[from_3], np.arange(3)) + assert_array_equal(from_3.j, np.arange(3)) to_2 = syn[:, 2] assert len(to_2) == 5 assert all(syn.j[to_2] == 2) - assert_array_equal(syn.i[to_2], np.arange(5)) + assert all(to_2.j == 2) mixed = syn[1:3, :2] assert len(mixed) == 4 connections = {(i, j) for i, j in zip(syn.i[mixed], syn.j[mixed])} - assert connections == {(1, 0), (1, 1), (2, 0), (2, 1)} + expected = {(1, 0), (1, 1), (2, 0), (2, 1)} + # assert connections == expected + connections = {(i, j) for i, j in zip(mixed.i[:], mixed.j[:])} + assert connections == expected + + string_based = syn["i > 1 and j % 2 == 0"] + assert len(string_based) == 6 + connections = {(i, j) for i, j in zip(syn.i[string_based], syn.j[string_based])} + expected = {(2, 0), (2, 2), (3, 0), (3, 2), (4, 0), (4, 2)} + assert connections == expected + connections = {(i, j) for i, j in zip(string_based.i[:], string_based.j[:])} + assert connections == expected + + array_based_2d = syn[[2, 3, 4], [0, 2]] # same as above + assert len(array_based_2d) == 6 + connections = {(i, j) for i, j in zip(syn.i[array_based_2d], syn.j[array_based_2d])} + expected = {(2, 0), (2, 2), (3, 0), (3, 2), (4, 0), (4, 2)} + assert connections == expected + connections = {(i, j) for i, j in zip(array_based_2d.i[:], array_based_2d.j[:])} + assert connections == expected + + indices = np.flatnonzero((syn.i > 1) & (syn.j % 2 == 0)) + array_based_1d = syn[indices] # same as above + assert len(array_based_1d) == 6 + connections = {(i, j) for i, j in zip(syn.i[array_based_1d], syn.j[array_based_1d])} + expected = {(2, 0), (2, 2), (3, 0), (3, 2), (4, 0), (4, 2)} + assert connections == expected + connections = {(i, j) for i, j in zip(array_based_1d.i[:], array_based_1d.j[:])} + assert connections == expected @pytest.mark.codegen_independent diff --git a/docs_sphinx/user/models.rst b/docs_sphinx/user/models.rst index 3d6007e80..a9b8d00bf 100644 --- a/docs_sphinx/user/models.rst +++ b/docs_sphinx/user/models.rst @@ -182,11 +182,12 @@ neurons. In general ``G[i:j]`` refers to the neurons with indices from ``i`` to ``j-1``, as in general in Python. For convenience, you can also use a single index, i.e. ``G[i]`` is equivalent -to ``G[i:i+1]``. In some situations, it can be easier to provide a list of -indices instead of a slice, Brian therefore also allows for this syntax. Note -that this is restricted to cases that are strictly equivalent with slicing -syntax, e.g. you can write ``G[[3, 4, 5]]`` instead of ``G[3:6]``, but you -*cannot* write ``G[[3, 5, 7]]`` or ``G[[5, 4, 3]]``. +to ``G[i:i+1]``. Brian also allows a simplified form of numpy's +`integer array indexing `_, +to create "non-contiguous" subgroups (subgroups with "gaps" in them). +You can for example refer to ``G[[0, 2, 4, 6, 8]]`` or ``G[[-5, -3, -1]]``. +There are two restrictions to index arrays for subgroups: they cannot contain +repeated indices, and the indices need to be provided in ascending order. Subgroups can be used in most places where regular groups are used, e.g. their state variables or spiking activity can be recorded using monitors, they can be @@ -194,6 +195,11 @@ connected via `Synapses`, etc. In such situations, indices (e.g. the indices of the neurons to record from in a `StateMonitor`) are relative to the subgroup, not to the main group +.. note:: + + Non-contiguous subgroups (i.e. subgroups created with the index array syntax) + cannot be used as the source/target groups for `Synapses`. + .. admonition:: The following topics are not essential for beginners. | diff --git a/docs_sphinx/user/multicompartmental.rst b/docs_sphinx/user/multicompartmental.rst index 4c3b92d49..5fdcf2671 100644 --- a/docs_sphinx/user/multicompartmental.rst +++ b/docs_sphinx/user/multicompartmental.rst @@ -322,15 +322,22 @@ indices of compartments, or with the distance from the root:: first_compartments = neuron[:3] first_compartments = neuron[0*um:30*um] -However, note that this is restricted to contiguous indices which most of the -time means that all compartments indexed in this way have to be part of the -same section. Such indices can be acquired directly from the morphology:: +Subgroups can also consist of several compartments that are not directly connected to +each other and lie in different sections. This can be achieved by using a list of +indices or a string expression instead of a slice. For example, to create a subgroup +referring to the compartments with indices 1, 2, 5, and 6, you can use:: - axon = neuron[morpho.axon.indices[:]] + subset = neuron[[1, 2, 5, 6]] -or, more concisely:: +The same restrictions as for general :ref:`subgroups` apply: the given indices need to +be in ascending order without duplicates. Note that you can get indices from the +`Morphology`, e.g. by asking for ``morpho.axon.indices[:]``. - axon = neuron[morpho.axon] +With string indexing, the compartments can be selected by referring to their attributes. +For example, to select all compartments that are at a distance of more than 100µm from +the soma, you can use:: + + distal = neuron['distance > 100*um'] Synaptic inputs ~~~~~~~~~~~~~~~ @@ -409,8 +416,3 @@ Again the location of the threshold can be specified with spatial position:: threshold_location=morpho.axon[30*um], refractory='m > 0.4') -Subgroups -~~~~~~~~~ - -In the same way that you can refer to a subset of neurons in a `NeuronGroup`, -you can also refer to a subset of compartments in a `SpatialNeuron` diff --git a/docs_sphinx/user/recording.rst b/docs_sphinx/user/recording.rst index 007ec99f1..34b819d6d 100644 --- a/docs_sphinx/user/recording.rst +++ b/docs_sphinx/user/recording.rst @@ -214,8 +214,8 @@ monitor:: Note that this technique cannot be applied in :ref:`standalone mode `. -Recording random subsets of neurons ------------------------------------ +Recording subsets of neurons +---------------------------- In large networks, you might only be interested in the activity of a random subset of neurons. While you can specify a ``record`` argument @@ -227,65 +227,15 @@ by using a :ref:`subgroup `:: group = NeuronGroup(1000, ...) spike_mon = SpikeMonitor(group[:100]) # only record first 100 neurons -It might seem like a restriction that such a subgroup has to be contiguous, but -the order of neurons in a group does not have any meaning as such; in a randomly -ordered group of neurons, any contiguous group of neurons can be considered a -random subset. If some aspects of your model *do* depend on the position of the -neuron in a group (e.g. a ring model, where neurons are connected based on their -distance in the ring, or a model where initial values or parameters span a -range of values in a regular fashion), then this requires an extra step: instead -of using the order of neurons in the group directly, or depending on the neuron -index ``i``, create a new, shuffled, index variable as part of the model -definition and then depend on this index instead:: - - group = NeuronGroup(10000, '''.... - index : integer (constant)''') - indices = group.i[:] - np.random.shuffle(indices) - group.index = indices - # Then use 'index' in string expressions or use it as an index array - # for initial values/parameters defined as numpy arrays - -If this solution is not feasible for some reason, there is another approach that -works for a `SpikeMonitor`/`EventMonitor`. You can add an additional flag to -each neuron, stating whether it should be recorded or not. Then, you define a -new :doc:`custom event ` that is identical to the event you are -interested in, but additionally requires the flag to be set. E.g. to only record -the spikes of neurons with the ``to_record`` attribute set:: - - group = NeuronGroup(..., '''... - to_record : boolean (constant)''', - threshold='...', reset='...', - events={'recorded_spike': '... and to_record'}) - group.to_record = ... - mon_events = EventMonitor(group, 'recorded_spike') - -Note that this solution will evaluate the threshold condition for each neuron -twice, and is therefore slightly less efficient. There's one additional caveat: -you'll have to manually include ``and not_refractory`` in your ``events`` -definition if your neuron uses refractoriness. This is done automatically -for the ``threshold`` condition, but not for any user-defined events. - -Recording population averages ------------------------------ - -Continuous recordings from large groups over long simulation times can -fill up the working memory quickly: recording a single variable from -1000 neurons for 100 seconds at the default time resolution results in -an array of about 8 Gigabytes. While this issue can be ameliorated using the -above approaches, the downstream data analysis is often based on -population averages. These can be recorded efficiently using a dummy -group and the `Synapses` class' :ref:`summed variable syntax -`:: - - group = NeuronGroup(..., 'dv/dt = ... : volt', ...) - - # Dummy group to store the average membrane potential at every time step - vm_container = NeuronGroup(1, 'average_vm : volt') - - # Synapses averaging the membrane potential of all neurons in group - vm_averager = Synapses(group, vm_container, 'average_vm_post = v_pre/N_pre : volt (summed)') - vm_averager.connect() - - # Monitor recording the average membrane potential - vm_monitor = StateMonitor(vm_container, 'average_vm', record=True) +This also extends to non-contiguous subgroups, e.g. to record every second cell +you can use:: + + group = NeuronGroup(1000, ...) + spike_mon = SpikeMonitor(group[::2]) # record every second neuron + +Note that this is less efficient than recording from a contiguous subgroup, since +for every spike in ``group``, Brian needs to check whether its index is part of the +subgroup (this check is quicker to do when the subgroup is a contiguous range). +In many models, the order of neurons in a group does not have any meaning as such. +If your model is randomly ordered, then recording from the first half of neurons +is equivalent to recording from every second neuron, but more efficient.