Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/spring_cleaning' into jax
Browse files Browse the repository at this point in the history
  • Loading branch information
1b15 committed Mar 12, 2024
2 parents aa9a80e + 19212bb commit 7774761
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 22 deletions.
12 changes: 6 additions & 6 deletions neurolib/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_bold_variable(self, variables):
default_index = self.state_vars.index(self.default_output)
return variables[default_index]

def simulateBold(self, bold_variable, append=False):
def simulateBold(self, bold_variable, append=True):
"""Gets the default output of the model and simulates the BOLD model.
Adds the simulated BOLD signal to outputs.
"""
Expand Down Expand Up @@ -192,7 +192,7 @@ def run(
The model can be run in three different ways:
1) `model.run()` starts a new run.
2) `model.run(chunkwise=True)` runs the simulation in chunks of length `chunksize`.
3) `mode.run(continue_run=True)` continues the simulation of a previous run.
3) `mode.run(continue_run=True)` continues the simulation of a previous run. This has no effect during the first run.
:param inputs: list of inputs to the model, must have the same order as model.input_vars. Note: no sanity check is performed for performance reasons. Take care of the inputs yourself.
:type inputs: list[np.ndarray|]
Expand All @@ -202,9 +202,9 @@ def run(
:type chunksize: int, optional
:param bold: simulate BOLD signal (only for chunkwise integration), defaults to False
:type bold: bool, optional
:param append_outputs: append new and chunkwise outputs to the outputs attribute, defaults to False. Note: BOLD outputs are always appended
:param append_outputs: append new and chunkwise outputs to the outputs attribute, defaults to False. Note: BOLD outputs are always appended.
:type append_outputs: bool, optional
:param continue_run: continue a simulation by using the initial values from a previous simulation
:param continue_run: continue a simulation by using the initial values from a previous simulation. This has no effect during the first run.
:type continue_run: bool
"""
self.initializeRun(initializeBold=bold)
Expand Down Expand Up @@ -325,8 +325,8 @@ def storeOutputsAndStates(self, t, variables, append=False):

def setInitialValuesToLastState(self):
"""Reads the last state of the model and sets the initial conditions to that state for continuing a simulation."""
if not hasattr(self, "t"):
raise ValueError("You tried using continue_run=True on the first run.")
if not all([sv in self.state for sv in self.state_vars]):
return
for iv, sv in zip(self.init_vars, self.state_vars):
# if state variables are one-dimensional (in space only)
if (self.state[sv].ndim == 0) or (self.state[sv].ndim == 1):
Expand Down
16 changes: 3 additions & 13 deletions neurolib/models/multimodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def __init__(self, model_instance):
self.boldInitialized = False
self.params["sampling_dt"] = self.params["sampling_dt"] or self.params["dt"]

self.start_t = 0.0

logging.info(f"{self.name}: Model initialized.")

def _set_model_params(self):
Expand Down Expand Up @@ -201,13 +199,11 @@ def integrate(self, append_outputs=False, simulate_bold=False, noise_input=None)

# bold simulation after integration
if simulate_bold and self.boldInitialized:
self.simulateBold(result[self.default_output].values.T, append=append_outputs)
self.simulateBold(result[self.default_output].values.T, append=True)

def setInitialValuesToLastState(self):
if not hasattr(self, "t"):
raise ValueError("You tried using continue_run=True on the first run.")
# set start t for next run for the last value now
self.start_t = self.t[-1]
if not self.state:
return
new_initial_state = np.zeros((self.model_instance.initial_state.shape[0], self.maxDelay + 1))
total_vars_counter = 0
for node_idx, node_vars in enumerate(self.state_vars):
Expand All @@ -217,12 +213,6 @@ def setInitialValuesToLastState(self):
# set initial state
self.model_instance.initial_state = new_initial_state

def clearModelState(self):
# set start_t to zero again
self.start_t = 0.0
# `clearModelState` as per base class
super().clearModelState()

def integrateChunkwise(self, chunksize, bold, append_outputs):
raise NotImplementedError("for now...")

Expand Down
3 changes: 0 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_init(self):
self.assertEqual(model.model_instance, fhn_net)
self.assertTrue(isinstance(model.params, star_dotdict))
self.assertTrue(model.integration is None)
self.assertEqual(model.start_t, 0.0)
self.assertEqual(model.num_noise_variables, 4)
self.assertEqual(model.num_state_variables, 4)
max_delay = int(DELAY / model.params["dt"])
Expand Down Expand Up @@ -352,7 +351,6 @@ def test_continue_run_node(self):
self.assertAlmostEqual(model.t[0] - last_t, model.params["dt"] / 1000.0)
# assert start_t is reset to 0, when continue_run=False
model.run()
self.assertEqual(model.start_t, 0.0)

def test_continue_run_network(self):
DELAY = 13.0
Expand All @@ -371,7 +369,6 @@ def test_continue_run_network(self):
self.assertAlmostEqual(model.t[0] - last_t, model.params["dt"] / 1000.0)
# assert start_t is reset to 0, when continue_run=False
model.run()
self.assertEqual(model.start_t, 0.0)


if __name__ == "__main__":
Expand Down

0 comments on commit 7774761

Please sign in to comment.