Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pydough/mask_server/mask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ class MaskServerInfo:
perform the evaluation.
"""

def __init__(self, base_url: str, token: str | None = None):
def __init__(
self, base_url: str, token: str | None = None, api_key: str | None = None
):
"""
Initialize the MaskServerInfo with the given server URL.

Expand All @@ -102,7 +104,7 @@ def __init__(self, base_url: str, token: str | None = None):
`token`: Optional authentication token for the server.
"""
self.connection: ServerConnection = ServerConnection(
base_url=base_url, token=token
base_url=base_url, token=token, api_key=api_key
)

def get_server_response_case(self, server_case: str) -> MaskServerResponse:
Expand Down
17 changes: 14 additions & 3 deletions pydough/mask_server/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ class ServerRequest:
Optional headers to include in the request.
"""

def add_header(self, header: dict[str, str]) -> None:
"""
Adds or updates headers.
"""
self.headers.update(header)


class ServerConnection:
"""
Class that manages the connection to the server.
"""

def __init__(self, base_url: str, token: str | None = None):
def __init__(
self, base_url: str, token: str | None = None, api_key: str | None = None
):
"""
Initialize the server connection with the given base URL and token if
given.
Expand All @@ -67,6 +75,7 @@ def __init__(self, base_url: str, token: str | None = None):
"""
self.base_url = base_url.rstrip("/")
self.token = token
self.api_key = api_key
self._client = httpx.Client(base_url=self.base_url)

def set_timeout(self, timeout: float) -> None:
Expand Down Expand Up @@ -111,9 +120,11 @@ def send_server_request(self, request: ServerRequest) -> dict:
try:
method: Callable[..., Response] = self.method_mapping(request.method)

headers = request.headers.copy() if request.headers else {}
if self.token:
headers.setdefault("Authorization", f"Bearer {self.token}")
request.add_header({"Authorization": f"Bearer {self.token}"})
if self.api_key:
request.add_header({"X-API-Key": self.api_key})
headers = request.headers.copy() if request.headers else {}
kwargs = {"headers": headers}

# Choose params vs json depending on method
Expand Down
13 changes: 9 additions & 4 deletions tests/mock_server/api_mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,28 @@ class RequestPayload(BaseModel):
expression_format: dict[str, str] = {"name": "linear", "version": "0.2.0"}


def verify_token(request: Request):
def authentication(request: Request):
auth_header = request.headers.get("Authorization", None)
api_key = request.headers.get("X-API-Key", None)

if auth_header and auth_header != "Bearer test-token-123":
if (auth_header and auth_header != "Bearer test-token-123") or (
api_key and api_key != "api-key-123"
):
raise HTTPException(status_code=401, detail="Unauthorized request")

return True


@app.get("/health")
def health(request: Request, authorized: bool = Depends(verify_token)):
def health(request: Request, authorized: bool = Depends(authentication)):
return {"status": "ok"}


@app.post("/v1/predicates/batch-evaluate")
def batch_evaluate(
request: Request, payload: RequestPayload, authorized: bool = Depends(verify_token)
request: Request,
payload: RequestPayload,
authorized: bool = Depends(authentication),
):
responses: list[dict] = []
for item in payload.items:
Expand Down
100 changes: 92 additions & 8 deletions tests/test_mock_mask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

