-
Notifications
You must be signed in to change notification settings - Fork 3
feat: Allow prefetching dependencies using a Threadpool #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import logging | ||
| import os | ||
| import time | ||
| import traceback | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from contextlib import ExitStack | ||
| from datetime import datetime, timedelta | ||
| from typing import Callable, Iterable, Optional | ||
|
|
@@ -10,7 +10,7 @@ | |
| import rich | ||
| from encord.http.bundle import Bundle | ||
| from encord.orm.workflow import WorkflowStageType | ||
| from encord.workflow.stages.agent import AgentStage | ||
| from encord.workflow.stages.agent import AgentStage, AgentTask | ||
| from rich.live import Live | ||
| from rich.panel import Panel | ||
| from rich.progress import ( | ||
|
|
@@ -28,9 +28,8 @@ | |
| from encord_agents.core.data_model import LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs | ||
| from encord_agents.core.dependencies.models import ( | ||
| Context, | ||
| Dependant, | ||
| ) | ||
| from encord_agents.core.dependencies.utils import get_dependant, solve_dependencies | ||
| from encord_agents.core.dependencies.utils import SolvedDependency, solve_dependencies | ||
| from encord_agents.core.rich_columns import TaskSpeedColumn | ||
| from encord_agents.core.utils import batch_iterator | ||
| from encord_agents.exceptions import PrintableError | ||
|
|
@@ -208,66 +207,94 @@ def decorator(func: DecoratedCallable) -> DecoratedCallable: | |
| return decorator | ||
|
|
||
| @staticmethod | ||
| def _execute_task_with_dependencies( | ||
| context: Context, | ||
| runner_agent: RunnerAgent, | ||
| num_retries: int, | ||
| dependencies: SolvedDependency, | ||
| task: AgentTask, | ||
| stage: AgentStage, | ||
| task_bundle: Bundle, | ||
| label_bundle: Bundle, | ||
| ) -> bool: | ||
| for attempt in range(num_retries + 1): | ||
| try: | ||
| agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values) | ||
| if isinstance(agent_response, TaskAgentReturnStruct): | ||
| pathway_to_follow = agent_response.pathway | ||
| if agent_response.label_row: | ||
| agent_response.label_row.save(bundle=label_bundle) | ||
| if agent_response.label_row_priority: | ||
| assert ( | ||
| context.label_row is not None | ||
| ), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row." | ||
| context.label_row.set_priority(agent_response.label_row_priority, bundle=label_bundle) | ||
| else: | ||
| pathway_to_follow = agent_response | ||
| if pathway_to_follow is None: | ||
| pass | ||
| elif next_stage_uuid := try_coerce_UUID(pathway_to_follow): | ||
| if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]: | ||
| raise PrintableError( | ||
| f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}" | ||
| ) | ||
| task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle) | ||
| else: | ||
| if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]: | ||
| raise PrintableError( | ||
| f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}" | ||
| ) | ||
| task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle) | ||
| return True | ||
| except KeyboardInterrupt: | ||
| raise | ||
| except PrintableError: | ||
| raise | ||
| except Exception: | ||
| print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ") | ||
| traceback.print_exc() | ||
| return False | ||
| return False | ||
|
|
||
| def _execute_tasks( | ||
| self, | ||
| contexts: Iterable[Context], | ||
| runner_agent: RunnerAgent, | ||
| stage: AgentStage, | ||
| num_retries: int, | ||
| pbar_update: Callable[[float | None], bool | None] | None = None, | ||
| pbar_update: Callable[[float | None], bool | None], | ||
| pre_fetch_factor: int, | ||
| ) -> None: | ||
| """ | ||
| INVARIANT: Tasks should always be for the stage that the runner_agent is associated too | ||
| """ | ||
| with Bundle() as task_bundle: | ||
| with Bundle(bundle_size=min(MAX_LABEL_ROW_BATCH_SIZE, len(list(contexts)))) as label_bundle: | ||
| for context in contexts: | ||
| assert context.task | ||
| for batch in batch_iterator(contexts, pre_fetch_factor): | ||
| with ExitStack() as stack: | ||
| task = context.task | ||
| dependencies = solve_dependencies( | ||
| context=context, dependant=runner_agent.dependant, stack=stack | ||
| ) | ||
| for attempt in range(num_retries + 1): | ||
| try: | ||
| agent_response: TaskAgentReturnType = runner_agent.callable(**dependencies.values) | ||
| if isinstance(agent_response, TaskAgentReturnStruct): | ||
| pathway_to_follow = agent_response.pathway | ||
| if agent_response.label_row: | ||
| agent_response.label_row.save(bundle=label_bundle) | ||
| if agent_response.label_row_priority: | ||
| assert ( | ||
| context.label_row is not None | ||
| ), f"Label row is not set for task {task} setting the priority requires either setting the `will_set_priority` to True on the stage decorator or depending on the label row." | ||
| context.label_row.set_priority( | ||
| agent_response.label_row_priority, bundle=label_bundle | ||
| ) | ||
| else: | ||
| pathway_to_follow = agent_response | ||
| if pathway_to_follow is None: | ||
| pass | ||
| elif next_stage_uuid := try_coerce_UUID(pathway_to_follow): | ||
| if next_stage_uuid not in [pathway.uuid for pathway in stage.pathways]: | ||
| raise PrintableError( | ||
| f"No pathway with UUID: {next_stage_uuid} found. Accepted pathway UUIDs are: {[pathway.uuid for pathway in stage.pathways]}" | ||
| ) | ||
| task.proceed(pathway_uuid=str(next_stage_uuid), bundle=task_bundle) | ||
| else: | ||
| if pathway_to_follow not in [str(pathway.name) for pathway in stage.pathways]: | ||
| raise PrintableError( | ||
| f"No pathway with name: {pathway_to_follow} found. Accepted pathway names are: {[pathway.name for pathway in stage.pathways]}" | ||
| ) | ||
| task.proceed(pathway_name=str(pathway_to_follow), bundle=task_bundle) | ||
| if pbar_update is not None: | ||
| pbar_update(1.0) | ||
| break | ||
|
|
||
| except KeyboardInterrupt: | ||
| raise | ||
| except PrintableError: | ||
| raise | ||
| except Exception: | ||
| print(f"[attempt {attempt+1}/{num_retries+1}] Agent failed with error: ") | ||
| traceback.print_exc() | ||
| with ThreadPoolExecutor() as executor: | ||
| dependency_list = list( | ||
| executor.map( | ||
| lambda context: solve_dependencies( | ||
| context=context, dependant=runner_agent.dependant, stack=stack | ||
| ), | ||
| batch, | ||
| ) | ||
| ) | ||
|
Comment on lines
+275
to
+283
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see why we'd want to wait for all tasks to be fetched before starting the compute? Can't we just call the agent when iterating the output of the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One could do. I was just replicating the customer's original behaviour where they do all fetching first ahead of task execution. |
||
| for context, dependency in zip(batch, dependency_list, strict=True): | ||
| assert context.task | ||
| task = context.task | ||
| if self._execute_task_with_dependencies( | ||
| context, | ||
| runner_agent, | ||
| num_retries, | ||
| dependency, | ||
| task, | ||
| stage, | ||
| task_bundle, | ||
| label_bundle, | ||
| ): | ||
| pbar_update(1.0) | ||
|
|
||
| def _validate_agent_stages( | ||
| self, valid_stages: list[AgentStage], agent_stages: dict[str | UUID, AgentStage] | ||
|
|
@@ -336,6 +363,9 @@ def __call__( | |
| help="Max number of tasks to try to process per stage on a given run. If `None`, will attempt all", | ||
| ), | ||
| ] = None, | ||
| pre_fetch_factor: Annotated[ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure name and functionality match here. Factor is multiplied, this seems to be an absolute number - at least based on doc string.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be pre-fetch batch size. My view on factor was: If at x, we perform a (grouped) dependency fetch N/x times rather than N times. |
||
| int, Option(help="Number of tasks to pre-fetch for dependency resolution at once.") | ||
| ] = 1, | ||
| ) -> None: | ||
| """ | ||
| Run your task agent `runner(...)`. | ||
|
|
@@ -472,6 +502,7 @@ def __call__( | |
| stage, | ||
| num_retries, | ||
| pbar_update=lambda x: batch_pbar.advance(batch_task, x or 1), | ||
| pre_fetch_factor=pre_fetch_factor, | ||
| ) | ||
| total += len(task_batch) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not specify how many threads (perhaps even give as argument)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could do. By default, ThreadPoolExecutor will pick it based on n_cpus. I didn't necessarily want to expand the interface too much. But I defo see your point