Skip to content

Commit

Permalink
Add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneswaran-chandrasekaran committed Nov 7, 2019
1 parent 3ac0adf commit 6d77b8e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
12 changes: 8 additions & 4 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,15 +944,19 @@ def resting_state(self, x0 = {}):
Dictioary with pair of state variables and resting state values. Returned values
are represented in SI units.
'''
# check whether the model is currently unsupported
if self.thresholder != {} or self.events != {}:
raise NotImplementedError('Event based and Neuron-specific models are currently not supported for resting state calculation')

if(x0.keys() - self.equations.diff_eq_names):
raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() - self.equations.diff_eq_names))))
raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() -
self.equations.diff_eq_names))))

# Add 0 as the intial value for non-mentioned state variables in x0
x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()})

return dict(zip(sorted(self.equations.diff_eq_names), root(_wrapper, list(dict(sorted(x0.items())).values()),
args = (self.equations, get_local_namespace(1))).x))
sorted_variable_values = list(dict(sorted(x0.items())).values())
result = root(_wrapper, sorted_variable_values, args = (self.equations, get_local_namespace(1)))
return dict(zip(sorted(self.equations.diff_eq_names), result.x))

def _evaluate_rhs(eqs, values, namespace=None):
"""
Expand Down
37 changes: 34 additions & 3 deletions brian2/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import division
from __future__ import absolute_import
import uuid

import sys
import sympy
import numpy as np
from numpy.testing.utils import assert_raises, assert_equal
Expand All @@ -22,8 +22,8 @@
from brian2.units.fundamentalunits import (DimensionMismatchError,
have_same_dimensions)
from brian2.units.unitsafefunctions import linspace
from brian2.units.allunits import second, volt
from brian2.units.stdunits import ms, mV, Hz
from brian2.units.allunits import second, volt, umetre, siemens, ufarad
from brian2.units.stdunits import ms, mV, Hz, cm
from brian2.utils.logger import catch_logs

from brian2.tests.utils import assert_allclose
Expand Down Expand Up @@ -1639,6 +1639,36 @@ def test_semantics_mod():
assert_allclose(G.x[:], float_values % 3)
assert_allclose(G.y[:], float_values % 3)

def test_resting_value():
"""
Test the resting state values of the system
"""
# simple model with single dependent variable, here it is not necessary
# to run the model as the resting value is certain
epsilon = sys.float_info.epsilon
El = - 100
tau = 1 * ms
eqs = '''
dv/dt = (El - v)/tau : 1
'''
grp = NeuronGroup(1, eqs, method = 'exact')
resting_state = grp.resting_state()
assert abs(resting_state['v'] - El) < epsilon * max(abs(resting_state['v']), abs(El))

# one more example
area = 100 * umetre ** 2
g_L = 1e-2 * siemens * cm ** -2 * area
E_L = 1000
Cm = 1 * ufarad * cm ** -2 * area
grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt
I_leak = g_L*(E_L - v) : amp''')
resting_state = grp.resting_state({'v': float(10000)})
assert abs(resting_state['v'] - E_L) < epsilon * max(abs(resting_state['v']), abs(E_L))

# check unsupported models are identified
tau = 10 * ms
grp = NeuronGroup(1, 'dv/dt = -v/tau : volt', threshold='v > -50*mV', reset='v = -70*mV')
assert_raises(NotImplementedError, lambda: grp.resting_state())

if __name__ == '__main__':
test_set_states()
Expand Down Expand Up @@ -1714,3 +1744,4 @@ def test_semantics_mod():
test_semantics_floor_division()
test_semantics_floating_point_division()
test_semantics_mod()
test_resting_value()

0 comments on commit 6d77b8e

Please sign in to comment.