Skip to content

Commit

Permalink
implementing ScanGroup class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Oct 15, 2024
1 parent ecd7390 commit 7027701
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 56 deletions.
18 changes: 9 additions & 9 deletions src/tavi/data/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def get_data(self) -> dict:
return data_dict

@staticmethod
def validate_rebin_params(rebin_params):
def validate_rebin_params(rebin_params: float | tuple) -> tuple:
if isinstance(rebin_params, tuple):
if len(rebin_params) != 3:
raise ValueError("Rebin parameters should have the form (min, max, step)")
Expand Down Expand Up @@ -210,13 +210,13 @@ def get_plot_data(
x_str = self.scan_info.def_x if x_str is None else x_str
y_str = self.scan_info.def_y if y_str is None else y_str

scan_data = ScanData1D(x=self.data[x_str], y=self.data[y_str])
scan_data_1d = ScanData1D(x=self.data[x_str], y=self.data[y_str])

if rebin_type is None: # no rebin
if norm_channel is not None: # normalize y-axis without rebining along x-axis
scan_data.renorm(norm_col=self.data[norm_channel] / norm_val)
scan_data_1d.renorm(norm_col=self.data[norm_channel] / norm_val)

plot1d = Plot1D(x=scan_data.x, y=scan_data.y, yerr=scan_data.err)
plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_channel, norm_val, self.scan_info)

