-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rng seed and rng distributions for Nest parameters
Update documentation.
- Loading branch information
Showing
5 changed files
with
78 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import builtins | ||
import typing | ||
|
||
import errr | ||
import nest.random.hl_api_random as _distributions | ||
from bsb import DistributionCastError, TypeHandler, config, types | ||
|
||
if typing.TYPE_CHECKING: | ||
from bsb import Scaffold | ||
|
||
_available_distributions = [d for d in _distributions.__all__] | ||
|
||
|
||
@config.node | ||
class NestRandomDistribution: | ||
""" | ||
Class to handle NEST random distributions. | ||
""" | ||
|
||
scaffold: "Scaffold" | ||
distribution: str = config.attr( | ||
type=types.in_(_available_distributions), required=True | ||
) | ||
"""Distribution name. Should correspond to a function of nest.random.hl_api_random""" | ||
parameters: dict[str, typing.Any] = config.catch_all(type=types.any_()) | ||
"""Dictionary of parameters to assign to the distribution. Should correspond to NEST's""" | ||
|
||
def __init__(self, **kwargs): | ||
try: | ||
self._distr = getattr(_distributions, self.distribution)(**self.parameters) | ||
except Exception as e: | ||
errr.wrap( | ||
DistributionCastError, e, prepend=f"Can't cast to '{self.distribution}': " | ||
) | ||
|
||
def __call__(self): | ||
return self._distr | ||
|
||
def __getattr__(self, attr): | ||
if "_distr" not in self.__dict__: | ||
raise AttributeError("No underlying _distr found for distribution node.") | ||
return getattr(self._distr, attr) | ||
|
||
|
||
class nest_parameter(TypeHandler): | ||
""" | ||
Type validator. Type casts the value or node to a Nest parameter, that can be either a value or | ||
a NestRandomDistribution. | ||
""" | ||
|
||
def __call__(self, value, _key=None, _parent=None): | ||
if isinstance(value, builtins.dict) and "distribution" in value.keys(): | ||
return NestRandomDistribution(**value, _key=_key, _parent=_parent) | ||
return value | ||
|
||
@property | ||
def __name__(self): # pragma: nocover | ||
return "nest_parameter" | ||
|
||
def __inv__(self, value): | ||
return value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters