Skip to content

Commit

Permalink
feat+fix(engine)!: Support enums and defaults in Action Templates (#644)
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Lo <[email protected]>
  • Loading branch information
topher-lo authored Dec 23, 2024
1 parent 0d820b9 commit 74bb3d7
Show file tree
Hide file tree
Showing 10 changed files with 401 additions and 132 deletions.
8 changes: 2 additions & 6 deletions registry/tracecat_registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
"Could not import tracecat. Please install `tracecat` to use the registry."
) from None

from tracecat_registry._internal import registry, secrets
from tracecat_registry._internal.exceptions import ( # noqa: E402
RegistryActionError,
RegistryValidationError,
)
from tracecat_registry._internal import exceptions, registry, secrets
from tracecat_registry._internal.exceptions import RegistryActionError
from tracecat_registry._internal.logger import logger
from tracecat_registry._internal.models import RegistrySecret

Expand All @@ -24,6 +21,5 @@
"logger",
"secrets",
"exceptions",
"RegistryValidationError",
"RegistryActionError",
]
11 changes: 0 additions & 11 deletions registry/tracecat_registry/_internal/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any

from pydantic_core import ValidationError


class TracecatException(Exception):
"""Tracecat generic user-facing exception"""
Expand All @@ -13,12 +11,3 @@ def __init__(self, *args, detail: Any | None = None, **kwargs):

class RegistryActionError(TracecatException):
"""Exception raised when a registry UDF error occurs."""


class RegistryValidationError(TracecatException):
"""Exception raised when a registry validation error occurs."""

def __init__(self, *args, key: str, err: ValidationError | str | None = None):
super().__init__(*args)
self.key = key
self.err = err
86 changes: 85 additions & 1 deletion tests/unit/test_expectation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from datetime import datetime

import lark
import pytest
from pydantic import ValidationError

Expand Down Expand Up @@ -247,4 +248,87 @@ def test_validate_schema_failure():
assert "end_time" in str(e.value)


# ... existing tests ...
@pytest.mark.parametrize(
"status,priority",
[
("PENDING", "low"),
("running", "low"),
("Completed", "low"),
],
)
def test_validate_schema_with_enum(status, priority):
schema = {
"status": {
"type": 'enum["PENDING", "running", "Completed"]',
"description": "The status of the job",
},
"priority": {
"type": 'enum["low", "medium", "high"]',
"description": "The priority level",
"default": "low",
},
}

mapped = {k: ExpectedField(**v) for k, v in schema.items()}
model = create_expectation_model(mapped)

# Test with provided priority
model_instance = model(status=status, priority=priority)
assert model_instance.status.__class__.__name__ == "EnumStatus"
assert model_instance.priority.__class__.__name__ == "EnumPriority"

# Test default priority
model_instance_default = model(status=status)
assert str(model_instance_default.priority) == "low"


@pytest.mark.parametrize(
"schema_def,error_type,error_message",
[
(
{"status": {"type": "enum[]", "description": "Empty enum"}},
lark.exceptions.UnexpectedCharacters,
"No terminal matches ']'",
),
(
{
"status": {
"type": 'enum["Pending", "PENDING"]',
"description": "Duplicate values",
}
},
lark.exceptions.VisitError,
"Duplicate enum value",
),
],
)
def test_validate_schema_with_invalid_enum_definition(
schema_def, error_type, error_message
):
with pytest.raises(error_type, match=error_message):
mapped = {k: ExpectedField(**v) for k, v in schema_def.items()}
create_expectation_model(mapped)


@pytest.mark.parametrize(
"invalid_value",
[
"invalid_status",
"INVALID",
"pending!",
"",
],
)
def test_validate_schema_with_invalid_enum_values(invalid_value):
schema = {
"status": {
"type": 'enum["PENDING", "running", "Completed"]',
"description": "The status of the job",
}
}

mapped = {k: ExpectedField(**v) for k, v in schema.items()}
DynamicModel = create_expectation_model(mapped)

with pytest.raises(ValidationError):
DynamicModel(status=invalid_value)
2 changes: 1 addition & 1 deletion tests/unit/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import textwrap

import pytest
from tracecat_registry import RegistryValidationError

from tracecat.concurrency import GatheringTaskGroup
from tracecat.registry.actions.models import RegistryActionRead
from tracecat.registry.actions.service import RegistryActionsService
from tracecat.registry.repository import Repository
from tracecat.types.exceptions import RegistryValidationError


@pytest.fixture
Expand Down
Loading

0 comments on commit 74bb3d7

Please sign in to comment.