@pytest.mark.server
@pytest.mark.parametrize(
"token, batch, answer",
"token, api_key, batch, answer",
[
pytest.param(
None,
None,
[
MaskServerInput(
Expand Down Expand Up @@ -81,6 +82,7 @@
id="alternated_supported_response",
),
pytest.param(
None,
None,
[
MaskServerInput(
Expand All @@ -102,6 +104,7 @@
id="single_supported_response",
),
pytest.param(
None,
None,
[
MaskServerInput(
Expand All @@ -120,6 +123,7 @@
),
pytest.param(
"test-token-123",
None,
[
MaskServerInput(
table_path="srv.db.orders",
Expand All @@ -141,8 +145,57 @@
],
id="with_token",
),
pytest.param(
None,
"api-key-123",
[
MaskServerInput(
table_path="srv.db.orders",
column_name="order_date",
expression=["BETWEEN", 3, "__col__", "2025-01-01", "2025-02-01"],
),
],
[
MaskServerOutput(
response_case=MaskServerResponse.IN_ARRAY,
payload=[
"2025-01-01",
"2025-01-02",
"2025-01-03",
"2025-01-04",
"2025-01-05",
],
),
],
id="with_api_key",
),
pytest.param(
"test-token-123",
"api-key-123",
[
MaskServerInput(
table_path="srv.db.orders",
column_name="order_date",
expression=["BETWEEN", 3, "__col__", "2025-01-01", "2025-02-01"],
),
],
[
MaskServerOutput(
response_case=MaskServerResponse.IN_ARRAY,
payload=[
"2025-01-01",
"2025-01-02",
"2025-01-03",
"2025-01-04",
"2025-01-05",
],
),
],
id="with_token_and_api_key",
),
pytest.param(
"test-token-123",
None,
[
MaskServerInput(
table_path="srv.db.tbl",
Expand All @@ -159,6 +212,7 @@
id="booleans",
),
pytest.param(
None,
None,
[
MaskServerInput(
Expand Down Expand Up @@ -196,6 +250,7 @@
id="decimal_and_money",
),
pytest.param(
None,
None,
[
MaskServerInput(
Expand Down Expand Up @@ -232,7 +287,8 @@
],
)
def test_mock_mask_server(
token: str,
token: str | None,
api_key: str | None,
batch: list[MaskServerInput],
answer: list[MaskServerOutput],
mock_server_setup,
Expand All @@ -243,7 +299,7 @@ def test_mock_mask_server(
"""

mask_server: MaskServerInfo = MaskServerInfo(
base_url="http://localhost:8000", token=token
base_url="http://localhost:8000", token=token, api_key=api_key
)

# Doing the request
Expand All @@ -258,18 +314,18 @@ def test_mock_mask_server(

@pytest.mark.server
@pytest.mark.parametrize(
"base_url, token, batch, error_msg",
"token, api_key, batch, error_msg",
[
pytest.param(
"http://localhost:8000",
None,
None,
[],
"Batch cannot be empty.",
id="empty_list_request",
),
pytest.param(
"http://localhost:8000",
"bad_token_123",
None,
[
MaskServerInput(
table_path="srv.db.tbl",
Expand All @@ -280,11 +336,37 @@ def test_mock_mask_server(
"Bad response 401: Unauthorized request",
id="wrong_token",
),
pytest.param(
None,
"bad_api_key_123",
[
MaskServerInput(
table_path="srv.db.tbl",
column_name="col",
expression=["OR", 2, "__col__", 5],
)
],
"Bad response 401: Unauthorized request",
id="wrong_api_key",
),
pytest.param(
"bad_token_123",
"bad_api-key-123",
[
MaskServerInput(
table_path="srv.db.tbl",
column_name="col",
expression=["OR", 2, "__col__", 5],
)
],
"Bad response 401: Unauthorized request",
id="wrong_token_and_api_key",
),
],
)
def test_mock_mask_server_errors(
base_url: str,
token: str | None,
api_key: str | None,
batch: list[MaskServerInput],
error_msg: str,
mock_server_setup,
Expand All @@ -293,7 +375,9 @@ def test_mock_mask_server_errors(
Testing that the MaskServer raises an exception with the expected error message
"""
with pytest.raises(Exception, match=re.escape(error_msg)):
mask_server: MaskServerInfo = MaskServerInfo(base_url=base_url, token=token)
mask_server: MaskServerInfo = MaskServerInfo(
base_url="http://localhost:8000", token=token, api_key=api_key
)
mask_server.connection.set_timeout(0.5)

# Doing the request
Expand Down