diff --git a/encord_agents/tasks/runner/sequential_runner.py b/encord_agents/tasks/runner/sequential_runner.py index 034460b7..61345c9b 100644 --- a/encord_agents/tasks/runner/sequential_runner.py +++ b/encord_agents/tasks/runner/sequential_runner.py @@ -1,7 +1,7 @@ -import logging import os import time import traceback +from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack from datetime import datetime, timedelta from typing import Callable, Iterable, Optional @@ -10,7 +10,7 @@ import rich from encord.http.bundle import Bundle from encord.orm.workflow import WorkflowStageType -from encord.workflow.stages.agent import AgentStage +from encord.workflow.stages.agent import AgentStage, AgentTask from rich.live import Live from rich.panel import Panel from rich.progress import ( @@ -28,9 +28,8 @@ from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.dependencies.models import ( Context, - Dependant, ) -from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies +from encord_agents.core.dependencies.utils import SolvedDependency, solve_dependencies from encord_agents.core.rich_columns import TaskSpeedColumn from encord_agents.core.utils import batch_iterator from encord_agents.exceptions import PrintableError @@ -208,66 +207,94 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable: return decorator @staticmethod + def _execute_task_with_dependencies( + context: Context, + runner_agent: RunnerAgent, + num_retries: int, + dependencies: SolvedDependency, + task: AgentTask, + stage: AgentStage, + task_bundle: Bundle, + label_bundle: Bundle, + ) -> bool: + for attempt in range(num_retries + 1): + try: + agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values) + if isinstance(agent_response, TaskAgentReturnStruct): + pathway_to_follow = agent_response.pathway + if agent_response.label_row: + agent_response.label_row.save(bundle=label_bundle) + if agent_response.label_row_priority: + assert ( + context.label_row is not None + ), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row." + context.label_row.set_priority(agent_response.label_row_priority, bundle=label_bundle) + else: + pathway_to_follow = agent_response + if pathway_to_follow is None: + pass + elif next_stage_uuid := try_coerce_UUID(pathway_to_follow): + if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]: + raise PrintableError( + f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}" + ) + task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle) + else: + if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]: + raise PrintableError( + f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}" + ) + task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle) + return True + except KeyboardInterrupt: + raise + except PrintableError: + raise + except Exception: + print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ") + traceback.print_exc() + return False + return False + def _execute_tasks( + self, contexts: Iterable[Context], runner_agent: RunnerAgent, stage: AgentStage, num_retries: int, - pbar_update: Callable[[float | None], bool | None] | None = None, + pbar_update: Callable[[float | None], bool | None], + pre_fetch_factor: int, ) -> None: """ INVARIANT: Tasks should always be for the stage that the runner_agent is associated too """ with Bundle() as task_bundle: with Bundle(bundle_size=min(MAX_LABEL_ROW_BATCH_SIZE, len(list(contexts)))) as label_bundle: - for context in contexts: - assert context.task + for batch in batch_iterator(contexts, pre_fetch_factor): with ExitStack() as stack: - task = context.task - dependencies = solve_dependencies( - context=context, dependant=runner_agent.dependant, stack=stack - ) - for attempt in range(num_retries + 1): - try: - agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values) - if isinstance(agent_response, TaskAgentReturnStruct): - pathway_to_follow = agent_response.pathway - if agent_response.label_row: - agent_response.label_row.save(bundle=label_bundle) - if agent_response.label_row_priority: - assert ( - context.label_row is not None - ), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row." - context.label_row.set_priority( - agent_response.label_row_priority, bundle=label_bundle - ) - else: - pathway_to_follow = agent_response - if pathway_to_follow is None: - pass - elif next_stage_uuid := try_coerce_UUID(pathway_to_follow): - if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]: - raise PrintableError( - f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}" - ) - task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle) - else: - if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]: - raise PrintableError( - f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}" - ) - task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle) - if pbar_update is not None: - pbar_update(1.0) - break - - except KeyboardInterrupt: - raise - except PrintableError: - raise - except Exception: - print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ") - traceback.print_exc() + with ThreadPoolExecutor() as executor: + dependency_list = list( + executor.map( + lambda context: solve_dependencies( + context=context, dependant=runner_agent.dependant, stack=stack + ), + batch, + ) + ) + for context, dependency in zip(batch, dependency_list, strict=True): + assert context.task + task = context.task + if self._execute_task_with_dependencies( + context, + runner_agent, + num_retries, + dependency, + task, + stage, + task_bundle, + label_bundle, + ): + pbar_update(1.0) def _validate_agent_stages( self, valid_stages: list[AgentStage], agent_stages: dict[str | UUID, AgentStage] @@ -336,6 +363,9 @@ def __call__( help="Max number of tasks to try to process per stage on a given run. If `None`, will attempt all", ), ] = None, + pre_fetch_factor: Annotated[ + int, Option(help="Number of tasks to pre-fetch for dependency resolution at once.") + ] = 1, ) -> None: """ Run your task agent `runner(...)`. @@ -472,6 +502,7 @@ def __call__( stage, num_retries, pbar_update=lambda x: batch_pbar.advance(batch_task, x or 1), + pre_fetch_factor=pre_fetch_factor, ) total += len(task_batch) diff --git a/encord_agents/utils/generic_utils.py b/encord_agents/utils/generic_utils.py index 2b24906e..602fcaee 100644 --- a/encord_agents/utils/generic_utils.py +++ b/encord_agents/utils/generic_utils.py @@ -1,3 +1,5 @@ +import itertools +from typing import Iterable, TypeVar from uuid import UUID @@ -9,3 +11,16 @@ def try_coerce_UUID(candidate_uuid: str | UUID) -> UUID | None: return UUID(candidate_uuid) except ValueError: return None + + +T = TypeVar("T") + + +def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]: + """Batch an iterable into a list of lists""" + iterator = iter(iterable) + while True: + batch = list(itertools.islice(iterator, batch_size)) + if not batch: + break + yield batch diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 4fad12bf..f2a40d1c 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,6 +1,5 @@ from typing import Annotated -from cv2 import meanShift from encord.user_client import EncordUserClient from fastapi import Depends from fastapi.testclient import TestClient