Skip to content

Commit

Permalink
implementing ScanGroup class
Browse files Browse the repository at this point in the history
  • Loading branch information
bingli621 committed Oct 20, 2024
1 parent 82321cc commit e0b3193
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 116 deletions.
32 changes: 16 additions & 16 deletions src/tavi/data/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,25 @@ def make_labels(
self,
x_str: str,
y_str: str,
norm_to: Optional[tuple[float, str]],
scan_info,
norm_to: tuple[float, str],
label: Optional[str] = None,
title: Optional[str] = None,
):
"""Create axes labels, plot title and curve label"""
if norm_to is not None:
norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
self.ylabel = y_str + "/ " + norm_channel_str
else:
self.ylabel = y_str + f" / {norm_val} " + norm_channel_str

norm_val, norm_channel = norm_to
if norm_channel == "time":
norm_channel_str = "seconds"
else:
norm_channel_str = norm_channel
if norm_val == 1:
self.ylabel = y_str + "/ " + norm_channel_str
else:
self.ylabel = f"{y_str} / {scan_info.preset_value} {scan_info.preset_channel}"
self.ylabel = y_str + f" / {norm_val} " + norm_channel_str

self.xlabel = x_str
self.label = "scan " + str(scan_info.scan_num)
self.title = self.label + ": " + scan_info.scan_title
self.label = label
self.title = title

def plot_curve(self, ax):
if self.yerr is None:
Expand All @@ -79,7 +78,8 @@ def plot_curve(self, ax):
if self.ylim is not None:
ax.set_ylim(bottom=self.ylim[0], top=self.ylim[1])

ax.set_title(self.title)
if self.title is not None:
ax.set_title(self.title)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.ylabel)
ax.grid(alpha=0.6)
Expand Down
12 changes: 9 additions & 3 deletions src/tavi/data/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_data(self) -> dict:
return data_dict

@staticmethod
def validate_rebin_params(rebin_params: float | tuple) -> tuple:
def validate_rebin_params(rebin_params: float | int | tuple) -> tuple:
if isinstance(rebin_params, tuple):
if len(rebin_params) != 3:
raise ValueError("Rebin parameters should have the form (min, max, step)")
Expand Down Expand Up @@ -211,14 +211,18 @@ def get_plot_data(
y_str = self.scan_info.def_y if y_str is None else y_str

scan_data_1d = ScanData1D(x=self.data[x_str], y=self.data[y_str])
label = "scan " + str(self.scan_info.scan_num)
title = f"{label}: {self.scan_info.scan_title}"

if rebin_type is None: # no rebin
if norm_to is not None: # normalize y-axis without rebining along x-axis
norm_val, norm_channel = norm_to
scan_data_1d.renorm(norm_col=self.data[norm_channel] / norm_val)
else:
norm_to = (self.scan_info.preset_value, self.scan_info.preset_channel)

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_to, self.scan_info)
plot1d.make_labels(x_str, y_str, norm_to, label, title)

return plot1d

