From 2218f0f38f842cc5bfae684b46f673af824480e7 Mon Sep 17 00:00:00 2001 From: Yogiraj Gutte <53410698+yogirajgutte@users.noreply.github.com> Date: Mon, 16 Dec 2024 08:22:33 +0530 Subject: [PATCH] MNT: move piecewise functions to separate file (#746) * MNT: move piecewise functions to separate file closes #667 * improved import for linting * MNT: applying code formaters * ENH: simplifying and optimizing the function, implementing tests. * MNT: update changelog and apply changes suggested in review --------- Co-authored-by: Lucas Prates <57069366+Lucas-Prates@users.noreply.github.com> Co-authored-by: Lucas de Oliveira Prates Co-authored-by: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com> --- CHANGELOG.md | 1 + rocketpy/mathutils/__init__.py | 8 +- rocketpy/mathutils/function.py | 103 ----------------------- rocketpy/mathutils/piecewise_function.py | 94 +++++++++++++++++++++ rocketpy/motors/tank_geometry.py | 3 +- tests/unit/test_piecewise_function.py | 35 ++++++++ 6 files changed, 134 insertions(+), 110 deletions(-) create mode 100644 rocketpy/mathutils/piecewise_function.py create mode 100644 tests/unit/test_piecewise_function.py diff --git a/CHANGELOG.md b/CHANGELOG.md index be9484b04..7e951ff67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ Attention: The newest changes should be on top --> ### Changed +- MNT: move piecewise functions to separate file [#746](https://github.com/RocketPy-Team/RocketPy/pull/746) - DOC: flight comparison improvements [#755](https://github.com/RocketPy-Team/RocketPy/pull/755) ### Fixed diff --git a/rocketpy/mathutils/__init__.py b/rocketpy/mathutils/__init__.py index fad155583..181b40e55 100644 --- a/rocketpy/mathutils/__init__.py +++ b/rocketpy/mathutils/__init__.py @@ -1,7 +1,3 @@ -from .function import ( - Function, - PiecewiseFunction, - funcify_method, - reset_funcified_methods, -) +from .function import Function, funcify_method, reset_funcified_methods +from .piecewise_function import PiecewiseFunction from .vector_matrix import Matrix, Vector diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index ca6005cf3..8ae7a2100 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -3419,109 +3419,6 @@ def __validate_extrapolation(self, extrapolation): return extrapolation -class PiecewiseFunction(Function): - """Class for creating piecewise functions. These kind of functions are - defined by a dictionary of functions, where the keys are tuples that - represent the domain of the function. The domains must be disjoint. - """ - - def __new__( - cls, - source, - inputs=None, - outputs=None, - interpolation="spline", - extrapolation=None, - datapoints=100, - ): - """ - Creates a piecewise function from a dictionary of functions. The keys of - the dictionary must be tuples that represent the domain of the function. - The domains must be disjoint. The piecewise function will be evaluated - at datapoints points to create Function object. - - Parameters - ---------- - source: dictionary - A dictionary of Function objects, where the keys are the domains. - inputs : list of strings - A list of strings that represent the inputs of the function. - outputs: list of strings - A list of strings that represent the outputs of the function. - interpolation: str - The type of interpolation to use. The default value is 'spline'. - extrapolation: str - The type of extrapolation to use. The default value is None. - datapoints: int - The number of points in which the piecewise function will be - evaluated to create a base function. The default value is 100. - """ - if inputs is None: - inputs = ["Scalar"] - if outputs is None: - outputs = ["Scalar"] - # Check if source is a dictionary - if not isinstance(source, dict): - raise TypeError("source must be a dictionary") - # Check if all keys are tuples - for key in source.keys(): - if not isinstance(key, tuple): - raise TypeError("keys of source must be tuples") - # Check if all domains are disjoint - for key1 in source.keys(): - for key2 in source.keys(): - if key1 != key2: - if key1[0] < key2[1] and key1[1] > key2[0]: - raise ValueError("domains must be disjoint") - - # Crate Function - def calc_output(func, inputs): - """Receives a list of inputs value and a function, populates another - list with the results corresponding to the same results. - - Parameters - ---------- - func : Function - The Function object to be - inputs : list, tuple, np.array - The array of points to applied the func to. - - Examples - -------- - >>> inputs = [0, 1, 2, 3, 4, 5] - >>> def func(x): - ... return x*10 - >>> calc_output(func, inputs) - [0, 10, 20, 30, 40, 50] - - Notes - ----- - In the future, consider using the built-in map function from python. - """ - output = np.zeros(len(inputs)) - for j, value in enumerate(inputs): - output[j] = func.get_value_opt(value) - return output - - input_data = [] - output_data = [] - for key in sorted(source.keys()): - i = np.linspace(key[0], key[1], datapoints) - i = i[~np.isin(i, input_data)] - input_data = np.concatenate((input_data, i)) - - f = Function(source[key]) - output_data = np.concatenate((output_data, calc_output(f, i))) - - return Function( - np.concatenate(([input_data], [output_data])).T, - inputs=inputs, - outputs=outputs, - interpolation=interpolation, - extrapolation=extrapolation, - ) - - def funcify_method(*args, **kwargs): # pylint: disable=too-many-statements """Decorator factory to wrap methods as Function objects and save them as cached properties. diff --git a/rocketpy/mathutils/piecewise_function.py b/rocketpy/mathutils/piecewise_function.py new file mode 100644 index 000000000..086e6d1da --- /dev/null +++ b/rocketpy/mathutils/piecewise_function.py @@ -0,0 +1,94 @@ +import numpy as np + +from rocketpy.mathutils.function import Function + + +class PiecewiseFunction(Function): + """Class for creating piecewise functions. These kind of functions are + defined by a dictionary of functions, where the keys are tuples that + represent the domain of the function. The domains must be disjoint. + """ + + def __new__( + cls, + source, + inputs=None, + outputs=None, + interpolation="spline", + extrapolation=None, + datapoints=100, + ): + """ + Creates a piecewise function from a dictionary of functions. The keys of + the dictionary must be tuples that represent the domain of the function. + The domains must be disjoint. The piecewise function will be evaluated + at datapoints points to create Function object. + + Parameters + ---------- + source: dictionary + A dictionary of Function objects, where the keys are the domains. + inputs : list of strings + A list of strings that represent the inputs of the function. + outputs: list of strings + A list of strings that represent the outputs of the function. + interpolation: str + The type of interpolation to use. The default value is 'spline'. + extrapolation: str + The type of extrapolation to use. The default value is None. + datapoints: int + The number of points in which the piecewise function will be + evaluated to create a base function. The default value is 100. + """ + cls.__validate__source(source) + if inputs is None: + inputs = ["Scalar"] + if outputs is None: + outputs = ["Scalar"] + + input_data = np.array([]) + output_data = np.array([]) + for lower, upper in sorted(source.keys()): + grid = np.linspace(lower, upper, datapoints) + + # since intervals are disjoint and sorted, we only need to check + # if the first point is already included + if input_data.size != 0: + if lower == input_data[-1]: + grid = np.delete(grid, 0) + input_data = np.concatenate((input_data, grid)) + + f = Function(source[(lower, upper)]) + output_data = np.concatenate((output_data, f.get_value(grid))) + + return Function( + np.concatenate(([input_data], [output_data])).T, + inputs=inputs, + outputs=outputs, + interpolation=interpolation, + extrapolation=extrapolation, + ) + + @staticmethod + def __validate__source(source): + """Validates that source is dictionary with non-overlapping + intervals + + Parameters + ---------- + source : dict + A dictionary of Function objects, where the keys are the domains. + """ + # Check if source is a dictionary + if not isinstance(source, dict): + raise TypeError("source must be a dictionary") + # Check if all keys are tuples + for key in source.keys(): + if not isinstance(key, tuple): + raise TypeError("keys of source must be tuples") + # Check if all domains are disjoint + for lower1, upper1 in source.keys(): + for lower2, upper2 in source.keys(): + if (lower1, upper1) != (lower2, upper2): + if lower1 < upper2 and upper1 > lower2: + raise ValueError("domains must be disjoint") diff --git a/rocketpy/motors/tank_geometry.py b/rocketpy/motors/tank_geometry.py index 272f8fc93..4fd5910c3 100644 --- a/rocketpy/motors/tank_geometry.py +++ b/rocketpy/motors/tank_geometry.py @@ -2,7 +2,8 @@ import numpy as np -from ..mathutils.function import Function, PiecewiseFunction, funcify_method +from ..mathutils.function import Function, funcify_method +from ..mathutils.piecewise_function import PiecewiseFunction from ..plots.tank_geometry_plots import _TankGeometryPlots from ..prints.tank_geometry_prints import _TankGeometryPrints diff --git a/tests/unit/test_piecewise_function.py b/tests/unit/test_piecewise_function.py new file mode 100644 index 000000000..347f3de27 --- /dev/null +++ b/tests/unit/test_piecewise_function.py @@ -0,0 +1,35 @@ +import pytest + +from rocketpy import PiecewiseFunction + + +@pytest.mark.parametrize( + "source", + [ + ((0, 4), lambda x: x), + {"0-4": lambda x: x}, + {(0, 4): lambda x: x, (3, 5): lambda x: 2 * x}, + ], +) +def test_invalid_source(source): + """Test an error is raised when the source parameter is invalid""" + with pytest.raises((TypeError, ValueError)): + PiecewiseFunction(source) + + +@pytest.mark.parametrize( + "source", + [ + {(-1, 0): lambda x: -x, (0, 1): lambda x: x}, + { + (0, 1): lambda x: x, + (1, 2): lambda x: 1, + (2, 3): lambda x: 3 - x, + }, + ], +) +@pytest.mark.parametrize("inputs", [None, "X"]) +@pytest.mark.parametrize("outputs", [None, "Y"]) +def test_new(source, inputs, outputs): + """Test if PiecewiseFunction.__new__ runs correctly""" + PiecewiseFunction(source, inputs, outputs)