Skip to content

Commit

Permalink
Merge main on feature/ref_parser
Browse files Browse the repository at this point in the history
  • Loading branch information
drodarie committed May 20, 2024
2 parents 7c379db + c21176a commit e105f80
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 8 deletions.
8 changes: 5 additions & 3 deletions bsb_nest/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from neo import SpikeTrain
from tqdm import tqdm

from .exceptions import NestConnectError, NestModelError, NestModuleError
from .exceptions import KernelWarning, NestConnectError, NestModelError, NestModuleError

if typing.TYPE_CHECKING:
from bsb import Simulation
from .simulation import NestSimulation


class NestResult(SimulationResult):
Expand Down Expand Up @@ -167,10 +167,12 @@ def create_devices(self, simulation):
for device_model in simulation.devices.values():
device_model.implement(self, simulation, simdata)

def set_settings(self, simulation: "Simulation"):
def set_settings(self, simulation: "NestSimulation"):
nest.set_verbosity(simulation.verbosity)
nest.resolution = simulation.resolution
nest.overwrite_files = True
if simulation.seed is not None:
nest.rng_seed = simulation.seed

def check_comm(self):
if nest.NumProcesses() != MPI.get_size():
Expand Down
13 changes: 10 additions & 3 deletions bsb_nest/cell.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import nest
from bsb import CellModel, config, types
from bsb import CellModel, config

from .distributions import NestRandomDistribution, nest_parameter


@config.node
class NestCell(CellModel):
model = config.attr(type=str, default="iaf_psc_alpha")
constants = config.dict(type=types.any_())
constants = config.dict(type=nest_parameter())

def create_population(self, simdata):
n = len(simdata.placement[self])
Expand All @@ -15,7 +17,12 @@ def create_population(self, simdata):
return population

def set_constants(self, population):
population.set(self.constants)
population.set(
{
k: (v() if isinstance(v, NestRandomDistribution) else v)
for k, v in self.constants.items()
}
)

def set_parameters(self, population, simdata):
ps = simdata.placement[self]
Expand Down
3 changes: 2 additions & 1 deletion bsb_nest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bsb import MPI, ConnectionModel, compose_nodes, config, types
from tqdm import tqdm

from .distributions import nest_parameter
from .exceptions import NestConnectError


Expand All @@ -16,7 +17,7 @@ class NestSynapseSettings:
weight = config.attr(type=float, required=True)
delay = config.attr(type=float, required=True)
receptor_type = config.attr(type=int)
constants = config.catch_all(type=types.any_())
constants = config.catch_all(type=nest_parameter())


@config.node
Expand Down
62 changes: 62 additions & 0 deletions bsb_nest/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import builtins
import typing

import errr
import nest.random.hl_api_random as _distributions
from bsb import DistributionCastError, TypeHandler, config, types

if typing.TYPE_CHECKING:
from bsb import Scaffold

_available_distributions = [d for d in _distributions.__all__]


@config.node
class NestRandomDistribution:
"""
Class to handle NEST random distributions.
"""

scaffold: "Scaffold"
distribution: str = config.attr(
type=types.in_(_available_distributions), required=True
)
"""Distribution name. Should correspond to a function of nest.random.hl_api_random"""
parameters: dict[str, typing.Any] = config.catch_all(type=types.any_())
"""Dictionary of parameters to assign to the distribution. Should correspond to NEST's"""

def __init__(self, **kwargs):
try:
self._distr = getattr(_distributions, self.distribution)(**self.parameters)
except Exception as e:
errr.wrap(
DistributionCastError, e, prepend=f"Can't cast to '{self.distribution}': "
)

def __call__(self):
return self._distr

def __getattr__(self, attr):
# hasattr does not work here. So we use __dict__
if "_distr" not in self.__dict__:
raise AttributeError("No underlying _distr found for distribution node.")
return getattr(self._distr, attr)


class nest_parameter(TypeHandler):
"""
Type validator. Type casts the value or node to a Nest parameter, that can be either a value or
a NestRandomDistribution.
"""

