Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(engine): Add router tests #685

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import uuid

import pytest
from fastapi import APIRouter, FastAPI
from fastapi.testclient import TestClient

from tracecat.auth.credentials import RoleACL
from tracecat.logger import logger
from tracecat.types.auth import AccessLevel, Role

WORKSPACE_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
SERVICE_ID = "tracecat-api"


router = APIRouter()


@router.get("/test")
async def test(role: Role = RoleACL(allow_user=True, allow_service=True)):
return {"test": "test", "role": role.model_dump()}


@pytest.fixture
def test_client(test_role):
"""Create a FastAPI test client with our router and mocked RoleACL."""
app = FastAPI()

app.include_router(router)
return TestClient(app)


def test_endpoint_can_be_hit(test_client: TestClient):
"""
Test that the /test endpoint is accessible and returns the expected response.

Args:
test_client: Pytest fixture that provides a configured TestClient
"""
response = test_client.get("/test")

assert response.status_code == 200
assert response.json() == {
"test": "test",
"role": {
"type": "user",
"access_level": 0,
"workspace_id": str(WORKSPACE_ID),
"user_id": str(USER_ID),
"service_id": SERVICE_ID,
},
}


@pytest.fixture
def override_test_role():
"""
Fixture that creates a different test role for dependency override testing.


Role: A test role with modified values
"""

async def _override_test_role(*args, **kwargs):
return Role(
type="service",
workspace_id=WORKSPACE_ID,
user_id=USER_ID,
service_id=SERVICE_ID,
access_level=AccessLevel.ADMIN,
)

return _override_test_role


@pytest.fixture
def test_client_with_override(
monkeypatch: pytest.MonkeyPatch, override_test_role: Role
):
"""
Create a FastAPI test client with our router and overridden role dependency.

Args:
override_test_role: Pytest fixture providing the override role

Returns:
TestClient: Configured test client with dependency override
"""

app = FastAPI()
app.include_router(router)

# Override the role dependency with our custom role
monkeypatch.setattr(
"tracecat.auth.credentials._role_dependency",
override_test_role,
)

return TestClient(app)


def test_endpoint_access_with_override(test_client_with_override: TestClient):
"""
Test that the /test endpoint uses the overridden role dependency.

Args:
test_client_with_override: Pytest fixture that provides a TestClient with overridden dependencies
"""
response = test_client_with_override.get(
"/test", params={"workspace_id": str(WORKSPACE_ID)}
)
details = response.json()
logger.info("RESPONSE", response=response, details=details)

assert response.status_code == 200
assert response.json() == {
"test": "test",
"role": {
"type": "service",
"access_level": AccessLevel.ADMIN,
"workspace_id": str(WORKSPACE_ID),
"user_id": str(USER_ID),
"service_id": SERVICE_ID,
},
}
121 changes: 121 additions & 0 deletions tests/unit/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import uuid
from collections.abc import AsyncGenerator

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient

from tracecat.contexts import RunContext
from tracecat.dsl.common import create_default_execution_context
from tracecat.dsl.models import ActionStatement, RunActionInput
from tracecat.executor.router import router
from tracecat.logger import logger
from tracecat.registry.actions.models import RegistryActionErrorInfo
from tracecat.types.auth import Role
from tracecat.types.exceptions import WrappedExecutionError

# Unit tests


@pytest.fixture
def mock_workspace_id():
return uuid.uuid4()


@pytest.fixture
def mock_role(mock_org_id: uuid.UUID, mock_workspace_id: uuid.UUID):
return Role(
type="service",
user_id=mock_org_id,
workspace_id=mock_workspace_id,
service_id="tracecat-runner",
)


@pytest.fixture
def override_role_dependency(mock_role: Role):
async def dep(*args, **kwargs):
return mock_role

return dep


@pytest.fixture
async def test_client_noauth(
monkeypatch: pytest.MonkeyPatch, override_role_dependency: Role
) -> AsyncGenerator[AsyncClient, None]:
app = FastAPI()
app.include_router(router)
# Override the role dependency with our custom role
monkeypatch.setattr(
"tracecat.auth.credentials._role_dependency", override_role_dependency
)

host = "localhost"
port = 8000
async with AsyncClient(
transport=ASGITransport(app=app, client=(host, port)), # type: ignore
base_url="http://test",
) as client:
yield client


async def mock_dispatch_error(*args, **kwargs):
"""Mock dispatch that raises a WrappedExecutionError"""
error_info = RegistryActionErrorInfo(
type="ValueError",
message="Test error message",
action_name="test.action",
filename="test_file.py",
function="test_function",
)
raise WrappedExecutionError(error=error_info)


@pytest.mark.anyio
async def test_run_action_endpoint_error_handling(
test_client_noauth: AsyncClient,
mock_run_context: RunContext,
mock_role: Role,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the run_action endpoint properly handles wrapped execution errors."""

# Mock the dispatch function to raise our error
if not mock_role.workspace_id:
return pytest.fail("Workspace ID is not set in test role")

monkeypatch.setattr(
"tracecat.executor.service.dispatch_action_on_cluster", mock_dispatch_error
)

# Create test input
input_data = RunActionInput(
task=ActionStatement(
ref="test",
action="test.action",
args={},
run_if=None,
for_each=None,
),
exec_context=create_default_execution_context(),
run_context=mock_run_context,
).model_dump(mode="json")

logger.warning("BASE URL", base_url=test_client_noauth.base_url)
response = await test_client_noauth.post(
"/run/test.action",
params={"workspace_id": str(mock_role.workspace_id)},
json=input_data,
)

# Verify response
logger.info("RESPONSE", response=response)
assert response.status_code == 500
error_detail = response.json()
err_info = RegistryActionErrorInfo.model_validate(error_detail["detail"])
assert err_info.type == "ValueError"
assert err_info.message == "Test error message"
assert err_info.action_name == "test.action"
assert err_info.filename == "test_file.py"
assert err_info.function == "mock_dispatch_error"
Loading