Expand All @@ -230,6 +234,7 @@ def get_plot_data(
if norm_to is None: # x weighted by preset channel
weight_channel = self.scan_info.preset_channel
scan_data_1d.rebin_tol(rebin_params_tuple, weight_col=self.data[weight_channel])
norm_to = (self.scan_info.preset_value, self.scan_info.preset_channel)
else: # x weighted by normalization channel
norm_val, norm_channel = norm_to
scan_data_1d.rebin_tol_renorm(
Expand All @@ -240,6 +245,7 @@ def get_plot_data(
case "grid":
if norm_to is None:
scan_data_1d.rebin_grid(rebin_params_tuple)
norm_to = (self.scan_info.preset_value, self.scan_info.preset_channel)
else:
norm_val, norm_channel = norm_to
scan_data_1d.rebin_grid_renorm(
Expand All @@ -251,7 +257,7 @@ def get_plot_data(
raise ValueError('Unrecogonized rebin type. Needs to be "tol" or "grid".')

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_to, self.scan_info)
plot1d.make_labels(x_str, y_str, norm_to, label, title)
return plot1d

def plot(
Expand Down
16 changes: 9 additions & 7 deletions src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ class ScanData1D(object):

def __init__(self, x: np.ndarray, y: np.ndarray) -> None:

ind = np.argsort(x)
self.x = x[ind]
self.y = y[ind]
# ind = np.argsort(x)
# self.x = x[ind]
# self.y = y[ind]
self.x = x
self.y = y
self.err = np.sqrt(y)
self._ind = ind
# self._ind = ind

def __add__(self, other):
# check x length, rebin other if do not match
Expand Down Expand Up @@ -72,7 +74,7 @@ def __sub__(self, other):

def renorm(self, norm_col: np.ndarray, norm_val: float = 1.0):
"""Renormalized to norm_val"""
norm_col = norm_col[self._ind]
# norm_col = norm_col[self._ind]
self.y = self.y / norm_col * norm_val
self.err = self.err / norm_col * norm_val

Expand Down Expand Up @@ -111,7 +113,7 @@ def rebin_tol_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
y = np.zeros_like(x_grid)
counts = np.zeros_like(x_grid)

norm_col = norm_col[self._ind]
# norm_col = norm_col[self._ind]

for i, x0 in enumerate(self.x):
idx = np.nanargmax(x_grid + rebin_step / 2 + ScanData1D.ZERO >= x0)
Expand Down Expand Up @@ -153,7 +155,7 @@ def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
y = np.zeros_like(x)
counts = np.zeros_like(x)

norm_col = norm_col[self._ind]
# norm_col = norm_col[self._ind]

for i, x0 in enumerate(self.x): # plus ZERO helps improve precision
idx = np.nanargmax(x + rebin_step / 2 + ScanData1D.ZERO >= x0)
Expand Down
193 changes: 104 additions & 89 deletions src/tavi/data/scan_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class ScanGroup(object):
name (string): Name of combined scans
Methods:
add_scan
remove_scan
get_plot_data
plot_contour
plot
"""

scan_group_number: int = 1
Expand All @@ -41,62 +43,110 @@ def remove_scan(self, scan_num: Union[tuple[str, int], int]):

# TODO non-orthogonal axes for constant E contours

def set_axes(
self,
x: Union[str, tuple[str], None] = None,
y: Union[str, tuple[str], None] = None,
z: Union[str, tuple[str], None] = None,
norm_to: Union[tuple[float, str], tuple[tuple[float, str]], None] = None,
):
"""Set axes and normalization parameters
def _get_plot_data_1d(self, axes, rebin_params, rebin_type, norm_to) -> Plot1D:
x_axis, y_axis = axes
x_list = np.array([])
y_list = np.array([])
norm_list = np.array([])

Args:
norm_to (norm_val (float), norm_channel(str)): value and channel for normalization
norm_channel should be "time", "monitor" or"mcu".
"""
num = len(self.scans)

if x is None:
x_axes = [scan.scan_info.def_x for scan in self.scans]
elif isinstance(x, str):
x_axes = [x] * num
elif isinstance(x, tuple):
if num != len(x):
raise ValueError(f"length of x-axes={x} does not match number of scans.")
x_axes = list(x)

if y is None:
y_axes = [scan.scan_info.def_y for scan in self.scans]
elif isinstance(y, str):
y_axes = [y] * num
elif isinstance(y, tuple):
if num != len(y):
raise ValueError(f"length of y-axes={y} does not match number of scans.")
y_axes = list(y)

if z is None:
z_axes = [None] * num
elif isinstance(z, str):
z_axes = [z] * num
elif isinstance(z, tuple):
if num != len(z):
raise ValueError(f"length of z-axes={z} does not match number of scans.")
z_axes = list(z)

if norm_to is None:
norms = [None] * num
elif isinstance(norm_to, tuple):
for item in norm_to:
if isinstance(item, tuple):
if num != len(norm_to):
raise ValueError(f"length of normalization channels={norm_to} does not match number of scans.")
norms = list(norm_to)
for scan in self.scans:
x_list = np.append(x_list, scan.data[x_axis])
y_list = np.append(y_list, scan.data[y_axis])
scan_data_1d = ScanData1D(np.array(x_list), np.array(y_list))

if rebin_params is None: # no rebin
if norm_to is not None: # normalize y-axis without rebining along x-axis
norm_val, norm_channel = norm_to
for scan in self.scans:
norm_list = np.append(y_list, scan.data[y_axis])
scan_data_1d.renorm(norm_col=norm_list / norm_val)
else:
norm_to = (self.scans[0].scan_info.preset_value, self.scans[0].scan_info.preset_channel)

plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_axis, y_axis, norm_to)

return plot1d

match rebin_type:
case "tol":
if norm_to is None: # x weighted by preset channel
weight_channel = self.scan_info.preset_channel
scan_data_1d.rebin_tol(rebin_params_tuple, weight_col=self.data[weight_channel])
else: # x weighted by normalization channel
norm_val, norm_channel = norm_to
scan_data_1d.rebin_tol_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
norm_val=norm_val,
)
case "grid":
if norm_to is None:
scan_data_1d.rebin_grid(rebin_params_tuple)
else:
norms = [norm_to] * num

self.axes = list(zip(x_axes, y_axes, z_axes, norms))
norm_val, norm_channel = norm_to
scan_data_1d.rebin_grid_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
norm_val=norm_val,
)
return scan_data_1d

def _get_plot_data_2d(self, axes, rebin_params, norm_to) -> Plot2D:
return Plot2D()

def get_plot_data(
self,
axes: Union[tuple[str, str], tuple[str, str, str], None] = None,
rebin_params: Optional[tuple] = None,
norm_to: Optional[tuple[float, str]] = None,
rebin_type: Literal["grid", "tol"] = "grid",
) -> Union[Plot1D, Plot2D]:
"""Get plot data
If axes is None, get default axes and return 1D data"""
if axes is None:
x_axes = []
y_axes = []
for scan in self.scans:
x_axes.append(scan.scan_info.def_x)
y_axes.append(scan.scan_info.def_y)
x_axis = set(x_axes)
y_axis = set(y_axes)

if not (len(x_axis) == 1 and len(y_axis) == 1):
raise ValueError(f"x axes={x_axis} or y axes={y_axis} are not identical.")
axes = (*x_axis, *y_axis)

match len(axes):
case 2:
if rebin_params is not None:
if isinstance(rebin_params, float | int | tuple):
rebin_params = Scan.validate_rebin_params(rebin_params)
else:
raise ValueError(
f"rebin parameters ={rebin_params} needs to be float or int of tuple of size 3"
)
plot_data_1d = self._get_plot_data_1d(axes, rebin_params, rebin_type, norm_to)
return plot_data_1d

case 3:
if not isinstance(rebin_params, tuple):
raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple.")
if not len(rebin_params) == 2:
raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple of size 2.")
rebin_params_list = []
for rebin in rebin_params:
if isinstance(rebin, float | int | tuple):
rebin_params_list.append(Scan.validate_rebin_params(rebin))
rebin_params = tuple(rebin_params_list)

plot_data_2d = self._get_plot_data_2d(axes, rebin_params, norm_to)
return plot_data_2d

case _:
raise ValueError(f"length of axes={axes} should be either 2 or 3.")

# TODO
def get_plot_data_1d(
self,
rebin_type: Literal["tol", "grid", None] = None,
Expand All @@ -107,20 +157,6 @@ def get_plot_data_1d(
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""
ScanData1D()
num_scans = np.size(self.signals)

signal_x, signal_y, signal_z = self.signal_axes

if np.size(signal_x) == 1:
signal_x = [signal_x] * num_scans
xlabel = signal_x[0]
if np.size(signal_y) == 1:
signal_y = [signal_y] * num_scans
ylabel = signal_y[0]
if np.size(signal_z) == 1:
signal_z = [signal_z] * num_scans
zlabel = signal_z[0]

# shape = (num_scans, num_pts)
# x_array = [scan.data[signal_x[i]] for i, scan in enumerate(self.signals)]
Expand Down Expand Up @@ -183,27 +219,6 @@ def get_plot_data_1d(

return (xv, yv, z, x_step, y_step, xlabel, ylabel, zlabel, title)

@staticmethod
def validate_rebin_params_2d(rebin_params_2d: tuple) -> tuple:

params = []
for rebin_params in rebin_params_2d:
if isinstance(rebin_params, tuple):
if len(rebin_params) != 3:
raise ValueError("Rebin parameters should have the form (min, max, step)")
rebin_min, rebin_max, rebin_step = rebin_params
if (rebin_min >= rebin_max) or (rebin_step < 0):
raise ValueError(f"Nonsensical rebin parameters {rebin_params}")
params.append(rebin_params)

elif isinstance(rebin_params, float | int):
if rebin_params < 0:
raise ValueError("Rebin step needs to be greater than zero.")
params.append((None, None, float(rebin_params)))
else:
raise ValueError(f"Unrecogonized rebin parameters {rebin_params}")
return tuple(params)

def get_plot_data_2d(
self,
axes: tuple[str, str, str],
Expand Down
Loading

0 comments on commit e0b3193

Please sign in to comment.