Skip to content

Commit

Permalink
syncing
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Dec 6, 2024
1 parent 3853fea commit a74d4d7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 52 deletions.
45 changes: 3 additions & 42 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,6 @@

from tavi.data.scan_data import ScanData1D

# @dataclass
# class FitParam:
# name: str
# value: Optional[float] = None
# vary: bool = True
# min: float = -np.inf
# max: float = np.inf
# expr: Optional[str] = None


# brute_step: Optional[float] = None


# @dataclass
# class FitData1D:
# x: np.ndarray
# y: np.ndarray
# fmt: dict = {}


class FitData1D(object):

def __init__(self, x: np.ndarray, y: np.ndarray) -> None:

self.x = x
self.y = y
self.fmt: dict = {}


class Fit1D(object):
"""Fit a 1D curve
Expand Down Expand Up @@ -83,7 +55,6 @@ def __init__(
self._num_backgrounds = 0
self._num_signals = 0
self.result: Optional[ModelResult] = None
self.fit_data: Optional[FitData1D] = None

self.PLOT_COMPONENTS = False
self.nan_policy = nan_policy
Expand Down Expand Up @@ -202,7 +173,7 @@ def model(self):
compposite_model = np.sum(self._signal_models + self._background_models)
return compposite_model

def x_to_plot(self, num_of_pts: Optional[int]):
def x_to_plot(self, num_of_pts: Optional[int] = 100):
if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
Expand All @@ -211,18 +182,8 @@ def x_to_plot(self, num_of_pts: Optional[int]):
raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.")
return x_to_plot

def eval(self, pars: Optional[Parameters], num_of_pts: Optional[int] = 100, x=None) -> FitData1D:
if pars is None:
pars = self.result.params

if x is not None:
return self.model.eval(pars, x=x)

else:
x_to_plot = self.x_to_plot(num_of_pts)
y_to_plot = self.model.eval(pars, x=x_to_plot)

return FitData1D(x_to_plot, y_to_plot)
def eval(self, pars: Parameters, x: np.ndarray) -> np.ndarray:
return self.model.eval(pars, x=x)

def fit(self, pars: Parameters) -> ModelResult:

Expand Down
21 changes: 17 additions & 4 deletions src/tavi/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@
from mpl_toolkits.axisartist.grid_finder import MaxNLocator
from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear

from tavi.data.fit import Fit1D, FitData1D
from tavi.data.fit import Fit1D
from tavi.data.scan_data import ScanData1D, ScanData2D
from tavi.instrument.resolution.ellipse import ResoEllipse


class FitData1D(object):

def __init__(self, x: np.ndarray, y: np.ndarray) -> None:

self.x = x
self.y = y
self.fmt: dict = {}


class ResoBar(object):

def __init__(self, pos: tuple[float, float], fwhm: float) -> None:

self.pos = pos
Expand Down Expand Up @@ -77,9 +87,12 @@ def _add_fit_from_fitting(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100
for key, val in kwargs.items():
data.fmt.update({key: val})

def add_fit(self, fit_data: Union[FitData1D, Fit1D], num_of_pts: Optional[int] = 100, **kwargs):
if isinstance(fit_data, FitData1D):
self._add_fit_from_eval(fit_data, **kwargs)
def add_fit(
self, fit_data: Union[tuple[np.ndarray, np.ndarray], Fit1D], num_of_pts: Optional[int] = 100, **kwargs
):
if isinstance(fit_data, tuple):
x, y = fit_data
self._add_fit_from_eval(FitData1D(x, y), **kwargs)
elif isinstance(fit_data, Fit1D):
self._add_fit_from_fitting(fit_data, num_of_pts, **kwargs)
else:
Expand Down
13 changes: 7 additions & 6 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ def test_guess_initial(fit_data):
f1.add_signal(model="Gaussian")
f1.add_background(model="Constant")
pars = f1.guess()
inital = f1.eval(pars, num_of_pts=50)
x = f1.x_to_plot(num_of_pts=100)
y = f1.eval(pars, x)
# fit_result = f1.fit(pars)
assert inital.y.shape == (50,)
assert inital.y.min() > 6
assert y.shape == (100,)
assert y.min() > 6

if PLOT:
p1 = Plot1D()
p1.add_scan(s1_scan, fmt="o", label="data")
p1.add_fit(inital, label="guess", color="C1", marker="s", linestyle="dashed", linewidth=2, markersize=4)
p1.add_fit((x, y), label="guess", color="C1", marker="s", linestyle="dashed", linewidth=2, markersize=4)

fig, ax = plt.subplots()
p1.plot(ax)
Expand All @@ -101,8 +102,8 @@ def test_fit_single_peak_internal_model(fit_data):
f1.add_signal(model="Gaussian")
f1.add_background(model="Constant")
pars = f1.guess()
fit_result = f1.fit(pars)
assert np.allclose(fit_result.redchi, 37.6, atol=1)
result = f1.fit(pars)
assert np.allclose(result.redchi, 37.6, atol=1)

if PLOT:
p1 = Plot1D()
Expand Down

0 comments on commit a74d4d7

Please sign in to comment.