def __call__(self, value, _key=None, _parent=None):
if isinstance(value, builtins.dict) and "distribution" in value.keys():
return NestRandomDistribution(**value, _key=_key, _parent=_parent)
return value

@property
def __name__(self): # pragma: nocover
return "nest parameter"

def __inv__(self, value):
return value
6 changes: 6 additions & 0 deletions bsb_nest/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ class NestSimulation(Simulation):
"""

modules = config.list(type=str)
"""List of NEST modules to load at the beginning of the simulation"""
threads = config.attr(type=types.int(min=1), default=1)
"""Number of threads to use during simulation"""
resolution = config.attr(type=types.float(min=0.0), required=True)
"""Simulation time step size in milliseconds"""
verbosity = config.attr(type=str, default="M_ERROR")
"""NEST verbosity level"""
seed = config.attr(type=int, default=None)
"""Random seed for the simulations"""

cell_models = config.dict(type=NestCell, required=True)
connection_models = config.dict(type=NestConnection, required=True)
Expand Down
89 changes: 88 additions & 1 deletion tests/test_nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import nest
import numpy as np
from bsb import BootError, ConfigurationError
from bsb import BootError, CastError, ConfigurationError
from bsb.config import Configuration
from bsb.core import Scaffold
from bsb.services import MPI
Expand Down Expand Up @@ -345,3 +345,90 @@ def test_dc_generator(self):
v_ms[int(50 / resolution) + 1 : int(60 / resolution) + 1] > -70,
"Current injected should raise membrane potential",
)

def test_nest_randomness(self):
nest.ResetKernel()
nest.resolution = 0.1
nest.rng_seed = 1234
# gif_cond_exp implements a random spiking process.
# So it's perfect to test the seed
A = nest.Create(
"gif_cond_exp",
1,
params={"I_e": 200.0, "V_m": nest.random.normal(mean=-70, std=20.0)},
)
spikeA = nest.Create("spike_recorder")
nest.Connect(A, spikeA)
nest.Simulate(1000.0)
spike_times_nest = spikeA.get("events")["times"]
print(spike_times_nest)

conf = {
"name": "test",
"storage": {"engine": "hdf5"},
"network": {"x": 1, "y": 1, "z": 1},
"partitions": {"B": {"type": "layer", "thickness": 1}},
"cell_types": {"A": {"spatial": {"radius": 1, "count": 1}}},
"placement": {
"placement_A": {
"strategy": "bsb.placement.strategy.FixedPositions",
"cell_types": ["A"],
"partitions": ["B"],
"positions": [[1, 1, 1]],
}
},
"connectivity": {},
"after_connectivity": {},
"simulations": {
"test": {
"simulator": "nest",
"duration": 1000,
"resolution": 0.1,
"seed": 1234,
"cell_models": {
"A": {
"model": "gif_cond_exp",
"constants": {
"I_e": 200.0,
"V_m": {
"distribution": "normal",
"mean": -70,
"std": 20.0,
},
},
}
},
"connection_models": {},
"devices": {
"record_A_spikes": {
"device": "spike_recorder",
"delay": 0.5,
"targetting": {
"strategy": "cell_model",
"cell_models": ["A"],
},
}
},
}
},
}
cfg = Configuration(conf)
netw = Scaffold(cfg, self.storage)
netw.compile()
results = netw.run_simulation("test")
spike_times_bsb = results.spiketrains[0]
self.assertClose(np.array(spike_times_nest), np.array(spike_times_bsb))
self.assertEqual(
cfg.__tree__()["simulations"]["test"]["cell_models"]["A"]["constants"]["V_m"],
{
"distribution": "normal",
"mean": -70,
"std": 20.0,
},
)
# Test with an unknown distribution
conf["simulations"]["test"]["cell_models"]["A"]["constants"]["V_m"][
"distribution"
] = "bean"
with self.assertRaises(CastError):
Configuration(conf)

0 comments on commit e105f80

Please sign in to comment.