diff --git a/README.md b/README.md index fe03c4d..a171437 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,10 @@ Easily build agents for the Encord echo system. With just few lines of code, you can take automation to the next level. +```shell +python -m pip install encord-agents +``` + **Key features:** 1. ⚡**Easy**: Multiple template agents to be adapted and hosted via GCP, own infra, or cloud. @@ -66,13 +70,13 @@ from encord_agents.tasks import Runner runner = Runner(project_hash="") -@runner.stage(UUID("")) +@runner.stage("") def by_file_name(lr: LabelRowV2) -> UUID | None: # Assuming the data_title is of the format "%d.jpg" # and in the range [0; 100] priority = int(lr.data_title.split(".")[0]) / 100 lr.set_priority(priority=priority) - return UUID("") + return "" if __name__ == "__main__": diff --git a/docs/code_examples/gcp/add_bounding_box.py b/docs/code_examples/gcp/add_bounding_box.py index 656ed20..9a97432 100644 --- a/docs/code_examples/gcp/add_bounding_box.py +++ b/docs/code_examples/gcp/add_bounding_box.py @@ -11,7 +11,7 @@ Successively, you can test it locally by pasting an editor url into the following command: ```shell -encord-agents test add_bounding_box +encord-agents test add_bounding_box '' ``` """ diff --git a/docs/code_examples/tasks/prioritize_by_data_title.py b/docs/code_examples/tasks/prioritize_by_data_title.py index a1269a9..a20d7da 100644 --- a/docs/code_examples/tasks/prioritize_by_data_title.py +++ b/docs/code_examples/tasks/prioritize_by_data_title.py @@ -1,17 +1,16 @@ -from uuid import UUID from encord.objects.ontology_labels_impl import LabelRowV2 from encord_agents.tasks import Runner runner = Runner(project_hash="") -@runner.stage(UUID("")) -def by_file_name(lr: LabelRowV2) -> UUID | None: +@runner.stage("") +def by_file_name(lr: LabelRowV2) -> str | None: # Assuming the data_title is of the format "%d.jpg" # and in the range [0; 100] priority = int(lr.data_title.split(".")[0]) / 100 lr.set_priority(priority=priority) - return UUID("") + return "" if __name__ == "__main__": diff --git a/docs/code_examples/tasks/prioritize_by_data_title_specific.py b/docs/code_examples/tasks/prioritize_by_data_title_specific.py index 764d46e..2f62b41 100644 --- a/docs/code_examples/tasks/prioritize_by_data_title_specific.py +++ b/docs/code_examples/tasks/prioritize_by_data_title_specific.py @@ -1,11 +1,10 @@ -from uuid import UUID from encord.objects.ontology_labels_impl import LabelRowV2 from encord_agents.tasks import Runner runner = Runner(project_hash="") -@runner.stage(UUID("1e7751b3-6dc8-4796-a64b-d1323918b8f4")) +@runner.stage("1e7751b3-6dc8-4796-a64b-d1323918b8f4") def by_file_name(lr: LabelRowV2) -> str | None: # Assuming the data_title is of the format "%d.jpg" # and in the range [0; 100] diff --git a/docs/code_examples/tasks/twin_project.py b/docs/code_examples/tasks/twin_project.py index 4d83428..d2fbf54 100644 --- a/docs/code_examples/tasks/twin_project.py +++ b/docs/code_examples/tasks/twin_project.py @@ -14,11 +14,9 @@ ) checklist_attribute = checklist_classification.attributes[0] -# 3. Define the agent -from uuid import UUID - -@runner.stage(stage=UUID("")) +# 3. Define the agent +@runner.stage(stage="") def copy_labels( manually_annotated_lr: LabelRowV2, twin: Annotated[ diff --git a/docs/code_examples/tasks/wrong_stage_and_pathway_names.py b/docs/code_examples/tasks/wrong_stage_and_pathway_names.py new file mode 100644 index 0000000..f5fabcf --- /dev/null +++ b/docs/code_examples/tasks/wrong_stage_and_pathway_names.py @@ -0,0 +1,14 @@ +from encord.objects.ontology_labels_impl import LabelRowV2 +from encord_agents.tasks import Runner + +r = Runner() + + +@r.stage(stage="wrong") +def my_stage(lr: LabelRowV2) -> None: + print(lr.data_title) + + +if __name__ == "__main__": + # r() + r.run() diff --git a/docs/editor_agents/fastapi.md b/docs/editor_agents/fastapi.md index 170fdf2..3a8af63 100644 --- a/docs/editor_agents/fastapi.md +++ b/docs/editor_agents/fastapi.md @@ -89,9 +89,12 @@ Then, run ```shell source venv/bin/activate -encord-agents test local my_agent +encord-agents test local my_agent '' ``` +!!! warning + Notice the single quotes around ``. They are important and should be there because you might copy a url with, e.g., an `&` character that have a [special meaning](https://www.howtogeek.com/439199/15-special-characters-you-need-to-know-for-bash/#amp-background-process){ target="_blank", rel="noopener noreferrer" } if it is not within a string (or escaped). + Refresh the label editor in your browser to see the effect that you applied to the `label_row: LabelRowV2` happening. ## Deployment diff --git a/docs/editor_agents/gcp.md b/docs/editor_agents/gcp.md index 1d8f891..0e9c2e5 100644 --- a/docs/editor_agents/gcp.md +++ b/docs/editor_agents/gcp.md @@ -88,9 +88,12 @@ Then, run ```shell source venv/bin/activate -encord-agents test local my_agent +encord-agents test local my_agent '' ``` +!!! warning + Notice the single quotes around ``. They are important and should be there because you might copy a url with, e.g., an `&` character that have a [special meaning](https://www.howtogeek.com/439199/15-special-characters-you-need-to-know-for-bash/#amp-background-process){ target="_blank", rel="noopener noreferrer" } if it is not within a string (or escaped). + Refresh the label editor in your browser to see the effect that you applied to the `label_row: LabelRowV2` happening. ## Deployment diff --git a/docs/getting_started.md b/docs/getting_started.md index 3602973..43fcf0b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -39,6 +39,17 @@ Furthermore, is has just one pathway called "annotate." Copy the `Project ID` in the top left of the project page. +!!! tip + After [authenticating](./authentication.md), you can check if your existing project has any agent nodes by running this command: + ```shell + encord-agents print agent-nodes + ``` + If the project has agent nodes in the workflow, you should see a list similar to this: + ```shell + AgentStage(title="pre-label", uuid="b9c1363c-615f-4125-ae1c-a81e19331c96") + AgentStage(title="evaluate", uuid="28d1bcc9-6a3a-4229-8c06-b498fcaf94a0") + ``` + ### 3. Define your agent In your freshly created directory, create a python file. diff --git a/docs/task_agents/examples/index.md b/docs/task_agents/examples/index.md index 955d210..1e1fccc 100644 --- a/docs/task_agents/examples/index.md +++ b/docs/task_agents/examples/index.md @@ -80,11 +80,10 @@ Which would mean that the agents would be defined with the following decorator t make the workflow stage association explicit. ```python -from uuid import UUID -@runner.stage(stage=UUID("60d9f14f-755e-40fd-...")) # <- last bit omitted +@runner.stage(stage="60d9f14f-755e-40fd-...") # <- last bit omitted ``` -> Notice the match between the uuid in the "label transfer" agent stage of the workflow in Project A and the UUID in the decorator. +> Notice the match between the uuid in the "label transfer" agent stage of the workflow in Project A and the uuid in the decorator. **To prepare your projects:** diff --git a/encord_agents/cli/main.py b/encord_agents/cli/main.py index 36e61a5..04463b9 100644 --- a/encord_agents/cli/main.py +++ b/encord_agents/cli/main.py @@ -12,20 +12,23 @@ import typer from .gcp import app as gcp_app -from .test import app as test_app from .print import app as print_app +from .test import app as test_app app = typer.Typer(rich_markup_mode="rich") app.add_typer(gcp_app, name="gcp") app.add_typer(test_app, name="test") app.add_typer(print_app, name="print") + @app.callback(invoke_without_command=True) -def version(version_: bool = typer.Option(False, "--version", "-v", "-V", help="Print the current version of Encord Agents")): +def version( + version_: bool = typer.Option(False, "--version", "-v", "-V", help="Print the current version of Encord Agents"), +): if version_: import rich + from encord_agents import __version__ as ea_version rich.print(f"[purple]encord-agents[/purple] version: [green]{ea_version}[/green]") exit() - diff --git a/encord_agents/cli/print.py b/encord_agents/cli/print.py index 74cdf02..f9fc029 100644 --- a/encord_agents/cli/print.py +++ b/encord_agents/cli/print.py @@ -1,6 +1,10 @@ import sys + +from encord.orm.workflow import WorkflowStageType from typer import Typer +from encord_agents.core.settings import Settings + app = Typer( name="print", help="Utility to print system info, e.g., for bug reporting.", @@ -9,21 +13,57 @@ ) +@app.command(name="agent-nodes") +def print_agent_nodes(project_hash: str): + """ + Prints agent nodes from project. + + Given the project hash, loads the project and prints the agent nodes. + + Args: + project_hash: The project hash for which to print agent nodes. + + """ + import rich + from encord.exceptions import AuthorisationError + from encord.user_client import EncordUserClient + + _ = Settings() + client = EncordUserClient.create_with_ssh_private_key() + try: + project = client.get_project(project_hash) + except AuthorisationError: + rich.print(f"You do not seem to have access to project with project hash `[purple]{project_hash}[/purple]`") + exit() + + agent_nodes = [ + f'AgentStage(title="{n.title}", uuid="{n.uuid}")' + for n in project.workflow.stages + if n.stage_type == WorkflowStageType.AGENT + ] + if not agent_nodes: + print("Project does not have any agent nodes.") + return + + for node in agent_nodes: + rich.print(node) + + @app.command(name="system-info") def print_system_info(): - """ - [bold]Prints[/bold] the information of the system for the purpose of bug reporting. - """ - import platform - print("System Information:") - uname = platform.uname() - print(f"\tSystem: {uname.system}") - print(f"\tRelease: {uname.release}") - print(f"\tMachine: {uname.machine}") - print(f"\tProcessor: {uname.processor}") - print(f"\tPython: {sys.version}") - - import encord_agents - print(f"encord-agents version: {encord_agents.__version__}") + """ + [bold]Prints[/bold] the information of the system for the purpose of bug reporting. + """ + import platform + + print("System Information:") + uname = platform.uname() + print(f"\tSystem: {uname.system}") + print(f"\tRelease: {uname.release}") + print(f"\tMachine: {uname.machine}") + print(f"\tProcessor: {uname.processor}") + print(f"\tPython: {sys.version}") + import encord_agents + print(f"encord-agents version: {encord_agents.__version__}") diff --git a/encord_agents/core/settings.py b/encord_agents/core/settings.py index 6291dc7..647d56e 100644 --- a/encord_agents/core/settings.py +++ b/encord_agents/core/settings.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings @@ -29,6 +29,16 @@ class Settings(BaseSettings): [the platform docs](https://docs.encord.com/platform-documentation/Annotate/annotate-api-keys). """ + @field_validator("ssh_key_file") + @classmethod + def check_path_expand_and_exists(cls, path: Path | None): + if path is None: + return path + + path = path.expanduser() + assert path.is_file(), f"Provided ssh key file (ENCORD_SSH_KEY_FILE: '{path}') does not exist" + return path + @model_validator(mode="after") def check_key(self): assert any( diff --git a/encord_agents/tasks/runner.py b/encord_agents/tasks/runner.py index c202f5b..5c157d3 100644 --- a/encord_agents/tasks/runner.py +++ b/encord_agents/tasks/runner.py @@ -1,3 +1,5 @@ +import io +import os import time import traceback from contextlib import ExitStack @@ -8,10 +10,14 @@ import rich from encord.http.bundle import Bundle from encord.objects.ontology_labels_impl import LabelRowV2 +from encord.orm.project import ProjectType from encord.orm.workflow import WorkflowStageType from encord.project import Project -from encord.workflow.stages.agent import AgentTask +from encord.workflow.stages.agent import AgentStage, AgentTask from encord.workflow.workflow import WorkflowStage +from rich.console import Console +from rich.panel import Panel +from rich.text import Text from tqdm.auto import tqdm from typer import Abort @@ -22,6 +28,9 @@ TaskAgentReturn = str | UUID | None +class PrintableError(ValueError): ... + + class RunnerAgent: def __init__(self, identity: str | UUID, callable: Callable[..., TaskAgentReturn]): self.name = identity @@ -36,6 +45,29 @@ class Runner: When called, it will iteratively run agent stages till they are empty. By default, runner will exit after finishing the tasks identified at the point of trigger. To automatically re-run, you can use the `refresh_every` keyword. + + **Example:** + + ```python title="example_agent.py" + from uuid import UUID + from encord.tasks import Runner + runner = Runner() + + @runner.stage("") + # or + @runner.stage("") + def my_agent(task: AgentTask) -> str | UUID | None + ... + return "pathway name" # or pathway uuid + + + runner(project_hash="") # (see __call__ for more arguments) + # or + if __name__ == "__main__": + # for CLI usage: `python example_agent.py --project-hash ""` + runner.run() + ``` + """ @staticmethod @@ -50,43 +82,151 @@ def verify_project_hash(ph: str) -> str: def __init__(self, project_hash: str | None = None): self.project_hash = self.verify_project_hash(project_hash) if project_hash else None self.client = get_user_client() + self.project: Project | None = self.client.get_project(self.project_hash) if self.project_hash else None + self.validate_project(self.project) - self.valid_stages: list[WorkflowStage] | None = None + self.valid_stages: list[AgentStage] | None = None if self.project is not None: self.valid_stages = [s for s in self.project.workflow.stages if s.stage_type == WorkflowStageType.AGENT] self.agents: list[RunnerAgent] = [] + self.was_called_from_cli = False + + @staticmethod + def validate_project(project: Project | None): + if project is None: + return + PROJECT_MUSTS = "Task agents only work for workflow projects that have agent nodes in the workflow." + assert ( + project.project_type == ProjectType.WORKFLOW + ), f"Provided project is not a workflow project. {PROJECT_MUSTS}" + assert ( + len([s for s in project.workflow.stages if s.stage_type == WorkflowStageType.AGENT]) > 0 + ), f"Provided project does not have any agent stages in it's workflow. {PROJECT_MUSTS}" def _add_stage_agent(self, identity: str | UUID, func: Callable[..., TaskAgentReturn]): self.agents.append(RunnerAgent(identity=identity, callable=func)) def stage(self, stage: str | UUID) -> Callable[[DecoratedCallable], DecoratedCallable]: - if stage in [a.name for a in self.agents]: - self.abort_with_message( - f"Stage name [blue]`{stage}`[/blue] has already been assigned a function. You can only assign one callable to each agent stage." - ) + r""" + Decorator to associate a function with an agent stage. + + A function decorated with a stage is added to the list of stages + that will be handled by the runner. + The runner will call the function for every task which is in that + stage. + + + **Example:** + + ```python + runner = Runner() + + @runner.stage("") + def my_func() -> str | None: + ... + return "" + ``` + + The function declaration can be any function that takes parameters + that are type annotated with the following types: + + * [Project][docs-project]{ target="\_blank", rel="noopener noreferrer" }: the `encord.project.Project` + that the runner is operating on. + * [LabelRowV2][docs-label-row]{ target="\_blank", rel="noopener noreferrer" }: the `encord.objects.LabelRowV2` + that the task is associated with. + * [AgentTask][docs-project]{ target="\_blank", rel="noopener noreferrer" }: the `encord.workflow.stages.agent.AgentTask` + that the task is associated with. + * Any other type: which is annotated with a [dependency](/dependencies.md) + + All those parameters will be automatically injected when the agent is called. + + **Example:** + + ```python + from typing import Iterator + from typing_extensions import Annotated + + from encord.project import Project + from encord_agents.tasks import Depends + from encord_agents.tasks.dependencies import dep_video_iterator + from encord.workflow.stages.agent import AgentTask + + runner = Runner() + + def random_value() -> float: + import random + return random.random() + + @runner.stage("") + def my_func( + project: Project, + lr: LabelRowV2, + task: AgentTask, + video_frames: Annotated[Iterator[Frame], Depends(dep_video_iterator)], + custom: Annotated[float, Depends(random_value)] + ) -> str | None: + ... + return "" + ``` + + [docs-project]: https://docs.encord.com/sdk-documentation/sdk-references/project + [docs-label-row]: https://docs.encord.com/sdk-documentation/sdk-references/LabelRowV2 + [docs-agent-task]: https://docs.encord.com/sdk-documentation/sdk-references/AgentTask + + Args: + stage: The name or uuid of the stage that the function should be + associated with. - if self.valid_stages is not None: - selected_stage: WorkflowStage | None = None - for v_stage in self.valid_stages: - attr = v_stage.title if isinstance(stage, str) else v_stage.uuid - if attr == stage: - selected_stage = v_stage - - if selected_stage is None: - agent_stage_names = ",".join([f"[magenta]`{k}`[/magenta]" for k in self.valid_stages]) - self.abort_with_message( - rf"Stage name [blue]`{stage}`[/blue] could not be matched against a project stage. Valid stages are \[{agent_stage_names}]." + Returns: + The decorated function. + """ + try: + try: + stage = UUID(str(stage)) + except ValueError: + pass + + if stage in [a.name for a in self.agents]: + raise PrintableError( + f"Stage name [blue]`{stage}`[/blue] has already been assigned a function. You can only assign one callable to each agent stage." ) - else: - stage = selected_stage.uuid - def decorator(func: DecoratedCallable) -> DecoratedCallable: - self._add_stage_agent(stage, func) - return func + if self.valid_stages is not None: + selected_stage: WorkflowStage | None = None + for v_stage in self.valid_stages: + attr = v_stage.title if isinstance(stage, str) else v_stage.uuid + if attr == stage: + selected_stage = v_stage + + if selected_stage is None: + agent_stage_names = self.get_stage_names(self.valid_stages) + raise PrintableError( + rf"Stage name [blue]`{stage}`[/blue] could not be matched against a project stage. Valid stages are \[{agent_stage_names}]." + ) + else: + stage = selected_stage.uuid + + def decorator(func: DecoratedCallable) -> DecoratedCallable: + self._add_stage_agent(stage, func) + return func + + return decorator + except PrintableError as err: + output = io.StringIO() + console = Console( + force_terminal=True, + color_system="auto", + file=output, + force_interactive=False, + width=1000, + ) - return decorator + text_obj = Text.from_markup(err.args[0]) + console.print(text_obj, end="") + err.args = (output.getvalue(),) + raise def _execute_tasks( self, @@ -120,14 +260,17 @@ def _execute_tasks( pbar.update(1) break + except KeyboardInterrupt: + raise except Exception: print(f"[attempt {attempt}/{num_retries}] Agent failed with error: ") traceback.print_exc() @staticmethod - def abort_with_message(error: str): - rich.print(error) - raise Abort() + def get_stage_names(valid_stages: list[AgentStage], join_str: str = ", "): + return join_str.join( + [f'[magenta]AgentStage(title="{k.title}", uuid="{k.uuid}")[/magenta]' for k in valid_stages] + ) def __call__( self, @@ -159,60 +302,113 @@ def __call__( project = self.project if project is None: - self.abort_with_message( - """Please specify project hash in one of the following ways: -At instantiation: [blue]`runner = Runner(project_hash="")`[/blue] -or when called: [blue]`runner(project_hash="")`[/blue] + import sys + + raise PrintableError( + f"""Please specify project hash in one of the following ways: +* At instantiation: [blue]`runner = Runner(project_hash="[green][/green]")`[/blue] +* When called directly: [blue]`runner(project_hash="[green][/green]")`[/blue] +* When called from CLI: [blue]`python {sys.argv[0]} --project-hash [green][/green]`[/blue] """ ) - exit() ## shouldn't be necessary but pleases pyright + + self.validate_project(project) # Verify stages + valid_stages = [s for s in project.workflow.stages if s.stage_type == WorkflowStageType.AGENT] agent_stages: dict[str | UUID, WorkflowStage] = { - **{s.title: s for s in project.workflow.stages if s.stage_type == WorkflowStageType.AGENT}, - **{s.uuid: s for s in project.workflow.stages if s.stage_type == WorkflowStageType.AGENT}, + **{s.title: s for s in valid_stages}, + **{s.uuid: s for s in valid_stages}, } - for runner_agent in self.agents: - fn_name = getattr(callable, "__name__", "agent function") - agent_stage_names = ",".join([f"[magenta]`{k}`[/magenta]" for k in agent_stages.keys()]) - if runner_agent.name not in agent_stages: - self.abort_with_message( - rf"Your function [blue]`{fn_name}`[/blue] was annotated to match agent stage [blue]`{runner_agent.name}`[/blue] but that stage is not present as an agent stage in your project workflow. The workflow has following agent stages : \[{agent_stage_names}]" - ) - - stage = agent_stages[runner_agent.name] - if stage.stage_type != WorkflowStageType.AGENT: - self.abort_with_message( - f"You cannot use the stage of type `{stage.stage_type}` as an agent stage. It has to be one of the agent stages: [{agent_stage_names}]." - ) - - # Run - delta = timedelta(seconds=refresh_every) if refresh_every else None - next_execution = None - - while True: - if isinstance(next_execution, datetime): - if next_execution > datetime.now(): - duration = next_execution - datetime.now() - print(f"Sleeping {duration.total_seconds()} secs until next execution time.") - time.sleep(duration.total_seconds()) - elif next_execution is not None: - break - - next_execution = datetime.now() + delta if delta else False + try: for runner_agent in self.agents: - stage = agent_stages[runner_agent.name] + fn_name = getattr(runner_agent.callable, "__name__", "agent function") + separator = f"{os.linesep}\t" + agent_stage_names = separator + self.get_stage_names(valid_stages, join_str=separator) + os.linesep + if runner_agent.name not in agent_stages: + suggestion: str + if len(valid_stages) == 1: + suggestion = f'Did you mean to wrap [blue]`{fn_name}`[/blue] with{os.linesep}[magenta]@runner.stage(stage="{valid_stages[0].title}")[/magenta]{os.linesep}or{os.linesep}[magenta]@runner.stage(stage="{valid_stages[0].uuid}")[/magenta]' + else: + suggestion = f""" +Please use either name annoitations: +[magenta]@runner.stage(stage="")[/magenta] + +or uuid annotations: +[magenta]@runner.stage(stage="")[/magenta] + +For example, if we use the first agent stage listed above, we can use: +[magenta]@runner.stage(stage="{valid_stages[0].title}") +def {fn_name}(...): + ... +[/magenta] +# or +[magenta]@runner.stage(stage="{valid_stages[0].uuid}") +def {fn_name}(...): + ...[/magenta]""" + raise PrintableError( + rf"""Your function [blue]`{fn_name}`[/blue] was annotated to match agent stage [blue]`{runner_agent.name}`[/blue] but that stage is not present as an agent stage in your project workflow. The workflow has following agent stages: + +[{agent_stage_names}] + +{suggestion} + """ + ) - batch: list[AgentTask] = [] - batch_lrs: list[LabelRowV2] = [] - - tasks = list(stage.get_tasks()) - pbar = tqdm(desc="Executing tasks", total=len(tasks)) - for task in tasks: - if not isinstance(task, AgentTask): - continue - batch.append(task) - if len(batch) == task_batch_size: + stage = agent_stages[runner_agent.name] + if stage.stage_type != WorkflowStageType.AGENT: + raise PrintableError( + f"""You cannot use the stage of type `{stage.stage_type}` as an agent stage. It has to be one of the agent stages: +[{agent_stage_names}].""" + ) + + # Run + delta = timedelta(seconds=refresh_every) if refresh_every else None + next_execution = None + + while True: + if isinstance(next_execution, datetime): + if next_execution > datetime.now(): + duration = next_execution - datetime.now() + print(f"Sleeping {duration.total_seconds()} secs until next execution time.") + time.sleep(duration.total_seconds()) + elif next_execution is not None: + break + + next_execution = datetime.now() + delta if delta else False + for runner_agent in self.agents: + stage = agent_stages[runner_agent.name] + + batch: list[AgentTask] = [] + batch_lrs: list[LabelRowV2] = [] + + tasks = list(stage.get_tasks()) + pbar = tqdm(desc="Executing tasks", total=len(tasks)) + for task in tasks: + if not isinstance(task, AgentTask): + continue + batch.append(task) + if len(batch) == task_batch_size: + label_rows = { + UUID(lr.data_hash): lr + for lr in project.list_label_rows_v2(data_hashes=[t.data_hash for t in batch]) + } + batch_lrs = [label_rows[t.data_hash] for t in batch] + with project.create_bundle() as lr_bundle: + for lr in batch_lrs: + lr.initialise_labels(bundle=lr_bundle) + + self._execute_tasks( + zip(batch, batch_lrs), + runner_agent, + num_retries, + pbar=pbar, + ) + + batch = [] + batch_lrs = [] + + if len(batch) > 0: label_rows = { UUID(lr.data_hash): lr for lr in project.list_label_rows_v2(data_hashes=[t.data_hash for t in batch]) @@ -221,27 +417,19 @@ def __call__( with project.create_bundle() as lr_bundle: for lr in batch_lrs: lr.initialise_labels(bundle=lr_bundle) + self._execute_tasks(zip(batch, batch_lrs), runner_agent, num_retries, pbar=pbar) + except (PrintableError, AssertionError) as err: + if self.was_called_from_cli: + panel = Panel(err.args[0], width=None) + rich.print(panel) + raise Abort() + else: + if isinstance(err, PrintableError): + from rich.text import Text - self._execute_tasks( - zip(batch, batch_lrs), - runner_agent, - num_retries, - pbar=pbar, - ) - - batch = [] - batch_lrs = [] - - if len(batch) > 0: - label_rows = { - UUID(lr.data_hash): lr - for lr in project.list_label_rows_v2(data_hashes=[t.data_hash for t in batch]) - } - batch_lrs = [label_rows[t.data_hash] for t in batch] - with project.create_bundle() as lr_bundle: - for lr in batch_lrs: - lr.initialise_labels(bundle=lr_bundle) - self._execute_tasks(zip(batch, batch_lrs), runner_agent, num_retries, pbar=pbar) + plain_text = Text.from_markup(err.args[0]).plain + err.args = (plain_text,) + raise def run(self): """ @@ -275,6 +463,7 @@ def your_func() -> str: """ from typer import Typer + self.was_called_from_cli = True app = Typer(add_completion=False, rich_markup_mode="rich") app.command()(self.__call__) app()