Skip to content

Commit

Permalink
Add rng seed and rng distributions for Nest parameters
Browse files Browse the repository at this point in the history
Update documentation.
  • Loading branch information
drodarie committed May 2, 2024
1 parent cfed414 commit d5aceb8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
8 changes: 5 additions & 3 deletions bsb_nest/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from neo import SpikeTrain
from tqdm import tqdm

from .exceptions import NestConnectError, NestModelError, NestModuleError
from .exceptions import KernelWarning, NestConnectError, NestModelError, NestModuleError

if typing.TYPE_CHECKING:
from bsb import Simulation
from .simulation import NestSimulation


class NestResult(SimulationResult):
Expand Down Expand Up @@ -167,10 +167,12 @@ def create_devices(self, simulation):
for device_model in simulation.devices.values():
device_model.implement(self, simulation, simdata)

def set_settings(self, simulation: "Simulation"):
def set_settings(self, simulation: "NestSimulation"):
nest.set_verbosity(simulation.verbosity)
nest.resolution = simulation.resolution
nest.overwrite_files = True
if simulation.seed is not None:
nest.rng_seed = simulation.seed

def check_comm(self):
if nest.NumProcesses() != MPI.get_size():
Expand Down
6 changes: 4 additions & 2 deletions bsb_nest/cell.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import nest
from bsb import CellModel, config, types
from bsb import CellModel, config

from .distributions import nest_parameter


@config.node
class NestCell(CellModel):
model = config.attr(type=str, default="iaf_psc_alpha")
constants = config.dict(type=types.any_())
constants = config.dict(type=nest_parameter())

def create_population(self, simdata):
n = len(simdata.placement[self])
Expand Down
3 changes: 2 additions & 1 deletion bsb_nest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bsb import MPI, ConnectionModel, compose_nodes, config, types
from tqdm import tqdm

from .distributions import nest_parameter
from .exceptions import NestConnectError


Expand All @@ -16,7 +17,7 @@ class NestSynapseSettings:
weight = config.attr(type=float, required=True)
delay = config.attr(type=float, required=True)
receptor_type = config.attr(type=int)
constants = config.catch_all(type=types.any_())
constants = config.catch_all(type=nest_parameter())


@config.node
Expand Down
61 changes: 61 additions & 0 deletions bsb_nest/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import builtins
import typing

import errr
import nest.random.hl_api_random as _distributions
from bsb import DistributionCastError, TypeHandler, config, types

if typing.TYPE_CHECKING:
from bsb import Scaffold

_available_distributions = [d for d in _distributions.__all__]


@config.node
class NestRandomDistribution:
"""
Class to handle NEST random distributions.
"""

scaffold: "Scaffold"
distribution: str = config.attr(
type=types.in_(_available_distributions), required=True
)
"""Distribution name. Should correspond to a function of nest.random.hl_api_random"""
parameters: dict[str, typing.Any] = config.catch_all(type=types.any_())
"""Dictionary of parameters to assign to the distribution. Should correspond to NEST's"""

def __init__(self, **kwargs):
try:
self._distr = getattr(_distributions, self.distribution)(**self.parameters)
except Exception as e:
errr.wrap(
DistributionCastError, e, prepend=f"Can't cast to '{self.distribution}': "
)

def __call__(self):
return self._distr

def __getattr__(self, attr):
if "_distr" not in self.__dict__:
raise AttributeError("No underlying _distr found for distribution node.")
return getattr(self._distr, attr)


class nest_parameter(TypeHandler):
"""
Type validator. Type casts the value or node to a Nest parameter, that can be either a value or
a NestRandomDistribution.
"""

def __call__(self, value, _key=None, _parent=None):
if isinstance(value, builtins.dict) and "distribution" in value.keys():
return NestRandomDistribution(**value, _key=_key, _parent=_parent)
return value

@property
def __name__(self): # pragma: nocover
return "nest_parameter"

def __inv__(self, value):
return value
6 changes: 6 additions & 0 deletions bsb_nest/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@ class NestSimulation(Simulation):
"""

modules = config.list(type=str)
"""List of NEST modules to load at the beginning of the simulation"""
threads = config.attr(type=types.int(min=1), default=1)
"""Number of threads to use during simulation"""
resolution = config.attr(type=types.float(min=0.0), required=True)
"""Simulation time step size in milliseconds"""
verbosity = config.attr(type=str, default="M_ERROR")
"""NEST verbosity level"""
seed = config.attr(type=int, default=None)
"""Random seed for the simulations"""

cell_models = config.dict(type=NestCell, required=True)
connection_models = config.dict(type=NestConnection, required=True)
Expand Down

0 comments on commit d5aceb8

Please sign in to comment.