Skip to content

Commit

Permalink
TST: adds more unit tests to the codebase
Browse files Browse the repository at this point in the history
MNT: linters

TST: complementing tests for sensitivity analysis and removing duplicate piece of code.

DEV: add pragma comments to exclude specific lines from coverage

MNT: fix pylint error
  • Loading branch information
Gui-FernandesBR committed Dec 22, 2024
1 parent 6c656f5 commit 9bd4383
Show file tree
Hide file tree
Showing 15 changed files with 440 additions and 32 deletions.
2 changes: 1 addition & 1 deletion rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,7 +2577,7 @@ def set_earth_geometry(self, datum):
}
try:
return ellipsoid[datum]
except KeyError as e:
except KeyError as e: # pragma: no cover
available_datums = ', '.join(ellipsoid.keys())
raise AttributeError(
f"The reference system '{datum}' is not recognized. Please use one of "
Expand Down
8 changes: 5 additions & 3 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3119,12 +3119,12 @@ def compose(self, func, extrapolate=False):
The result of inputting the function into the function.
"""
# Check if the input is a function
if not isinstance(func, Function):
if not isinstance(func, Function): # pragma: no cover
raise TypeError("Input must be a Function object.")

if isinstance(self.source, np.ndarray) and isinstance(func.source, np.ndarray):
# Perform bounds check for composition
if not extrapolate:
if not extrapolate: # pragma: no cover
if func.min < self.x_initial or func.max > self.x_final:
raise ValueError(
f"Input Function image {func.min, func.max} must be within "
Expand Down Expand Up @@ -3197,7 +3197,7 @@ def savetxt(

# create the datapoints
if callable(self.source):
if lower is None or upper is None or samples is None:
if lower is None or upper is None or samples is None: # pragma: no cover
raise ValueError(
"If the source is a callable, lower, upper and samples"
+ " must be provided."
Expand Down Expand Up @@ -3323,6 +3323,7 @@ def __validate_inputs(self, inputs):
if isinstance(inputs, (list, tuple)):
if len(inputs) == 1:
return inputs
# pragma: no cover
raise ValueError(
"Inputs must be a string or a list of strings with "
"the length of the domain dimension."
Expand All @@ -3335,6 +3336,7 @@ def __validate_inputs(self, inputs):
isinstance(i, str) for i in inputs
):
return inputs
# pragma: no cover
raise ValueError(
"Inputs must be a list of strings with "
"the length of the domain dimension."
Expand Down
2 changes: 1 addition & 1 deletion rocketpy/rocket/aero_surface/nose_cone.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def bluffness(self, value):
raise ValueError(
"Parameter 'bluffness' must be None or 0 when using a nose cone kind 'powerseries'."
)
if value is not None and not (0 <= value <= 1): # pragma: no cover
if value is not None and not 0 <= value <= 1: # pragma: no cover
raise ValueError(
f"Bluffness ratio of {value} is out of range. "
"It must be between 0 and 1."
Expand Down
8 changes: 2 additions & 6 deletions rocketpy/sensitivity/sensitivity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def set_target_variables_nominal(self, target_variables_nominal_value):
self.target_variables_info[target_variable]["nominal_value"] = (
target_variables_nominal_value[i]
)
for i, target_variable in enumerate(self.target_variables_names):
self.target_variables_info[target_variable]["nominal_value"] = (
target_variables_nominal_value[i]
)

self._nominal_target_passed = True

Expand Down Expand Up @@ -356,12 +352,12 @@ def __check_requirements(self):
version = ">=0" if not version else version
try:
check_requirement_version(module_name, version)
except (ValueError, ImportError) as e:
except (ValueError, ImportError) as e: # pragma: no cover
has_error = True
print(
f"The following error occurred while importing {module_name}: {e}"
)
if has_error:
if has_error: # pragma: no cover
print(
"Given the above errors, some methods may not work. Please run "
+ "'pip install rocketpy[sensitivity]' to install extra requirements."
Expand Down
10 changes: 5 additions & 5 deletions rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
self.env = environment
self.rocket = rocket
self.rail_length = rail_length
if self.rail_length <= 0:
if self.rail_length <= 0: # pragma: no cover
raise ValueError("Rail length must be a positive value.")
self.parachutes = self.rocket.parachutes[:]
self.inclination = inclination
Expand Down Expand Up @@ -951,7 +951,7 @@ def __simulate(self, verbose):
for t_root in t_roots
if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1
]
if len(valid_t_root) > 1:
if len(valid_t_root) > 1: # pragma: no cover
raise ValueError(
"Multiple roots found when solving for impact time."
)
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def __init_controllers(self):
self._controllers = self.rocket._controllers[:]
self.sensors = self.rocket.sensors.get_components()
if self._controllers or self.sensors:
if self.time_overshoot:
if self.time_overshoot: # pragma: no cover
self.time_overshoot = False
warnings.warn(
"time_overshoot has been set to False due to the presence "
Expand Down Expand Up @@ -1266,7 +1266,7 @@ def __set_ode_solver(self, solver):
else:
try:
self._solver = ODE_SOLVER_MAP[solver]
except KeyError as e:
except KeyError as e: # pragma: no cover
raise ValueError(
f"Invalid ``ode_solver`` input: {solver}. "
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def udot_rail1(self, t, u, post_processing=False):

return [vx, vy, vz, ax, ay, az, 0, 0, 0, 0, 0, 0, 0]

def udot_rail2(self, t, u, post_processing=False):
def udot_rail2(self, t, u, post_processing=False): # pragma: no cover
"""[Still not implemented] Calculates derivative of u state vector with
respect to time when rocket is flying in 3 DOF motion in the rail.
Expand Down
5 changes: 2 additions & 3 deletions rocketpy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,9 +978,8 @@ def wrapper(*args, **kwargs):
for i in range(max_attempts):
try:
return func(*args, **kwargs)
except (
Exception
) as e: # pragma: no cover # pylint: disable=broad-except
# pylint: disable=broad-except
except Exception as e: # pragma: no cover
if i == max_attempts - 1:
raise e from None
delay = min(delay * 2, max_delay)
Expand Down
23 changes: 21 additions & 2 deletions tests/fixtures/surfaces/surface_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import pytest

from rocketpy import NoseCone, RailButtons, Tail, TrapezoidalFins
from rocketpy.rocket.aero_surface.fins.free_form_fins import FreeFormFins
from rocketpy.rocket.aero_surface import (
EllipticalFins,
FreeFormFins,
NoseCone,
RailButtons,
Tail,
TrapezoidalFins,
)


@pytest.fixture
Expand Down Expand Up @@ -94,3 +100,16 @@ def calisto_rail_buttons():
angular_position=45,
name="Rail Buttons",
)


