Skip to content

Commit

Permalink
implemeted NexusEntry class
Browse files Browse the repository at this point in the history
  • Loading branch information
Bing Li committed Sep 23, 2024
1 parent 4845925 commit 1029dc3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 22 deletions.
91 changes: 69 additions & 22 deletions src/tavi/data/nxentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
class NexusEntry(dict):

@staticmethod
def _getitem_recursively(obj, key, ATTRS):
def _getitem_recursively(obj: dict, key: str, ATTRS: bool):
"find key in obj recursively, return None if nonexsiting"
value = None
if key in obj:
if key.split("/")[0] in obj:
for key_item in key.split("/"):
obj = obj[key_item]
if ATTRS:
try:
value = obj[key]["attrs"]
value = obj["attrs"]
except KeyError:
print("Attribute does not exist.")
print(f"Attribute of {key} does not exist.")
else:
try:
value = obj[key]["dataset"]
value = obj["dataset"]
except KeyError:
print("Dataset does not exist.")
print(f"Dataset of {key} does not exist.")

for k, v in obj.items():
if value is not None:
Expand All @@ -30,36 +32,81 @@ def _getitem_recursively(obj, key, ATTRS):

return value

def get(self, key, ATTRS=False, default=None):
"""
Return dataset spicified by key regardless of the hierarchy.
Return attributes instead if ATTRS is True.
@staticmethod
def _write_recursively(items, nexus_entry):
"""write items to nexus entry recursively
Note:
Only works if the key is unique
string encoded with utf-8
"""
value = NexusEntry._getitem_recursively(self, key, ATTRS)
if value is not None:
return value
else:
return default

@staticmethod
def _write_recursively(items, nexus_entry):
"""write items to nexus entry recursively"""
for key, value in items.items():
if key == "attrs":
for attr_key, attr_value in value.items():
if isinstance(attr_value, str):
attr_value = attr_value.encode("utf-8")
nexus_entry.attrs[attr_key] = attr_value
else:
if isinstance(value, dict):
if "dataset" in value.keys():
ds = nexus_entry.create_dataset(name=key, data=value["dataset"], maxshape=None)
dv = value["dataset"]
if isinstance(dv, str):
dv = dv.encode("utf-8")
ds = nexus_entry.create_dataset(
name=key,
data=dv,
maxshape=None,
)
NexusEntry._write_recursively(value, ds)
else:
grp = nexus_entry.create_group(key + "/")
NexusEntry._write_recursively(value, grp)

def to_nexus(self, path_to_nexus):
@staticmethod
def _read_recursively(nexus_entry, items=None):
"""read item from nexus_entry recursively"""
if items is None:
items = {}
for key, value in nexus_entry.items():
if isinstance(value, h5py.Group):
attr_dict = {}
for k, v in value.attrs.items():
attr_dict.update({k: v})
items.update({key: {"attrs": attr_dict}})
NexusEntry._read_recursively(value, items[key])
elif isinstance(value, h5py.Dataset):
attr_dict = {}
for k, v in value.attrs.items():
attr_dict.update({k: v})
try: # unpacking as string
value = str(value.asstr()[...])
except TypeError: # arrays instead
value = value[...]

items.update({key: {"attrs": attr_dict, "dataset": value}})

return items

def get(self, key, ATTRS=False, default=None):
"""
Return dataset spicified by key regardless of the hierarchy.
Return attributes instead if ATTRS is True.
Note:
Unique keys like 's1' or 'm2' can be found straight forwardly.
To find monitor or detecor data use monitor/data or detector/data
"""
value = NexusEntry._getitem_recursively(self, key, ATTRS)
if value is not None:
return value
else:
return default

@classmethod
def from_nexus(cls, path_to_nexus: str) -> dict:
with h5py.File(path_to_nexus, "r") as nexus_file:
nexus_dict = NexusEntry._read_recursively(nexus_file)
return cls([(key, val) for key, val in nexus_dict.items()])

def to_nexus(self, path_to_nexus: str) -> None:
with h5py.File(path_to_nexus, "w") as nexus_file:
NexusEntry._write_recursively(self, nexus_file)
Binary file modified test_data/scan_to_nexus_test.h5
Binary file not shown.
20 changes: 20 additions & 0 deletions tests/test_nxentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_get_dataset(nexus_entry):
}
assert nexus_entry.get("a3") is None
assert nexus_entry.get("detector") is None
assert nexus_entry.get("detector/data", ATTRS=True) == {"EX_required": "true", "type": "NX_INT", "units": "counts"}
assert np.allclose(nexus_entry.get("instrument/analyser/a2"), np.array([242.0, 242.1, 242.2]))


def test_to_nexus(nexus_entry):
Expand All @@ -27,6 +29,24 @@ def test_to_nexus(nexus_entry):
assert nexus_file["scan0034"].attrs["NX_class"] == "NXentry"


def test_from_nexus():
# path_to_nexus_entry = "./test_data/IPTS32124_CG4C_exp0424/scan0034.h5"
path_to_nexus_entry = "./test_data/scan_to_nexus_test.h5"
nexus_entry = NexusEntry.from_nexus(path_to_nexus_entry)
assert nexus_entry.get("definition") == "NXtas"
assert np.allclose(nexus_entry.get("a2"), np.array([242.0, 242.1, 242.2]))
assert nexus_entry.get("data", ATTRS=True) == {
"EX_required": "true",
"NX_class": "NXdata",
"axes": "en",
"signal": "detector",
}
assert nexus_entry.get("a3") is None
assert nexus_entry.get("detector") is None
assert nexus_entry.get("detector/data", ATTRS=True) == {"EX_required": "true", "type": "NX_INT", "units": "counts"}
assert np.allclose(nexus_entry.get("instrument/analyser/a2"), np.array([242.0, 242.1, 242.2]))


@pytest.fixture
def nexus_entry():

Expand Down

0 comments on commit 1029dc3

Please sign in to comment.