Skip to content

Commit

Permalink
Merge pull request #702 from emtee14/enh/monte-carlo-callback
Browse files Browse the repository at this point in the history
ENH: Callback function for collecting additional data from Monte Carlo sims
  • Loading branch information
Lucas-Prates authored Nov 22, 2024
2 parents 5856353 + 6c477e3 commit ac4d3af
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Attention: The newest changes should be on top -->

### Added

- ENH: Callback function for collecting additional data from Monte Carlo sims [#702](https://github.com/RocketPy-Team/RocketPy/pull/702)
- ENH: Implement optional plot saving [#597](https://github.com/RocketPy-Team/RocketPy/pull/597)

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,100 @@
" type=\"impact\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom exports using callback functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have shown, so far, how to perform to use the `MonteCarlo` class and visualize its results. By default, some variables exported to the output files, such as *apogee* and *x_impact*. The `export_list` argument provides a simplified way for the user to export additional variables listed in the documentation, such as *inclination* and *heading*. \n",
"\n",
"There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `data_collector` argument which allows you customize further the output of the simulation.\n",
"\n",
"To exemplify its use, we show how to export the *date* of the environment used in the simulation together with the *average reynolds number* along with the default variables."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized data collector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"# Defining custom callback functions\n",
"def get_average_reynolds_number(flight):\n",
" reynold_number_list = flight.reynolds_number(flight.time)\n",
" average_reynolds_number = np.mean(reynold_number_list)\n",
" return average_reynolds_number\n",
"\n",
"\n",
"def get_date(flight):\n",
" return flight.env.date\n",
"\n",
"\n",
"custom_data_collector = {\n",
" \"average_reynolds_number\": get_average_reynolds_number,\n",
" \"date\": get_date,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `data_collector` must be a dictionary whose keys are the names of the variables we want to export and the values are callback functions (python callables) that compute these variable values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *get_average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n",
"\n",
"After we define the data collector, we pass it as an argument to the `MonteCarlo` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion = MonteCarlo(\n",
" filename=\"monte_carlo_analysis_outputs/monte_carlo_class_example_customized\",\n",
" environment=stochastic_env,\n",
" rocket=stochastic_rocket,\n",
" flight=stochastic_flight,\n",
" export_list=[\"apogee\", \"apogee_time\", \"x_impact\"],\n",
" data_collector=custom_data_collector,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion.simulate(number_of_simulations=10, append=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion.prints.all()"
]
}
],
"metadata": {
Expand All @@ -1134,7 +1228,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
5 changes: 4 additions & 1 deletion rocketpy/prints/monte_carlo_prints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ def all(self):
print(f"{'Parameter':>25} {'Mean':>15} {'Std. Dev.':>15}")
print("-" * 60)
for key, value in self.monte_carlo.processed_results.items():
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
try:
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
except TypeError:
print(f"{key:>25} {str(value[0]):>15} {str(value[1]):>15}")

Check warning on line 30 in rocketpy/prints/monte_carlo_prints.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/prints/monte_carlo_prints.py#L29-L30

Added lines #L29 - L30 were not covered by tests
75 changes: 71 additions & 4 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class MonteCarlo:
The stochastic flight object to be iterated over.
export_list : list
The list of variables to export at each simulation.
data_collector : dict
A dictionary whose keys are the names of the additional
exported variables and the values are callback functions.
inputs_log : list
List of dictionaries with the inputs used in each simulation.
outputs_log : list
Expand Down Expand Up @@ -80,7 +83,13 @@ class MonteCarlo:
"""

def __init__(
self, filename, environment, rocket, flight, export_list=None
self,
filename,
environment,
rocket,
flight,
export_list=None,
data_collector=None,
): # pylint: disable=too-many-statements
"""
Initialize a MonteCarlo object.
Expand All @@ -104,6 +113,17 @@ def __init__(
`out_of_rail_stability_margin`, `out_of_rail_time`,
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
`lateral_surface_wind`. Default is None.
data_collector : dict, optional
A dictionary whose keys are the names of the exported variables
and the values are callback functions. The keys (variable names) must not
overwrite the default names on 'export_list'. The callback functions receive
a Flight object and returns a value of that variable. For instance
.. code-block:: python
custom_data_collector = {
"max_acceleration": lambda flight: max(flight.acceleration(flight.time)),
"date": lambda flight: flight.env.date,
}
Returns
-------
Expand Down Expand Up @@ -132,6 +152,8 @@ def __init__(
self._last_print_len = 0 # used to print on the same line

self.export_list = self.__check_export_list(export_list)
self._check_data_collector(data_collector)
self.data_collector = data_collector

try:
self.import_inputs()
Expand Down Expand Up @@ -359,6 +381,17 @@ def __export_flight_data(
for export_item in self.export_list
}

if self.data_collector is not None:
additional_exports = {}
for key, callback in self.data_collector.items():
try:
additional_exports[key] = callback(flight)
except Exception as e:
raise ValueError(

Check warning on line 390 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L384-L390

Added lines #L384 - L390 were not covered by tests
f"An error was encountered running 'data_collector' callback {key}. "
) from e
results = results | additional_exports

Check warning on line 393 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L393

Added line #L393 was not covered by tests

input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")

Expand Down Expand Up @@ -466,6 +499,37 @@ def __check_export_list(self, export_list):

return export_list

def _check_data_collector(self, data_collector):
"""Check if data collector provided is a valid
Parameters
----------
data_collector : dict
A dictionary whose keys are the names of the exported variables
and the values are callback functions that receive a Flight object
and returns a value of that variable
"""

if data_collector is not None:

if not isinstance(data_collector, dict):
raise ValueError(

Check warning on line 516 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L515-L516

Added lines #L515 - L516 were not covered by tests
"Invalid 'data_collector' argument! "
"It must be a dict of callback functions."
)

for key, callback in data_collector.items():
if key in self.export_list:
raise ValueError(

Check warning on line 523 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L521-L523

Added lines #L521 - L523 were not covered by tests
"Invalid 'data_collector' key! "
f"Variable names overwrites 'export_list' key '{key}'."
)
if not callable(callback):
raise ValueError(

Check warning on line 528 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L527-L528

Added lines #L527 - L528 were not covered by tests
f"Invalid value in 'data_collector' for key '{key}'! "
"Values must be python callables (callback functions)."
)

def __reprint(self, msg, end="\n", flush=False):
"""
Prints a message on the same line as the previous one and replaces the
Expand Down Expand Up @@ -654,9 +718,12 @@ def set_processed_results(self):
"""
self.processed_results = {}
for result, values in self.results.items():
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
try:
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
except TypeError:
self.processed_results[result] = (None, None)

Check warning on line 726 in rocketpy/simulation/monte_carlo.py

View check run for this annotation

Codecov / codecov/patch

rocketpy/simulation/monte_carlo.py#L725-L726

Added lines #L725 - L726 were not covered by tests

# Import methods

Expand Down
55 changes: 55 additions & 0 deletions tests/integration/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,58 @@ def test_monte_carlo_export_ellipses_to_kml(monte_carlo_calisto_pre_loaded):
)

os.remove("monte_carlo_class_example.kml")


@pytest.mark.slow
def test_monte_carlo_callback(monte_carlo_calisto):
"""Tests the data_collector argument of the MonteCarlo class.
Parameters
----------
monte_carlo_calisto : MonteCarlo
The MonteCarlo object, this is a pytest fixture.
"""

# define valid data collector
valid_data_collector = {
"name": lambda flight: flight.name,
"density_t0": lambda flight: flight.env.density(0),
}

monte_carlo_calisto.data_collector = valid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

# tests if print works when we have None in summary
monte_carlo_calisto.info()

## tests if an error is raised for invalid data_collector definitions
# invalid type
def invalid_data_collector(flight):
return flight.name

with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid key overwrite
invalid_data_collector = {"apogee": lambda flight: flight.apogee}
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid callback definition
invalid_data_collector = {"name": "Calisto"} # callbacks must be callables!
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid logic (division by zero)
invalid_data_collector = {
"density_t0": lambda flight: flight.env.density(0) / "0",
}
monte_carlo_calisto.data_collector = invalid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
with pytest.raises(ValueError):
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

os.remove("monte_carlo_test.errors.txt")
os.remove("monte_carlo_test.outputs.txt")
os.remove("monte_carlo_test.inputs.txt")

0 comments on commit ac4d3af

Please sign in to comment.