diff --git a/cdiutils/load/loader.py b/cdiutils/load/loader.py index edd2fce..f52a29a 100644 --- a/cdiutils/load/loader.py +++ b/cdiutils/load/loader.py @@ -56,21 +56,24 @@ def from_setup(cls, beamline_setup: str, metadata: dict) -> "Loader": Returns: Loader: the subclass loader according to the provided name. """ - if beamline_setup == "ID01BLISS": + if beamline_setup.lower() == "id01bliss": from . import BlissLoader return BlissLoader(**metadata) - if beamline_setup == "ID01SPEC": + if beamline_setup.lower() == "id01spec": from . import SpecLoader return SpecLoader(**metadata) - if beamline_setup == "SIXS2022": + if beamline_setup.lower() == "sixs2022": from . import SIXS2022Loader return SIXS2022Loader(**metadata) - if "P10" in beamline_setup: + if "p10" in beamline_setup.lower(): from . import P10Loader - if beamline_setup == "P10EH2": + if beamline_setup.lower() == "p10eh2": return P10Loader(hutch="EH2", **metadata) else: return P10Loader(**metadata) + if beamline_setup.lower == "cristal": + from . import CristalLoader + return CristalLoader(**metadata) raise ValueError(f"Invalid beamline setup: {beamline_setup}") @staticmethod @@ -93,7 +96,14 @@ def _check_load(data_or_path: np.ndarray | str) -> np.ndarray: return np.load(data_or_path) if data_or_path.endswith(".npz"): with np.load(data_or_path, "r") as file: - return file["arr_0"] + for possible_key in ( + "arr_0", "data", "mask", "flatfield", "flat_field" + ): + if possible_key in file.keys(): + return file[possible_key] + raise KeyError( + f"Unvalid file provided containing {file.keys()}." + ) elif data_or_path is None or isinstance(data_or_path, np.ndarray): return data_or_path raise ValueError(