diff --git a/src/murfey/server/api/clem.py b/src/murfey/server/api/clem.py index 128bb2801..df484e7f1 100644 --- a/src/murfey/server/api/clem.py +++ b/src/murfey/server/api/clem.py @@ -83,7 +83,7 @@ def validate_and_sanitise( machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - rsync_basepath = machine_config.rsync_basepath.resolve() + rsync_basepath = (machine_config.rsync_basepath or Path("")).resolve() # Check that full file path doesn't contain unallowed characters # Currently allows only: diff --git a/src/murfey/server/api/file_io_frontend.py b/src/murfey/server/api/file_io_frontend.py index 9bf2b43ab..b9837f55c 100644 --- a/src/murfey/server/api/file_io_frontend.py +++ b/src/murfey/server/api/file_io_frontend.py @@ -14,6 +14,7 @@ process_gain as _process_gain, ) from murfey.server.murfey_db import murfey_db +from murfey.util import secure_path from murfey.util.config import get_machine_config from murfey.util.db import Session @@ -50,10 +51,23 @@ async def create_symlink( machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - symlink_full_path = machine_config.rsync_basepath / symlink_params.symlink + rsync_basepath = (machine_config.rsync_basepath or Path("")).resolve() + symlink_full_path = secure_path( + rsync_basepath / symlink_params.symlink, keep_spaces=True + ) + # Verify that the symlink provided does not lead elsewhere + if not symlink_full_path.resolve().is_relative_to(rsync_basepath): + logger.warning( + "Symlink rejected because it will be created in a forbidden location" + ) + return "" + # Remove and replace symlink if it exists are 'override' is set if symlink_full_path.is_symlink() and symlink_params.override: symlink_full_path.unlink() + # If a file/folder already exists using the desired symlink name, return empty string if symlink_full_path.exists(): return "" - symlink_full_path.symlink_to(machine_config.rsync_basepath / symlink_params.target) + symlink_full_path.symlink_to( + secure_path(rsync_basepath / symlink_params.target, keep_spaces=True) + ) return str(symlink_params.symlink) diff --git a/src/murfey/server/api/file_io_instrument.py b/src/murfey/server/api/file_io_instrument.py index 2644b7f89..bbe12f1b0 100644 --- a/src/murfey/server/api/file_io_instrument.py +++ b/src/murfey/server/api/file_io_instrument.py @@ -57,7 +57,8 @@ def suggest_path( ) # Construct the full path to where the dataset is to be saved - check_path = machine_config.rsync_basepath / base_path + rsync_basepath = (machine_config.rsync_basepath or Path("")).resolve() + check_path = rsync_basepath / base_path # Check previous year to account for the year rolling over during data collection if not check_path.parent.exists(): @@ -69,7 +70,7 @@ def suggest_path( base_path_parts[year_idx] = str(int(part) - 1) base_path = "/".join(base_path_parts) check_path_prev = check_path - check_path = machine_config.rsync_basepath / base_path + check_path = rsync_basepath / base_path # If it's not in the previous year either, it's a genuine error if not check_path.parent.exists(): @@ -88,7 +89,7 @@ def suggest_path( check_path.mkdir(mode=0o750) if params.extra_directory: (check_path / secure_filename(params.extra_directory)).mkdir(mode=0o750) - return {"suggested_path": check_path.relative_to(machine_config.rsync_basepath)} + return {"suggested_path": check_path.relative_to(rsync_basepath)} class Dest(BaseModel): @@ -107,7 +108,9 @@ def make_rsyncer_destination(session_id: int, destination: Dest, db=murfey_db): ] if not machine_config: raise ValueError("No machine configuration set when making rsyncer destination") - full_destination_path = machine_config.rsync_basepath / destination_path + full_destination_path = ( + machine_config.rsync_basepath or Path("") + ).resolve() / destination_path for parent_path in full_destination_path.parents: parent_path.mkdir(mode=0o750, exist_ok=True) return destination @@ -151,7 +154,7 @@ async def write_eer_fractionation_file( ) / secure_filename(fractionation_params.fractionation_file_name) else: file_path = ( - Path(machine_config.rsync_basepath) + (machine_config.rsync_basepath or Path("")).resolve() / str(datetime.now().year) / secure_filename(visit_name) / machine_config.gain_directory_name diff --git a/src/murfey/server/api/file_io_shared.py b/src/murfey/server/api/file_io_shared.py index 310c8e4b9..d510d732a 100644 --- a/src/murfey/server/api/file_io_shared.py +++ b/src/murfey/server/api/file_io_shared.py @@ -37,8 +37,9 @@ async def process_gain( executables = machine_config.external_executables env = machine_config.external_environment safe_path_name = secure_filename(gain_reference_params.gain_ref.name) + rsync_basepath = machine_config.rsync_basepath or Path("") filepath = ( - Path(machine_config.rsync_basepath) + rsync_basepath / str(datetime.now().year) / secure_filename(visit_name) / machine_config.gain_directory_name @@ -48,7 +49,7 @@ async def process_gain( if not filepath.exists(): filepath_prev = filepath filepath = ( - Path(machine_config.rsync_basepath) + rsync_basepath / str(datetime.now().year - 1) / secure_filename(visit_name) / machine_config.gain_directory_name @@ -80,14 +81,12 @@ async def process_gain( ) if new_gain_ref and new_gain_ref_superres: return { - "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), - "gain_ref_superres": new_gain_ref_superres.relative_to( - Path(machine_config.rsync_basepath) - ), + "gain_ref": new_gain_ref.relative_to(rsync_basepath), + "gain_ref_superres": new_gain_ref_superres.relative_to(rsync_basepath), } elif new_gain_ref: return { - "gain_ref": new_gain_ref.relative_to(Path(machine_config.rsync_basepath)), + "gain_ref": new_gain_ref.relative_to(rsync_basepath), "gain_ref_superres": None, } else: diff --git a/src/murfey/server/api/session_control.py b/src/murfey/server/api/session_control.py index 9d9770029..c384fb62d 100644 --- a/src/murfey/server/api/session_control.py +++ b/src/murfey/server/api/session_control.py @@ -24,7 +24,6 @@ get_foil_holes_from_grid_square as _get_foil_holes_from_grid_square, get_grid_squares as _get_grid_squares, get_grid_squares_from_dcg as _get_grid_squares_from_dcg, - get_machine_config_for_instrument, get_tiff_file as _get_tiff_file, get_upstream_file as _get_upstream_file, remove_session_by_id, @@ -32,7 +31,7 @@ from murfey.server.ispyb import DB as ispyb_db, get_all_ongoing_visits from murfey.server.murfey_db import murfey_db from murfey.util import sanitise -from murfey.util.config import MachineConfig +from murfey.util.config import get_machine_config from murfey.util.db import ( AutoProcProgram, ClientEnvironment, @@ -80,8 +79,8 @@ async def get_current_timestamp(): @router.get("/instruments/{instrument_name}/machine") -def machine_info_by_instrument(instrument_name: str) -> Optional[MachineConfig]: - return get_machine_config_for_instrument(instrument_name) +def machine_info_by_instrument(instrument_name: str): + return get_machine_config(instrument_name)[instrument_name] @router.get("/new_client_id/") diff --git a/src/murfey/server/api/session_info.py b/src/murfey/server/api/session_info.py index 8da98536a..d43dd1d8d 100644 --- a/src/murfey/server/api/session_info.py +++ b/src/murfey/server/api/session_info.py @@ -24,7 +24,6 @@ get_foil_holes_from_grid_square as _get_foil_holes_from_grid_square, get_grid_squares as _get_grid_squares, get_grid_squares_from_dcg as _get_grid_squares_from_dcg, - get_machine_config_for_instrument, get_tiff_file as _get_tiff_file, get_upstream_file as _get_upstream_file, remove_session_by_id, @@ -32,7 +31,7 @@ from murfey.server.ispyb import DB as ispyb_db, get_all_ongoing_visits from murfey.server.murfey_db import murfey_db from murfey.util import sanitise -from murfey.util.config import MachineConfig +from murfey.util.config import get_machine_config from murfey.util.db import ( ClassificationFeedbackParameters, ClientEnvironment, @@ -78,8 +77,8 @@ def connections_check(): @router.get("/instruments/{instrument_name}/machine") def machine_info_by_instrument( instrument_name: MurfeyInstrumentName, -) -> Optional[MachineConfig]: - return get_machine_config_for_instrument(instrument_name) +): + return get_machine_config(instrument_name)[instrument_name] @router.get("/instruments/{instrument_name}/visits_raw", response_model=List[Visit]) diff --git a/src/murfey/server/api/session_shared.py b/src/murfey/server/api/session_shared.py index 29607d98a..644b9b35b 100644 --- a/src/murfey/server/api/session_shared.py +++ b/src/murfey/server/api/session_shared.py @@ -1,7 +1,6 @@ import logging -from functools import lru_cache from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List from sqlmodel import select from sqlmodel.orm.session import Session as SQLModelSession @@ -9,7 +8,7 @@ import murfey.server.prometheus as prom from murfey.util import safe_run, sanitise, secure_path -from murfey.util.config import MachineConfig, from_file, get_machine_config, settings +from murfey.util.config import get_machine_config from murfey.util.db import ( DataCollection, DataCollectionGroup, @@ -23,15 +22,6 @@ logger = logging.getLogger("murfey.server.api.shared") -@lru_cache(maxsize=5) -def get_machine_config_for_instrument(instrument_name: str) -> Optional[MachineConfig]: - if settings.murfey_machine_configuration: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] - return None - - def remove_session_by_id(session_id: int, db): session = db.exec(select(MurfeySession).where(MurfeySession.id == session_id)).one() sessions_for_visit = db.exec( diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index cd4056628..220d90e27 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -245,15 +245,14 @@ def start_dc( machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] + rsync_basepath = (machine_config.rsync_basepath or Path("")).resolve() logger.info( f"Starting data collection on microscope {instrument_name!r} " - f"with basepath {sanitise(str(machine_config.rsync_basepath))} and directory {sanitise(dc_params.image_directory)}" + f"with basepath {sanitise(str(rsync_basepath))} and directory {sanitise(dc_params.image_directory)}" ) dc_parameters = { "visit": visit_name, - "image_directory": str( - machine_config.rsync_basepath / dc_params.image_directory - ), + "image_directory": str(rsync_basepath / dc_params.image_directory), "start_time": str(datetime.now()), "voltage": dc_params.voltage, "pixel_size": str(float(dc_params.pixel_size_on_image) * 1e9), @@ -744,7 +743,10 @@ async def request_tomography_preprocessing( "fm_dose": proc_file.dose_per_frame, "frame_count": proc_file.frame_count, "gain_ref": ( - str(machine_config.rsync_basepath / proc_file.gain_ref) + str( + (machine_config.rsync_basepath or Path("")).resolve() + / proc_file.gain_ref + ) if proc_file.gain_ref and machine_config.data_transfer_enabled else proc_file.gain_ref ), @@ -1060,7 +1062,7 @@ async def make_gif( instrument_name ] output_dir = ( - Path(machine_config.rsync_basepath) + (machine_config.rsync_basepath or Path("")).resolve() / secure_filename(year) / secure_filename(visit_name) / "processed" diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index 79fab4c05..039ca9056 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -43,8 +43,8 @@ from murfey.util import sanitise_path from murfey.util.config import ( MachineConfig, - from_file, get_hostname, + machine_config_from_file, security_from_file, ) from murfey.util.db import ( @@ -93,7 +93,9 @@ class Settings(BaseSettings): machine_config: dict[str, MachineConfig] = {} if settings.murfey_machine_configuration: microscope = get_microscope() - machine_config = from_file(Path(settings.murfey_machine_configuration), microscope) + machine_config = machine_config_from_file( + Path(settings.murfey_machine_configuration), microscope + ) # This will be the homepage for a given microscope. @@ -114,9 +116,9 @@ async def root(request: Request): def machine_info() -> Optional[MachineConfig]: instrument_name = os.getenv("BEAMLINE") if settings.murfey_machine_configuration and instrument_name: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None @@ -124,9 +126,9 @@ def machine_info() -> Optional[MachineConfig]: @router.get("/instruments/{instrument_name}/machine") def machine_info_by_name(instrument_name: str) -> Optional[MachineConfig]: if settings.murfey_machine_configuration: - return from_file(Path(settings.murfey_machine_configuration), instrument_name)[ - instrument_name - ] + return machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name + )[instrument_name] return None diff --git a/src/murfey/server/feedback.py b/src/murfey/server/feedback.py index 480db7fad..2efad95ce 100644 --- a/src/murfey/server/feedback.py +++ b/src/murfey/server/feedback.py @@ -1100,7 +1100,9 @@ def _register_class_selection(message: dict, _db, demo: bool = False): def _find_initial_model(visit: str, machine_config: MachineConfig) -> Path | None: if machine_config.initial_model_search_directory: visit_directory = ( - machine_config.rsync_basepath / str(datetime.now().year) / visit + (machine_config.rsync_basepath or Path("")).resolve() + / str(datetime.now().year) + / visit ) possible_models = [ p @@ -1512,7 +1514,10 @@ def _flush_tomography_preprocessing(message: dict, _db): "fm_dose": proc_params.dose_per_frame, "frame_count": proc_params.frame_count, "gain_ref": ( - str(machine_config.rsync_basepath / proc_params.gain_ref) + str( + (machine_config.rsync_basepath or Path("")).resolve() + / proc_params.gain_ref + ) if proc_params.gain_ref else proc_params.gain_ref ), @@ -2042,7 +2047,10 @@ def feedback_callback(header: dict, message: dict, _db=murfey_db) -> None: angpix=float(message["pixel_size_on_image"]) * 1e10, dose_per_frame=message["dose_per_frame"], gain_ref=( - str(machine_config.rsync_basepath / message["gain_ref"]) + str( + (machine_config.rsync_basepath or Path("")).resolve() + / message["gain_ref"] + ) if message["gain_ref"] and machine_config.data_transfer_enabled else message["gain_ref"] ), diff --git a/src/murfey/util/config.py b/src/murfey/util/config.py index a4aa6220f..f989d5b38 100644 --- a/src/murfey/util/config.py +++ b/src/murfey/util/config.py @@ -30,35 +30,38 @@ class MachineConfig(BaseModel): # type: ignore # General info -------------------------------------------------------------------- display_name: str = "" instrument_name: str = "" + instrument_type: str = "" # For use with hierarchical config files image_path: Optional[Path] = None machine_override: str = "" # Hardware and software ----------------------------------------------------------- camera: str = "FALCON" superres: bool = False - calibrations: dict[str, Any] - acquisition_software: list[str] + calibrations: dict[str, Any] = {} + acquisition_software: list[str] = [] software_versions: dict[str, str] = {} software_settings_output_directories: dict[str, list[str]] = {} data_required_substrings: dict[str, dict[str, list[str]]] = {} # Client side directory setup ----------------------------------------------------- - data_directories: list[Path] + data_directories: list[Path] = [] create_directories: list[str] = ["atlas"] analyse_created_directories: list[str] = [] gain_reference_directory: Optional[Path] = None eer_fractionation_file_template: str = "" - substrings_blacklist: dict[str, list] = { + + # Data transfer setup ------------------------------------------------------------- + # General setup + data_transfer_enabled: bool = True + substrings_blacklist: dict[str, list[str]] = { "directories": [], "files": [], } - # Data transfer setup ------------------------------------------------------------- # Rsync setup - data_transfer_enabled: bool = True rsync_url: str = "" rsync_module: str = "" - rsync_basepath: Path + rsync_basepath: Optional[Path] = None allow_removal: bool = False # Upstream data download setup @@ -86,7 +89,7 @@ class MachineConfig(BaseModel): # type: ignore } # Particle picking setup - default_model: Path + default_model: Optional[Path] = None picking_model_search_directory: str = "processing" initial_model_search_directory: str = "processing/initial_model" @@ -147,14 +150,77 @@ def validate_software_versions(cls, v: dict[str, Any]) -> dict[str, str]: return v -def from_file(config_file_path: Path, instrument: str = "") -> dict[str, MachineConfig]: +@lru_cache(maxsize=1) +def machine_config_from_file( + config_file_path: Path, + instrument_name: str, +) -> dict[str, MachineConfig]: + """ + Loads the machine config YAML file and constructs instrument-specific configs from + a hierarchical set of dictionary key-value pairs. It will populate the keys listed + in the general dictionary, then update the keys specified in the shared instrument + dictionary, before finally updating the keys for that specific instrument. + """ + + def _recursive_update(base: dict[str, Any], new: dict[str, Any]): + """ + Helper function to recursively update nested dictionaries. + + If the old and new values are both dicts, it will add the new keys and values + to the existing dictionary recursively without overwriting entries. + + If the old and new values are both lists, it will extend the existing list. + For all other values, it will overwrite the existing value with the new one. + """ + for key, value in new.items(): + # If new values are dicts and dict values already exist, do recursive update + if key in base and isinstance(base[key], dict) and isinstance(value, dict): + base[key] = _recursive_update(base[key], value) + # If new values are lists and a list already exists, extend the list + elif ( + key in base and isinstance(base[key], list) and isinstance(value, list) + ): + base[key].extend(value) + # Otherwise, overwrite/add values as normal + else: + base[key] = value + return base + + # Load the dict from the file with open(config_file_path, "r") as config_stream: - config = yaml.safe_load(config_stream) - return { - i: MachineConfig(**config[i]) - for i in config.keys() - if not instrument or i == instrument - } + master_config: dict[str, Any] = yaml.safe_load(config_stream) + + # Construct requested machine configs from the YAML file + all_machine_configs: dict[str, MachineConfig] = {} + for i in sorted(master_config.keys()): + # Skip reserved top-level keys + if i in ("general", "clem", "fib", "tem"): + continue + # If instrument name is set, skip irrelevant configs + if instrument_name and i != instrument_name: + continue + + # Construct instrument config hierarchically + config: dict[str, Any] = {} + + # Populate with general values + general_config: dict[str, Any] = master_config.get("general", {}) + config = _recursive_update(config, general_config) + + # Populate with shared instrument values + instrument_config: dict[str, Any] = master_config.get(i, {}) + instrument_shared_config: dict[str, Any] = master_config.get( + str(instrument_config.get("instrument_type", "")).lower(), {} + ) + config = _recursive_update(config, instrument_shared_config) + + # Insert instrument-specific values + config = _recursive_update(config, instrument_config) + + # Add to master dictionary + all_machine_configs[i] = MachineConfig(**config) + + return all_machine_configs class Security(BaseModel): @@ -248,22 +314,13 @@ def get_security_config() -> Security: @lru_cache(maxsize=1) def get_machine_config(instrument_name: str = "") -> dict[str, MachineConfig]: - machine_config = { - "": MachineConfig( - acquisition_software=[], - calibrations={}, - data_directories=[], - rsync_basepath=Path("dls/tmp"), - murfey_db_credentials="", - default_model="/tmp/weights.h5", - ) - } + # Create an empty machine config as a placeholder + machine_configs = {instrument_name: MachineConfig()} if settings.murfey_machine_configuration: - microscope = instrument_name - machine_config = from_file( - Path(settings.murfey_machine_configuration), microscope + machine_configs = machine_config_from_file( + Path(settings.murfey_machine_configuration), instrument_name ) - return machine_config + return machine_configs def get_extended_machine_config( diff --git a/src/murfey/util/processing_params.py b/src/murfey/util/processing_params.py index a40ba9872..760b69ff1 100644 --- a/src/murfey/util/processing_params.py +++ b/src/murfey/util/processing_params.py @@ -42,7 +42,9 @@ def cryolo_model_path(visit: str, instrument_name: str) -> Path: ] if machine_config.picking_model_search_directory: visit_directory = ( - machine_config.rsync_basepath / str(datetime.now().year) / visit + (machine_config.rsync_basepath or Path("")).resolve() + / str(datetime.now().year) + / visit ) possible_models = list( (visit_directory / machine_config.picking_model_search_directory).glob( @@ -51,7 +53,7 @@ def cryolo_model_path(visit: str, instrument_name: str) -> Path: ) if possible_models: return sorted(possible_models, key=lambda x: x.stat().st_ctime)[-1] - return machine_config.default_model + return (machine_config.default_model or Path("")).resolve() class CLEMProcessingParameters(BaseModel): diff --git a/src/murfey/workflows/clem/__init__.py b/src/murfey/workflows/clem/__init__.py index b40d48cb3..4aedd01c0 100644 --- a/src/murfey/workflows/clem/__init__.py +++ b/src/murfey/workflows/clem/__init__.py @@ -64,7 +64,7 @@ def _validate_and_sanitise( machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] - rsync_basepath = machine_config.rsync_basepath.resolve() + rsync_basepath = (machine_config.rsync_basepath or Path("")).resolve() # Check that full file path doesn't contain unallowed characters # Currently allows only: diff --git a/tests/util/test_config.py b/tests/util/test_config.py new file mode 100644 index 000000000..82068a8c1 --- /dev/null +++ b/tests/util/test_config.py @@ -0,0 +1,366 @@ +from pathlib import Path +from typing import Any + +import pytest +import yaml +from pytest_mock import MockerFixture + +from murfey.util.config import Settings, get_machine_config + + +@pytest.fixture +def mock_general_config(): + # Most extra keys go in this category + return { + "pkg_2": { + "url": "https://some-url.some.org", + "token": "pneumonoultrasmicroscopicsilicovolcanoconiosis", + } + } + + +@pytest.fixture +def mock_tem_shared_config(): + return { + # Hardware and software + "acquisition_software": ["epu", "tomo", "serialem"], + "software_versions": {"tomo": "5.12"}, + "data_required_substrings": { + "epu": { + ".mrc": ["fractions", "Fractions"], + ".tiff": ["fractions", "Fractions"], + ".eer": ["EER"], + }, + "tomo": { + ".mrc": ["fractions", "Fractions"], + ".tiff": ["fractions", "Fractions"], + ".eer": ["EER"], + }, + }, + # Client directory setup + "analyse_created_directories": ["atlas"], + "gain_reference_directory": "C:/ProgramData/Gatan/Reference Images/", + # Data transfer keys + "data_transfer_enabled": True, + "substrings_blacklist": { + "directories": ["some_str"], + "files": ["some_str"], + }, + "rsync_module": "rsync", + "allow_removal": True, + "upstream_data_directories": { + "upstream_instrument": "/path/to/upstream_instrument", + }, + "upstream_data_download_directory": "/path/to/download/directory", + "upstream_data_search_strings": { + "upstream_instrument": ["some_string"], + }, + # Data processing keys + "processing_enabled": True, + "gain_directory_name": "some_directory", + "processed_directory_name": "some_directory", + "processed_extra_directory": "some_directory", + "recipes": { + "recipe_1": "recipe_1", + "recipe_2": "recipe_2", + }, + "default_model": "some_file", + "external_executables": { + "app_1": "/path/to/app_1", + "app_2": "/path/to/app_2", + "app_3": "/path/to/app_3", + }, + "external_executables_eer": { + "app_1": "/path/to/app_1", + "app_2": "/path/to/app_2", + "app_3": "/path/to/app_3", + }, + "external_environment": { + "ENV_1": "/path/to/env_1", + "ENV_2": "/path/to/env_2", + }, + "plugin_packages": { + "pkg_1": "/path/to/pkg_1", + "pkg_2": "/path/to/pkg_2", + }, + # Extra keys + "pkg_1": { + "file_path": "", + "command": [ + "/path/to/executable", + "--some_arg", + "-a", + "./path/to/file", + ], + "step_size": 100, + }, + } + + +@pytest.fixture +def mock_instrument_config(): + return { + # Extra key to point to hierarchical dictionary to use + "instrument_type": "tem", + # General information + "display_name": "Some TEM", + "image_path": "/path/to/tem.jpg", + # Hardware and software + "camera": "Some camera", + "superres": True, + "calibrations": { + "magnification": { + 100: 0.1, + 200: 0.05, + 400: 0.025, + }, + }, + # Client directory setup + "data_directories": ["C:"], + # Data transfer keys + "rsync_basepath": "/path/to/data", + "rsync_url": "http://123.45.678.90:8000", + # Server and network keys + "security_configuration_path": "/path/to/security-config.yaml", + "murfey_url": "https://www.murfey.com", + "instrument_server_url": "http://10.123.4.5:8000", + "node_creator_queue": "node_creator", + # Extra keys + "pkg_1": { + "file_path": "/path/to/pkg_1/file.txt", + }, + } + + +@pytest.fixture +def mock_hierarchical_machine_config_yaml( + mock_general_config: dict[str, Any], + mock_tem_shared_config: dict[str, Any], + mock_instrument_config: dict[str, Any], + tmp_path: Path, +): + # Create machine config (with all currently supported keys) for the instrument + hierarchical_config = { + "general": mock_general_config, + "tem": mock_tem_shared_config, + "m01": mock_instrument_config, + "m02": mock_instrument_config, + } + config_file = tmp_path / "config" / "murfey-machine-config-hierarchical.yaml" + config_file.parent.mkdir(parents=True, exist_ok=True) + with open(config_file, "w") as file: + yaml.safe_dump(hierarchical_config, file, indent=2) + return config_file + + +@pytest.fixture +def mock_standard_machine_config_yaml( + mock_general_config: dict[str, Any], + mock_tem_shared_config: dict[str, Any], + mock_instrument_config: dict[str, Any], + tmp_path: Path, +): + # Compile the different dictionaries into one dictionary for the instrument + machine_config = { + key: value + for config in ( + mock_general_config, + mock_tem_shared_config, + mock_instrument_config, + ) + for key, value in config.items() + } + + # Correct nested dicts that would have been partially overwritten + machine_config["pkg_1"] = { + "file_path": "/path/to/pkg_1/file.txt", + "command": [ + "/path/to/executable", + "--some_arg", + "-a", + "./path/to/file", + ], + "step_size": 100, + } + + # Remove 'instrument_type' value (not needed in standard config) + machine_config["instrument_type"] = "" + + master_config = { + "m01": machine_config, + "m02": machine_config, + } + config_file = tmp_path / "config" / "murfey-machine-config-standard.yaml" + config_file.parent.mkdir(parents=True, exist_ok=True) + with open(config_file, "w") as file: + yaml.safe_dump(master_config, file, indent=2) + return config_file + + +get_machine_config_test_matrix: tuple[tuple[str, list[str]], ...] = ( + # Config to test | Instrument names to pass to function + ("hierarchical", ["", "m01", "m02"]), + ("standard", ["", "m01", "m02"]), +) + + +@pytest.mark.parametrize("test_params", get_machine_config_test_matrix) +def test_get_machine_config( + mocker: MockerFixture, + mock_general_config: dict[str, Any], + mock_tem_shared_config: dict[str, Any], + mock_instrument_config: dict[str, Any], + mock_hierarchical_machine_config_yaml: Path, + mock_standard_machine_config_yaml: Path, + test_params: tuple[str, list[str]], +): + # Unpack test params + config_to_test, instrument_names = test_params + + # Set up mocks + mock_settings = mocker.patch("murfey.util.config.settings", spec=Settings) + + # Run 'get_machine_config' using different instrument name parameters + for i in instrument_names: + # Patch the 'settings' environment variable with the YAML file to test + mock_settings.murfey_machine_configuration = ( + str(mock_hierarchical_machine_config_yaml) + if config_to_test == "hierarchical" + else str(mock_standard_machine_config_yaml) + ) + # Run the function + config = get_machine_config(i) + + # Validate that the config was loaded correctly + assert config + + # Multiple configs should be returned if instrument name was "" + assert len(config) == 2 if i == "" else len(config) == 1 + + # When getting the config for individual microscopes, validate key-by-key + if i != "": + # General info + assert config[i].display_name == mock_instrument_config["display_name"] + assert config[i].image_path == Path(mock_instrument_config["image_path"]) + assert ( + config[i].instrument_type == mock_instrument_config["instrument_type"] + if config_to_test == "hierarchical" + else not config[i].instrument_type + ) + # Hardware & software + assert config[i].camera == mock_instrument_config["camera"] + assert config[i].superres == mock_instrument_config["superres"] + assert config[i].calibrations == mock_instrument_config["calibrations"] + assert ( + config[i].acquisition_software + == mock_tem_shared_config["acquisition_software"] + ) + assert ( + config[i].software_versions + == mock_tem_shared_config["software_versions"] + ) + assert ( + config[i].data_required_substrings + == mock_tem_shared_config["data_required_substrings"] + ) + # Client directory setup + assert config[i].data_directories == [ + Path(p) for p in mock_instrument_config["data_directories"] + ] + assert ( + config[i].analyse_created_directories + == mock_tem_shared_config["analyse_created_directories"] + ) + assert config[i].gain_reference_directory == Path( + mock_tem_shared_config["gain_reference_directory"] + ) + # Data transfer setup + assert ( + config[i].data_transfer_enabled + == mock_tem_shared_config["data_transfer_enabled"] + ) + assert ( + config[i].substrings_blacklist + == mock_tem_shared_config["substrings_blacklist"] + ) + assert config[i].rsync_url == mock_instrument_config["rsync_url"] + assert config[i].rsync_basepath == Path( + mock_instrument_config["rsync_basepath"] + ) + assert config[i].rsync_module == mock_tem_shared_config["rsync_module"] + assert config[i].allow_removal == mock_tem_shared_config["allow_removal"] + assert config[i].upstream_data_directories == { + key: Path(value) + for key, value in mock_tem_shared_config[ + "upstream_data_directories" + ].items() + } + assert config[i].upstream_data_download_directory == Path( + mock_tem_shared_config["upstream_data_download_directory"] + ) + assert ( + config[i].upstream_data_search_strings + == mock_tem_shared_config["upstream_data_search_strings"] + ) + # Data processing setup + assert ( + config[i].processing_enabled + == mock_tem_shared_config["processing_enabled"] + ) + assert ( + config[i].gain_directory_name + == mock_tem_shared_config["gain_directory_name"] + ) + assert ( + config[i].processed_directory_name + == mock_tem_shared_config["processed_directory_name"] + ) + assert ( + config[i].processed_extra_directory + == mock_tem_shared_config["processed_extra_directory"] + ) + assert config[i].recipes == mock_tem_shared_config["recipes"] + assert config[i].default_model == Path( + mock_tem_shared_config["default_model"] + ) + assert ( + config[i].external_executables + == mock_tem_shared_config["external_executables"] + ) + assert ( + config[i].external_executables_eer + == mock_tem_shared_config["external_executables_eer"] + ) + assert ( + config[i].external_environment + == mock_tem_shared_config["external_environment"] + ) + assert config[i].plugin_packages == { + key: Path(value) + for key, value in mock_tem_shared_config["plugin_packages"].items() + } + # Server and network setup + assert config[i].security_configuration_path == Path( + mock_instrument_config["security_configuration_path"] + ) + assert config[i].murfey_url == mock_instrument_config["murfey_url"] + assert ( + config[i].instrument_server_url + == mock_instrument_config["instrument_server_url"] + ) + assert ( + config[i].node_creator_queue + == mock_instrument_config["node_creator_queue"] + ) + # Extra keys + assert config[i].pkg_1 == { + "file_path": "/path/to/pkg_1/file.txt", + "command": [ + "/path/to/executable", + "--some_arg", + "-a", + "./path/to/file", + ], + "step_size": 100, + } + assert config[i].pkg_2 == mock_general_config["pkg_2"]