Skip to content

Commit

Permalink
Merge pull request #748 from RocketPy-Team/enh/new-ode-solvers
Browse files Browse the repository at this point in the history
ENH: Allow for Alternative and Custom ODE Solvers.
  • Loading branch information
phmbressan authored Dec 7, 2024
2 parents 39d47cf + 83aa20e commit 1e06469
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 13 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"pytest",
"pytz",
"quantile",
"Radau",
"Rdot",
"referece",
"relativetoground",
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Attention: The newest changes should be on top -->

### Added

-
- ENH: Allow for Alternative and Custom ODE Solvers. [#748](https://github.com/RocketPy-Team/RocketPy/pull/748)

### Changed

Expand Down
74 changes: 63 additions & 11 deletions rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import simplekml
from scipy import integrate
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau

from ..mathutils.function import Function, funcify_method
from ..mathutils.vector_matrix import Matrix, Vector
Expand All @@ -24,8 +24,19 @@
quaternions_to_spin,
)

ODE_SOLVER_MAP = {
'RK23': RK23,
'RK45': RK45,
'DOP853': DOP853,
'Radau': Radau,
'BDF': BDF,
'LSODA': LSODA,
}

class Flight: # pylint: disable=too-many-public-methods

# pylint: disable=too-many-public-methods
# pylint: disable=too-many-instance-attributes
class Flight:
"""Keeps all flight information and has a method to simulate flight.
Attributes
Expand Down Expand Up @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
verbose=False,
name="Flight",
equations_of_motion="standard",
ode_solver="LSODA",
):
"""Run a trajectory simulation.
Expand Down Expand Up @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
more restricted set of equations of motion that only works for
solid propulsion rockets. Such equations were used in RocketPy v0
and are kept here for backwards compatibility.
ode_solver : str, ``scipy.integrate.OdeSolver``, optional
Integration method to use to solve the equations of motion ODE.
Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
'LSODA' from ``scipy.integrate.solve_ivp``.
Default is 'LSODA', which is recommended for most flights.
A custom ``scipy.integrate.OdeSolver`` can be passed as well.
For more information on the integration methods, see the scipy
documentation [1]_.
Returns
-------
None
References
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
"""
# Save arguments
self.env = environment
Expand All @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
self.terminate_on_apogee = terminate_on_apogee
self.name = name
self.equations_of_motion = equations_of_motion
self.ode_solver = ode_solver

# Controller initialization
self.__init_controllers()
Expand Down Expand Up @@ -651,15 +677,16 @@ def __simulate(self, verbose):

# Create solver for this flight phase # TODO: allow different integrators
self.function_evaluations.append(0)
phase.solver = integrate.LSODA(

phase.solver = self._solver(
phase.derivative,
t0=phase.t,
y0=self.y_sol,
t_bound=phase.time_bound,
min_step=self.min_time_step,
max_step=self.max_time_step,
rtol=self.rtol,
atol=self.atol,
max_step=self.max_time_step,
min_step=self.min_time_step,
)

# Initialize phase time nodes
Expand Down Expand Up @@ -691,13 +718,14 @@ def __simulate(self, verbose):
for node_index, node in self.time_iterator(phase.time_nodes):
# Determine time bound for this time node
node.time_bound = phase.time_nodes[node_index + 1].t
# NOTE: Setting the time bound and status for the phase solver,
# and updating its internal state for the next integration step.
phase.solver.t_bound = node.time_bound
phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
if self.__is_lsoda:
phase.solver._lsoda_solver._integrator.rwork[0] = (
phase.solver.t_bound
)
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
phase.solver.status = "running"

# Feed required parachute and discrete controller triggers
Expand Down Expand Up @@ -1185,6 +1213,8 @@ def __init_solver_monitors(self):
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

self.__set_ode_solver(self.ode_solver)

def __init_equations_of_motion(self):
"""Initialize equations of motion."""
if self.equations_of_motion == "solid_propulsion":
Expand Down Expand Up @@ -1222,6 +1252,28 @@ def __cache_sensor_data(self):
sensor_data[sensor] = sensor.measured_data[:]
self.sensor_data = sensor_data

def __set_ode_solver(self, solver):
"""Sets the ODE solver to be used in the simulation.
Parameters
----------
solver : str, ``scipy.integrate.OdeSolver``
Integration method to use to solve the equations of motion ODE,
or a custom ``scipy.integrate.OdeSolver``.
"""
if isinstance(solver, OdeSolver):
self._solver = solver
else:
try:
self._solver = ODE_SOLVER_MAP[solver]
except KeyError as e:
raise ValueError(
f"Invalid ``ode_solver`` input: {solver}. "
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
) from e

self.__is_lsoda = hasattr(self._solver, "_lsoda_solver")

@cached_property
def effective_1rl(self):
"""Original rail length minus the distance measured from nozzle exit
Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


@patch("matplotlib.pyplot.show")
def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-argument
# pylint: disable=unused-argument
def test_all_info(mock_show, flight_calisto_robust):
"""Test that the flight class is working as intended. This basically calls
the all_info() method and checks if it returns None. It is not testing if
the values are correct, but whether the method is working without errors.
Expand All @@ -27,6 +28,42 @@ def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-a
assert flight_calisto_robust.all_info() is None


@pytest.mark.slow
@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize("solver_method", ["RK45", "DOP853", "Radau", "BDF"])
# RK23 is unstable and requires a very low tolerance to work
# pylint: disable=unused-argument
def test_all_info_different_solvers(
mock_show, calisto_robust, example_spaceport_env, solver_method
):
"""Test that the flight class is working as intended with different solver
methods. This basically calls the all_info() method and checks if it returns
None. It is not testing if the values are correct, but whether the method is
working without errors.
Parameters
----------
mock_show : unittest.mock.MagicMock
Mock object to replace matplotlib.pyplot.show
calisto_robust : rocketpy.Rocket
Rocket to be simulated. See the conftest.py file for more info.
example_spaceport_env : rocketpy.Environment
Environment to be simulated. See the conftest.py file for more info.
solver_method : str
The solver method to be used in the simulation.
"""
test_flight = Flight(
environment=example_spaceport_env,
rocket=calisto_robust,
rail_length=5.2,
inclination=85,
heading=0,
terminate_on_apogee=False,
ode_solver=solver_method,
)
assert test_flight.all_info() is None


class TestExportData:
"""Tests the export_data method of the Flight class."""

Expand Down

0 comments on commit 1e06469

Please sign in to comment.