Skip to content

Commit

Permalink
implemented the fit class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Aug 8, 2024
1 parent cba1dea commit 1415e5c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 45 deletions.
7 changes: 4 additions & 3 deletions scripts/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ def test_fit_scan(tavi):
curve1 = s1.generate_curve(norm_channel="mcu", norm_val=30, rebin_type="grid", rebin_step=0.25)

x, y, xerr, yerr, xlabel, ylabel, title, label = curve1
f1 = Fit(x=x, y=y, err=yerr, fit_range=(0.5, 3))
f1 = Fit(x=x, y=y, fit_range=(0.0, 4))
f1.add_background()
f1.add_signal()
# f1.add_signal()
f1.add_signal(model="Gaussian", values=(0, None, None), vary=(False, True, True))
f1.perform_fit()

p1.plot_curve(*curve1)
p1.plot_curve(f1.x_plot, f1.y_plot, fmt="-")

# s2 = tavi.data[datasets]["scan0043"]
# curve2 = s2.generate_curve(norm_channel="mcu", norm_val=30, rebin_type="grid", rebin_step=0.25)
Expand All @@ -30,7 +31,7 @@ def test_fit_scan(tavi):
if __name__ == "__main__":
tavi = TAVI()

tavi_file_name = "./tests/test_data_folder/tavi_test_exp424.h5"
tavi_file_name = "./test_data/tavi_test_exp424.h5"
tavi.open_tavi_file(tavi_file_name)

test_fit_scan(tavi)
124 changes: 82 additions & 42 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import matplotlib.pyplot as plt
import numpy as np
from lmfit import models
from lmfit import Parameters, models


class Fit(object):
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(self, x, y, err=None, fit_range=None):
err (list | None)
fit_range (tuple)
"""

self.NUM_PTS = 100
self.range = fit_range
self.x = np.array(x)
self.y = np.array(y)
Expand All @@ -51,22 +50,31 @@ def __init__(self, x, y, err=None, fit_range=None):
if self.err is not None:
self.err = self.err[mask]

self.x_plot = np.linspace(
self.x.min(),
self.x.max(),
num=self.NUM_PTS,
)

self.background_models = []
self.signal_models = []
self.pars = Parameters()
self.num_backgrounds = 0
self.num_signals = 0
self.FIT_STATUS = None
self.chi_squred = 0
self.chisqr = 0
self.PLOT_SEPARATELY = False
self.y_plot = None
self.best_values = None

def add_background(
self,
model="Constant",
p0=None,
min=None,
max=None,
fixed=None,
expr=None,
values=None,
vary=None,
mins=None,
maxs=None,
exprs=None,
):
"""Set the model for background
Expand All @@ -87,21 +95,42 @@ def add_background(
else:
prefix = ""
model = Fit.models[model](prefix=prefix, nan_policy="propagate")
num_params = len(model.param_names)

pass

# pars = model.guess(self.y, x=self.x)
# pars["c"].set(value=0.7, vary=True, expr="")
pars = model.guess(self.y, x=self.x)
if values is None:
values = []
for idx, param in enumerate(model.param_names):
values.append(pars[param].value)
if vary is None:
vary = [True] * num_params
if mins is None:
mins = [None] * num_params
if maxs is None:
maxs = [None] * num_params
if exprs is None:
exprs = [None] * num_params

for idx, param in enumerate(model.param_names):
self.pars.add(
param,
value=pars[param].value,
vary=vary[idx],
min=mins[idx],
max=maxs[idx],
expr=exprs[idx],
)

self.background_models.append(model)

def add_signal(
self,
model="Gaussian",
p0=None,
min=None,
max=None,
expr=None,
values=None,
vary=None,
mins=None,
maxs=None,
exprs=None,
):
"""Set the model for background
Expand All @@ -116,40 +145,51 @@ def add_signal(
self.num_signals += 1
prefix = f"s{self.num_signals}_"
model = Fit.models[model](prefix=prefix, nan_policy="propagate")
print(model.param_names)
self.signal_models.append(model)

pars = model.guess(self.y, x=self.x)
pars["c"].set(value=0.7, vary=True, expr="")

def perform_fit(self):
num_params = len(model.param_names)
if values is None:
values = []
for idx, param in enumerate(model.param_names):
values.append(pars[param].value)
if vary is None:
vary = [True] * num_params
if mins is None:
mins = [None] * num_params
if maxs is None:
maxs = [None] * num_params
if exprs is None:
exprs = [None] * num_params

# pars = model.guess(self.y, x=self.x)
for idx, param in enumerate(model.param_names):
self.pars.add(
param,
value=values[idx],
vary=vary[idx],
min=mins[idx],
max=maxs[idx],
expr=exprs[idx],
)

self.signal_models.append(model)

def perform_fit(self):
model = np.sum(self.signal_models)

if self.num_backgrounds > 0:
model += np.sum(self.background_models)

if self.err is None:
# pars = model.guess(self.y, x=self.x)
out = model.fit(self.y, pars, x=self.x)
out = model.fit(self.y, self.pars, x=self.x)
else:
pars = model.fit(self.y, x=self.x, weights=self.err)

# out = model.fit(self.y, pars, x=self.x)
print(out.fit_report(min_correl=0.25))
out = model.fit(self.y, self.pars, x=self.x, weights=self.err)

self.chisqr = out.chisqr
self.FIT_STATUS = out.success

# # plot fitting results
# fig, ax = plt.subplots()
# if std is None:
# ax.plot(self.scan_data[x_str], self.scan_data[y_str], "o")
# else:
# ax.errorbar(x, y, yerr=std, fmt="o")
# ax.plot(x, out.best_fit, "-")

# if "scan_title" in self.scan_params:
# ax.set_title(self.scan_params["scan_title"])
# ax.set_xlabel(x_str)
# ax.set_ylabel(y_str)
# ax.grid(alpha=0.6)
# ax.set_ylim(bottom=0)
# plt.tight_layout()
# self.y_plot = model.eval(self.pars, x=self.x_plot)
self.y_plot = model.eval(out.params, x=self.x_plot)
self.best_values = out.best_values
print(out.fit_report(min_correl=0.25))

0 comments on commit 1415e5c

Please sign in to comment.