diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 75688ff8e..795087850 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -3117,12 +3117,12 @@ def compose(self, func, extrapolate=False): The result of inputting the function into the function. """ # Check if the input is a function - if not isinstance(func, Function): + if not isinstance(func, Function): # pragma: no cover raise TypeError("Input must be a Function object.") if isinstance(self.source, np.ndarray) and isinstance(func.source, np.ndarray): # Perform bounds check for composition - if not extrapolate: + if not extrapolate: # pragma: no cover if func.min < self.x_initial or func.max > self.x_final: raise ValueError( f"Input Function image {func.min, func.max} must be within " @@ -3195,7 +3195,7 @@ def savetxt( # create the datapoints if callable(self.source): - if lower is None or upper is None or samples is None: + if lower is None or upper is None or samples is None: # pragma: no cover raise ValueError( "If the source is a callable, lower, upper and samples" + " must be provided." @@ -3321,6 +3321,7 @@ def __validate_inputs(self, inputs): if isinstance(inputs, (list, tuple)): if len(inputs) == 1: return inputs + # pragma: no cover raise ValueError( "Inputs must be a string or a list of strings with " "the length of the domain dimension." @@ -3333,6 +3334,7 @@ def __validate_inputs(self, inputs): isinstance(i, str) for i in inputs ): return inputs + # pragma: no cover raise ValueError( "Inputs must be a list of strings with " "the length of the domain dimension." diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 87ab6b33a..f49fee50f 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -615,7 +615,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements self.env = environment self.rocket = rocket self.rail_length = rail_length - if self.rail_length <= 0: + if self.rail_length <= 0: # pragma: no cover raise ValueError("Rail length must be a positive value.") self.parachutes = self.rocket.parachutes[:] self.inclination = inclination @@ -951,7 +951,7 @@ def __simulate(self, verbose): for t_root in t_roots if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1 ] - if len(valid_t_root) > 1: + if len(valid_t_root) > 1: # pragma: no cover raise ValueError( "Multiple roots found when solving for impact time." ) @@ -1226,7 +1226,7 @@ def __init_controllers(self): self._controllers = self.rocket._controllers[:] self.sensors = self.rocket.sensors.get_components() if self._controllers or self.sensors: - if self.time_overshoot: + if self.time_overshoot: # pragma: no cover self.time_overshoot = False warnings.warn( "time_overshoot has been set to False due to the presence " @@ -1266,7 +1266,7 @@ def __set_ode_solver(self, solver): else: try: self._solver = ODE_SOLVER_MAP[solver] - except KeyError as e: + except KeyError as e: # pragma: no cover raise ValueError( f"Invalid ``ode_solver`` input: {solver}. " f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}" @@ -1398,7 +1398,7 @@ def udot_rail1(self, t, u, post_processing=False): return [vx, vy, vz, ax, ay, az, 0, 0, 0, 0, 0, 0, 0] - def udot_rail2(self, t, u, post_processing=False): + def udot_rail2(self, t, u, post_processing=False): # pragma: no cover """[Still not implemented] Calculates derivative of u state vector with respect to time when rocket is flying in 3 DOF motion in the rail. diff --git a/tests/unit/test_function.py b/tests/unit/test_function.py index df540c1ae..776bf7530 100644 --- a/tests/unit/test_function.py +++ b/tests/unit/test_function.py @@ -787,3 +787,107 @@ def test_low_pass_filter(alpha): f"The filtered value at index {i} is not the expected value. " f"Expected: {expected}, Actual: {filtered_func.source[i][1]}" ) + + +def test_average_function_ndarray(): + + dummy_function = Function( + source=[ + [0, 0], + [1, 1], + [2, 0], + [3, 1], + [4, 0], + [5, 1], + [6, 0], + [7, 1], + [8, 0], + [9, 1], + ], + inputs=["x"], + outputs=["y"], + ) + avg_function = dummy_function.average_function() + + assert isinstance(avg_function, Function) + assert np.isclose(avg_function(0), 0) + assert np.isclose(avg_function(9), 0.5) + + +def test_average_function_callable(): + + dummy_function = Function(lambda x: 2) + avg_function = dummy_function.average_function(lower=0) + + assert isinstance(avg_function, Function) + assert np.isclose(avg_function(1), 2) + assert np.isclose(avg_function(9), 2) + + +@pytest.mark.parametrize( + "lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive", + [ + (0, 10, 100, 1, 0.5, True, True), + (0, 10, 100, 1, 0.5, True, False), + (0, 10, 100, 1, 0.5, False, True), + (0, 10, 100, 1, 0.5, False, False), + (0, 20, 200, 2, 1, True, True), + ], +) +def test_short_time_fft( + lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive +): + """Test the short_time_fft method of the Function class. + + Parameters + ---------- + lower : float + Lower bound of the time range. + upper : float + Upper bound of the time range. + sampling_frequency : float + Sampling frequency at which to perform the Fourier transform. + window_size : float + Size of the window for the STFT, in seconds. + step_size : float + Step size for the window, in seconds. + remove_dc : bool + If True, the DC component is removed from each window before + computing the Fourier transform. + only_positive: bool + If True, only the positive frequencies are returned. + """ + # Generate a test signal + t = np.linspace(lower, upper, int((upper - lower) * sampling_frequency)) + signal = np.sin(2 * np.pi * 5 * t) # 5 Hz sine wave + func = Function(np.column_stack((t, signal))) + + # Perform STFT + stft_results = func.short_time_fft( + lower=lower, + upper=upper, + sampling_frequency=sampling_frequency, + window_size=window_size, + step_size=step_size, + remove_dc=remove_dc, + only_positive=only_positive, + ) + + # Check the results + assert isinstance(stft_results, list) + assert all(isinstance(f, Function) for f in stft_results) + + for f in stft_results: + assert f.get_inputs() == ["Frequency (Hz)"] + assert f.get_outputs() == ["Amplitude"] + assert f.get_interpolation_method() == "linear" + assert f.get_extrapolation_method() == "zero" + + frequencies = f.source[:, 0] + amplitudes = f.source[:, 1] + + if only_positive: + assert np.all(frequencies >= 0) + else: + assert np.all(frequencies >= -sampling_frequency / 2) + assert np.all(frequencies <= sampling_frequency / 2)