Skip to content

Commit

Permalink
updating Fit1D class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 26, 2024
1 parent 56bc809 commit 4fef66a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 49 deletions.
51 changes: 13 additions & 38 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(

self.background_models: models = []
self.signal_models: models = []
self.parameters: Optional[Parameters] = None
self._parameters: Optional[Parameters] = None
self._num_backgrounds = 0
self._num_signals = 0
self.result: Optional[ModelResult] = None
Expand Down Expand Up @@ -122,13 +122,7 @@ def add_signal(
):
self._num_signals += 1
prefix = f"s{self._num_signals}_"
self.signal_models.append(
Fit1D._add_model(
model,
prefix,
nan_policy=self.nan_policy,
)
)
self.signal_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))

def add_background(
self,
Expand All @@ -143,13 +137,7 @@ def add_background(
):
self._num_backgrounds += 1
prefix = f"b{self._num_backgrounds}_"
self.background_models.append(
Fit1D._add_model(
model,
prefix,
nan_policy=self.nan_policy,
)
)
self.background_models.append(Fit1D._add_model(model, prefix, nan_policy=self.nan_policy))

@staticmethod
def _get_param_names(models) -> list[list[str]]:
Expand All @@ -166,11 +154,15 @@ def signal_param_names(self):
def background_param_names(self):
return Fit1D._get_param_names(self.background_models)

@staticmethod
def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParams, ...], ...]:
signal_params_list = []
@property
def params(self) -> dict[str, tuple[FitParams, ...]]:

all_pars = self.guess() if self._parameters is None else self._parameters
parsms_names = Fit1D._get_param_names(self.signal_models + self.background_models)
params_dict = {}

for params in parsms_names:
key = params[0].split("_")[0]
params_list = []
for param_name in params:
param = all_pars[param_name]
Expand All @@ -185,34 +177,17 @@ def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParam
brute_step=param.brute_step,
)
)
signal_params_list.append(tuple(params_list))
return tuple(signal_params_list)

@property
def signal_params(self) -> tuple[tuple[FitParams, ...], ...]:

pars = self.guess() if self.parameters is None else self.parameters
names = Fit1D._get_param_names(self.signal_models)
signal_params = Fit1D._get_params(pars, names)

return signal_params

@property
def background_params(self) -> tuple[tuple[FitParams, ...], ...]:

pars = self.guess() if self.parameters is None else self.parameters
names = Fit1D._get_param_names(self.background_models)
background_params = Fit1D._get_params(pars, names)
params_dict.update({key: tuple(params_list)})

return background_params
return params_dict

def guess(self) -> Parameters:
pars = Parameters()
for signal in self.signal_models:
pars += signal.guess(self.y, x=self.x)
for bkg in self.background_models:
pars += bkg.guess(self.y, x=self.x)
self.parameters = pars
self._parameters = pars
return pars

@property
Expand Down
23 changes: 12 additions & 11 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ def test_get_fitting_variables(fit_data):
assert f1.signal_param_names == [["s1_amplitude", "s1_center", "s1_sigma"]]
assert f1.background_param_names == [["b1_c"]]

assert len(f1.signal_params) == 1
assert len(f1.signal_params[0]) == 5
assert len(f1.params) == 2

assert f1.signal_params[0][0].name == "s1_amplitude"
assert f1.signal_params[0][1].name == "s1_center"
assert f1.signal_params[0][2].name == "s1_sigma"
assert f1.signal_params[0][3].name == "s1_fwhm"
assert f1.signal_params[0][4].name == "s1_height"
assert f1.signal_params[0][4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)"
assert f1.background_params[0][0].name == "b1_c"
s1_params = f1.params["s1"]
assert len(s1_params) == 5
assert s1_params[0].name == "s1_amplitude"
assert s1_params[1].name == "s1_center"
assert s1_params[2].name == "s1_sigma"
assert s1_params[3].name == "s1_fwhm"
assert s1_params[4].name == "s1_height"
assert s1_params[4].expr == "0.3989423*s1_amplitude/max(1e-15, s1_sigma)"
assert f1.params["b1"][0].name == "b1_c"


def test_guess_initial(fit_data):
Expand Down Expand Up @@ -117,9 +118,9 @@ def test_fit_two_peak(fit_data):

s1_scan, PLOT = fit_data

f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0))
f1 = Fit1D(s1_scan, fit_range=(0.0, 4.0), name="scan42_fit2peaks")

f1.add_background(values=(0.7,))
f1.add_background(model="Constant")
f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True))
f1.add_signal(
model="Gaussian",
Expand Down

0 comments on commit 4fef66a

Please sign in to comment.