diff --git a/pydough/mask_server/mask_server.py b/pydough/mask_server/mask_server.py index 38f50bd50..8a16d07e2 100644 --- a/pydough/mask_server/mask_server.py +++ b/pydough/mask_server/mask_server.py @@ -93,7 +93,9 @@ class MaskServerInfo: perform the evaluation. """ - def __init__(self, base_url: str, token: str | None = None): + def __init__( + self, base_url: str, token: str | None = None, api_key: str | None = None + ): """ Initialize the MaskServerInfo with the given server URL. @@ -102,7 +104,7 @@ def __init__(self, base_url: str, token: str | None = None): `token`: Optional authentication token for the server. """ self.connection: ServerConnection = ServerConnection( - base_url=base_url, token=token + base_url=base_url, token=token, api_key=api_key ) def get_server_response_case(self, server_case: str) -> MaskServerResponse: diff --git a/pydough/mask_server/server_connection.py b/pydough/mask_server/server_connection.py index d21025c18..f02095a81 100644 --- a/pydough/mask_server/server_connection.py +++ b/pydough/mask_server/server_connection.py @@ -50,13 +50,21 @@ class ServerRequest: Optional headers to include in the request. """ + def add_header(self, header: dict[str, str]) -> None: + """ + Adds or updates headers. + """ + self.headers.update(header) + class ServerConnection: """ Class that manages the connection to the server. """ - def __init__(self, base_url: str, token: str | None = None): + def __init__( + self, base_url: str, token: str | None = None, api_key: str | None = None + ): """ Initialize the server connection with the given base URL and token if given. @@ -67,6 +75,7 @@ def __init__(self, base_url: str, token: str | None = None): """ self.base_url = base_url.rstrip("/") self.token = token + self.api_key = api_key self._client = httpx.Client(base_url=self.base_url) def set_timeout(self, timeout: float) -> None: @@ -111,9 +120,11 @@ def send_server_request(self, request: ServerRequest) -> dict: try: method: Callable[..., Response] = self.method_mapping(request.method) - headers = request.headers.copy() if request.headers else {} if self.token: - headers.setdefault("Authorization", f"Bearer {self.token}") + request.add_header({"Authorization": f"Bearer {self.token}"}) + if self.api_key: + request.add_header({"X-API-Key": self.api_key}) + headers = request.headers.copy() if request.headers else {} kwargs = {"headers": headers} # Choose params vs json depending on method diff --git a/tests/mock_server/api_mock_server.py b/tests/mock_server/api_mock_server.py index 2a1b125d9..9167f9e26 100644 --- a/tests/mock_server/api_mock_server.py +++ b/tests/mock_server/api_mock_server.py @@ -28,23 +28,28 @@ class RequestPayload(BaseModel): expression_format: dict[str, str] = {"name": "linear", "version": "0.2.0"} -def verify_token(request: Request): +def authentication(request: Request): auth_header = request.headers.get("Authorization", None) + api_key = request.headers.get("X-API-Key", None) - if auth_header and auth_header != "Bearer test-token-123": + if (auth_header and auth_header != "Bearer test-token-123") or ( + api_key and api_key != "api-key-123" + ): raise HTTPException(status_code=401, detail="Unauthorized request") return True @app.get("/health") -def health(request: Request, authorized: bool = Depends(verify_token)): +def health(request: Request, authorized: bool = Depends(authentication)): return {"status": "ok"} @app.post("/v1/predicates/batch-evaluate") def batch_evaluate( - request: Request, payload: RequestPayload, authorized: bool = Depends(verify_token) + request: Request, + payload: RequestPayload, + authorized: bool = Depends(authentication), ): responses: list[dict] = [] for item in payload.items: diff --git a/tests/test_mock_mask_server.py b/tests/test_mock_mask_server.py index 4d0e44479..2b51b38a6 100644 --- a/tests/test_mock_mask_server.py +++ b/tests/test_mock_mask_server.py @@ -16,9 +16,10 @@ @pytest.mark.server @pytest.mark.parametrize( - "token, batch, answer", + "token, api_key, batch, answer", [ pytest.param( + None, None, [ MaskServerInput( @@ -81,6 +82,7 @@ id="alternated_supported_response", ), pytest.param( + None, None, [ MaskServerInput( @@ -102,6 +104,7 @@ id="single_supported_response", ), pytest.param( + None, None, [ MaskServerInput( @@ -120,6 +123,7 @@ ), pytest.param( "test-token-123", + None, [ MaskServerInput( table_path="srv.db.orders", @@ -141,8 +145,57 @@ ], id="with_token", ), + pytest.param( + None, + "api-key-123", + [ + MaskServerInput( + table_path="srv.db.orders", + column_name="order_date", + expression=["BETWEEN", 3, "__col__", "2025-01-01", "2025-02-01"], + ), + ], + [ + MaskServerOutput( + response_case=MaskServerResponse.IN_ARRAY, + payload=[ + "2025-01-01", + "2025-01-02", + "2025-01-03", + "2025-01-04", + "2025-01-05", + ], + ), + ], + id="with_api_key", + ), pytest.param( "test-token-123", + "api-key-123", + [ + MaskServerInput( + table_path="srv.db.orders", + column_name="order_date", + expression=["BETWEEN", 3, "__col__", "2025-01-01", "2025-02-01"], + ), + ], + [ + MaskServerOutput( + response_case=MaskServerResponse.IN_ARRAY, + payload=[ + "2025-01-01", + "2025-01-02", + "2025-01-03", + "2025-01-04", + "2025-01-05", + ], + ), + ], + id="with_token_and_api_key", + ), + pytest.param( + "test-token-123", + None, [ MaskServerInput( table_path="srv.db.tbl", @@ -159,6 +212,7 @@ id="booleans", ), pytest.param( + None, None, [ MaskServerInput( @@ -196,6 +250,7 @@ id="decimal_and_money", ), pytest.param( + None, None, [ MaskServerInput( @@ -232,7 +287,8 @@ ], ) def test_mock_mask_server( - token: str, + token: str | None, + api_key: str | None, batch: list[MaskServerInput], answer: list[MaskServerOutput], mock_server_setup, @@ -243,7 +299,7 @@ def test_mock_mask_server( """ mask_server: MaskServerInfo = MaskServerInfo( - base_url="http://localhost:8000", token=token + base_url="http://localhost:8000", token=token, api_key=api_key ) # Doing the request @@ -258,18 +314,18 @@ def test_mock_mask_server( @pytest.mark.server @pytest.mark.parametrize( - "base_url, token, batch, error_msg", + "token, api_key, batch, error_msg", [ pytest.param( - "http://localhost:8000", + None, None, [], "Batch cannot be empty.", id="empty_list_request", ), pytest.param( - "http://localhost:8000", "bad_token_123", + None, [ MaskServerInput( table_path="srv.db.tbl", @@ -280,11 +336,37 @@ def test_mock_mask_server( "Bad response 401: Unauthorized request", id="wrong_token", ), + pytest.param( + None, + "bad_api_key_123", + [ + MaskServerInput( + table_path="srv.db.tbl", + column_name="col", + expression=["OR", 2, "__col__", 5], + ) + ], + "Bad response 401: Unauthorized request", + id="wrong_api_key", + ), + pytest.param( + "bad_token_123", + "bad_api-key-123", + [ + MaskServerInput( + table_path="srv.db.tbl", + column_name="col", + expression=["OR", 2, "__col__", 5], + ) + ], + "Bad response 401: Unauthorized request", + id="wrong_token_and_api_key", + ), ], ) def test_mock_mask_server_errors( - base_url: str, token: str | None, + api_key: str | None, batch: list[MaskServerInput], error_msg: str, mock_server_setup, @@ -293,7 +375,9 @@ def test_mock_mask_server_errors( Testing that the MaskServer raises an exception with the expected error message """ with pytest.raises(Exception, match=re.escape(error_msg)): - mask_server: MaskServerInfo = MaskServerInfo(base_url=base_url, token=token) + mask_server: MaskServerInfo = MaskServerInfo( + base_url="http://localhost:8000", token=token, api_key=api_key + ) mask_server.connection.set_timeout(0.5) # Doing the request