Skip to content

Commit

Permalink
Merge branch 'develop' into docs/red-2023-flight-sim
Browse files Browse the repository at this point in the history
  • Loading branch information
LUCKIN13 authored Nov 23, 2024
2 parents 239a396 + ac4d3af commit 05882f1
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Attention: The newest changes should be on top -->
### Added

- DOC: Camoes Flight Example [#733](https://github.com/RocketPy-Team/RocketPy/pull/733)
- ENH: Callback function for collecting additional data from Monte Carlo sims [#702](https://github.com/RocketPy-Team/RocketPy/pull/702)
- ENH: Implement optional plot saving [#597](https://github.com/RocketPy-Team/RocketPy/pull/597)

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,100 @@
" type=\"impact\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom exports using callback functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have shown, so far, how to perform to use the `MonteCarlo` class and visualize its results. By default, some variables exported to the output files, such as *apogee* and *x_impact*. The `export_list` argument provides a simplified way for the user to export additional variables listed in the documentation, such as *inclination* and *heading*. \n",
"\n",
"There are applications in which you might need to extract more information in the results than the `export_list` argument can handle. To that end, the `MonteCarlo` class has a `data_collector` argument which allows you customize further the output of the simulation.\n",
"\n",
"To exemplify its use, we show how to export the *date* of the environment used in the simulation together with the *average reynolds number* along with the default variables."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the `stochastic_env`, `stochastic_rocket` and `stochastic_flight` objects previously defined, and only change the `MonteCarlo` object. First, we need to define our customized data collector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"\n",
"# Defining custom callback functions\n",
"def get_average_reynolds_number(flight):\n",
" reynold_number_list = flight.reynolds_number(flight.time)\n",
" average_reynolds_number = np.mean(reynold_number_list)\n",
" return average_reynolds_number\n",
"\n",
"\n",
"def get_date(flight):\n",
" return flight.env.date\n",
"\n",
"\n",
"custom_data_collector = {\n",
" \"average_reynolds_number\": get_average_reynolds_number,\n",
" \"date\": get_date,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `data_collector` must be a dictionary whose keys are the names of the variables we want to export and the values are callback functions (python callables) that compute these variable values. Notice how we can compute complex expressions in this function and just export the result. For instance, the *get_average_reynolds_number* calls the `flight.reynolds_number` method for each value in `flight.time` list and computes the average value using numpy's `mean`. The *date* variable is straightforward.\n",
"\n",
"After we define the data collector, we pass it as an argument to the `MonteCarlo` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion = MonteCarlo(\n",
" filename=\"monte_carlo_analysis_outputs/monte_carlo_class_example_customized\",\n",
" environment=stochastic_env,\n",
" rocket=stochastic_rocket,\n",
" flight=stochastic_flight,\n",
" export_list=[\"apogee\", \"apogee_time\", \"x_impact\"],\n",
" data_collector=custom_data_collector,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion.simulate(number_of_simulations=10, append=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dispersion.prints.all()"
]
}
],
"metadata": {
Expand All @@ -1134,7 +1228,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
5 changes: 4 additions & 1 deletion rocketpy/prints/monte_carlo_prints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ def all(self):
print(f"{'Parameter':>25} {'Mean':>15} {'Std. Dev.':>15}")
print("-" * 60)
for key, value in self.monte_carlo.processed_results.items():
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
try:
print(f"{key:>25} {value[0]:>15.3f} {value[1]:>15.3f}")
except TypeError:
print(f"{key:>25} {str(value[0]):>15} {str(value[1]):>15}")
75 changes: 71 additions & 4 deletions rocketpy/simulation/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class MonteCarlo:
The stochastic flight object to be iterated over.
export_list : list
The list of variables to export at each simulation.
data_collector : dict
A dictionary whose keys are the names of the additional
exported variables and the values are callback functions.
inputs_log : list
List of dictionaries with the inputs used in each simulation.
outputs_log : list
Expand Down Expand Up @@ -80,7 +83,13 @@ class MonteCarlo:
"""

def __init__(
self, filename, environment, rocket, flight, export_list=None
self,
filename,
environment,
rocket,
flight,
export_list=None,
data_collector=None,
): # pylint: disable=too-many-statements
"""
Initialize a MonteCarlo object.
Expand All @@ -104,6 +113,17 @@ def __init__(
`out_of_rail_stability_margin`, `out_of_rail_time`,
`out_of_rail_velocity`, `max_mach_number`, `frontal_surface_wind`,
`lateral_surface_wind`. Default is None.
data_collector : dict, optional
A dictionary whose keys are the names of the exported variables
and the values are callback functions. The keys (variable names) must not
overwrite the default names on 'export_list'. The callback functions receive
a Flight object and returns a value of that variable. For instance
.. code-block:: python
custom_data_collector = {
"max_acceleration": lambda flight: max(flight.acceleration(flight.time)),
"date": lambda flight: flight.env.date,
}
Returns
-------
Expand Down Expand Up @@ -132,6 +152,8 @@ def __init__(
self._last_print_len = 0 # used to print on the same line

self.export_list = self.__check_export_list(export_list)
self._check_data_collector(data_collector)
self.data_collector = data_collector

try:
self.import_inputs()
Expand Down Expand Up @@ -359,6 +381,17 @@ def __export_flight_data(
for export_item in self.export_list
}

if self.data_collector is not None:
additional_exports = {}
for key, callback in self.data_collector.items():
try:
additional_exports[key] = callback(flight)
except Exception as e:
raise ValueError(
f"An error was encountered running 'data_collector' callback {key}. "
) from e
results = results | additional_exports

input_file.write(json.dumps(inputs_dict, cls=RocketPyEncoder) + "\n")
output_file.write(json.dumps(results, cls=RocketPyEncoder) + "\n")

Expand Down Expand Up @@ -466,6 +499,37 @@ def __check_export_list(self, export_list):

return export_list

def _check_data_collector(self, data_collector):
"""Check if data collector provided is a valid
Parameters
----------
data_collector : dict
A dictionary whose keys are the names of the exported variables
and the values are callback functions that receive a Flight object
and returns a value of that variable
"""

if data_collector is not None:

if not isinstance(data_collector, dict):
raise ValueError(
"Invalid 'data_collector' argument! "
"It must be a dict of callback functions."
)

for key, callback in data_collector.items():
if key in self.export_list:
raise ValueError(
"Invalid 'data_collector' key! "
f"Variable names overwrites 'export_list' key '{key}'."
)
if not callable(callback):
raise ValueError(
f"Invalid value in 'data_collector' for key '{key}'! "
"Values must be python callables (callback functions)."
)

def __reprint(self, msg, end="\n", flush=False):
"""
Prints a message on the same line as the previous one and replaces the
Expand Down Expand Up @@ -654,9 +718,12 @@ def set_processed_results(self):
"""
self.processed_results = {}
for result, values in self.results.items():
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
try:
mean = np.mean(values)
stdev = np.std(values)
self.processed_results[result] = (mean, stdev)
except TypeError:
self.processed_results[result] = (None, None)

# Import methods

Expand Down
55 changes: 55 additions & 0 deletions tests/integration/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,58 @@ def test_monte_carlo_export_ellipses_to_kml(monte_carlo_calisto_pre_loaded):
)

os.remove("monte_carlo_class_example.kml")


@pytest.mark.slow
def test_monte_carlo_callback(monte_carlo_calisto):
"""Tests the data_collector argument of the MonteCarlo class.
Parameters
----------
monte_carlo_calisto : MonteCarlo
The MonteCarlo object, this is a pytest fixture.
"""

# define valid data collector
valid_data_collector = {
"name": lambda flight: flight.name,
"density_t0": lambda flight: flight.env.density(0),
}

monte_carlo_calisto.data_collector = valid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

# tests if print works when we have None in summary
monte_carlo_calisto.info()

## tests if an error is raised for invalid data_collector definitions
# invalid type
def invalid_data_collector(flight):
return flight.name

with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid key overwrite
invalid_data_collector = {"apogee": lambda flight: flight.apogee}
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid callback definition
invalid_data_collector = {"name": "Calisto"} # callbacks must be callables!
with pytest.raises(ValueError):
monte_carlo_calisto._check_data_collector(invalid_data_collector)

# invalid logic (division by zero)
invalid_data_collector = {
"density_t0": lambda flight: flight.env.density(0) / "0",
}
monte_carlo_calisto.data_collector = invalid_data_collector
# NOTE: this is really slow, it runs 10 flight simulations
with pytest.raises(ValueError):
monte_carlo_calisto.simulate(number_of_simulations=10, append=False)

os.remove("monte_carlo_test.errors.txt")
os.remove("monte_carlo_test.outputs.txt")
os.remove("monte_carlo_test.inputs.txt")

0 comments on commit 05882f1

Please sign in to comment.