From 4fef66acf087fe1dde6482e87f3336c271cc7a45 Mon Sep 17 00:00:00 2001 From: Bing Li Date: Tue, 26 Nov 2024 16:39:15 -0500 Subject: [PATCH] updating Fit1D class --- src/tavi/data/fit.py | 51 +++++++++++--------------------------------- tests/test_fit.py | 23 ++++++++++---------- 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/src/tavi/data/fit.py b/src/tavi/data/fit.py index ff73789..012a85d 100644 --- a/src/tavi/data/fit.py +++ b/src/tavi/data/fit.py @@ -83,7 +83,7 @@ def __init__( self.background_models: models = [] self.signal_models: models = [] - self.parameters: Optional[Parameters] = None + self._parameters: Optional[Parameters] = None self._num_backgrounds = 0 self._num_signals = 0 self.result: Optional[ModelResult] = None @@ -122,13 +122,7 @@ def add_signal( ): self._num_signals += 1 prefix = f"s{self._num_signals}_" - self.signal_models.append( - Fit1D._add_model( - model, - prefix, - nan_policy=self.nan_policy, - ) - ) + self.signal_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy)) def add_background( self, @@ -143,13 +137,7 @@ def add_background( ): self._num_backgrounds += 1 prefix = f"b{self._num_backgrounds}_" - self.background_models.append( - Fit1D._add_model( - model, - prefix, - nan_policy=self.nan_policy, - ) - ) + self.background_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy)) @staticmethod def _get_param_names(models) -> list[list[str]]: @@ -166,11 +154,15 @@ def signal_param_names(self): def background_param_names(self): return Fit1D._get_param_names(self.background_models) - @staticmethod - def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParams, ...], ...]: - signal_params_list = [] + @property + def params(self) -> dict[str, tuple[FitParams, ...]]: + + all_pars = self.guess() if self._parameters is None else self._parameters + parsms_names = Fit1D._get_param_names(self.signal_models + self.background_models) + params_dict = {} for params in parsms_names: + key = params[0].split("_")[0] params_list = [] for param_name in params: param = all_pars[param_name] @@ -185,26 +177,9 @@ def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParam brute_step=param.brute_step, ) ) - signal_params_list.append(tuple(params_list)) - return tuple(signal_params_list) - - @property - def signal_params(self) -> tuple[tuple[FitParams, ...], ...]: - - pars = self.guess() if self.parameters is None else self.parameters - names = Fit1D._get_param_names(self.signal_models) - signal_params = Fit1D._get_params(pars, names) - - return signal_params - - @property - def background_params(self) -> tuple[tuple[FitParams, ...], ...]: - - pars = self.guess() if self.parameters is None else self.parameters - names = Fit1D._get_param_names(self.background_models) - background_params = Fit1D._get_params(pars, names) + params_dict.update({key: tuple(params_list)}) - return background_params + return params_dict def guess(self) -> Parameters: pars = Parameters() @@ -212,7 +187,7 @@ def guess(self) -> Parameters: pars += signal.guess(self.y, x=self.x) for bkg in self.background_models: pars += bkg.guess(self.y, x=self.x) - self.parameters = pars + self._parameters = pars return pars @property diff --git a/tests/test_fit.py b/tests/test_fit.py index ba3325d..6f9a7e0 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -58,16 +58,17 @@ def test_get_fitting_variables(fit_data): assert f1.signal_param_names == [["s1_amplitude", "s1_center", "s1_sigma"]] assert f1.background_param_names == [["b1_c"]] - assert len(f1.signal_params) == 1 - assert len(f1.signal_params[0]) == 5 + assert len(f1.params) == 2 - assert f1.signal_params[0][0].name == "s1_amplitude" - assert f1.signal_params[0][1].name == "s1_center" - assert f1.signal_params[0][2].name == "s1_sigma" - assert f1.signal_params[0][3].name == "s1_fwhm" - assert f1.signal_params[0][4].name == "s1_height" - assert f1.signal_params[0][4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)" - assert f1.background_params[0][0].name == "b1_c" + s1_params = f1.params["s1"] + assert len(s1_params) == 5 + assert s1_params[0].name == "s1_amplitude" + assert s1_params[1].name == "s1_center" + assert s1_params[2].name == "s1_sigma" + assert s1_params[3].name == "s1_fwhm" + assert s1_params[4].name == "s1_height" + assert s1_params[4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)" + assert f1.params["b1"][0].name == "b1_c" def test_guess_initial(fit_data): @@ -117,9 +118,9 @@ def test_fit_two_peak(fit_data): s1_scan, PLOT = fit_data - f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0)) + f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks") - f1.add_background(values=(0.7,)) + f1.add_background(model="Constant") f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True)) f1.add_signal( model="Gaussian",