return plot1d
Expand All @@ -228,26 +228,26 @@ def get_plot_data(
case "tol":
if norm_channel is None: # x weighted by preset channel
weight_channel = self.scan_info.preset_channel
scan_data.rebin_tol(rebin_params_tuple, weight_col=self.data[weight_channel])
scan_data_1d.rebin_tol(rebin_params_tuple, weight_col=self.data[weight_channel])
else: # x weighted by normalization channel
scan_data.rebin_tol_renorm(
scan_data_1d.rebin_tol_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
norm_val=norm_val,
)
case "grid":
if norm_channel is None:
scan_data.rebin_grid(rebin_params_tuple)
scan_data_1d.rebin_grid(rebin_params_tuple)
else:
scan_data.rebin_grid_renorm(
scan_data_1d.rebin_grid_renorm(
rebin_params_tuple,
norm_col=self.data[norm_channel],
norm_val=norm_val,
)
case _:
raise ValueError('Unrecogonized rebin type. Needs to be "tol" or "grid".')

plot1d = Plot1D(x=scan_data.x, y=scan_data.y, yerr=scan_data.err)
plot1d = Plot1D(x=scan_data_1d.x, y=scan_data_1d.y, yerr=scan_data_1d.err)
plot1d.make_labels(x_str, y_str, norm_channel, norm_val, self.scan_info)
return plot1d

Expand Down
87 changes: 81 additions & 6 deletions src/tavi/data/scan_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,69 @@ class ScanData1D(object):

def __init__(self, x: np.ndarray, y: np.ndarray) -> None:

self.ind = np.argsort(x)
self.x = x[self.ind]
self.y = y[self.ind]
ind = np.argsort(x)
self.x = x[ind]
self.y = y[ind]
self.err = np.sqrt(y)
self._ind = ind

def __add__(self, other): # addition is not really needed
# check x length, rebin other if do not match
if len(self.x) != len(other.x):
rebin_intervals = np.diff(self.x)
rebin_intervals = np.append(rebin_intervals, rebin_intervals[-1])
rebin_boundary = self.x + rebin_intervals / 2

y = np.zeros_like(rebin_boundary)
counts = np.zeros_like(rebin_boundary)
err = np.zeros_like(rebin_boundary)
(x_min, x_max) = (self.x[0] - rebin_intervals[0] / 2, self.x[-1] + rebin_intervals[-1] / 2)

for i, x0 in enumerate(other.x):
if x0 > x_max or x0 < x_min:
continue
idx = np.nanargmax(rebin_boundary + ScanData1D.ZERO >= x0)
y[idx] += other.y[i]
err[idx] += other.err[i] ** 2
counts[idx] += 1

other.err = err / counts
other.y = y / counts

scan_data_1d = ScanData1D(self.x, self.y + other.y)
scan_data_1d.err = np.sqrt(self.err**2 + other.err**2)
return scan_data_1d

def __sub__(self, other):
# check x length, rebin other if do not match
if len(self.x) != len(other.x):
rebin_intervals = np.diff(self.x)
rebin_intervals = np.append(rebin_intervals, rebin_intervals[-1])
rebin_boundary = self.x + rebin_intervals / 2

y = np.zeros_like(rebin_boundary)
counts = np.zeros_like(rebin_boundary)
err = np.zeros_like(rebin_boundary)
(x_min, x_max) = (self.x[0] - rebin_intervals[0] / 2, self.x[-1] + rebin_intervals[-1] / 2)

for i, x0 in enumerate(other.x):
if x0 > x_max or x0 < x_min:
continue
idx = np.nanargmax(rebin_boundary + ScanData1D.ZERO >= x0)
y[idx] += other.y[i]
err[idx] += other.err[i] ** 2
counts[idx] += 1

other.err = err / counts
other.y = y / counts

scan_data_1d = ScanData1D(self.x, self.y - other.y)
scan_data_1d.err = np.sqrt(self.err**2 + other.err**2)
return scan_data_1d

def renorm(self, norm_col: np.ndarray, norm_val: float = 1.0):
"""Renormalized to norm_val"""
norm_col = norm_col[self.ind]
norm_col = norm_col[self._ind]
self.y = self.y / norm_col * norm_val
self.err = self.err / norm_col * norm_val

Expand Down Expand Up @@ -56,7 +111,7 @@ def rebin_tol_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
y = np.zeros_like(x_grid)
counts = np.zeros_like(x_grid)

norm_col = norm_col[self.ind]
norm_col = norm_col[self._ind]

for i, x0 in enumerate(self.x):
idx = np.nanargmax(x_grid + rebin_step / 2 + ScanData1D.ZERO >= x0)
Expand Down Expand Up @@ -98,7 +153,7 @@ def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
y = np.zeros_like(x)
counts = np.zeros_like(x)

norm_col = norm_col[self.ind]
norm_col = norm_col[self._ind]

for i, x0 in enumerate(self.x): # plus ZERO helps improve precision
idx = np.nanargmax(x + rebin_step / 2 + ScanData1D.ZERO >= x0)
Expand All @@ -108,3 +163,23 @@ def rebin_grid_renorm(self, rebin_params: tuple, norm_col: np.ndarray, norm_val:
self.x = x
self.err = np.sqrt(y) / counts * norm_val
self.y = y / counts * norm_val


class ScanData2D(object):

ZEROS = 1e-6

def __init__(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> None:
self.x = x
self.y = y
self.y = z
self.err = np.sqrt(z)

def __sub__(self, other):
pass

def renorm(self):
pass

def rebin_grid(self):
pass
65 changes: 63 additions & 2 deletions src/tavi/data/scan_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from dataclasses import dataclass
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np

from tavi.data.scan import Scan


@dataclass
class SGInfo:
"""Information needed to generate a ScanGroup"""

scan_num: int
x_axis: Optional[str] = None
y_axis: Optional[str] = None
z_axis: Optional[str] = None
norm_channel: Optional[str] = None
exp_id: Optional[str] = None


class ScanGroup(object):
"""
Expand All @@ -14,14 +31,58 @@ class ScanGroup(object):
plot_contour
"""

scan_group_number: int = 1

def __init__(
self,
scan_path_list,
):

self.name = ""
scans = {}
for scan_path in scan_path_list:
if "/" in scan_path:
exp_id, scan_name = scan_path.split("/")
else:
exp_id = next(iter(self.data))
scan_name = scan_path
scan_path = "/".join([exp_id, scan_name])
scans.update({scan_path: Scan(scan_name, self.data[exp_id][scan_name])})

# axes: tuple,
# rebin_params: tuple,
# sg_info_list: list[SGInfo],
# scan_group_name: Optional[str] = None,
# self.axes = axes
# self.dim = len(axes)
# if len(rebin_params) != self.dim:
# raise ValueError(f"Mismatched dimension with axes={axes} and rebin_params={rebin_params}")

# for scan in sg_info_list:
# self.add_scan(scan)

# if self.dim == 2: # 1D data
# ScanData1D()
# elif self.dim == 3: # 2D data
# ScanData2D()

# self.axes = axes

self.name = scan_group_name if scan_group_name is not None else f"ScanGroup{ScanGroup.scan_group_number}"
ScanGroup.scan_group_number += 1

# TODO
def add_scan(self, scan_path: str):
pass

# TODO
def remove_scan(self, scan_path: str):
pass

# TODO non-orthogonal axes for constant E contours

# @staticmethod
# def validate_rebin_params(rebin_params: float | tuple) -> tuple:
# return rebin_params

def get_plot_data(
self,
norm_channel=None,
Expand Down
41 changes: 18 additions & 23 deletions src/tavi/data/tavi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tavi.data.nxentry import NexusEntry
from tavi.data.scan import Scan
from tavi.data.scan_group import ScanGroup


class TAVI(object):
Expand Down Expand Up @@ -134,34 +135,28 @@ def save(self, file_path: Optional[str] = None):
except OSError:
print(f"Cannot create tavi file at {self.file_path}")

def get_scan(
self,
scan_num: int,
exp_id: Optional[str] = None,
) -> Scan:
def get_scan(self, scan_path: str) -> Scan:
"""Get the scan at location /data/exp_id/scanXXXX, return a Scan instance
Arguments:
scan_num (int): scan number
exp_id (str | None): in the format of IPTSXXXXX_INSTRU_expXXXX, needed when
scan_path (str): exp_id /scan_name. exp_id is in the format of
IPTSXXXXX_INSTRU_expXXXX, it is needed when
more than one experiment is loaded as data
Return:
Scan: an instance of Scan class
"""
if exp_id is None:
if "/" in scan_path:
exp_id, scan_name = scan_path.split("/")
else:
exp_id = next(iter(self.data))
dataset = self.data[exp_id]
scan_name = f"scan{scan_num:04}"
return Scan(scan_name, dataset[scan_name])

# def generate_scan_group(
# self,
# signals=None,
# backgrounds=None,
# signal_axes=(None, None, None),
# background_axes=(None, None, None),
# ):
# """Generate a scan group."""
# sg = ScanGroup(signals, backgrounds, signal_axes, background_axes)

# return sg
scan_name = scan_path
return Scan(scan_name, self.data[exp_id][scan_name])

def get_scan_group(self, scan_group_name: str):
pass

def make_scan_group(
self,
scan_path_list: list,
):
sg = ScanGroup(scan_path_list)
6 changes: 3 additions & 3 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_scan_from_nexus():
assert np.allclose(scan.data["detector"][0:3], [569, 194, 40])


def test_generate_curve():
def test_get_plot_data():
nexus_file_name = "./test_data/IPTS32124_CG4C_exp0424/scan0042.h5"
scan = Scan.from_nexus(nexus_file_name)
plot = scan.get_plot_data()
Expand All @@ -65,7 +65,7 @@ def test_generate_curve():
assert np.allclose(scan.data["time"], time_data)


def test_generate_curve_norm():
def test_get_plot_data_norm():
nexus_file_name = "./test_data/IPTS32124_CG4C_exp0424/scan0042.h5"
scan = Scan.from_nexus(nexus_file_name)
plot = scan.get_plot_data(norm_channel="mcu", norm_val=5)
Expand All @@ -82,7 +82,7 @@ def test_generate_curve_norm():
assert np.allclose(plot.yerr, yerr_data / 12)


def test_generate_curve_rebin_grid():
def test_get_plot_data_rebin_grid():
nexus_file_name = "./test_data/IPTS32124_CG4C_exp0424/scan0042.h5"
scan = Scan.from_nexus(nexus_file_name)
plot = scan.get_plot_data(rebin_type="grid", rebin_params=0.25)
Expand Down
Loading

0 comments on commit 7027701

Please sign in to comment.