diff --git a/src/tavi/data/fit.py b/src/tavi/data/fit.py index afad0aa..ab15842 100644 --- a/src/tavi/data/fit.py +++ b/src/tavi/data/fit.py @@ -2,6 +2,7 @@ import numpy as np from lmfit import Parameters, models +from lmfit.model import ModelResult from tavi.data.scan_data import ScanData1D @@ -20,7 +21,7 @@ def __init__( class Fit1D(object): - """Fit a 1d curve + """Fit a 1D curve Attributes: @@ -50,9 +51,13 @@ def __init__( self, data: ScanData1D, fit_range: Optional[tuple[float, float]] = None, + nan_policy: Literal["raise", "propagate", "omit"] = "propagate", + name="", ): """initialize a fit model, mask based on fit_range if given""" + self.name = name + self.x: np.ndarray = data.x self.y: np.ndarray = data.y self.err: Optional[np.ndarray] = data.err @@ -66,6 +71,7 @@ def __init__( self.fit_data: Optional[FitData1D] = None self.PLOT_SEPARATELY = False + self.nan_policy = nan_policy if fit_range is not None: self.set_range(fit_range) @@ -80,41 +86,68 @@ def set_range(self, fit_range: tuple[float, float]): self.err = self.err[mask] @staticmethod - def _add_model(model, prefix): + def _add_model(model, prefix, nan_policy): model = Fit1D.models[model] - return model(prefix=prefix, nan_policy="propagate") + return model(prefix=prefix, nan_policy=nan_policy) def add_signal( self, - model_name: Literal[ - "Gaussian", "Lorentzian", "Voigt", "PseudoVoigt", "DampedOscillator", "DampedHarmonicOscillator" + model: Literal[ + "Gaussian", + "Lorentzian", + "Voigt", + "PseudoVoigt", + "DampedOscillator", + "DampedHarmonicOscillator", ], ): self._num_signals += 1 prefix = f"s{self._num_signals}_" - self.signal_models.append(Fit1D._add_model(model_name, prefix)) + self.signal_models.append( + Fit1D._add_model( + model, + prefix, + nan_policy=self.nan_policy, + ) + ) def add_background( - self, model_name: Literal["Constant", "Linear", "Quadratic", "Polynomial", "Exponential", "PowerLaw"] + self, + model: Literal[ + "Constant", + "Linear", + "Quadratic", + "Polynomial", + "Exponential", + "PowerLaw", + ], ): self._num_backgrounds += 1 prefix = f"b{self._num_backgrounds}_" - self.background_models.append(Fit1D._add_model(model_name, prefix)) + self.background_models.append( + Fit1D._add_model( + model, + prefix, + nan_policy=self.nan_policy, + ) + ) @staticmethod - def _get_model_params(models): + def _get_model_params(models) -> list[list[str]]: params = [] for model in models: params.append(model.param_names) return params - def get_signal_params(self): + @property + def signal_params(self) -> list[list[str]]: return Fit1D._get_model_params(self.signal_models) - def get_background_params(self): + @property + def background_params(self) -> list[list[str]]: return Fit1D._get_model_params(self.background_models) - def guess(self): + def guess(self) -> Parameters: pars = Parameters() for signal in self.signal_models: pars += signal.guess(self.y, x=self.x) @@ -123,47 +156,28 @@ def guess(self): self.pars = pars return pars - @property - def x_to_plot(self): - return - - def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D: - mod = self.signal_models[0] - if (sz := len(self.signal_models)) > 1: - for i in range(1, sz): - mod += self.signal_models[i] - - for bkg in self.background_models: - mod += bkg + def _build_composite_model(self): + compposite_model = np.sum(self.signal_models + self.background_models) + return compposite_model + def _get_x_to_plot(self, num_of_pts: Optional[int]): if num_of_pts is None: x_to_plot = self.x elif isinstance(num_of_pts, int): x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts) else: raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.") - y_to_plot = mod.eval(pars, x=x_to_plot) - return FitData1D(x_to_plot, y_to_plot) + return x_to_plot - def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D: - mod = self.signal_models[0] - if (sz := len(self.signal_models)) > 1: - for i in range(1, sz): - mod += self.signal_models[i] + def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D: + model = self._build_composite_model() + x_to_plot = self._get_x_to_plot(num_of_pts) + y_to_plot = model.eval(pars, x=x_to_plot) - for bkg in self.background_models: - mod += bkg + return FitData1D(x_to_plot, y_to_plot) + def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> ModelResult: + mod = self._build_composite_model() result = mod.fit(self.y, pars, x=self.x, weights=self.err) self.result = result - - if num_of_pts is None: - x_to_plot = self.x - elif isinstance(num_of_pts, int): - x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts) - else: - raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.") - - y_to_plot = mod.eval(result.params, x=x_to_plot) - - return FitData1D(x_to_plot, y_to_plot) + return result diff --git a/src/tavi/data/nxdict.py b/src/tavi/data/nxdict.py index 7051acd..56d90de 100644 --- a/src/tavi/data/nxdict.py +++ b/src/tavi/data/nxdict.py @@ -4,6 +4,7 @@ from datetime import datetime from pathlib import Path from typing import Optional +from zoneinfo import ZoneInfo import h5py import numpy as np @@ -75,7 +76,17 @@ def __init__(self, ds, **kwargs): case {"type": "NX_INT"} | {"type": "NX_FLOAT"}: dataset = _recast_type(ds, kwargs["type"]) case {"type": "NX_DATE_TIME"}: - dataset = datetime.strptime(ds, "%m/%d/%Y %I:%M:%S %p").isoformat() + dt = datetime.strptime(ds, "%m/%d/%Y %I:%M:%S %p") + date = datetime( + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + tzinfo=ZoneInfo("America/New_York"), + ) + dataset = date.isoformat() case _: dataset = ds @@ -464,7 +475,6 @@ def spice_scan_to_nxdict( nxdata = NXentry(NX_class="NXdata", EX_required="true", signal=def_y, axes=def_x) # ---------------------------------------- scan --------------------------------------------- - # TODO timezone start_date_time = "{} {}".format(metadata.get("date"), metadata.get("time")) # TODO what is last scan never finished? # if "end_time" in das_logs.attrs: diff --git a/src/tavi/data/nxentry.py b/src/tavi/data/nxentry.py index 92227a3..08980a7 100644 --- a/src/tavi/data/nxentry.py +++ b/src/tavi/data/nxentry.py @@ -256,7 +256,7 @@ def to_nexus(self, path_to_nexus: str, name="scan") -> None: # Create the ATTRIBUTES nexus_file.attrs["file_name"] = os.path.abspath(path_to_nexus) - nexus_file.attrs["file_time"] = datetime.now().isoformat() + nexus_file.attrs["file_time"] = datetime.now().astimezone().isoformat() nexus_file.attrs["h5py_version"] = h5py.version.version nexus_file.attrs["HDF5_Version"] = h5py.version.hdf5_version diff --git a/src/tavi/plotter.py b/src/tavi/plotter.py index 0c8e2ac..13222b5 100644 --- a/src/tavi/plotter.py +++ b/src/tavi/plotter.py @@ -1,8 +1,9 @@ # import matplotlib.colors as colors from functools import partial -from typing import Optional +from typing import Optional, Union import numpy as np +from lmfit.model import ModelResult from mpl_toolkits.axisartist.grid_finder import MaxNLocator from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear @@ -54,10 +55,13 @@ def add_scan(self, scan_data: ScanData1D, **kwargs): for key, val in kwargs.items(): scan_data.fmt.update({key: val}) - def add_fit(self, fit_data: FitData1D, **kwargs): - self.fit_data.append(fit_data) - for key, val in kwargs.items(): - fit_data.fmt.update({key: val}) + def add_fit(self, fit_data: Union[FitData1D, ModelResult], PLOT_COMPONENTS=False, **kwargs): + if PLOT_COMPONENTS: + pass + else: + self.fit_data.append(fit_data) + for key, val in kwargs.items(): + fit_data.fmt.update({key: val}) def plot(self, ax): for data in self.scan_data: @@ -66,7 +70,10 @@ def plot(self, ax): else: ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt) for fit in self.fit_data: - ax.plot(fit.x, fit.y, **fit.fmt) + if fit.PLOT_INDIVIDUALLY: + pass + else: + ax.plot(fit.x, fit.y, **fit.fmt) if self.xlim is not None: ax.set_xlim(left=self.xlim[0], right=self.xlim[1]) diff --git a/test_data/scan_to_nexus_test.h5 b/test_data/scan_to_nexus_test.h5 index bc1eb57..4c5952b 100644 Binary files a/test_data/scan_to_nexus_test.h5 and b/test_data/scan_to_nexus_test.h5 differ diff --git a/test_data/spice_to_nxdict_test_scan0034.h5 b/test_data/spice_to_nxdict_test_scan0034.h5 index ae0bb21..85bc27c 100644 Binary files a/test_data/spice_to_nxdict_test_scan0034.h5 and b/test_data/spice_to_nxdict_test_scan0034.h5 differ diff --git a/tests/test_fit.py b/tests/test_fit.py index 25031bb..f8c83b8 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -29,21 +29,12 @@ def test_fit_single_peak_external_model(): assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01) assert np.allclose(out.redchi, 2.50, atol=0.01) - # p1 = Plot1D() - # p1.add_scan(s1_scan, fmt="o") - # fig, ax = plt.subplots() - # p1.plot(ax) - # ax.plot(f1.x, out.best_fit) - # plt.show() - - # f1.add_background(model="Constant", values=(0.7,)) - # f1.add_signal(values=(None, 3.5, None), vary=(True, True, True)) - # f1.perform_fit() - - # fig, ax = plt.subplots() - # plot1d.plot_curve(ax) - # f1.fit_plot.plot_curve(ax) - # plt.show() + p1 = Plot1D() + p1.add_scan(s1_scan, fmt="o") + fig, ax = plt.subplots() + p1.plot(ax) + ax.plot(f1.x, out.best_fit) + plt.show() def test_get_fitting_variables(): @@ -53,11 +44,11 @@ def test_get_fitting_variables(): s1_scan = scan42.get_data(norm_to=(30, "mcu")) f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) - f1.add_signal(model_name="Gaussian") - f1.add_background(model_name="Constant") + f1.add_signal(model="Gaussian") + f1.add_background(model="Constant") - assert f1.get_signal_params() == [["s1_amplitude", "s1_center", "s1_sigma"]] - assert f1.get_background_params() == [["b1_c"]] + assert f1.signal_params == [["s1_amplitude", "s1_center", "s1_sigma"]] + assert f1.background_params == [["b1_c"]] def test_guess_initial(): @@ -65,18 +56,18 @@ def test_guess_initial(): scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) s1_scan = scan42.get_data(norm_to=(30, "mcu")) - f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) + f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0), name="scan42_fit") - f1.add_signal(model_name="Gaussian") - f1.add_background(model_name="Constant") + f1.add_signal(model="Gaussian") + f1.add_background(model="Constant") pars = f1.guess() inital = f1.eval(pars) - fit = f1.fit(pars) + # fit_result = f1.fit(pars) p1 = Plot1D() p1.add_scan(s1_scan, fmt="o", label="data") p1.add_fit(inital, label="guess") - p1.add_fit(fit, label="fit") + # p1.add_fit(fit_result, label="fit") fig, ax = plt.subplots() p1.plot(ax) @@ -91,38 +82,20 @@ def test_fit_single_peak_internal_model(): s1_scan = scan42.get_data(norm_to=(30, "mcu")) f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) + f1.add_signal(model="Gaussian") f1.add_background(model="Constant") - f1.add_signal(model_name="Gaussian") - f1.eval() - f1.fit() - - bkg = ConstantModel(prefix="bkg_", nan_policy="propagate") - peak = GaussianModel(prefix="peak_", nan_policy="propagate") - model = peak + bkg - pars = peak.guess(f1.y, x=f1.x) - pars += bkg.make_params(c=0) - out = model.fit(f1.y, pars, x=f1.x, weight=f1.err) - - assert np.allclose(out.values["peak_center"], 3.54, atol=0.01) - assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01) - assert np.allclose(out.redchi, 10.012, atol=0.01) + pars = f1.guess() + fit_result = f1.fit(pars) p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o") + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(fit_result, label="fit") + p1.add_fit(fit_result, label="fit", PLOT_COMPONENTS=True) + fig, ax = plt.subplots() p1.plot(ax) - ax.plot(f1.x, out.best_fit) plt.show() - # f1.add_background(model="Constant", values=(0.7,)) - # f1.add_signal(values=(None, 3.5, None), vary=(True, True, True)) - # f1.perform_fit() - - # fig, ax = plt.subplots() - # plot1d.plot_curve(ax) - # f1.fit_plot.plot_curve(ax) - # plt.show() - def test_fit_two_peak(): @@ -135,7 +108,7 @@ def test_fit_two_peak(): f1.add_background(values=(0.7,)) f1.add_signal(values=(None, 3.5, 0.29), vary=(True, True, True)) f1.add_signal( - model_name="Gaussian", + model="Gaussian", values=(None, 0, None), vary=(True, False, True), mins=(0, 0, 0.1), @@ -151,32 +124,45 @@ def test_fit_two_peak(): plt.show() -def test_plot_fit(): +def test_plot(): path_to_spice_folder = "./test_data/exp424" scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) - s1_scan = scan42.get_data(norm_channel="mcu", norm_val=30) + s1_scan = scan42.get_data(norm_to=(30, "mcu")) f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) - bkg = ConstantModel(prefix="bkg_", nan_policy="propagate") - peak = GaussianModel(prefix="peak_", nan_policy="propagate") - model = peak + bkg - pars = peak.guess(f1.y, x=f1.x) - pars += bkg.make_params(c=0) - out = model.fit(f1.y, pars, x=f1.x, weight=f1.err) + f1.add_signal(model="Gaussian") + f1.add_background(model="Constant") + pars = f1.guess() + inital = f1.eval(pars) + fit_result = f1.fit(pars) - assert np.allclose(out.values["peak_center"], 3.54, atol=0.01) - assert np.allclose(out.values["peak_fwhm"], 0.39, atol=0.01) - assert np.allclose(out.redchi, 10.012, atol=0.01) + p1 = Plot1D() + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(inital, label="guess") + p1.add_fit(fit_result, label="fit_result") - comps = out.eval_components(x=f1.x) + _, ax = plt.subplots() + p1.plot(ax) + plt.show() + + +def test_plot_indiviaully(): + path_to_spice_folder = "./test_data/exp424" + scan42 = Scan.from_spice(path_to_spice_folder=path_to_spice_folder, scan_num=42) + + s1_scan = scan42.get_data(norm_to=(30, "mcu")) + f1 = Fit1D(s1_scan, fit_range=(0.5, 4.0)) + + f1.add_signal(model="Gaussian") + f1.add_background(model="Constant") + pars = f1.guess() + fit_result = f1.fit(pars) p1 = Plot1D() - p1.add_scan(s1_scan, fmt="o") - fig, ax = plt.subplots() + p1.add_scan(s1_scan, fmt="o", label="data") + p1.add_fit(fit_result, label="fit_result", PLOT_INDIVIDUALLY=True) + + _, ax = plt.subplots() p1.plot(ax) - ax.plot(f1.x, out.best_fit, label="peak+bkg") - ax.plot(f1.x, comps["peak_"], label="peak") - ax.plot(f1.x, comps["bkg_"], label="bkg") - ax.legend() plt.show() diff --git a/tests/test_nxdict.py b/tests/test_nxdict.py index d261bb8..5422b20 100644 --- a/tests/test_nxdict.py +++ b/tests/test_nxdict.py @@ -74,7 +74,7 @@ def test_spice_scan_to_nxdict(): nxdict = spice_scan_to_nxdict(path_to_spice_data) assert nxdict["SPICElogs"]["attrs"]["scan"] == "34" - assert nxdict["start_time"]["dataset"] == "2024-07-03T01:44:46" + assert nxdict["start_time"]["dataset"] == "2024-07-03T01:44:46-04:00" assert np.allclose(nxdict["instrument"]["monochromator"]["ei"]["dataset"][0:3], [4.9, 5, 5.1]) entries = {"scan0034": nxdict}