diff --git a/src/tavi/data/fit.py b/src/tavi/data/fit.py index c86421b..fc4872e 100644 --- a/src/tavi/data/fit.py +++ b/src/tavi/data/fit.py @@ -6,34 +6,6 @@ from tavi.data.scan_data import ScanData1D -# @dataclass -# class FitParam: -# name: str -# value: Optional[float] = None -# vary: bool = True -# min: float = -np.inf -# max: float = np.inf -# expr: Optional[str] = None - - -# brute_step: Optional[float] = None - - -# @dataclass -# class FitData1D: -# x: np.ndarray -# y: np.ndarray -# fmt: dict = {} - - -class FitData1D(object): - - def __init__(self, x: np.ndarray, y: np.ndarray) -> None: - - self.x = x - self.y = y - self.fmt: dict = {} - class Fit1D(object): """Fit a 1D curve @@ -83,7 +55,6 @@ def __init__( self._num_backgrounds = 0 self._num_signals = 0 self.result: Optional[ModelResult] = None - self.fit_data: Optional[FitData1D] = None self.PLOT_COMPONENTS = False self.nan_policy = nan_policy @@ -202,7 +173,7 @@ def model(self): compposite_model = np.sum(self._signal_models + self._background_models) return compposite_model - def x_to_plot(self, num_of_pts: Optional[int]): + def x_to_plot(self, num_of_pts: Optional[int] = 100): if num_of_pts is None: x_to_plot = self.x elif isinstance(num_of_pts, int): @@ -211,18 +182,8 @@ def x_to_plot(self, num_of_pts: Optional[int]): raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.") return x_to_plot - def eval(self, pars: Optional[Parameters], num_of_pts: Optional[int] = 100, x=None) -> FitData1D: - if pars is None: - pars = self.result.params - - if x is not None: - return self.model.eval(pars, x=x) - - else: - x_to_plot = self.x_to_plot(num_of_pts) - y_to_plot = self.model.eval(pars, x=x_to_plot) - - return FitData1D(x_to_plot, y_to_plot) + def eval(self, pars: Parameters, x: np.ndarray) -> np.ndarray: + return self.model.eval(pars, x=x) def fit(self, pars: Parameters) -> ModelResult: diff --git a/src/tavi/plotter.py b/src/tavi/plotter.py index df95c9f..219d5c4 100644 --- a/src/tavi/plotter.py +++ b/src/tavi/plotter.py @@ -6,12 +6,22 @@ from mpl_toolkits.axisartist.grid_finder import MaxNLocator from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear -from tavi.data.fit import Fit1D, FitData1D +from tavi.data.fit import Fit1D from tavi.data.scan_data import ScanData1D, ScanData2D from tavi.instrument.resolution.ellipse import ResoEllipse +class FitData1D(object): + + def __init__(self, x: np.ndarray, y: np.ndarray) -> None: + + self.x = x + self.y = y + self.fmt: dict = {} + + class ResoBar(object): + def __init__(self, pos: tuple[float, float], fwhm: float) -> None: self.pos = pos @@ -77,9 +87,12 @@ def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100 for key, val in kwargs.items(): data.fmt.update({key: val}) - def add_fit(self, fit_data: Union[FitData1D, Fit1D], num_of_pts: Optional[int] = 100, **kwargs): - if isinstance(fit_data, FitData1D): - self._add_fit_from_eval(fit_data, **kwargs) + def add_fit( + self, fit_data: Union[tuple[np.ndarray, np.ndarray], Fit1D], num_of_pts: Optional[int] = 100, **kwargs + ): + if isinstance(fit_data, tuple): + x, y = fit_data + self._add_fit_from_eval(FitData1D(x, y), **kwargs) elif isinstance(fit_data, Fit1D): self._add_fit_from_fitting(fit_data, num_of_pts, **kwargs) else: diff --git a/tests/test_fit.py b/tests/test_fit.py index d4a338e..48046e7 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -78,15 +78,16 @@ def test_guess_initial(fit_data): f1.add_signal(model="Gaussian") f1.add_background(model="Constant") pars = f1.guess() - inital = f1.eval(pars, num_of_pts=50) + x = f1.x_to_plot(num_of_pts=100) + y = f1.eval(pars, x) # fit_result = f1.fit(pars) - assert inital.y.shape == (50,) - assert inital.y.min() > 6 + assert y.shape == (100,) + assert y.min() > 6 if PLOT: p1 = Plot1D() p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(inital, label="guess", color="C1", marker="s", linestyle="dashed", linewidth=2, markersize=4) + p1.add_fit((x, y), label="guess", color="C1", marker="s", linestyle="dashed", linewidth=2, markersize=4) fig, ax = plt.subplots() p1.plot(ax) @@ -101,8 +102,8 @@ def test_fit_single_peak_internal_model(fit_data): f1.add_signal(model="Gaussian") f1.add_background(model="Constant") pars = f1.guess() - fit_result = f1.fit(pars) - assert np.allclose(fit_result.redchi, 37.6, atol=1) + result = f1.fit(pars) + assert np.allclose(result.redchi, 37.6, atol=1) if PLOT: p1 = Plot1D()