@pytest.fixture
def elliptical_fin_set():
return EllipticalFins(
n=4,
span=0.100,
root_chord=0.120,
rocket_radius=0.0635,
cant_angle=0,
airfoil=None,
name="Test Elliptical Fins",
)
68 changes: 68 additions & 0 deletions tests/unit/test_aero_surfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import pytest

from rocketpy import NoseCone
Expand Down Expand Up @@ -71,3 +73,69 @@ def test_powerseries_nosecones_setters(power, invalid_power, new_power):
expected_k = (2 * new_power) / ((2 * new_power) + 1)

assert pytest.approx(test_nosecone.k) == expected_k


@patch("matplotlib.pyplot.show")
def test_elliptical_fins_draw(
mock_show, elliptical_fin_set
): # pylint: disable=unused-argument
assert elliptical_fin_set.plots.draw(filename=None) is None


def test_nose_cone_info(calisto_nose_cone):
assert calisto_nose_cone.info() is None


@patch("matplotlib.pyplot.show")
def test_nose_cone_draw(
mock_show, calisto_nose_cone
): # pylint: disable=unused-argument
assert calisto_nose_cone.draw(filename=None) is None


def test_trapezoidal_fins_info(calisto_trapezoidal_fins):
assert calisto_trapezoidal_fins.info() is None


def test_trapezoidal_fins_tip_chord_setter(calisto_trapezoidal_fins):
calisto_trapezoidal_fins.tip_chord = 0.1
assert calisto_trapezoidal_fins.tip_chord == 0.1


def test_trapezoidal_fins_root_chord_setter(calisto_trapezoidal_fins):
calisto_trapezoidal_fins.root_chord = 0.1
assert calisto_trapezoidal_fins.root_chord == 0.1


def test_trapezoidal_fins_sweep_angle_setter(calisto_trapezoidal_fins):
calisto_trapezoidal_fins.sweep_angle = 0.1
assert calisto_trapezoidal_fins.sweep_angle == 0.1


