Skip to content

Commit

Permalink
Implementing Fit class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 13, 2024
1 parent c6cecbd commit 7c8b1d8
Show file tree
Hide file tree
Showing 90 changed files with 134 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
tas = CN(SPICE_CONVENTION=True)
tas.load_instrument_params_from_json(instrument_config_json_path)

sample_json_path = "./scripts/IPTS32816_HB1A_exp1034/fesn.json"
sample_json_path = "./test_data/IPTS32816_HB1A_exp1034/fesn.json"
sample = Xtal.from_json(sample_json_path)
tas.mount_sample(sample)

Expand All @@ -19,7 +19,7 @@
R0 = False


path_to_spice_folder = "./scripts/IPTS32816_HB1A_exp1034/exp1034/"
path_to_spice_folder = "./test_data/IPTS32816_HB1A_exp1034/exp1034/"
scan35 = Scan.from_spice(path_to_spice_folder, scan_num=35)
fesn000p5_lscan = scan35.get_data(norm_to=(120, "mcu"))

Expand Down
56 changes: 33 additions & 23 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Optional
from typing import Literal, Optional

import numpy as np
from lmfit import Parameters, models

from tavi.data.scan_data import ScanData1D
from tavi.plotter import Plot1D


class Fit1D(object):
"""Fit a 1d curve
Attributes:
NUM_PTS (int): number of points for the fit curve"""
"""

models = {
# ---------- peak models ---------------
Expand All @@ -32,39 +34,46 @@ class Fit1D(object):
"Spline": models.SplineModel,
}

def __init__(self, plot1d: Plot1D):
"""initialize a fit model"""
def __init__(
self,
data: ScanData1D,
fit_range: Optional[tuple[float, float]] = None,
):
"""initialize a fit model, mask based on fit_range if given"""

self.NUM_PTS: int = 100
self.x = plot1d.x
self.y = plot1d.y
self.yerr = plot1d.yerr
self.x: np.ndarray = data.x
self.y: np.ndarray = data.y
self.err: Optional[np.ndarray] = data.err

self.background_models: models = []
self.signal_models: models = []
self.pars = Parameters()
self.num_backgrounds = 0
self.num_signals = 0
self._num_backgrounds = 0
self._num_signals = 0
self.fit_result = None

self.PLOT_SEPARATELY = False
self.fit_plot: Optional[Plot1D] = None

def set_range(self, fit_min, fit_max):
"""set the range used for fitting"""
if fit_range is not None:
self.set_range(fit_range)

def set_range(self, fit_range: tuple[float, float]):
"""set the range used for fitting"""
fit_min, fit_max = fit_range
mask = np.bitwise_and(self.x >= fit_min, self.x <= fit_max)
self.x = self.x[mask]
self.y = self.y[mask]
if self.yerr is not None:
self.yerr = self.yerr[mask]
if self.err is not None:
self.err = self.err[mask]

@property
def x_plot(self):
return np.linspace(self.x.min(), self.x.max(), num=self.NUM_PTS)

def add_background(
self,
model="Constant",
model: Literal["Constant", "Linear", "Quadratic", "Polynomial", "Exponential", "PowerLaw"] = "Constant",
values=None,
vary=None,
mins=None,
Expand All @@ -82,13 +91,14 @@ def add_background(
fixed (tuple | None): tuple of flags
expr (tuple| None ): constraint expressions
"""
self.num_backgrounds += 1
self._num_backgrounds += 1

# add prefix if more than one background
if self.num_backgrounds > 1:
prefix = f"b{self.num_backgrounds}_"
if self._num_backgrounds > 1:
prefix = f"b{self._num_backgrounds}_"
else:
prefix = ""

