diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 916901b6f..d252fe280 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -944,15 +944,19 @@ def resting_state(self, x0 = {}): Dictioary with pair of state variables and resting state values. Returned values are represented in SI units. ''' + # check whether the model is currently unsupported + if self.thresholder != {} or self.events != {}: + raise NotImplementedError('Event based and Neuron-specific models are currently not supported for resting state calculation') if(x0.keys() - self.equations.diff_eq_names): - raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() - self.equations.diff_eq_names)))) + raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() - + self.equations.diff_eq_names)))) # Add 0 as the intial value for non-mentioned state variables in x0 x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()}) - - return dict(zip(sorted(self.equations.diff_eq_names), root(_wrapper, list(dict(sorted(x0.items())).values()), - args = (self.equations, get_local_namespace(1))).x)) + sorted_variable_values = list(dict(sorted(x0.items())).values()) + result = root(_wrapper, sorted_variable_values, args = (self.equations, get_local_namespace(1))) + return dict(zip(sorted(self.equations.diff_eq_names), result.x)) def _evaluate_rhs(eqs, values, namespace=None): """ diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index a997ef248..53a95976d 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -1,7 +1,7 @@ from __future__ import division from __future__ import absolute_import import uuid - +import sys import sympy import numpy as np from numpy.testing.utils import assert_raises, assert_equal @@ -22,8 +22,8 @@ from brian2.units.fundamentalunits import (DimensionMismatchError, have_same_dimensions) from brian2.units.unitsafefunctions import linspace -from brian2.units.allunits import second, volt -from brian2.units.stdunits import ms, mV, Hz +from brian2.units.allunits import second, volt, umetre, siemens, ufarad +from brian2.units.stdunits import ms, mV, Hz, cm from brian2.utils.logger import catch_logs from brian2.tests.utils import assert_allclose @@ -1639,6 +1639,36 @@ def test_semantics_mod(): assert_allclose(G.x[:], float_values % 3) assert_allclose(G.y[:], float_values % 3) +def test_resting_value(): + """ + Test the resting state values of the system + """ + # simple model with single dependent variable, here it is not necessary + # to run the model as the resting value is certain + epsilon = sys.float_info.epsilon + El = - 100 + tau = 1 * ms + eqs = ''' + dv/dt = (El - v)/tau : 1 + ''' + grp = NeuronGroup(1, eqs, method = 'exact') + resting_state = grp.resting_state() + assert abs(resting_state['v'] - El) < epsilon * max(abs(resting_state['v']), abs(El)) + + # one more example + area = 100 * umetre ** 2 + g_L = 1e-2 * siemens * cm ** -2 * area + E_L = 1000 + Cm = 1 * ufarad * cm ** -2 * area + grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt + I_leak = g_L*(E_L - v) : amp''') + resting_state = grp.resting_state({'v': float(10000)}) + assert abs(resting_state['v'] - E_L) < epsilon * max(abs(resting_state['v']), abs(E_L)) + + # check unsupported models are identified + tau = 10 * ms + grp = NeuronGroup(1, 'dv/dt = -v/tau : volt', threshold='v > -50*mV', reset='v = -70*mV') + assert_raises(NotImplementedError, lambda: grp.resting_state()) if __name__ == '__main__': test_set_states() @@ -1714,3 +1744,4 @@ def test_semantics_mod(): test_semantics_floor_division() test_semantics_floating_point_division() test_semantics_mod() + test_resting_value() \ No newline at end of file