diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py new file mode 100644 index 000000000..e048ab387 --- /dev/null +++ b/tests/unit/test_client.py @@ -0,0 +1,125 @@ +import uuid + +import pytest +from fastapi import APIRouter, FastAPI +from fastapi.testclient import TestClient + +from tracecat.auth.credentials import RoleACL +from tracecat.logger import logger +from tracecat.types.auth import AccessLevel, Role + +WORKSPACE_ID = uuid.uuid4() +USER_ID = uuid.uuid4() +SERVICE_ID = "tracecat-api" + + +router = APIRouter() + + +@router.get("/test") +async def test(role: Role = RoleACL(allow_user=True, allow_service=True)): + return {"test": "test", "role": role.model_dump()} + + +@pytest.fixture +def test_client(test_role): + """Create a FastAPI test client with our router and mocked RoleACL.""" + app = FastAPI() + + app.include_router(router) + return TestClient(app) + + +def test_endpoint_can_be_hit(test_client: TestClient): + """ + Test that the /test endpoint is accessible and returns the expected response. + + Args: + test_client: Pytest fixture that provides a configured TestClient + """ + response = test_client.get("/test") + + assert response.status_code == 200 + assert response.json() == { + "test": "test", + "role": { + "type": "user", + "access_level": 0, + "workspace_id": str(WORKSPACE_ID), + "user_id": str(USER_ID), + "service_id": SERVICE_ID, + }, + } + + +@pytest.fixture +def override_test_role(): + """ + Fixture that creates a different test role for dependency override testing. + + + Role: A test role with modified values + """ + + async def _override_test_role(*args, **kwargs): + return Role( + type="service", + workspace_id=WORKSPACE_ID, + user_id=USER_ID, + service_id=SERVICE_ID, + access_level=AccessLevel.ADMIN, + ) + + return _override_test_role + + +@pytest.fixture +def test_client_with_override( + monkeypatch: pytest.MonkeyPatch, override_test_role: Role +): + """ + Create a FastAPI test client with our router and overridden role dependency. + + Args: + override_test_role: Pytest fixture providing the override role + + Returns: + TestClient: Configured test client with dependency override + """ + + app = FastAPI() + app.include_router(router) + + # Override the role dependency with our custom role + monkeypatch.setattr( + "tracecat.auth.credentials._role_dependency", + override_test_role, + ) + + return TestClient(app) + + +def test_endpoint_access_with_override(test_client_with_override: TestClient): + """ + Test that the /test endpoint uses the overridden role dependency. + + Args: + test_client_with_override: Pytest fixture that provides a TestClient with overridden dependencies + """ + response = test_client_with_override.get( + "/test", params={"workspace_id": str(WORKSPACE_ID)} + ) + details = response.json() + logger.info("RESPONSE", response=response, details=details) + + assert response.status_code == 200 + assert response.json() == { + "test": "test", + "role": { + "type": "service", + "access_level": AccessLevel.ADMIN, + "workspace_id": str(WORKSPACE_ID), + "user_id": str(USER_ID), + "service_id": SERVICE_ID, + }, + } diff --git a/tests/unit/test_dependencies.py b/tests/unit/test_dependencies.py new file mode 100644 index 000000000..44a5acde1 --- /dev/null +++ b/tests/unit/test_dependencies.py @@ -0,0 +1,121 @@ +import uuid +from collections.abc import AsyncGenerator + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from tracecat.contexts import RunContext +from tracecat.dsl.common import create_default_execution_context +from tracecat.dsl.models import ActionStatement, RunActionInput +from tracecat.executor.router import router +from tracecat.logger import logger +from tracecat.registry.actions.models import RegistryActionErrorInfo +from tracecat.types.auth import Role +from tracecat.types.exceptions import WrappedExecutionError + +# Unit tests + + +@pytest.fixture +def mock_workspace_id(): + return uuid.uuid4() + + +@pytest.fixture +def mock_role(mock_org_id: uuid.UUID, mock_workspace_id: uuid.UUID): + return Role( + type="service", + user_id=mock_org_id, + workspace_id=mock_workspace_id, + service_id="tracecat-runner", + ) + + +@pytest.fixture +def override_role_dependency(mock_role: Role): + async def dep(*args, **kwargs): + return mock_role + + return dep + + +@pytest.fixture +async def test_client_noauth( + monkeypatch: pytest.MonkeyPatch, override_role_dependency: Role +) -> AsyncGenerator[AsyncClient, None]: + app = FastAPI() + app.include_router(router) + # Override the role dependency with our custom role + monkeypatch.setattr( + "tracecat.auth.credentials._role_dependency", override_role_dependency + ) + + host = "localhost" + port = 8000 + async with AsyncClient( + transport=ASGITransport(app=app, client=(host, port)), # type: ignore + base_url="http://test", + ) as client: + yield client + + +async def mock_dispatch_error(*args, **kwargs): + """Mock dispatch that raises a WrappedExecutionError""" + error_info = RegistryActionErrorInfo( + type="ValueError", + message="Test error message", + action_name="test.action", + filename="test_file.py", + function="test_function", + ) + raise WrappedExecutionError(error=error_info) + + +@pytest.mark.anyio +async def test_run_action_endpoint_error_handling( + test_client_noauth: AsyncClient, + mock_run_context: RunContext, + mock_role: Role, + monkeypatch: pytest.MonkeyPatch, +): + """Test that the run_action endpoint properly handles wrapped execution errors.""" + + # Mock the dispatch function to raise our error + if not mock_role.workspace_id: + return pytest.fail("Workspace ID is not set in test role") + + monkeypatch.setattr( + "tracecat.executor.service.dispatch_action_on_cluster", mock_dispatch_error + ) + + # Create test input + input_data = RunActionInput( + task=ActionStatement( + ref="test", + action="test.action", + args={}, + run_if=None, + for_each=None, + ), + exec_context=create_default_execution_context(), + run_context=mock_run_context, + ).model_dump(mode="json") + + logger.warning("BASE URL", base_url=test_client_noauth.base_url) + response = await test_client_noauth.post( + "/run/test.action", + params={"workspace_id": str(mock_role.workspace_id)}, + json=input_data, + ) + + # Verify response + logger.info("RESPONSE", response=response) + assert response.status_code == 500 + error_detail = response.json() + err_info = RegistryActionErrorInfo.model_validate(error_detail["detail"]) + assert err_info.type == "ValueError" + assert err_info.message == "Test error message" + assert err_info.action_name == "test.action" + assert err_info.filename == "test_file.py" + assert err_info.function == "mock_dispatch_error"