Skip to content

Commit

Permalink
added time zone info when converting to nexus
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Nov 22, 2024
1 parent 0aa4bd4 commit 64d5484
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 122 deletions.
102 changes: 58 additions & 44 deletions src/tavi/data/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from lmfit import Parameters, models
from lmfit.model import ModelResult

from tavi.data.scan_data import ScanData1D

Expand All @@ -20,7 +21,7 @@ def __init__(


class Fit1D(object):
"""Fit a 1d curve
"""Fit a 1D curve
Attributes:
Expand Down Expand Up @@ -50,9 +51,13 @@ def __init__(
self,
data: ScanData1D,
fit_range: Optional[tuple[float, float]] = None,
nan_policy: Literal["raise", "propagate", "omit"] = "propagate",
name="",
):
"""initialize a fit model, mask based on fit_range if given"""

self.name = name

self.x: np.ndarray = data.x
self.y: np.ndarray = data.y
self.err: Optional[np.ndarray] = data.err
Expand All @@ -66,6 +71,7 @@ def __init__(
self.fit_data: Optional[FitData1D] = None

self.PLOT_SEPARATELY = False
self.nan_policy = nan_policy

if fit_range is not None:
self.set_range(fit_range)
Expand All @@ -80,41 +86,68 @@ def set_range(self, fit_range: tuple[float, float]):
self.err = self.err[mask]

@staticmethod
def _add_model(model, prefix):
def _add_model(model, prefix, nan_policy):
model = Fit1D.models[model]
return model(prefix=prefix, nan_policy="propagate")
return model(prefix=prefix, nan_policy=nan_policy)

def add_signal(
self,
model_name: Literal[
"Gaussian", "Lorentzian", "Voigt", "PseudoVoigt", "DampedOscillator", "DampedHarmonicOscillator"
model: Literal[
"Gaussian",
"Lorentzian",
"Voigt",
"PseudoVoigt",
"DampedOscillator",
"DampedHarmonicOscillator",
],
):
self._num_signals += 1
prefix = f"s{self._num_signals}_"
self.signal_models.append(Fit1D._add_model(model_name, prefix))
self.signal_models.append(
Fit1D._add_model(
model,
prefix,
nan_policy=self.nan_policy,
)
)

def add_background(
self, model_name: Literal["Constant", "Linear", "Quadratic", "Polynomial", "Exponential", "PowerLaw"]
self,
model: Literal[
"Constant",
"Linear",
"Quadratic",
"Polynomial",
"Exponential",
"PowerLaw",
],
):
self._num_backgrounds += 1
prefix = f"b{self._num_backgrounds}_"
self.background_models.append(Fit1D._add_model(model_name, prefix))
self.background_models.append(
Fit1D._add_model(
model,
prefix,
nan_policy=self.nan_policy,
)
)

@staticmethod
def _get_model_params(models):
def _get_model_params(models) -> list[list[str]]:
params = []
for model in models:
params.append(model.param_names)
return params

def get_signal_params(self):
@property
def signal_params(self) -> list[list[str]]:
return Fit1D._get_model_params(self.signal_models)

def get_background_params(self):
@property
def background_params(self) -> list[list[str]]:
return Fit1D._get_model_params(self.background_models)

def guess(self):
def guess(self) -> Parameters:
pars = Parameters()
for signal in self.signal_models:
pars += signal.guess(self.y, x=self.x)
Expand All @@ -123,47 +156,28 @@ def guess(self):
self.pars = pars
return pars

@property
def x_to_plot(self):
return

def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
mod = self.signal_models[0]
if (sz := len(self.signal_models)) > 1:
for i in range(1, sz):
mod += self.signal_models[i]

for bkg in self.background_models:
mod += bkg
def _build_composite_model(self):
compposite_model = np.sum(self.signal_models + self.background_models)
return compposite_model

def _get_x_to_plot(self, num_of_pts: Optional[int]):
if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts)
else:
raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.")
y_to_plot = mod.eval(pars, x=x_to_plot)
return FitData1D(x_to_plot, y_to_plot)
return x_to_plot

def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
mod = self.signal_models[0]
if (sz := len(self.signal_models)) > 1:
for i in range(1, sz):
mod += self.signal_models[i]
def eval(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> FitData1D:
model = self._build_composite_model()
x_to_plot = self._get_x_to_plot(num_of_pts)
y_to_plot = model.eval(pars, x=x_to_plot)

for bkg in self.background_models:
mod += bkg
return FitData1D(x_to_plot, y_to_plot)

def fit(self, pars: Parameters, num_of_pts: Optional[int] = 100) -> ModelResult:
mod = self._build_composite_model()
result = mod.fit(self.y, pars, x=self.x, weights=self.err)
self.result = result

if num_of_pts is None:
x_to_plot = self.x
elif isinstance(num_of_pts, int):
x_to_plot = np.linspace(self.x.min(), self.x.max(), num=num_of_pts)
else:
raise ValueError(f"num_of_points={num_of_pts} needs to be an integer.")

y_to_plot = mod.eval(result.params, x=x_to_plot)

return FitData1D(x_to_plot, y_to_plot)
return result
14 changes: 12 additions & 2 deletions src/tavi/data/nxdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from zoneinfo import ZoneInfo

import h5py
import numpy as np
Expand Down Expand Up @@ -75,7 +76,17 @@ def __init__(self, ds, **kwargs):
case {"type": "NX_INT"} | {"type": "NX_FLOAT"}:
dataset = _recast_type(ds, kwargs["type"])
case {"type": "NX_DATE_TIME"}:
dataset = datetime.strptime(ds, "%m/%d/%Y %I:%M:%S %p").isoformat()
dt = datetime.strptime(ds, "%m/%d/%Y %I:%M:%S %p")
date = datetime(
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
dt.second,
tzinfo=ZoneInfo("America/New_York"),
)
dataset = date.isoformat()
case _:
dataset = ds

Expand Down Expand Up @@ -464,7 +475,6 @@ def spice_scan_to_nxdict(
nxdata = NXentry(NX_class="NXdata", EX_required="true", signal=def_y, axes=def_x)
# ---------------------------------------- scan ---------------------------------------------

# TODO timezone
start_date_time = "{} {}".format(metadata.get("date"), metadata.get("time"))
# TODO what is last scan never finished?
# if "end_time" in das_logs.attrs:
Expand Down
2 changes: 1 addition & 1 deletion src/tavi/data/nxentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def to_nexus(self, path_to_nexus: str, name="scan") -> None:

# Create the ATTRIBUTES
nexus_file.attrs["file_name"] = os.path.abspath(path_to_nexus)
nexus_file.attrs["file_time"] = datetime.now().isoformat()
nexus_file.attrs["file_time"] = datetime.now().astimezone().isoformat()
nexus_file.attrs["h5py_version"] = h5py.version.version
nexus_file.attrs["HDF5_Version"] = h5py.version.hdf5_version

Expand Down
19 changes: 13 additions & 6 deletions src/tavi/plotter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# import matplotlib.colors as colors
from functools import partial
from typing import Optional
from typing import Optional, Union

import numpy as np
from lmfit.model import ModelResult
from mpl_toolkits.axisartist.grid_finder import MaxNLocator
from mpl_toolkits.axisartist.grid_helper_curvelinear import GridHelperCurveLinear

Expand Down Expand Up @@ -54,10 +55,13 @@ def add_scan(self, scan_data: ScanData1D, **kwargs):
for key, val in kwargs.items():
scan_data.fmt.update({key: val})

def add_fit(self, fit_data: FitData1D, **kwargs):
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})
def add_fit(self, fit_data: Union[FitData1D, ModelResult], PLOT_COMPONENTS=False, **kwargs):
if PLOT_COMPONENTS:
pass
else:
self.fit_data.append(fit_data)
for key, val in kwargs.items():
fit_data.fmt.update({key: val})

def plot(self, ax):
for data in self.scan_data:
Expand All @@ -66,7 +70,10 @@ def plot(self, ax):
else:
ax.errorbar(x=data.x, y=data.y, yerr=data.err, **data.fmt)
for fit in self.fit_data:
ax.plot(fit.x, fit.y, **fit.fmt)
if fit.PLOT_INDIVIDUALLY:
pass
else:
ax.plot(fit.x, fit.y, **fit.fmt)

if self.xlim is not None:
ax.set_xlim(left=self.xlim[0], right=self.xlim[1])
Expand Down
Binary file modified test_data/scan_to_nexus_test.h5
Binary file not shown.
Binary file modified test_data/spice_to_nxdict_test_scan0034.h5
Binary file not shown.
Loading

0 comments on commit 64d5484

Please sign in to comment.