diff --git a/src/tavi/data/fit.py b/src/tavi/data/fit.py index ab15842..ff73789 100644 --- a/src/tavi/data/fit.py +++ b/src/tavi/data/fit.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Literal, Optional import numpy as np @@ -7,6 +8,24 @@ from tavi.data.scan_data import ScanData1D +@dataclass +class FitParams: + name: str + value: Optional[float] = None + vary: bool = True + min: float = -np.inf + max: float = np.inf + expr: Optional[str] = None + brute_step: Optional[float] = None + + +# @dataclass +# class FitData1D: +# x: np.ndarray +# y: np.ndarray +# fmt: dict = {} + + class FitData1D(object): def __init__( @@ -64,13 +83,13 @@ def __init__( self.background_models: models = [] self.signal_models: models = [] - self.pars = Parameters() + self.parameters: Optional[Parameters] = None self._num_backgrounds = 0 self._num_signals = 0 - self.result = None + self.result: Optional[ModelResult] = None self.fit_data: Optional[FitData1D] = None - self.PLOT_SEPARATELY = False + self.PLOT_COMPONENTS = False self.nan_policy = nan_policy if fit_range is not None: @@ -133,19 +152,59 @@ def add_background( ) @staticmethod - def _get_model_params(models) -> list[list[str]]: + def _get_param_names(models) -> list[list[str]]: params = [] for model in models: params.append(model.param_names) return params @property - def signal_params(self) -> list[list[str]]: - return Fit1D._get_model_params(self.signal_models) + def signal_param_names(self): + return Fit1D._get_param_names(self.signal_models) @property - def background_params(self) -> list[list[str]]: - return Fit1D._get_model_params(self.background_models) + def background_param_names(self): + return Fit1D._get_param_names(self.background_models) + + @staticmethod + def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParams, ...], ...]: + signal_params_list = [] + + for params in parsms_names: + params_list = [] + for param_name in params: + param = all_pars[param_name] + params_list.append( + FitParams( + name=param.name, + value=param.value, + vary=param.vary, + min=param.min, + max=param.max, + expr=param.expr, + brute_step=param.brute_step, + ) + ) + signal_params_list.append(tuple(params_list)) + return tuple(signal_params_list) + + @property + def signal_params(self) -> tuple[tuple[FitParams, ...], ...]: + + pars = self.guess() if self.parameters is None else self.parameters + names = Fit1D._get_param_names(self.signal_models) + signal_params = Fit1D._get_params(pars, names) + + return signal_params + + @property + def background_params(self) -> tuple[tuple[FitParams, ...], ...]: + + pars = self.guess() if self.parameters is None else self.parameters + names = Fit1D._get_param_names(self.background_models) + background_params = Fit1D._get_params(pars, names) + + return background_params def guess(self) -> Parameters: pars = Parameters() @@ -153,14 +212,15 @@ def guess(self) -> Parameters: pars += signal.guess(self.y, x=self.x) for bkg in self.background_models: pars += bkg.guess(self.y, x=self.x) - self.pars = pars + self.parameters = pars return pars - def _build_composite_model(self): + @property + def model(self): compposite_model = np.sum(self.signal_models + self.background_models) return compposite_model - def _get_x_to_plot(self, num_of_pts: Optional[int]): + def x_to_plot(self, num_of_pts: Optional[int]): if num_of_pts is None: x_to_plot = self.x elif isinstance(num_of_pts, int): @@ -170,14 +230,14 @@ def _get_x_to_plot(self, num_of_pts: Optional[int]): return x_to_plot def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D: - model = self._build_composite_model() - x_to_plot = self._get_x_to_plot(num_of_pts) - y_to_plot = model.eval(pars, x=x_to_plot) + + x_to_plot = self.x_to_plot(num_of_pts) + y_to_plot = self.model.eval(pars, x=x_to_plot) return FitData1D(x_to_plot, y_to_plot) - def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> ModelResult: - mod = self._build_composite_model() - result = mod.fit(self.y, pars, x=self.x, weights=self.err) + def fit(self, pars: Parameters) -> ModelResult: + + result = self.model.fit(self.y, pars, x=self.x, weights=self.err) self.result = result - return result + return self diff --git a/src/tavi/plotter.py b/src/tavi/plotter.py index 13222b5..a910fde 100644 --- a/src/tavi/plotter.py +++ b/src/tavi/plotter.py @@ -3,11 +3,10 @@ from typing import Optional, Union import numpy as np -from lmfit.model import ModelResult from mpl_toolkits.axisartist.grid_finder import MaxNLocator from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear -from tavi.data.fit import FitData1D +from tavi.data.fit import Fit1D, FitData1D from tavi.data.scan_data import ScanData1D, ScanData2D from tavi.instrument.resolution.ellipse import ResoEllipse @@ -55,13 +54,37 @@ def add_scan(self, scan_data: ScanData1D, **kwargs): for key, val in kwargs.items(): scan_data.fmt.update({key: val}) - def add_fit(self, fit_data: Union[FitData1D, ModelResult], PLOT_COMPONENTS=False, **kwargs): - if PLOT_COMPONENTS: - pass - else: + def add_fit(self, fit_data: Union[FitData1D, Fit1D], num_of_pts: Optional[int] = 100, **kwargs): + """ + Note: + PLOT_COMPONENTS is ignored if fit_data has the type FitData1D + """ + if isinstance(fit_data, FitData1D): self.fit_data.append(fit_data) for key, val in kwargs.items(): fit_data.fmt.update({key: val}) + elif isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None: + x = fit_data.x_to_plot(num_of_pts) + data = FitData1D(x=x, y=result.eval(param=result.params, x=x)) + self.fit_data.append(data) + for key, val in kwargs.items(): + data.fmt.update({key: val}) + else: + raise ValueError(f"Invalid input fit_data={fit_data}") + + def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs): + if isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None: + x = fit_data.x_to_plot(num_of_pts) + components = result.eval_components(result.params, x=x) + for i, (prefix, y) in enumerate(components.items()): + data = FitData1D(x=x, y=y) + self.fit_data.append(data) + data.fmt.update({"label": prefix[:-1]}) + for key, val in kwargs.items(): + data.fmt.update({key: val[i]}) + + else: + raise ValueError(f"Invalid input fit_data={fit_data}") def plot(self, ax): for data in self.scan_data: @@ -70,10 +93,7 @@ def plot(self, ax): else: ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt) for fit in self.fit_data: - if fit.PLOT_INDIVIDUALLY: - pass - else: - ax.plot(fit.x, fit.y, **fit.fmt) + ax.plot(fit.x, fit.y, **fit.fmt) if self.xlim is not None: ax.set_xlim(left=self.xlim[0], right=self.xlim[1]) diff --git a/tests/test_fit.py b/tests/test_fit.py index f8c83b8..ba3325d 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -* import matplotlib.pyplot as plt import numpy as np +import pytest from lmfit.models import ConstantModel, GaussianModel from tavi.data.fit import Fit1D @@ -8,12 +9,20 @@ from tavi.plotter import Plot1D -def test_fit_single_peak_external_model(): - +@pytest.fixture +def fit_data(): + PLOT = True path_to_spice_folder = "./test_data/exp424" scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) s1_scan = scan42.get_data(norm_to=(30, "mcu")) + return s1_scan, PLOT + + +def test_fit_single_peak_external_model(fit_data): + + s1_scan, PLOT = fit_data + f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) bkg = ConstantModel(prefix="bkg_", nan_policy="propagate") @@ -31,55 +40,61 @@ def test_fit_single_peak_external_model(): p1 = Plot1D() p1.add_scan(s1_scan, fmt="o") - fig, ax = plt.subplots() - p1.plot(ax) - ax.plot(f1.x, out.best_fit) - plt.show() + if PLOT: + fig, ax = plt.subplots() + p1.plot(ax) + ax.plot(f1.x, out.best_fit) + plt.show() -def test_get_fitting_variables(): - path_to_spice_folder = "./test_data/exp424" - scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) - s1_scan = scan42.get_data(norm_to=(30, "mcu")) +def test_get_fitting_variables(fit_data): + s1_scan, _ = fit_data f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) f1.add_signal(model="Gaussian") f1.add_background(model="Constant") - assert f1.signal_params == [["s1_amplitude", "s1_center", "s1_sigma"]] - assert f1.background_params == [["b1_c"]] + assert f1.signal_param_names == [["s1_amplitude", "s1_center", "s1_sigma"]] + assert f1.background_param_names == [["b1_c"]] + assert len(f1.signal_params) == 1 + assert len(f1.signal_params[0]) == 5 -def test_guess_initial(): - path_to_spice_folder = "./test_data/exp424" - scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) + assert f1.signal_params[0][0].name == "s1_amplitude" + assert f1.signal_params[0][1].name == "s1_center" + assert f1.signal_params[0][2].name == "s1_sigma" + assert f1.signal_params[0][3].name == "s1_fwhm" + assert f1.signal_params[0][4].name == "s1_height" + assert f1.signal_params[0][4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)" + assert f1.background_params[0][0].name == "b1_c" - s1_scan = scan42.get_data(norm_to=(30, "mcu")) + +def test_guess_initial(fit_data): + s1_scan, PLOT = fit_data f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0), name="scan42_fit") f1.add_signal(model="Gaussian") f1.add_background(model="Constant") pars = f1.guess() - inital = f1.eval(pars) + inital = f1.eval(pars, num_of_pts=50) # fit_result = f1.fit(pars) + assert inital.y.shape == (50,) + assert inital.y.min() > 6 - p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(inital, label="guess") - # p1.add_fit(fit_result, label="fit") - - fig, ax = plt.subplots() - p1.plot(ax) - plt.show() + if PLOT: + p1 = Plot1D() + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(inital, label="guess", color="C1", marker="s", linestyle="dashed", linewidth=2, markersize=4) + fig, ax = plt.subplots() + p1.plot(ax) + plt.show() -def test_fit_single_peak_internal_model(): - path_to_spice_folder = "./test_data/exp424" - scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) +def test_fit_single_peak_internal_model(fit_data): - s1_scan = scan42.get_data(norm_to=(30, "mcu")) + s1_scan, PLOT = fit_data f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) f1.add_signal(model="Gaussian") @@ -87,24 +102,23 @@ def test_fit_single_peak_internal_model(): pars = f1.guess() fit_result = f1.fit(pars) - p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(fit_result, label="fit") - p1.add_fit(fit_result, label="fit", PLOT_COMPONENTS=True) + if PLOT: + p1 = Plot1D() + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(fit_result, label="fit", color="C3", num_of_pts=50, marker="^") + p1.add_fit_components(fit_result, color=["C4", "C5"]) - fig, ax = plt.subplots() - p1.plot(ax) - plt.show() + fig, ax = plt.subplots() + p1.plot(ax) + plt.show() -def test_fit_two_peak(): +def test_fit_two_peak(fit_data): - nexus_file_name = "./test_data/IPTS32124_CG4C_exp0424/scan0042.h5" - _, s1 = Scan.from_nexus_file(nexus_file_name) + s1_scan, PLOT = fit_data + + f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0)) - plot1d = s1.generate_curve(norm_channel="mcu", norm_val=30) - f1 = Fit1D(plot1d) - f1.set_range(0.0, 4.0) f1.add_background(values=(0.7,)) f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True)) f1.add_signal( @@ -118,51 +132,12 @@ def test_fit_two_peak(): assert np.allclose(f1.result.params["s1_center"].value, 3.54, atol=0.01) assert np.allclose(f1.result.params["s1_fwhm"].value, 0.40, atol=0.01) - fig, ax = plt.subplots() - plot1d.plot_curve(ax) - f1.fit_plot.plot_curve(ax) - plt.show() - - -def test_plot(): - path_to_spice_folder = "./test_data/exp424" - scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) - - s1_scan = scan42.get_data(norm_to=(30, "mcu")) - f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) - - f1.add_signal(model="Gaussian") - f1.add_background(model="Constant") - pars = f1.guess() - inital = f1.eval(pars) - fit_result = f1.fit(pars) - - p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(inital, label="guess") - p1.add_fit(fit_result, label="fit_result") - - _, ax = plt.subplots() - p1.plot(ax) - plt.show() - - -def test_plot_indiviaully(): - path_to_spice_folder = "./test_data/exp424" - scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) - - s1_scan = scan42.get_data(norm_to=(30, "mcu")) - f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) - - f1.add_signal(model="Gaussian") - f1.add_background(model="Constant") - pars = f1.guess() - fit_result = f1.fit(pars) - - p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o", label="data") - p1.add_fit(fit_result, label="fit_result", PLOT_INDIVIDUALLY=True) + if PLOT: + p1 = Plot1D() + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(fit_result, label="fit", color="C3", num_of_pts=50, marker="^") + p1.add_fit_components(fit_result, color=["C4", "C5"]) - _, ax = plt.subplots() - p1.plot(ax) - plt.show() + fig, ax = plt.subplots() + p1.plot(ax) + plt.show()