diff --git a/pyproject.toml b/pyproject.toml index 129e39b..480f4ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] description = "A service to put and get your config values from" -dependencies = [] +dependencies = ["urllib3", "requests"] dynamic = ["version"] license.file = "LICENSE" readme = "README.md" @@ -56,6 +56,10 @@ addopts = """ --tb=native -vv --asyncio-mode=auto """ +markers = """ + uses_live_server: mark a test which uses the live config server +""" + # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings filterwarnings = "error" # Doctest python code in docs, python code in src docstrings, test functions in tests @@ -86,7 +90,7 @@ allowlist_externals = commands = pre-commit: pre-commit run --all-files --show-diff-on-failure {posargs} type-checking: pyright src tests {posargs} - tests: pytest --cov=daq_config_server --cov-report term --cov-report xml:cov.xml {posargs} + tests: pytest -m "not uses_live_server" --cov=daq_config_server --cov-report term --cov-report xml:cov.xml {posargs} """ [tool.ruff] diff --git a/src/daq_config_server/app.py b/src/daq_config_server/app.py index 485ba41..dec2af1 100644 --- a/src/daq_config_server/app.py +++ b/src/daq_config_server/app.py @@ -3,6 +3,7 @@ import uvicorn from fastapi import FastAPI, Request, Response, status from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel from redis import Redis from .beamline_parameters import ( @@ -50,11 +51,17 @@ def get_beamline_parameter(param: str): return {param: BEAMLINE_PARAMS.params.get(param)} +class ParamList(BaseModel): + param_list: list[str] + + @app.get(ENDPOINTS.BL_PARAM) -def get_all_beamline_parameters(): +def get_all_beamline_parameters(param_list_data: ParamList | None): """Get a dict of all the current beamline parameters.""" assert BEAMLINE_PARAMS is not None - return BEAMLINE_PARAMS.params + if param_list_data is None: + return BEAMLINE_PARAMS.params + return {k: BEAMLINE_PARAMS.params.get(k) for k in param_list_data.param_list} @app.get(ENDPOINTS.FEATURE) diff --git a/src/daq_config_server/client.py b/src/daq_config_server/client.py index 1db2c01..2be9bdd 100644 --- a/src/daq_config_server/client.py +++ b/src/daq_config_server/client.py @@ -1,22 +1,17 @@ -import json -from http.client import HTTPConnection, HTTPSConnection from logging import Logger, getLogger -from typing import TypeVar +from typing import Any, TypeVar -from urllib3.util import parse_url +import requests from .constants import ENDPOINTS T = TypeVar("T") +BlParamDType = str | int | float | bool class ConfigServer: def __init__(self, url: str, log: Logger | None = None) -> None: - self._url = parse_url(url) - self._uri_prefix = self._url.request_uri if self._url.request_uri != "/" else "" - if self._url.scheme != "http" and self._url.scheme != "https": - raise ValueError("ConfigServer must use HTTP or HTTPS!") - self._Conn = HTTPSConnection if self._url.scheme == "https" else HTTPConnection + self._url = url.rstrip("/") self._log = log if log else getLogger("daq_config_server.client") def _get( @@ -24,32 +19,26 @@ def _get( endpoint: str, item: str | None = None, options: dict[str, str] | None = None, + data: dict[str, Any] | None = None, ): - req_item = f"/{item}" if item else "" - conn = self._Conn(self._url.host, self._url.port or self._url.scheme) - req_ops = ( - f"?{''.join(f'{k}={v}&' for k,v in options.items())}"[:-1] - if options - else "" + r = requests.get( + self._url + endpoint + (f"/{item}" if item else ""), options, json=data ) - complete_req = self._uri_prefix + endpoint + req_item + req_ops - conn.connect() - conn.request("GET", complete_req) - resp = conn.getresponse() - assert resp.status == 200, f"Failed to get response: {resp!r}" - body = json.loads(resp.read()) - assert item in body, f"Malformed response: {body} does not contain {item}" - resp.close() - conn.close() - return body[item] + return r.json() - def get_beamline_param(self, param: str) -> str | int | float | bool | None: - return self._get(ENDPOINTS.BL_PARAM, param) + def get_beamline_param(self, param: str) -> BlParamDType | None: + return self._get(ENDPOINTS.BL_PARAM, param).get(param) + + def get_some_beamline_params(self, params: list[str]) -> dict[str, BlParamDType]: + return self._get(ENDPOINTS.BL_PARAM, data={"param_list": params}) + + def get_all_beamline_params(self) -> dict[str, BlParamDType]: + return self._get(ENDPOINTS.BL_PARAM) def get_feature_flag(self, flag_name: str) -> bool | None: """Get the specified feature flag; returns None if it does not exist. Will check that the HTTP response is correct and raise an AssertionError if not.""" - return self._get(ENDPOINTS.FEATURE, flag_name) + return self._get(ENDPOINTS.FEATURE, flag_name).get(flag_name) def get_all_feature_flags(self) -> dict | None: """Get the values for all flags; returns None if it does not exist. Will check @@ -68,7 +57,10 @@ def best_effort_get_feature_flag( doesn't exist or if there is a connection error - in the latter case logs to error.""" try: - return self._get(ENDPOINTS.FEATURE, flag_name) + assert ( + result := self._get(ENDPOINTS.FEATURE, flag_name).get(flag_name) + ) is not None + return result except (AssertionError, OSError): self._log.error( "Encountered an error reading from the config service.", exc_info=True @@ -79,7 +71,10 @@ def best_effort_get_all_feature_flags(self) -> dict[str, bool]: """Get all flags, returns an empty dict if there are no flags, or if there is a connection error - in the latter case logs to error.""" try: - return self._get(ENDPOINTS.FEATURE, options={"get_values": "true"}) + assert ( + result := self._get(ENDPOINTS.FEATURE, options={"get_values": "true"}) + ) is not None + return result except (AssertionError, OSError): self._log.error( "Encountered an error reading from the config service.", exc_info=True diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..0c5fdff --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,54 @@ +import pytest +import requests + +from daq_config_server.client import ConfigServer + +SERVER_ADDRESS = "https://daq-config.diamond.ac.uk/api" +USE_PANDA_FLAG = "use_panda_for_gridscan" + + +@pytest.fixture +def server(): + return ConfigServer(SERVER_ADDRESS) + + +@pytest.mark.uses_live_server +class TestConfigServerClient: + def test_fetch_one_flag_from_server(self, server: ConfigServer): + assert isinstance(server.get_feature_flag(USE_PANDA_FLAG), bool) + + def test_fetch_all_flags_from_server(self, server: ConfigServer): + assert isinstance(flags := server.get_feature_flag_list(), list) + assert all(isinstance(server.get_feature_flag(flag), bool) for flag in flags) + + def test_best_effort_gets_real_flag(self, server: ConfigServer): + assert isinstance(current := server.get_feature_flag(USE_PANDA_FLAG), bool) + assert ( + server.best_effort_get_feature_flag(USE_PANDA_FLAG, not current) is current + ) + + @pytest.mark.skip(reason="ONLY RUN THIS MANUALLY") + def test_fetch_and_set_flags(self, server: ConfigServer): + """ONLY RUN THIS IF YOU ARE SURE NOTHING IS RUNNING USING THE SERVICE!!!""" + flag = USE_PANDA_FLAG + assert isinstance(initial_value := server.get_feature_flag(flag), bool) + r = requests.put( + SERVER_ADDRESS + + f"/featureflag/{flag}?value={str(not initial_value).lower()}" + ) + assert r.json()["success"] is True + assert server.get_feature_flag(flag) is not initial_value + r = requests.put( + SERVER_ADDRESS + f"/featureflag/{flag}?value={str(initial_value).lower()}" + ) + assert r.json()["success"] is True + assert server.get_feature_flag(flag) is initial_value + + def test_get_some_beamline_params(self, server: ConfigServer): + params_list = [ + "miniap_x_ROBOT_LOAD", + "miniap_y_ROBOT_LOAD", + "miniap_z_ROBOT_LOAD", + ] + params = server.get_some_beamline_params(params_list) + assert all(p in params.keys() for p in params_list)