Skip to content

Commit

Permalink
syncing scangroup class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Oct 22, 2024
1 parent ca020d5 commit 5e0275e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 142 deletions.
5 changes: 4 additions & 1 deletion src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ class ScanData2D(object):
def __init__(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:
self.x = x
self.y = y
self.y = z
self.z = z

self.err = np.sqrt(z)
self.title = ""

def __sub__(self, other):
pass
Expand Down
197 changes: 57 additions & 140 deletions src/tavi/data/scan_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, Optional, Union
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from tavi.data.plotter import Plot1D, Plot2D
from tavi.data.scan import Scan
from tavi.data.scan_data import ScanData1D, ScanData2D

Expand All @@ -18,8 +17,6 @@ class ScanGroup(object):
Methods:
add_scan
remove_scan
get_plot_data
plot
"""

scan_group_number: int = 1
Expand Down Expand Up @@ -130,16 +127,48 @@ def _get_data_2d(
norm_to: Optional[tuple[float, str]],
**rebin_params_dict: Optional[tuple],
) -> ScanData2D:
"""
Note:
rebin_params_dict should be grid=(float |tuple[float,float,float],float|tuple[float,float,float]) only
"""

x_axis, y_axis, z_axis = axes
x_array = np.array([])
y_array = np.array([])
z_array = np.array([])

title = "Combined scans: "

if not isinstance(rebin_params, tuple):
raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple.")
if not len(rebin_params) == 2:
raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple of size 2.")
rebin_params_list = []
for rebin in rebin_params:
if isinstance(rebin, float | int | tuple):
rebin_params_list.append(Scan.validate_rebin_params(rebin))
rebin_params = tuple(rebin_params_list)
for scan in self.scans:
x_array = np.append(x_array, scan.data[x_axis])
y_array = np.append(y_array, scan.data[y_axis])
z_array = np.append(y_array, scan.data[z_axis])
title += f"{scan.scan_info.scan_num} "

scan_data_2d = ScanData2D(x=x_array, y=y_array, z=z_array)
rebin_params = rebin_params_dict.get("grid")

if not rebin_params: # no rebin,
if norm_to is not None: # renorm
norm_val, norm_channel = norm_to
norm_list = self._get_norm_list(norm_channel)
scan_data_1d.renorm(norm_col=norm_list, norm_val=norm_val)
else: # no renorm, check if all presets are the same
norm_to = self._get_default_renorm_params()

scan_data_1d.make_labels(axes, norm_to, title=title)
return scan_data_1d

# if not isinstance(rebin_params, tuple):
# raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple.")
# if not len(rebin_params) == 2:
# raise ValueError(f"rebin parameters ={rebin_params} needs to be a tuple of size 2.")
# rebin_params_list = []
# for rebin in rebin_params:
# if isinstance(rebin, float | int | tuple):
# rebin_params_list.append(Scan.validate_rebin_params(rebin))
# rebin_params = tuple(rebin_params_list)
return scan_data_2d

def get_data(
Expand All @@ -156,140 +185,28 @@ def get_data(
rebin_params_dict could be either "tol" or "grid" for 1D data, but only
"grid" for 2D data.
"""
if axes is None:
x_axes = []
y_axes = []
for scan in self.scans:
x_axes.append(scan.scan_info.def_x)
y_axes.append(scan.scan_info.def_y)
x_axis = set(x_axes)
y_axis = set(y_axes)

if not (len(x_axis) == 1 and len(y_axis) == 1):
raise ValueError(f"x axes={x_axis} or y axes={y_axis} are not identical.")
axes = (*x_axis, *y_axis)

match len(axes):
case 2:
if axes is not None:
if len(axes) == 2:
return self._get_data_1d(axes, norm_to, **rebin_params_dict)

case 3:
elif len(axes) == 3:
return self._get_data_2d(axes, norm_to, **rebin_params_dict)

case _:
else:
raise ValueError(f"length of axes={axes} should be either 2 or 3.")

def get_plot_data_1d(
self,
rebin_type: Literal["tol", "grid", None] = None,
rebin_params: Union[float, tuple] = 0.0,
) -> Plot1D:
"""
rebin_type (str | None): "tol" or "grid"
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""

# shape = (num_scans, num_pts)
# x_array = [scan.data[signal_x[i]] for i, scan in enumerate(self.signals)]
# y_array = [scan.data[signal_y[i]] for i, scan in enumerate(self.signals)]

x_array = [getattr(scan.data, signal_x[i]) for i, scan in enumerate(self.signals)]
y_array = [getattr(scan.data, signal_y[i]) for i, scan in enumerate(self.signals)]

x_min = np.min([np.min(np.round(x, 3)) for x in x_array])
x_max = np.max([np.max(np.round(x, 3)) for x in x_array])
y_min = np.min([np.min(np.round(y, 3)) for y in y_array])
y_max = np.max([np.max(np.round(y, 3)) for y in y_array])

# TODO problem if irregular size
x_step, y_step = rebin_steps
if x_step is None:
x_precision = 1
x_unique = np.unique(np.concatenate([np.unique(np.round(x, x_precision)) for x in x_array]))
x_diff = np.unique(np.round(np.diff(x_unique), x_precision))
x_diff = x_diff[x_diff > 0]
x_step = x_diff[0]

if y_step is None:
y_precision = 5
y_unique = np.unique(np.concatenate([np.unique(np.round(y, y_precision)) for y in y_array]))
y_diff = np.unique(np.round(np.diff(y_unique), y_precision))
y_diff = y_diff[y_diff > 0]
y_step = y_diff[0]

x_list = np.round(np.arange(x_min, x_max + x_step / 2, x_step), 3)
y_list = np.round(np.arange(y_min, y_max + y_step / 2, y_step), 3)
# shape = (num_pts, num_scans)
xv, yv = np.meshgrid(x_list, y_list)

# finding bin boxes
cts = np.zeros_like(xv)
z = np.zeros_like(xv)
for i in range(num_scans):
scan = self.signals[i]
scan_len = np.size(getattr(scan.data, signal_z[i]))
for j in range(scan_len):
# if SCAN_ALONG_Y:
x0 = getattr(scan.data, signal_x[i])[j]
y0 = getattr(scan.data, signal_y[i])[j]
z0 = getattr(scan.data, signal_z[i])[j]
idx = np.nanargmax(x_list + x_step / 2 >= x0)
idy = np.nanargmax(y_list + y_step / 2 >= y0)
z[idy, idx] += z0
if norm_channel is None:
cts[idy, idx] += 1
else:
cts[idy, idx] += getattr(scan.data, norm_channel)[j] / norm_val

z = z / cts

title = self.name
if norm_channel is not None:
zlabel += f" / {norm_val} " + norm_channel
title += f" nomralized by {norm_val} " + norm_channel

return (xv, yv, z, x_step, y_step, xlabel, ylabel, zlabel, title)

def get_plot_data_2d(
self,
axes: tuple[str, str, str],
rebin_params: tuple[Union[float, tuple], Union[float, tuple]],
norm_to: Optional[tuple[float, Literal["monitor", "time", "mcu"]]] = None,
) -> Plot2D:
"""
Args:
rebin_params (float | tuple(flot, float, float)): take as step size if a numer is given,
take as (min, max, step) if a tuple of size 3 is given
"""
x_axis, y_axis, z_axis = axes

x_data = []
y_data = []
z_data = []

x_axes = []
y_axes = []
for scan in self.scans:
x_data.append(scan.data.get(x_axis))
y_data.append(scan.data.get(y_axis))
z_data.append(scan.data.get(z_axis))

if norm_to is not None:
norm_data = []
norm_val, norm_channel = norm_to
for scan in self.scans:
norm_data.append(scan.data.get(norm_channel))

scan_data_2d = ScanData2D(x=np.concatenate(x_data), y=np.concatenate(y_data), z=np.concatenate(z_data))
# Rebin, first validate rebin params
rebin_params_2d = ScanGroup.validate_rebin_params_2d(rebin_params)
if norm_to is not None:
pass
else:
scan_data_2d.rebin_grid(rebin_params_2d)
plot2d = Plot2D(scan_data_2d.x, scan_data_2d.y, scan_data_2d.z)
# plot2d.make_labels(self.axes)
return plot2d
x_axes.append(scan.scan_info.def_x)
y_axes.append(scan.scan_info.def_y)
x_axis = set(x_axes)
y_axis = set(y_axes)

if not (len(x_axis) == 1 and len(y_axis) == 1):
raise ValueError(f"x axes={x_axis} or y axes={y_axis} are not identical.")
axes = (*x_axis, *y_axis)
return self._get_data_1d(axes, norm_to, **rebin_params_dict)

def plot(self, contour_plot, cmap="turbo", vmax=100, vmin=0, ylim=None, xlim=None):
"""Plot contour"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scan_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def test_scan_group_2d():

fig, ax = plt.subplots()
plot2d = Plot2D()
plot2d.plot_contour(scan_data_2d, cmap="turbo", vmax=80)
plot2d.plot(ax, scan_data_2d, cmap="turbo", vmax=80)
plt.show()

0 comments on commit 5e0275e

Please sign in to comment.