diff --git a/src/daq_config_server/app.py b/src/daq_config_server/app.py index dec2af1..ee45aea 100644 --- a/src/daq_config_server/app.py +++ b/src/daq_config_server/app.py @@ -56,7 +56,7 @@ class ParamList(BaseModel): @app.get(ENDPOINTS.BL_PARAM) -def get_all_beamline_parameters(param_list_data: ParamList | None): +def get_all_beamline_parameters(param_list_data: ParamList | None = None): """Get a dict of all the current beamline parameters.""" assert BEAMLINE_PARAMS is not None if param_list_data is None: @@ -140,12 +140,20 @@ async def catch_all(request: Request, full_path: str): } -def main(args): - global BEAMLINE_PARAM_PATH +def _load_beamline_params(): global BEAMLINE_PARAMS - if args.dev: + BEAMLINE_PARAMS = GDABeamlineParameters.from_file(BEAMLINE_PARAM_PATH) + + +def _set_beamline_param_path(dev: bool = True): + global BEAMLINE_PARAM_PATH + if dev: BEAMLINE_PARAM_PATH = "tests/test_data/beamline_parameters.txt" else: BEAMLINE_PARAM_PATH = BEAMLINE_PARAMETER_PATHS["i03"] - BEAMLINE_PARAMS = GDABeamlineParameters.from_file(BEAMLINE_PARAM_PATH) + + +def main(args): + _set_beamline_param_path(args.dev) + _load_beamline_params() uvicorn.run(app="daq_config_server.app:app", host="0.0.0.0", port=8555) diff --git a/tests/test_api.py b/tests/test_api.py index 0825231..01e0252 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,7 +7,7 @@ from fastapi.testclient import TestClient from mockito import when -from daq_config_server.app import app +from daq_config_server import app from daq_config_server.beamline_parameters import GDABeamlineParameters from daq_config_server.constants import ENDPOINTS @@ -17,7 +17,9 @@ @pytest.fixture def mock_app(): - return TestClient(app) + app._set_beamline_param_path(True) + app._load_beamline_params() + return TestClient(app.app) async def _assert_get_and_response(