Skip to content

Commit 40c75d0

Browse files
authored
Fix state and parameters getter (#1180)
1 parent 0431f83 commit 40c75d0

File tree

3 files changed

+39
-22
lines changed

3 files changed

+39
-22
lines changed

pynestml/codegeneration/nest_desktop_code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
6969
namespace["nestml_version"] = pynestml.__version__
7070
namespace["neuronName"] = neuron.get_name()
7171
namespace["neuron"] = neuron
72-
namespace["parameters"], namespace["state"] = PythonStandaloneTargetTools.get_neuron_parameters_and_state(neuron.get_name())
72+
namespace["parameters"], namespace["state"] = PythonStandaloneTargetTools.get_neuron_parameters_and_state(neuron.file_path)
7373

7474
return namespace

pynestml/codegeneration/python_standalone_target_tools.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,61 +18,78 @@
1818
#
1919
# You should have received a copy of the GNU General Public License
2020
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
21+
2122
import importlib
23+
import multiprocessing
2224
import os
2325
import sys
26+
import tempfile
2427

2528
from pynestml.frontend.frontend_configuration import FrontendConfiguration
2629
from pynestml.frontend.pynestml_frontend import generate_python_standalone_target
30+
from pynestml.meta_model.ast_model import ASTModel
2731
from pynestml.utils.logger import LoggingLevel, Logger
32+
from pynestml.utils.model_parser import ModelParser
2833

2934

3035
class PythonStandaloneTargetTools:
31-
"""
32-
Helper functions for Python standalone target.
36+
r"""
37+
Helper functions for the Python standalone target.
3338
"""
3439
@classmethod
35-
def _get_model_parameters_and_state(cls, model_name: str):
36-
input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
37-
os.pardir, os.pardir, "models", "neurons", model_name + ".nestml"))))
38-
suffix = "_nestml"
40+
def _get_model_parameters_and_state(cls, nestml_file_name: str):
41+
suffix = ""
3942
module_name = FrontendConfiguration.get_module_name()
40-
target_path = FrontendConfiguration.get_module_name()
41-
generate_python_standalone_target(input_path=input_path,
42-
target_path=target_path,
43-
suffix=suffix,
44-
module_name=module_name,
45-
logging_level="INFO")
43+
target_path = tempfile.mkdtemp(prefix="nestml_python_target_", suffix="", dir=".") # dir = "." is necessary for Python import
44+
# this has to run in a different process, because otherwise the frontend configuration gets overwritten
45+
process = multiprocessing.Process(target=generate_python_standalone_target, kwargs={"input_path": nestml_file_name,
46+
"target_path": target_path,
47+
"suffix": suffix,
48+
"module_name": module_name,
49+
"logging_level": "ERROR"})
50+
process.start()
51+
process.join() # wait for code generation to complete
52+
53+
ast_compilation_unit = ModelParser.parse_file(nestml_file_name)
54+
if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0:
55+
raise Exception("Error(s) occurred during code generation; please check error messages")
4656

47-
py_module_name = module_name + "." + model_name + suffix
57+
model: ASTModel = ast_compilation_unit.get_model_list()[0]
58+
model_name = model.get_name()
59+
60+
py_module_name = os.path.basename(target_path) + "." + model_name
4861
module = importlib.import_module(py_module_name)
49-
neuron_name = "Neuron_" + model_name + suffix + "(1.0)"
62+
neuron_name = "Neuron_" + model_name + "(1.0)" # 1.0 is a dummy value for the timestep
5063
neuron = eval("module." + neuron_name)
64+
5165
parameters_list = [p for p in dir(neuron.Parameters_) if not "__" in p]
5266
parameters = {p: getattr(neuron, "get_" + p)() for p in parameters_list}
5367

54-
state_list = [p for p in dir(neuron.State_) if not "__" in p]
68+
if "ode_state_variable_name_to_index" in dir(neuron.State_):
69+
state_list = neuron.State_.ode_state_variable_name_to_index.keys()
70+
else:
71+
state_list = [p for p in dir(neuron.State_) if not "__" in p]
5572
state_vars = {p: getattr(neuron, "get_" + p)() for p in state_list}
5673

5774
return parameters, state_vars
5875

5976
@classmethod
60-
def get_neuron_parameters_and_state(cls, neuron_model_name: str) -> tuple[dict, dict]:
77+
def get_neuron_parameters_and_state(cls, nestml_file_name: str) -> tuple[dict, dict]:
6178
r"""
6279
Get the parameters for the given neuron model. The code is generated for the model for Python standalone target
6380
The parameters and state variables are then queried by creating the neuron in Python standalone simulator.
64-
:param neuron_model_name: Name of the neuron model
81+
:param nestml_file_name: File name of the neuron model
6582
:return: A dictionary of parameters and state variables
6683
"""
67-
parameters, state = cls._get_model_parameters_and_state(neuron_model_name)
84+
parameters, state = cls._get_model_parameters_and_state(nestml_file_name)
6885

6986
if not parameters or not state:
7087
Logger.log_message(None, -1,
71-
"An error occurred while creating the neuron for python standalone target: " + neuron_model_name,
88+
"An error occurred while creating the neuron for Python standalone target: " + nestml_file_name,
7289
None, LoggingLevel.ERROR)
7390
sys.exit(1)
7491
else:
75-
Logger.log_message(None, -1, "The model parameters were successfully queried from python standalone target.",
92+
Logger.log_message(None, -1, "The model parameters were successfully queried from Python standalone target.",
7693
None, LoggingLevel.INFO)
7794

7895
return parameters, state

tests/nest_desktop_tests/nest_desktop_code_generator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_nest_desktop_code_generator(self):
4242
target_path=target_path,
4343
target_platform=target_platform,
4444
module_name=target_path,
45-
logging_level="INFO")
45+
logging_level="DEBUG")
4646

4747
# Read the parameters from the generated json file and match them with the actual values
4848
with open(os.path.join(target_path, "iaf_psc_exp_neuron.json")) as f:

0 commit comments

Comments
 (0)