Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 83 additions & 52 deletions encord_agents/tasks/runner/sequential_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from datetime import datetime, timedelta
from typing import Callable, Iterable, Optional
Expand All @@ -10,7 +10,7 @@
import rich
from encord.http.bundle import Bundle
from encord.orm.workflow import WorkflowStageType
from encord.workflow.stages.agent import AgentStage
from encord.workflow.stages.agent import AgentStage, AgentTask
from rich.live import Live
from rich.panel import Panel
from rich.progress import (
Expand All @@ -28,9 +28,8 @@
from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs
from encord_agents.core.dependencies.models import (
Context,
Dependant,
)
from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies
from encord_agents.core.dependencies.utils import SolvedDependency, solve_dependencies
from encord_agents.core.rich_columns import TaskSpeedColumn
from encord_agents.core.utils import batch_iterator
from encord_agents.exceptions import PrintableError
Expand Down Expand Up @@ -208,66 +207,94 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable:
return decorator

@staticmethod
def _execute_task_with_dependencies(
context: Context,
runner_agent: RunnerAgent,
num_retries: int,
dependencies: SolvedDependency,
task: AgentTask,
stage: AgentStage,
task_bundle: Bundle,
label_bundle: Bundle,
) -> bool:
for attempt in range(num_retries + 1):
try:
agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values)
if isinstance(agent_response, TaskAgentReturnStruct):
pathway_to_follow = agent_response.pathway
if agent_response.label_row:
agent_response.label_row.save(bundle=label_bundle)
if agent_response.label_row_priority:
assert (
context.label_row is not None
), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row."
context.label_row.set_priority(agent_response.label_row_priority, bundle=label_bundle)
else:
pathway_to_follow = agent_response
if pathway_to_follow is None:
pass
elif next_stage_uuid := try_coerce_UUID(pathway_to_follow):
if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]:
raise PrintableError(
f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}"
)
task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle)
else:
if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]:
raise PrintableError(
f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}"
)
task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle)
return True
except KeyboardInterrupt:
raise
except PrintableError:
raise
except Exception:
print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ")
traceback.print_exc()
return False
return False

def _execute_tasks(
self,
contexts: Iterable[Context],
runner_agent: RunnerAgent,
stage: AgentStage,
num_retries: int,
pbar_update: Callable[[float | None], bool | None] | None = None,
pbar_update: Callable[[float | None], bool | None],
pre_fetch_factor: int,
) -> None:
"""
INVARIANT: Tasks should always be for the stage that the runner_agent is associated too
"""
with Bundle() as task_bundle:
with Bundle(bundle_size=min(MAX_LABEL_ROW_BATCH_SIZE, len(list(contexts)))) as label_bundle:
for context in contexts:
assert context.task
for batch in batch_iterator(contexts, pre_fetch_factor):
with ExitStack() as stack:
task = context.task
dependencies = solve_dependencies(
context=context, dependant=runner_agent.dependant, stack=stack
)
for attempt in range(num_retries + 1):
try:
agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values)
if isinstance(agent_response, TaskAgentReturnStruct):
pathway_to_follow = agent_response.pathway
if agent_response.label_row:
agent_response.label_row.save(bundle=label_bundle)
if agent_response.label_row_priority:
assert (
context.label_row is not None
), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row."
context.label_row.set_priority(
agent_response.label_row_priority, bundle=label_bundle
)
else:
pathway_to_follow = agent_response
if pathway_to_follow is None:
pass
elif next_stage_uuid := try_coerce_UUID(pathway_to_follow):
if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]:
raise PrintableError(
f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}"
)
task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle)
else:
if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]:
raise PrintableError(
f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}"
)
task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle)
if pbar_update is not None:
pbar_update(1.0)
break

except KeyboardInterrupt:
raise
except PrintableError:
raise
except Exception:
print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ")
traceback.print_exc()
with ThreadPoolExecutor() as executor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not specify how many threads (perhaps even give as argument)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do. By default, ThreadPoolExecutor will pick it based on n_cpus. I didn't necessarily want to expand the interface too much. But I defo see your point

dependency_list = list(
executor.map(
lambda context: solve_dependencies(
context=context, dependant=runner_agent.dependant, stack=stack
),
batch,
)
)
Comment on lines +275 to +283
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we'd want to wait for all tasks to be fetched before starting the compute? Can't we just call the agent when iterating the output of the executor.map call rather than wrapping it in list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could do. I was just replicating the customer's original behaviour where they do all fetching first ahead of task execution.
Notably the customer wanted explicitly sequential inference and if we had the map include the agent execution, we lose this benefit.

for context, dependency in zip(batch, dependency_list, strict=True):
assert context.task
task = context.task
if self._execute_task_with_dependencies(
context,
runner_agent,
num_retries,
dependency,
task,
stage,
task_bundle,
label_bundle,
):
pbar_update(1.0)

def _validate_agent_stages(
self, valid_stages: list[AgentStage], agent_stages: dict[str | UUID, AgentStage]
Expand Down Expand Up @@ -336,6 +363,9 @@ def __call__(
help="Max number of tasks to try to process per stage on a given run. If `None`, will attempt all",
),
] = None,
pre_fetch_factor: Annotated[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure name and functionality match here. Factor is multiplied, this seems to be an absolute number - at least based on doc string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be pre-fetch batch size. My view on factor was: If at x, we perform a (grouped) dependency fetch N/x times rather than N times.

int, Option(help="Number of tasks to pre-fetch for dependency resolution at once.")
] = 1,
) -> None:
"""
Run your task agent `runner(...)`.
Expand Down Expand Up @@ -472,6 +502,7 @@ def __call__(
stage,
num_retries,
pbar_update=lambda x: batch_pbar.advance(batch_task, x or 1),
pre_fetch_factor=pre_fetch_factor,
)
total += len(task_batch)

Expand Down
15 changes: 15 additions & 0 deletions encord_agents/utils/generic_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
from typing import Iterable, TypeVar
from uuid import UUID


Expand All @@ -9,3 +11,16 @@ def try_coerce_UUID(candidate_uuid: str | UUID) -> UUID | None:
return UUID(candidate_uuid)
except ValueError:
return None


T = TypeVar("T")


def batch_iterator(iterable: Iterable[T], batch_size: int) -> Iterable[list[T]]:
"""Batch an iterable into a list of lists"""
iterator = iter(iterable)
while True:
batch = list(itertools.islice(iterator, batch_size))
if not batch:
break
yield batch
1 change: 0 additions & 1 deletion tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Annotated

from cv2 import meanShift
from encord.user_client import EncordUserClient
from fastapi import Depends
from fastapi.testclient import TestClient
Expand Down