def test_trapezoidal_fins_sweep_length_setter(calisto_trapezoidal_fins):
calisto_trapezoidal_fins.sweep_length = 0.1
assert calisto_trapezoidal_fins.sweep_length == 0.1


def test_tail_info(calisto_tail):
assert calisto_tail.info() is None


def test_tail_length_setter(calisto_tail):
calisto_tail.length = 0.1
assert calisto_tail.length == 0.1


def test_tail_rocket_radius_setter(calisto_tail):
calisto_tail.rocket_radius = 0.1
assert calisto_tail.rocket_radius == 0.1


def test_tail_bottom_radius_setter(calisto_tail):
calisto_tail.bottom_radius = 0.1
assert calisto_tail.bottom_radius == 0.1


def test_tail_top_radius_setter(calisto_tail):
calisto_tail.top_radius = 0.1
assert calisto_tail.top_radius == 0.1
10 changes: 10 additions & 0 deletions tests/unit/test_flight_time_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,13 @@ def test_time_node_lt(flight_calisto):
node2 = flight_calisto.TimeNodes.TimeNode(2.0, [], [], [])
assert node1 < node2
assert not node2 < node1


def test_time_node_repr(flight_calisto):
node = flight_calisto.TimeNodes.TimeNode(1.0, [], [], [])
assert isinstance(repr(node), str)


def test_time_nodes_repr(flight_calisto):
time_nodes = flight_calisto.TimeNodes()
assert isinstance(repr(time_nodes), str)
104 changes: 104 additions & 0 deletions tests/unit/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,107 @@ def test_low_pass_filter(alpha):
f"The filtered value at index {i} is not the expected value. "
f"Expected: {expected}, Actual: {filtered_func.source[i][1]}"
)


def test_average_function_ndarray():

dummy_function = Function(
source=[
[0, 0],
[1, 1],
[2, 0],
[3, 1],
[4, 0],
[5, 1],
[6, 0],
[7, 1],
[8, 0],
[9, 1],
],
inputs=["x"],
outputs=["y"],
)
avg_function = dummy_function.average_function()

assert isinstance(avg_function, Function)
assert np.isclose(avg_function(0), 0)
assert np.isclose(avg_function(9), 0.5)


def test_average_function_callable():

dummy_function = Function(lambda x: 2)
avg_function = dummy_function.average_function(lower=0)

assert isinstance(avg_function, Function)
assert np.isclose(avg_function(1), 2)
assert np.isclose(avg_function(9), 2)


@pytest.mark.parametrize(
"lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive",
[
(0, 10, 100, 1, 0.5, True, True),
(0, 10, 100, 1, 0.5, True, False),
(0, 10, 100, 1, 0.5, False, True),
(0, 10, 100, 1, 0.5, False, False),
(0, 20, 200, 2, 1, True, True),
],
)
def test_short_time_fft(
lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive
):
"""Test the short_time_fft method of the Function class.
Parameters
----------
lower : float
Lower bound of the time range.
upper : float
Upper bound of the time range.
sampling_frequency : float
Sampling frequency at which to perform the Fourier transform.
window_size : float
Size of the window for the STFT, in seconds.
step_size : float
Step size for the window, in seconds.
remove_dc : bool
If True, the DC component is removed from each window before
computing the Fourier transform.
only_positive: bool
If True, only the positive frequencies are returned.
"""
# Generate a test signal
t = np.linspace(lower, upper, int((upper - lower) * sampling_frequency))
signal = np.sin(2 * np.pi * 5 * t) # 5 Hz sine wave
func = Function(np.column_stack((t, signal)))

# Perform STFT
stft_results = func.short_time_fft(
lower=lower,
upper=upper,
sampling_frequency=sampling_frequency,
window_size=window_size,
step_size=step_size,
remove_dc=remove_dc,
only_positive=only_positive,
)

# Check the results
assert isinstance(stft_results, list)
assert all(isinstance(f, Function) for f in stft_results)

for f in stft_results:
assert f.get_inputs() == ["Frequency (Hz)"]
assert f.get_outputs() == ["Amplitude"]
assert f.get_interpolation_method() == "linear"
assert f.get_extrapolation_method() == "zero"

frequencies = f.source[:, 0]
# amplitudes = f.source[:, 1]

if only_positive:
assert np.all(frequencies >= 0)
else:
assert np.all(frequencies >= -sampling_frequency / 2)
assert np.all(frequencies <= sampling_frequency / 2)
Loading

0 comments on commit 9bd4383

Please sign in to comment.