Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Resting state integration #1123

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from brian2.core.spikesource import SpikeSource
from brian2.core.variables import (Variables, LinkedVariable,
DynamicArrayVariable, Subexpression)
from brian2.core.namespace import get_local_namespace
from brian2.equations.equations import (Equations, DIFFERENTIAL_EQUATION,
SUBEXPRESSION, PARAMETER,
check_subexpressions,
extract_constant_subexpressions)
extract_constant_subexpressions,
SingleEquation)
from brian2.equations.refractory import add_refractoriness
from brian2.parsing.expressions import (parse_expression_dimensions,
is_boolean_expression)
Expand All @@ -31,10 +33,16 @@
fail_for_dimension_mismatch)
from brian2.utils.logger import get_logger
from brian2.utils.stringtools import get_identifiers

from brian2.codegen.runtime.numpy_rt.numpy_rt import NumpyCodeObject
from .group import Group, CodeRunner, get_dtype
from .subgroup import Subgroup

try:
from scipy.optimize import root
scipy_available = True
except ImportError:
scipy_available = False

__all__ = ['NeuronGroup']

logger = get_logger(__name__)
Expand Down Expand Up @@ -920,3 +928,94 @@ def add_event_to_text(event):
add_event_to_text(event)

return '\n'.join(text)

def resting_state(self, x0 = {}):
'''
Calculate resting state of the system.

Parameters
----------
x0 : dict
Initial guess for the state variables. If any of the system's state variables are not
added, default value of 0 is mapped as the initial guess to the missing state variables.
Note: Time elapsed to locate the resting state would be lesser for better initial guesses.

Returns
-------
rest_state : dict
Dictioary with pair of state variables and resting state values. Returned values
are represented in SI units.
'''
# check scipy availability
if scipy_available == False:
raise NotImplementedError("Scipy is not available for using `scipy.optimize.root()`")
# check state variables defined in initial guess are valid
if(x0.keys() - self.equations.diff_eq_names):
raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() -
self.equations.diff_eq_names))))

# Add 0 as the intial value for non-mentioned state variables in x0
x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()})
sorted_variable_values = list(dict(sorted(x0.items())).values())
result = root(_wrapper, sorted_variable_values, args = (self.equations, get_local_namespace(1)))
# check the result message for the status of convergence
if result.success == False:
raise Exception("The model failed to converge at a resting state. Trying better initial guess shall fix the problem")
return dict(zip(sorted(self.equations.diff_eq_names), result.x))

def _evaluate_rhs(eqs, values, namespace=None):
"""
Evaluates the RHS of a system of differential equations for given state
variable values. External constants can be provided via the namespace or
will be taken from the local namespace.
This function could be used for example to find a resting state of the
system, i.e. a fixed point where the RHS of all equations are approximately
0.
Parameters
----------
eqs : `Equations`
The equations
values : dict-like
Values for each of the state variables (differential equations and
parameters).
Returns
-------
rhs : dict
A dictionary with the names of all variables defined by differential
equations as keys and the respective RHS of the equations as values.
"""
# Make a new set of equations, where differential equations are replaced
# by parameters, and a new subexpression defines their RHS.
# E.g. for 'dv/dt = -v / tau : volt' use:
# '''v : volt
# RHS_v = -v / tau : volt'''
new_equations = []
for eq in eqs.values():
if eq.type == DIFFERENTIAL_EQUATION:
new_equations.append(SingleEquation(PARAMETER, eq.varname,
dimensions=eq.dim,
var_type=eq.var_type))
new_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname,
dimensions=eq.dim/second.dim,
var_type=eq.var_type,
expr=eq.expr))
else:
new_equations.append(eq)
# TODO: Hide this from standalone mode
group = NeuronGroup(1, model=Equations(new_equations),
codeobj_class=NumpyCodeObject,
namespace=namespace)

# Set the values of the state variables/parameters and units are not taken into account
group.set_states(values, units = False)

# Get the values of all RHS_... subexpressions
states = ['RHS_' + name for name in eqs.diff_eq_names]
return group.get_states(states)

def _wrapper(args, equations, namespace):
"""
Function for which root needs to be calculated. Callable function of scipy.optimize.root()
"""
rhs = _evaluate_rhs(equations, {name : arg for name, arg in zip(sorted(equations.diff_eq_names), args)}, namespace)
return [float(rhs['RHS_{}'.format(name)]) for name in sorted(equations.diff_eq_names)]
56 changes: 55 additions & 1 deletion brian2/tests/test_neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from brian2.units.allunits import second, volt
from brian2.units.fundamentalunits import (DimensionMismatchError,
have_same_dimensions)
from brian2.units.stdunits import ms, mV, Hz
from brian2.units.stdunits import ms, mV, Hz, cm, msiemens, nA
from brian2.units.unitsafefunctions import linspace
from brian2.units.allunits import second, volt, umetre, siemens, ufarad
from brian2.utils.logger import catch_logs


Expand Down Expand Up @@ -1716,6 +1717,57 @@ def test_semantics_mod():
assert_allclose(G.x[:], float_values % 3)
assert_allclose(G.y[:], float_values % 3)

def test_simple_resting_value():
"""
Test the resting state values of the system
"""
# simple model with single dependent variable, here it is not necessary
# to run the model as the resting value is certain
El = - 100
tau = 1 * ms
eqs = '''
dv/dt = (El - v)/tau : 1
'''
grp = NeuronGroup(1, eqs, method = 'exact')
resting_state = grp.resting_state()
assert_allclose(resting_state['v'], El)

# one more example
area = 100 * umetre ** 2
g_L = 1e-2 * siemens * cm ** -2 * area
E_L = 1000
Cm = 1 * ufarad * cm ** -2 * area
grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt
I_leak = g_L*(E_L - v) : amp''')
resting_state = grp.resting_state({'v': float(10000)})
assert_allclose(resting_state['v'], E_L)

def test_failed_resting_state():
# check the failed to converge system is correctly notified to the user
area = 20000 * umetre ** 2
Cm = 1 * ufarad * cm ** -2 * area
gl = 5e-5 * siemens * cm ** -2 * area
El = -65 * mV
EK = -90 * mV
ENa = 50 * mV
g_na = 100 * msiemens * cm ** -2 * area
g_kd = 30 * msiemens * cm ** -2 * area
VT = -63 * mV
I = 0.01*nA
eqs = Equations('''
dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/
(exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/
(exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1
dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/
(exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1
dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1
''')
group = NeuronGroup(1, eqs, method='exponential_euler')
group.v = -70*mV
# very poor choice of initial values causing the convergence to fail
with pytest.raises(Exception):
group.resting_state({'v': 0, 'm': 100000000, 'n': 1000000, 'h': 100000000})

if __name__ == '__main__':
test_set_states()
Expand Down Expand Up @@ -1792,3 +1844,5 @@ def test_semantics_mod():
test_semantics_floor_division()
test_semantics_floating_point_division()
test_semantics_mod()
test_simple_resting_value()
test_failed_resting_state()