Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(engine): Use parsed workflow trigger inputs #652

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/unit/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from tracecat.secrets.models import SecretCreate, SecretKeyValue
from tracecat.secrets.service import SecretsService
from tracecat.types.auth import Role
from tracecat.types.exceptions import TracecatValidationError
from tracecat.workflow.management.definitions import WorkflowDefinitionsService
from tracecat.workflow.management.management import WorkflowsManagementService

Expand Down Expand Up @@ -2065,3 +2066,175 @@ async def test_workflow_runs_template_for_each(
),
)
assert result == [101, 102, 103, 104, 105]


@pytest.mark.anyio
@pytest.mark.parametrize(
"trigger_inputs,expected_result,should_raise",
[
# Test case 1: All required fields with valid values
(
{"user_id": "[email protected]", "priority": "high", "count": 5},
{
"ACTIONS": {},
"INPUTS": {},
"TRIGGER": {
"user_id": "[email protected]",
"priority": "high",
"count": 5,
},
},
False,
),
# Test case 2: Using default values
(
{"user_id": "[email protected]"},
{
"ACTIONS": {},
"INPUTS": {},
"TRIGGER": {
"user_id": "[email protected]",
"priority": "low",
"count": 10,
},
},
False,
),
# Test case 3: Invalid enum value
(
{"user_id": "[email protected]", "priority": "INVALID"},
None,
True,
),
# Test case 4: Missing required field
(
{"priority": "medium", "count": 3},
None,
True,
),
# Test case 5: No expects defined
(
{"any": "value"},
{
"ACTIONS": {},
"INPUTS": {},
"TRIGGER": {"any": "value"},
},
False,
),
],
ids=[
"valid_all_fields",
"with_defaults",
"invalid_enum",
"missing_required",
"no_expects",
],
)
async def test_workflow_trigger_validation(
trigger_inputs, expected_result, should_raise, test_role, temporal_client
):
"""Test workflow trigger input validation.

This test verifies that:
1. Required trigger inputs are properly validated
2. Default values are correctly applied
3. Enum values are properly validated
4. Missing required fields are rejected
5. Workflows with no expects defined accept any trigger inputs
"""
test_name = f"{test_workflow_trigger_validation.__name__}"
wf_exec_id = generate_test_exec_id(test_name)

# Base DSL with validation
dsl_with_validation = {
"title": "Test Workflow Trigger Validation",
"description": "Test workflow with trigger input validation",
"entrypoint": {
"expects": {
"user_id": {
"type": "str",
"description": "User identifier",
},
"priority": {
"type": 'enum["low", "medium", "high"]',
"description": "Task priority level",
"default": "low",
},
"count": {
"type": "int",
"description": "Number of items",
"default": 10,
},
},
"ref": "start",
},
"actions": [
{
"ref": "start",
"action": "core.transform.reshape",
"args": {"value": "START"},
},
],
"inputs": {},
"returns": None,
"tests": [],
"triggers": [],
}

# DSL without expects for the "no_expects" test case
dsl_without_validation = {
**dsl_with_validation,
"entrypoint": {"ref": "start"},
}

# Use appropriate DSL based on test case
dsl = DSLInput(
**(
dsl_without_validation
if trigger_inputs.get("any") == "value"
else dsl_with_validation
)
)

run_args = DSLRunArgs(
dsl=dsl,
role=test_role,
wf_id=TEST_WF_ID,
trigger_inputs=trigger_inputs,
)

if should_raise:
with pytest.raises(TracecatValidationError) as exc_info:
async with Worker(
temporal_client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=DSLActivities.load() + DSL_UTILITIES,
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
await temporal_client.execute_workflow(
DSLWorkflow.run,
run_args,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=retry_policies["workflow:fail_fast"],
)
# Verify that it's a validation error
assert "ValidationError" in str(exc_info.value)
else:
async with Worker(
temporal_client,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
activities=DSLActivities.load() + DSL_UTILITIES,
workflows=[DSLWorkflow],
workflow_runner=new_sandbox_runner(),
):
result = await temporal_client.execute_workflow(
DSLWorkflow.run,
run_args,
id=wf_exec_id,
task_queue=os.environ["TEMPORAL__CLUSTER_QUEUE"],
retry_policy=retry_policies["workflow:fail_fast"],
)
assert result == expected_result
11 changes: 8 additions & 3 deletions tracecat/dsl/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def validate_trigger_inputs(
}
if isinstance(payload, dict):
# NOTE: We only validate dict payloads for now
validator = create_expectation_model(expects_schema, model_name=model_name)
model = create_expectation_model(expects_schema, model_name=model_name)
try:
validator(**payload)
validated_payload = model(**payload).model_dump(mode="json")
except ValidationError as e:
if raise_exceptions:
raise
Expand All @@ -44,7 +44,12 @@ def validate_trigger_inputs(
msg=f"Validation error in trigger inputs ({e.title}). Please refer to the schema for more details.",
detail={"errors": e.errors()},
)
return ValidationResult(status="success", msg="Trigger inputs are valid.")
result = ValidationResult(
status="success",
msg="Trigger inputs are valid.",
payload=validated_payload,
)
return result


class ValidateTriggerInputsActivityInputs(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion tracecat/dsl/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ async def run(self, args: DSLRunArgs) -> Any:
) from e

# Prepare user facing context
validated_payload = validation_result.payload or {}
self.context: ExecutionContext = {
ExprContext.ACTIONS: {},
ExprContext.INPUTS: self.dsl.inputs,
ExprContext.TRIGGER: trigger_inputs,
ExprContext.TRIGGER: validated_payload,
ExprContext.ENV: DSLEnvironment(
workflow={
"start_time": wf_info.start_time,
Expand Down
1 change: 1 addition & 0 deletions tracecat/validation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ValidationResult(BaseModel):
msg: str = ""
detail: Any | None = None
ref: str | None = None
payload: dict[str, Any] | None = None

def __hash__(self) -> int:
detail = json.dumps(self.detail, sort_keys=True)
Expand Down
Loading