From 192ce0d8db3956e7f3814762b85e23d8480be62b Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 10 Dec 2024 10:18:28 -0500 Subject: [PATCH] updated plotting scheme for fits --- src/tavi/data/fit.py | 15 +++++++++++---- src/tavi/plotter.py | 45 ++++++++++++++++++++------------------------ tests/test_fit.py | 8 ++++---- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/tavi/data/fit.py b/src/tavi/data/fit.py index b237c797..4ec6c1ea 100644 --- a/src/tavi/data/fit.py +++ b/src/tavi/data/fit.py @@ -127,9 +127,12 @@ def add_background( def params(self) -> Parameters: """Get fitting parameters as a dictionary with the model prefix being the key""" - if self._parameters is None: - self._parameters = self.guess() + if self.result is not None: + params = self.result.params + else: + params = self.guess() + self._parameters = params return self._parameters def guess(self) -> Parameters: @@ -171,5 +174,9 @@ def eval(self, pars: Parameters, x: np.ndarray) -> np.ndarray: def fit(self, pars: Parameters) -> ModelResult: result = self.model.fit(self.y, pars, x=self.x, weights=self.err) - self.result = result - return result + if result.success: + self.result = result + self._parameters = result.params + return result + else: + return None diff --git a/src/tavi/plotter.py b/src/tavi/plotter.py index 1027ade9..82fc27ba 100644 --- a/src/tavi/plotter.py +++ b/src/tavi/plotter.py @@ -1,6 +1,6 @@ # import matplotlib.colors as colors from functools import partial -from typing import Optional, Union +from typing import Optional import numpy as np from mpl_toolkits.axisartist.grid_finder import MaxNLocator @@ -73,11 +73,6 @@ def add_scan(self, scan_data: ScanData1D, **kwargs): for key, val in kwargs.items(): scan_data.fmt.update({key: val}) - def _add_fit_from_eval(self, fit_data: FitData1D, **kwargs): - self.fit_data.append(fit_data) - for key, val in kwargs.items(): - fit_data.fmt.update({key: val}) - def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs): if (result := fit_data.result) is None: raise ValueError("Fitting result is None.") @@ -87,22 +82,24 @@ def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100 for key, val in kwargs.items(): data.fmt.update({key: val}) - # TODO - def add_fit( - self, fit_data: Union[tuple[np.ndarray, np.ndarray], Fit1D], num_of_pts: Optional[int] = 100, **kwargs - ): - if isinstance(fit_data, tuple): - x, y = fit_data - self._add_fit_from_eval(FitData1D(x, y), **kwargs) - elif isinstance(fit_data, Fit1D): - self._add_fit_from_fitting(fit_data, num_of_pts, **kwargs) - else: - raise ValueError(f"Invalid input fit_data={fit_data}") + def add_fit(self, fit1d: Fit1D, x: Optional[np.ndarray] = None, DISPLAY_PARAMS=True, **kwargs): + if x is None: + x = fit1d.x + if (result := fit1d.result) is None: # evaluate + y = fit1d.eval(fit1d.params, x) + else: # fit + y = result.eval(param=fit1d.params, x=x) + fit_data = FitData1D(x, y) + self.fit_data.append(fit_data) + for key, val in kwargs.items(): + fit_data.fmt.update({key: val}) - # TODO - def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs): - if isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None: - x = fit_data.x_to_plot(num_of_pts) + def add_fit_components(self, fit1d: Fit1D, x: Optional[np.ndarray] = None, DISPLAY_PARAMS=True, **kwargs): + if x is None: + x = fit1d.x + if (result := fit1d.result) is None: # fit first + pass + else: components = result.eval_components(result.params, x=x) num_components = len(components) @@ -115,13 +112,11 @@ def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, * for i, (prefix, y) in enumerate(components.items()): data = FitData1D(x=x, y=y) self.fit_data.append(data) - data.fmt.update({"label": prefix[:-1]}) # remove "_" + label = prefix[:-1] # remove "_" + data.fmt.update({"label": label}) for key, val in kwargs.items(): data.fmt.update({key: val[i]}) - else: - raise ValueError(f"Invalid input fit_data={fit_data}") - def add_reso_bar(self, pos: tuple, fwhm: float, **kwargs): reso_data = ResoBar(pos, fwhm) for key, val in kwargs.items(): diff --git a/tests/test_fit.py b/tests/test_fit.py index c1d27c79..ba0b8200 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -99,7 +99,7 @@ def test_fit_single_peak_internal_model(fit_data): if PLOT: p1 = Plot1D() p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(f1, label="fit", color="C3", num_of_pts=50, marker="^") + p1.add_fit(f1, label="fit", color="C3", marker="^") p1.add_fit_components(f1, color=["C4", "C5"]) fig, ax = plt.subplots() @@ -114,7 +114,7 @@ def test_fit_two_peak(fit_data): f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks") - # f1.add_background(model="Constant") + f1.add_background(model="Constant") f1.add_signal(model="Gaussian") f1.add_signal(model="Gaussian") @@ -130,14 +130,14 @@ def test_fit_two_peak(fit_data): result = f1.fit(pars) assert np.allclose(result.params["s1_center"].value, 3.54, atol=0.01) assert np.allclose(result.params["s1_fwhm"].value, 0.40, atol=0.01) - x = f1.x_to_plot(num_of_pts=100, min=-1, max=5) + x = f1.x_to_plot(num_of_pts=200, min=-1, max=5) # y = f1.eval(result.params, x) if PLOT: p1 = Plot1D() p1.add_scan(s1_scan, fmt="o", label="data") p1.add_fit(f1, x=x, label="fit", color="C3") - p1.add_fit_components(f1, x=x, color=["C4", "C5"]) + p1.add_fit_components(f1, x=x, color=["C1", "C2", "C4"]) fig, ax = plt.subplots() p1.plot(ax)