Skip to content

Commit

Permalink
updated Fit and Plotter class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 25, 2024
1 parent 64d5484 commit 56bc809
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 117 deletions.
96 changes: 78 additions & 18 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Literal, Optional

import numpy as np
Expand All @@ -7,6 +8,24 @@
from tavi.data.scan_data import ScanData1D


@dataclass
class FitParams:
name: str
value: Optional[float] = None
vary: bool = True
min: float = -np.inf
max: float = np.inf
expr: Optional[str] = None
brute_step: Optional[float] = None


# @dataclass
# class FitData1D:
# x: np.ndarray
# y: np.ndarray
# fmt: dict = {}


class FitData1D(object):

def __init__(
Expand Down Expand Up @@ -64,13 +83,13 @@ def __init__(

self.background_models: models = []
self.signal_models: models = []
self.pars = Parameters()
self.parameters: Optional[Parameters] = None
self._num_backgrounds = 0
self._num_signals = 0
self.result = None
self.result: Optional[ModelResult] = None
self.fit_data: Optional[FitData1D] = None

self.PLOT_SEPARATELY = False
self.PLOT_COMPONENTS = False
self.nan_policy = nan_policy

if fit_range is not None:
Expand Down Expand Up @@ -133,34 +152,75 @@ def add_background(
)

@staticmethod
def _get_model_params(models) -> list[list[str]]:
def _get_param_names(models) -> list[list[str]]:
params = []
for model in models:
params.append(model.param_names)
return params

@property
def signal_params(self) -> list[list[str]]:
return Fit1D._get_model_params(self.signal_models)
def signal_param_names(self):
return Fit1D._get_param_names(self.signal_models)

@property
def background_params(self) -> list[list[str]]:
return Fit1D._get_model_params(self.background_models)
def background_param_names(self):
return Fit1D._get_param_names(self.background_models)

@staticmethod
def _get_params(all_pars, parsms_names: list[list[str]]) -> tuple[tuple[FitParams, ...], ...]:
signal_params_list = []

for params in parsms_names:
params_list = []
for param_name in params:
param = all_pars[param_name]
params_list.append(
FitParams(
name=param.name,
value=param.value,
vary=param.vary,
min=param.min,
max=param.max,
expr=param.expr,
brute_step=param.brute_step,
)
)
signal_params_list.append(tuple(params_list))
return tuple(signal_params_list)

@property
def signal_params(self) -> tuple[tuple[FitParams, ...], ...]:

pars = self.guess() if self.parameters is None else self.parameters
names = Fit1D._get_param_names(self.signal_models)
signal_params = Fit1D._get_params(pars, names)

return signal_params

@property
def background_params(self) -> tuple[tuple[FitParams, ...], ...]:

pars = self.guess() if self.parameters is None else self.parameters
names = Fit1D._get_param_names(self.background_models)
background_params = Fit1D._get_params(pars, names)

return background_params

def guess(self) -> Parameters:
pars = Parameters()
for signal in self.signal_models:
pars += signal.guess(self.y, x=self.x)
for bkg in self.background_models:
pars += bkg.guess(self.y, x=self.x)
self.pars = pars
self.parameters = pars
return pars

def _build_composite_model(self):
@property
def model(self):
compposite_model = np.sum(self.signal_models + self.background_models)
return compposite_model

def _get_x_to_plot(self, num_of_pts: Optional[int]):
def x_to_plot(self, num_of_pts: Optional[int]):
if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
Expand All @@ -170,14 +230,14 @@ def _get_x_to_plot(self, num_of_pts: Optional[int]):
return x_to_plot

def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
model = self._build_composite_model()
x_to_plot = self._get_x_to_plot(num_of_pts)
y_to_plot = model.eval(pars, x=x_to_plot)

x_to_plot = self.x_to_plot(num_of_pts)
y_to_plot = self.model.eval(pars, x=x_to_plot)

return FitData1D(x_to_plot, y_to_plot)

def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> ModelResult:
mod = self._build_composite_model()
result = mod.fit(self.y, pars, x=self.x, weights=self.err)
def fit(self, pars: Parameters) -> ModelResult:

result = self.model.fit(self.y, pars, x=self.x, weights=self.err)
self.result = result
return result
return self
40 changes: 30 additions & 10 deletions src/tavi/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
from typing import Optional, Union

import numpy as np
from lmfit.model import ModelResult
from mpl_toolkits.axisartist.grid_finder import MaxNLocator
from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear

from tavi.data.fit import FitData1D
from tavi.data.fit import Fit1D, FitData1D
from tavi.data.scan_data import ScanData1D, ScanData2D
from tavi.instrument.resolution.ellipse import ResoEllipse

Expand Down Expand Up @@ -55,13 +54,37 @@ def add_scan(self, scan_data: ScanData1D, **kwargs):
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def add_fit(self, fit_data: Union[FitData1D, ModelResult], PLOT_COMPONENTS=False, **kwargs):
if PLOT_COMPONENTS:
pass
else:
def add_fit(self, fit_data: Union[FitData1D, Fit1D], num_of_pts: Optional[int] = 100, **kwargs):
"""
Note:
PLOT_COMPONENTS is ignored if fit_data has the type FitData1D
"""
if isinstance(fit_data, FitData1D):
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})
elif isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None:
x = fit_data.x_to_plot(num_of_pts)
data = FitData1D(x=x, y=result.eval(param=result.params, x=x))
self.fit_data.append(data)
for key, val in kwargs.items():
data.fmt.update({key: val})
else:
raise ValueError(f"Invalid input fit_data={fit_data}")

def add_fit_components(self, fit_data: Fit1D, num_of_pts: Optional[int] = 100, **kwargs):
if isinstance(fit_data, Fit1D) and (result := fit_data.result) is not None:
x = fit_data.x_to_plot(num_of_pts)
components = result.eval_components(result.params, x=x)
for i, (prefix, y) in enumerate(components.items()):
data = FitData1D(x=x, y=y)
self.fit_data.append(data)
data.fmt.update({"label": prefix[:-1]})
for key, val in kwargs.items():
data.fmt.update({key: val[i]})

else:
raise ValueError(f"Invalid input fit_data={fit_data}")

def plot(self, ax):
for data in self.scan_data:
Expand All @@ -70,10 +93,7 @@ def plot(self, ax):
else:
ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt)
for fit in self.fit_data:
if fit.PLOT_INDIVIDUALLY:
pass
else:
ax.plot(fit.x, fit.y, **fit.fmt)
ax.plot(fit.x, fit.y, **fit.fmt)

if self.xlim is not None:
ax.set_xlim(left=self.xlim[0], right=self.xlim[1])
Expand Down
Loading

0 comments on commit 56bc809

Please sign in to comment.