model = Fit1D.models[model](prefix=prefix, nan_policy="propagate")
param_names = model.param_names
# guess initials
Expand Down Expand Up @@ -143,8 +153,8 @@ def add_signal(
max (tuple | None): maximum
expr (str| None ): constraint expression
"""
self.num_signals += 1
prefix = f"s{self.num_signals}_"
self._num_signals += 1
prefix = f"s{self._num_signals}_"
model = Fit1D.models[model](prefix=prefix, nan_policy="propagate")
param_names = model.param_names
# guess initials
Expand Down Expand Up @@ -182,13 +192,13 @@ def add_signal(
def perform_fit(self) -> None:
model = np.sum(self.signal_models)

if self.num_backgrounds > 0:
if self._num_backgrounds > 0:
model += np.sum(self.background_models)

if self.yerr is None:
if self.err is None:
out = model.fit(self.y, self.pars, x=self.x)
else:
out = model.fit(self.y, self.pars, x=self.x, weights=self.yerr)
out = model.fit(self.y, self.pars, x=self.x, weights=self.err)

self.result = out
self.y_plot = model.eval(out.params, x=self.x_plot)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
112 changes: 99 additions & 13 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,84 @@
# -*- coding: utf-8 -*
import matplotlib.pyplot as plt
import numpy as np
from lmfit.models import ConstantModel, GaussianModel

from tavi.data.fit import Fit1D
from tavi.data.scan import Scan
from tavi.plotter import Plot1D


def test_fit_single_peak():
def test_fit_single_peak_external_model():

nexus_file_name = "./test_data/IPTS32124_CG4C_exp0424/scan0042.h5"
_, s1 = Scan.from_nexus_file(nexus_file_name)
path_to_spice_folder = "./test_data/exp424"
scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42)

plot1d = s1.generate_curve(norm_channel="mcu", norm_val=30)
f1 = Fit1D(plot1d)
f1.set_range(0.5, 4.0)
f1.add_background(values=(0.7,))
f1.add_signal(values=(None, 3.5, None), vary=(True, True, True))
f1.perform_fit()
assert np.allclose(f1.result.params["s1_center"].value, 3.54, atol=0.01)
assert np.allclose(f1.result.params["s1_fwhm"].value, 0.39, atol=0.01)
s1_scan = scan42.get_data(norm_channel="mcu", norm_val=30)
f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0))

bkg = ConstantModel(prefix="bkg_", nan_policy="propagate")
peak = GaussianModel(prefix="peak_", nan_policy="propagate")
model = peak + bkg
pars = peak.guess(f1.y, x=f1.x)
pars += bkg.make_params(c=0)
out = model.fit(f1.y, pars, x=f1.x, weight=f1.err)

assert np.allclose(out.values["peak_center"], 3.54, atol=0.01)
assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01)
assert np.allclose(out.redchi, 10.012, atol=0.01)

# p1 = Plot1D()
# p1.add_scan(s1_scan, fmt="o")
# fig, ax = plt.subplots()
# p1.plot(ax)
# ax.plot(f1.x, out.best_fit)
# plt.show()

# f1.add_background(model="Constant", values=(0.7,))
# f1.add_signal(values=(None, 3.5, None), vary=(True, True, True))
# f1.perform_fit()

# fig, ax = plt.subplots()
# plot1d.plot_curve(ax)
# f1.fit_plot.plot_curve(ax)
# plt.show()


def test_fit_single_peak_internal_model():

path_to_spice_folder = "./test_data/exp424"
scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42)

s1_scan = scan42.get_data(norm_channel="mcu", norm_val=30)
f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0))

bkg = ConstantModel(prefix="bkg_", nan_policy="propagate")
peak = GaussianModel(prefix="peak_", nan_policy="propagate")
model = peak + bkg
pars = peak.guess(f1.y, x=f1.x)
pars += bkg.make_params(c=0)
out = model.fit(f1.y, pars, x=f1.x, weight=f1.err)

assert np.allclose(out.values["peak_center"], 3.54, atol=0.01)
assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01)
assert np.allclose(out.redchi, 10.012, atol=0.01)

p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o")
fig, ax = plt.subplots()
plot1d.plot_curve(ax)
f1.fit_plot.plot_curve(ax)
p1.plot(ax)
ax.plot(f1.x, out.best_fit)
plt.show()

# f1.add_background(model="Constant", values=(0.7,))
# f1.add_signal(values=(None, 3.5, None), vary=(True, True, True))
# f1.perform_fit()

# fig, ax = plt.subplots()
# plot1d.plot_curve(ax)
# f1.fit_plot.plot_curve(ax)
# plt.show()


def test_fit_two_peak():

Expand All @@ -50,3 +105,34 @@ def test_fit_two_peak():
plot1d.plot_curve(ax)
f1.fit_plot.plot_curve(ax)
plt.show()


def test_plot_fit():
path_to_spice_folder = "./test_data/exp424"
scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42)

s1_scan = scan42.get_data(norm_channel="mcu", norm_val=30)
f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0))

bkg = ConstantModel(prefix="bkg_", nan_policy="propagate")
peak = GaussianModel(prefix="peak_", nan_policy="propagate")
model = peak + bkg
pars = peak.guess(f1.y, x=f1.x)
pars += bkg.make_params(c=0)
out = model.fit(f1.y, pars, x=f1.x, weight=f1.err)

assert np.allclose(out.values["peak_center"], 3.54, atol=0.01)
assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01)
assert np.allclose(out.redchi, 10.012, atol=0.01)

comps = out.eval_components(x=f1.x)

p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o")
fig, ax = plt.subplots()
p1.plot(ax)
ax.plot(f1.x, out.best_fit, label="peak+bkg")
ax.plot(f1.x, comps["peak_"], label="peak")
ax.plot(f1.x, comps["bkg_"], label="bkg")
ax.legend()
plt.show()

0 comments on commit 7c8b1d8

Please sign in to comment.