Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions autobot-backend/orchestration/dag_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

from constants.status_enums import TaskStatus
from orchestration.dag_executor import (
DAGExecutionContext,
DAGExecutor,
Expand Down Expand Up @@ -182,7 +183,7 @@ async def test_single_node(self):
dag = WorkflowDAG(nodes, [])
executor = DAGExecutor(_noop_executor)
ctx = await executor.execute(dag, "wf1")
assert ctx.status == "completed"
assert ctx.status == TaskStatus.COMPLETED.value
assert "a" in ctx.step_results

@pytest.mark.asyncio
Expand All @@ -191,7 +192,7 @@ async def test_linear_three_nodes(self):
dag = WorkflowDAG(nodes, _linear_edges("a", "b", "c"))
executor = DAGExecutor(_noop_executor)
ctx = await executor.execute(dag, "wf2")
assert ctx.status == "completed"
assert ctx.status == TaskStatus.COMPLETED.value
assert set(ctx.step_results.keys()) == {"a", "b", "c"}

@pytest.mark.asyncio
Expand All @@ -210,15 +211,15 @@ async def test_cycle_aborts_immediately(self):
dag = WorkflowDAG(nodes, edges)
executor = DAGExecutor(_noop_executor)
ctx = await executor.execute(dag, "wf_cycle")
assert ctx.status == "failed"
assert ctx.status == TaskStatus.FAILED.value
assert "cycle" in ctx.error.lower()

@pytest.mark.asyncio
async def test_empty_dag_fails(self):
dag = WorkflowDAG([], [])
executor = DAGExecutor(_noop_executor)
ctx = await executor.execute(dag, "wf_empty")
assert ctx.status == "failed"
assert ctx.status == TaskStatus.FAILED.value


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -254,7 +255,7 @@ async def recording_executor(

executor = DAGExecutor(recording_executor)
ctx = await executor.execute(dag, "wf_branch")
assert ctx.status == "completed"
assert ctx.status == TaskStatus.COMPLETED.value
assert "true_branch" in executed
assert "false_branch" not in executed
assert "end" in executed
Expand All @@ -273,7 +274,7 @@ async def recording_executor(

executor = DAGExecutor(recording_executor)
ctx = await executor.execute(dag, "wf_false")
assert ctx.status == "completed"
assert ctx.status == TaskStatus.COMPLETED.value
assert "false_branch" in executed
assert "true_branch" not in executed
assert ctx.branches_taken["cond"] is False
Expand Down Expand Up @@ -321,7 +322,7 @@ async def recording_executor(

executor = DAGExecutor(recording_executor)
ctx = await executor.execute(dag, "wf_fork")
assert ctx.status == "completed"
assert ctx.status == TaskStatus.COMPLETED.value
assert set(executed) == {"root", "branch_a", "branch_b", "end"}

@pytest.mark.asyncio
Expand Down
20 changes: 11 additions & 9 deletions autobot-backend/orchestration/error_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import pytest

from constants.status_enums import TaskStatus

from .error_handler import (
BackoffStrategy,
StepCheckpoint,
Expand Down Expand Up @@ -70,18 +72,18 @@ def test_from_dict_invalid_action_raises(self) -> None:

class TestStepCheckpoint:
def test_round_trip(self) -> None:
cp = StepCheckpoint(step_id="s1", status="completed", output={"success": True})
cp = StepCheckpoint(step_id="s1", status=TaskStatus.COMPLETED.value, output={"success": True})
restored = StepCheckpoint.from_dict(cp.to_dict())
assert restored.step_id == "s1"
assert restored.status == "completed"
assert restored.status == TaskStatus.COMPLETED.value
assert restored.output == {"success": True}

def test_timestamp_is_populated(self) -> None:
cp = StepCheckpoint(step_id="s1", status="completed", output={})
cp = StepCheckpoint(step_id="s1", status=TaskStatus.COMPLETED.value, output={})
assert cp.timestamp != ""

def test_to_dict_is_json_serialisable(self) -> None:
cp = StepCheckpoint(step_id="s1", status="completed", output={"k": "v"})
cp = StepCheckpoint(step_id="s1", status=TaskStatus.COMPLETED.value, output={"k": "v"})
serialised = json.dumps(cp.to_dict())
assert "s1" in serialised

Expand Down Expand Up @@ -119,13 +121,13 @@ def _manager_with_fake_redis(self) -> WorkflowCheckpointManager:
def test_save_and_load(self) -> None:
mgr = self._manager_with_fake_redis()
cp = StepCheckpoint(
step_id="step1", status="completed", output={"success": True}
step_id="step1", status=TaskStatus.COMPLETED.value, output={"success": True}
)
mgr.save("wf-1", cp)

loaded = mgr.load_all("wf-1")
assert "step1" in loaded
assert loaded["step1"].status == "completed"
assert loaded["step1"].status == TaskStatus.COMPLETED.value
assert loaded["step1"].output == {"success": True}

def test_load_empty_when_no_checkpoints(self) -> None:
Expand All @@ -136,14 +138,14 @@ def test_save_multiple_steps(self) -> None:
mgr = self._manager_with_fake_redis()
for i in range(3):
mgr.save(
"wf-2", StepCheckpoint(step_id=f"s{i}", status="completed", output={})
"wf-2", StepCheckpoint(step_id=f"s{i}", status=TaskStatus.COMPLETED.value, output={})
)
loaded = mgr.load_all("wf-2")
assert set(loaded.keys()) == {"s0", "s1", "s2"}

def test_clear_removes_all(self) -> None:
mgr = self._manager_with_fake_redis()
mgr.save("wf-3", StepCheckpoint(step_id="s1", status="completed", output={}))
mgr.save("wf-3", StepCheckpoint(step_id="s1", status=TaskStatus.COMPLETED.value, output={}))
mgr.clear("wf-3")
assert mgr.load_all("wf-3") == {}

Expand All @@ -166,7 +168,7 @@ def test_redis_error_on_save_logged_not_raised(self) -> None:
bad_redis = MagicMock()
bad_redis.hset.side_effect = ConnectionError("Redis down")
mgr._redis = bad_redis
cp = StepCheckpoint(step_id="s1", status="completed", output={})
cp = StepCheckpoint(step_id="s1", status=TaskStatus.COMPLETED.value, output={})
# Must not raise
mgr.save("wf-fail", cp)

Expand Down
8 changes: 5 additions & 3 deletions autobot-backend/orchestration/sub_workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import pytest

from constants.status_enums import TaskStatus

from orchestration.sub_workflow import (
MAX_NESTING_DEPTH,
SubWorkflowExecutor,
Expand All @@ -26,12 +28,12 @@ def _make_workflow_executor() -> MagicMock:
"""Return a MagicMock that satisfies WorkflowExecutor's interface."""
executor = MagicMock()
executor.execute_coordinated_workflow = AsyncMock(
return_value={"status": "completed", "step_results": {}}
return_value={"status": TaskStatus.COMPLETED.value, "step_results": {}}
)
return executor


def _make_step_output(data: Dict[str, Any], status: str = "completed") -> StepOutput:
def _make_step_output(data: Dict[str, Any], status: str = TaskStatus.COMPLETED.value) -> StepOutput:
import json

stdout = json.dumps(data)
Expand Down Expand Up @@ -144,7 +146,7 @@ async def test_failed_child_returns_success_false(self):
workflow_def = {"steps": []}
wf_executor = _make_workflow_executor()
wf_executor.execute_coordinated_workflow = AsyncMock(
return_value={"status": "failed", "step_results": {}}
return_value={"status": TaskStatus.FAILED.value, "step_results": {}}
)
fetcher = MagicMock(return_value=workflow_def)
executor = SubWorkflowExecutor(
Expand Down
21 changes: 11 additions & 10 deletions autobot-backend/orchestration/variable_resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from typing import Any, Dict

from constants.status_enums import TaskStatus
from orchestration.variable_resolver import (
StepOutput,
VariableResolver,
Expand All @@ -17,7 +18,7 @@
# ---------------------------------------------------------------------------


def _output(stdout: str = "", status: str = "completed") -> StepOutput:
def _output(stdout: str = "", status: str = TaskStatus.COMPLETED.value) -> StepOutput:
"""Build a StepOutput from raw stdout string."""
parsed: Any = None
if stdout:
Expand All @@ -28,7 +29,7 @@ def _output(stdout: str = "", status: str = "completed") -> StepOutput:
return StepOutput(status=status, stdout=stdout, parsed_json=parsed)


def _json_output(data: Dict[str, Any], status: str = "completed") -> StepOutput:
def _json_output(data: Dict[str, Any], status: str = TaskStatus.COMPLETED.value) -> StepOutput:
"""Build a StepOutput whose parsed_json is *data*."""
stdout = json.dumps(data)
return StepOutput(status=status, stdout=stdout, parsed_json=data)
Expand All @@ -43,15 +44,15 @@ class TestStepOutput:
def test_from_step_result_success(self):
result = {"success": True, "stdout": '{"key": "value"}', "exit_code": 0}
so = StepOutput.from_step_result(result)
assert so.status == "completed"
assert so.status == TaskStatus.COMPLETED.value
assert so.stdout == '{"key": "value"}'
assert so.parsed_json == {"key": "value"}
assert so.metadata["exit_code"] == 0

def test_from_step_result_failure(self):
result = {"success": False, "stdout": "", "error": "boom"}
so = StepOutput.from_step_result(result)
assert so.status == "failed"
assert so.status == TaskStatus.FAILED.value
assert so.parsed_json is None

def test_from_step_result_non_json_stdout(self):
Expand Down Expand Up @@ -103,17 +104,17 @@ class TestVariableResolverStatus:
def setup_method(self):
self.resolver = VariableResolver()
self.outputs = {
"step1": StepOutput(status="completed", stdout=""),
"step2": StepOutput(status="failed", stdout=""),
"step1": StepOutput(status=TaskStatus.COMPLETED.value, stdout=""),
"step2": StepOutput(status=TaskStatus.FAILED.value, stdout=""),
}

def test_status_completed(self):
result = self.resolver.resolve("${steps.step1.status}", self.outputs)
assert result == "completed"
assert result == TaskStatus.COMPLETED.value

def test_status_failed(self):
result = self.resolver.resolve("${steps.step2.status}", self.outputs)
assert result == "failed"
assert result == TaskStatus.FAILED.value

def test_status_in_sentence(self):
result = self.resolver.resolve(
Expand Down Expand Up @@ -203,7 +204,7 @@ def setup_method(self):

def test_two_tokens_in_one_string(self):
outputs = {
"s1": StepOutput(status="completed", stdout=""),
"s1": StepOutput(status=TaskStatus.COMPLETED.value, stdout=""),
"s2": _json_output({"msg": "hello"}),
}
result = self.resolver.resolve(
Expand Down Expand Up @@ -276,7 +277,7 @@ def setup_method(self):
def test_metadata_field_access(self):
outputs = {
"s1": StepOutput(
status="completed",
status=TaskStatus.COMPLETED.value,
stdout="",
metadata={"exit_code": 0, "execution_time": 2.5},
)
Expand Down
Loading