diff --git a/src/promptflow/CHANGELOG.md b/src/promptflow/CHANGELOG.md index cd49623225d..38dfa0a6d1d 100644 --- a/src/promptflow/CHANGELOG.md +++ b/src/promptflow/CHANGELOG.md @@ -2,6 +2,9 @@ ## 1.2.0 (upcoming) +### Features Added +- [SDK/CLI] Support `pfazure run download` to download run data from Azure AI. + ### Bugs Fixed - [SDK/CLI] Removing telemetry warning when running commands. diff --git a/src/promptflow/promptflow/_cli/_params.py b/src/promptflow/promptflow/_cli/_params.py index df17434c5fe..3aea229b36b 100644 --- a/src/promptflow/promptflow/_cli/_params.py +++ b/src/promptflow/promptflow/_cli/_params.py @@ -56,20 +56,20 @@ def add_param_ua(parser): def add_param_flow_display_name(parser): - parser.add_argument("--flow", type=str, required=True, help="the flow name to create.") + parser.add_argument("--flow", type=str, required=True, help="The flow name to create.") def add_param_entry(parser): - parser.add_argument("--entry", type=str, help="the entry file.") + parser.add_argument("--entry", type=str, help="The entry file.") def add_param_function(parser): - parser.add_argument("--function", type=str, help="the function name in entry file.") + parser.add_argument("--function", type=str, help="The function name in entry file.") def add_param_prompt_template(parser): parser.add_argument( - "--prompt-template", action=AppendToDictAction, help="the prompt template parameter and assignment.", nargs="+" + "--prompt-template", action=AppendToDictAction, help="The prompt template parameter and assignment.", nargs="+" ) @@ -146,58 +146,30 @@ def add_param_inputs(parser): ) -def add_param_input(parser): - parser.add_argument( - "--input", type=str, required=True, help="the input file path. Note that we accept jsonl file only for now." - ) - - def add_param_env(parser): parser.add_argument( "--env", type=str, default=None, - help="the dotenv file path containing the environment variables to be used in the flow.", + help="The dotenv file path containing the environment variables to be used in the flow.", ) def add_param_output(parser): - parser.add_argument("--output", type=str, default="outputs", help="the output directory to store the results.") - - -def add_param_flow(parser): - parser.add_argument("--flow", type=str, required=True, help="the evaluation flow to be used.") - - -def add_param_source(parser): - parser.add_argument("--source", type=str, required=True, help="The flow or run source to be used.") - - -def add_param_bulk_run_output(parser): - parser.add_argument("--bulk-run-output", type=str, help="the output directory of the bulk run.") - - -def add_param_eval_output(parser): - parser.add_argument("--eval-output", type=str, help="the output file path of the evaluation result.") - - -def add_param_column_mapping(parser): parser.add_argument( - "--column-mapping", type=str, required=True, help="the column mapping to be used in the evaluation." + "-o", + "--output", + type=str, + help="The output directory to store the results. Default to be current working directory if not specified.", ) -def add_param_runtime(parser): - parser.add_argument( - "--runtime", - type=str, - default="local", - help="Name of your runtime in Azure ML workspace, will run in cloud when runtime is not none.", - ) +def add_param_overwrite(parser): + parser.add_argument("--overwrite", action="store_true", help="Overwrite the existing results.") -def add_param_connection(parser): - parser.add_argument("--connection", type=str, help="Name of your connection in Azure ML workspace.") +def add_param_source(parser): + parser.add_argument("--source", type=str, required=True, help="The flow or run source to be used.") def add_param_run_name(parser): @@ -208,16 +180,6 @@ def add_param_connection_name(parser): parser.add_argument("-n", "--name", type=str, help="Name of the connection to create.") -def add_param_variants(parser): - parser.add_argument( - "--variants", - type=str, - nargs="+", - help="the variant run ids to be used in the evaluation. Note that we only support one variant for now.", - default=[], - ) - - def add_param_max_results(parser): parser.add_argument( # noqa: E731 "-r", diff --git a/src/promptflow/promptflow/_cli/_pf_azure/_run.py b/src/promptflow/promptflow/_cli/_pf_azure/_run.py index ff2a7c701e4..6eb10af7be9 100644 --- a/src/promptflow/promptflow/_cli/_pf_azure/_run.py +++ b/src/promptflow/promptflow/_cli/_pf_azure/_run.py @@ -13,7 +13,9 @@ add_param_debug, add_param_include_archived, add_param_max_results, + add_param_output, add_param_output_format, + add_param_overwrite, add_param_run_name, add_param_set, add_param_verbose, @@ -52,6 +54,7 @@ def add_parser_run(subparsers): add_parser_run_archive(subparsers) add_parser_run_restore(subparsers) add_parser_run_update(subparsers) + add_parser_run_download(subparsers) run_parser.set_defaults(action="run") @@ -334,7 +337,7 @@ def add_parser_run_update(subparsers): Example: # Update a run metadata: -pf run update --name --set display_name="" description="" tags.key="" +pfazure run update --name --set display_name="" description="" tags.key="" """ add_params = [ _set_workspace_argument_for_subparsers, @@ -353,6 +356,32 @@ def add_parser_run_update(subparsers): ) +def add_parser_run_download(subparsers): + """Add run download parser to the pfazure subparsers.""" + epilog = """ +Example: + +# Download a run data to local: +pfazure run download --name --output +""" + add_params = [ + add_param_run_name, + add_param_output, + add_param_overwrite, + _set_workspace_argument_for_subparsers, + ] + base_params + + activate_action( + name="download", + description="A CLI tool to download a run.", + epilog=epilog, + add_params=add_params, + subparsers=subparsers, + help_message="Download a run.", + action_param_name="sub_action", + ) + + def dispatch_run_commands(args: argparse.Namespace): if args.sub_action == "create": pf = _get_azure_pf_client(args.subscription, args.resource_group, args.workspace_name, debug=args.debug) @@ -403,6 +432,8 @@ def dispatch_run_commands(args: argparse.Namespace): restore_run(args.subscription, args.resource_group, args.workspace_name, args.name) elif args.sub_action == "update": update_run(args.subscription, args.resource_group, args.workspace_name, args.name, params=args.params_override) + elif args.sub_action == "download": + download_run(args) @exception_handler("List runs") @@ -530,3 +561,9 @@ def update_run( pf = _get_azure_pf_client(subscription_id, resource_group, workspace_name) run = pf.runs.update(run=run_name, display_name=display_name, description=description, tags=tags) print(json.dumps(run._to_dict(), indent=4)) + + +@exception_handler("Download run") +def download_run(args: argparse.Namespace): + pf = _get_azure_pf_client(args.subscription, args.resource_group, args.workspace_name, debug=args.debug) + pf.runs.download(run=args.name, output=args.output, overwrite=args.overwrite) diff --git a/src/promptflow/promptflow/_sdk/_constants.py b/src/promptflow/promptflow/_sdk/_constants.py index 2ef182e6d43..399251669ec 100644 --- a/src/promptflow/promptflow/_sdk/_constants.py +++ b/src/promptflow/promptflow/_sdk/_constants.py @@ -18,6 +18,7 @@ FLOW_TOOLS_JSON = "flow.tools.json" FLOW_TOOLS_JSON_GEN_TIMEOUT = 60 PROMPT_FLOW_DIR_NAME = ".promptflow" +PROMPT_FLOW_RUNS_DIR_NAME = ".runs" HOME_PROMPT_FLOW_DIR = (Path.home() / PROMPT_FLOW_DIR_NAME).resolve() SERVICE_CONFIG_FILE = "pf.yaml" PF_SERVICE_PORT_FILE = "pfs.port" diff --git a/src/promptflow/promptflow/_sdk/_errors.py b/src/promptflow/promptflow/_sdk/_errors.py index c850b567f77..9e3e686c2eb 100644 --- a/src/promptflow/promptflow/_sdk/_errors.py +++ b/src/promptflow/promptflow/_sdk/_errors.py @@ -92,6 +92,12 @@ class RunOperationParameterError(PromptflowException): pass +class RunOperationError(PromptflowException): + """Exception raised when run operation failed.""" + + pass + + class FlowOperationError(PromptflowException): """Exception raised when flow operation failed.""" diff --git a/src/promptflow/promptflow/azure/operations/_async_run_downloader.py b/src/promptflow/promptflow/azure/operations/_async_run_downloader.py new file mode 100644 index 00000000000..75176e109e1 --- /dev/null +++ b/src/promptflow/promptflow/azure/operations/_async_run_downloader.py @@ -0,0 +1,252 @@ +import asyncio +import json +import logging +from pathlib import Path +from typing import Optional, Union + +import httpx +from azure.storage.blob.aio import BlobServiceClient + +from promptflow._sdk._constants import DEFAULT_ENCODING, LOGGER_NAME +from promptflow._sdk._errors import RunNotFoundError, RunOperationError +from promptflow._utils.logger_utils import LoggerFactory +from promptflow.exceptions import UserErrorException + +logger = LoggerFactory.get_logger(name=LOGGER_NAME, verbosity=logging.WARNING) + + +class AsyncRunDownloader: + """Download run results from the service asynchronously. + + :param run: The run id. + :type run: str + :param run_ops: The run operations. + :type run_ops: ~promptflow.azure.operations.RunOperations + :param output_folder: The output folder to save the run results. + :type output_folder: Union[Path, str] + """ + + LOCAL_SNAPSHOT_FOLDER = "snapshot" + LOCAL_METRICS_FILE_NAME = "metrics.json" + LOCAL_LOGS_FILE_NAME = "logs.txt" + + IGNORED_PATTERN = ["__pycache__"] + + def __init__(self, run: str, run_ops: "RunOperations", output_folder: Union[str, Path]) -> None: + self.run = run + self.run_ops = run_ops + self.datastore = run_ops._workspace_default_datastore + self.output_folder = Path(output_folder) + self.blob_service_client = self._init_blob_service_client() + self._use_flow_outputs = False # old runtime does not write debug_info output asset, use flow_outputs instead + + def _init_blob_service_client(self): + logger.debug("Initializing blob service client.") + account_url = f"{self.datastore.account_name}.blob.{self.datastore.endpoint}" + return BlobServiceClient(account_url=account_url, credential=self.run_ops._credential) + + async def download(self) -> str: + """Download the run results asynchronously.""" + try: + # pass verify=False to client to disable SSL verification. + # Source: https://github.com/encode/httpx/issues/1331 + async with httpx.AsyncClient(verify=False) as client: + + async_tasks = [ + # put async functions in tasks to run in coroutines + self._download_artifacts_and_snapshot(client), + ] + sync_tasks = [ + # below functions are actually synchronous functions in order to reuse code, + # the execution time of these functions should be shorter than the above async functions + # so it won't increase the total execution time. + # the reason we still put them in the tasks is, on one hand the code is more consistent and + # we can use asyncio.gather() to wait for all tasks to finish, on the other hand, we can + # also evaluate below functions to be shorter than the async functions with the help of logs + self._download_run_metrics(), + self._download_run_logs(), + ] + tasks = async_tasks + sync_tasks + await asyncio.gather(*tasks) + except Exception as e: + raise RunOperationError(f"Failed to download run {self.run!r}. Error: {e}") from e + + return self.output_folder.resolve().as_posix() + + async def _download_artifacts_and_snapshot(self, httpx_client: httpx.AsyncClient): + run_data = await self._get_run_data_from_run_history(httpx_client) + + logger.debug("Parsing run data from run history to get necessary information.") + # extract necessary information from run data + snapshot_id = run_data["runMetadata"]["properties"]["azureml.promptflow.snapshot_id"] + output_data = run_data["runMetadata"]["outputs"].get("debug_info", None) + if output_data is None: + logger.warning( + "Downloading run '%s' but the 'debug_info' output assets is not available, " + "maybe because the job ran on old version runtime, trying to get `flow_outputs` output asset instead.", + self.run, + ) + self._use_flow_outputs = True + output_data = run_data["runMetadata"]["outputs"].get("flow_outputs", None) + output_asset_id = output_data["assetId"] + + async with self.blob_service_client: + container_name = self.datastore.container_name + logger.debug("Getting container client (%s) from workspace default datastore.", container_name) + container_client = self.blob_service_client.get_container_client(container_name) + + async with container_client: + tasks = [ + self._download_flow_artifacts(httpx_client, container_client, output_asset_id), + self._download_snapshot(httpx_client, container_client, snapshot_id), + ] + await asyncio.gather(*tasks) + + async def _get_run_data_from_run_history(self, client: httpx.AsyncClient): + """Get the run data from the run history.""" + logger.debug("Getting run data from run history.") + headers = self.run_ops._get_headers() + url = self.run_ops._run_history_endpoint_url + "/rundata" + + payload = { + "runId": self.run, + "selectRunMetadata": True, + "selectRunDefinition": True, + "selectJobSpecification": True, + } + + response = await client.post(url, headers=headers, json=payload) + if response.status_code == 200: + return response.json() + elif response.status_code == 404: + raise RunNotFoundError(f"Run {self.run!r} not found.") + else: + raise RunOperationError( + f"Failed to get run from service. Code: {response.status_code}, text: {response.text}" + ) + + async def _download_run_metrics( + self, + ): + """Download the run metrics.""" + logger.debug("Downloading run metrics.") + metrics = self.run_ops.get_metrics(self.run) + with open(self.output_folder / self.LOCAL_METRICS_FILE_NAME, "w", encoding=DEFAULT_ENCODING) as f: + json.dump(metrics, f, ensure_ascii=False) + logger.debug("Downloaded run metrics.") + + async def _download_flow_artifacts(self, httpx_client: httpx.AsyncClient, container_client, output_data): + """Download the output data.""" + asset_path = await self._get_asset_path(httpx_client, output_data) + await self._download_blob_folder_from_asset_path(container_client, asset_path) + + async def _download_blob_folder_from_asset_path( + self, container_client, asset_path: str, local_folder: Optional[Path] = None + ): + """Download the blob data from the data path.""" + logger.debug("Downloading all blobs from data path prefix '%s'", asset_path) + if local_folder is None: + local_folder = self.output_folder + + tasks = [] + async for blob in container_client.list_blobs(name_starts_with=asset_path): + blob_client = container_client.get_blob_client(blob.name) + relative_path = Path(blob.name).relative_to(asset_path) + local_path = local_folder / relative_path + tasks.append(self._download_single_blob(blob_client, local_path)) + await asyncio.gather(*tasks) + + async def _download_single_blob(self, blob_client, local_path: Optional[Path] = None): + """Download a single blob.""" + if local_path is None: + local_path = Path(self.output_folder / blob_client.blob_name) + elif local_path.exists(): + raise UserErrorException(f"Local file {local_path.resolve().as_posix()!r} already exists.") + + # ignore some files + for item in self.IGNORED_PATTERN: + if item in blob_client.blob_name: + logger.warning( + "Ignoring file '%s' because it matches the ignored pattern '%s'", local_path.as_posix(), item + ) + return None + + logger.debug("Downloading blob '%s' to local path '%s'", blob_client.blob_name, local_path.resolve().as_posix()) + local_path.parent.mkdir(parents=True, exist_ok=True) + async with blob_client: + with open(local_path, "wb") as f: + stream = await blob_client.download_blob() + data = await stream.readall() + # TODO: File IO may block the event loop, consider using thread pool. e.g. to_thread() method + f.write(data) + return local_path + + async def _download_snapshot(self, httpx_client: httpx.AsyncClient, container_client, snapshot_id): + """Download the flow snapshot.""" + snapshot_urls = await self._get_flow_snapshot_urls(httpx_client, snapshot_id) + + logger.debug("Downloading all snapshot blobs from snapshot urls.") + tasks = [] + for url in snapshot_urls: + blob_name = url.split(self.datastore.container_name)[-1].lstrip("/") + blob_client = container_client.get_blob_client(blob_name) + relative_path = url.split(self.run)[-1].lstrip("/") + local_path = Path(self.output_folder / self.LOCAL_SNAPSHOT_FOLDER / relative_path) + tasks.append(self._download_single_blob(blob_client, local_path)) + await asyncio.gather(*tasks) + + async def _get_flow_snapshot_urls(self, httpx_client: httpx.AsyncClient, snapshot_id): + logger.debug("Getting flow snapshot blob urls from snapshot id with calling to content service.") + headers = self.run_ops._get_headers() + endpoint = self.run_ops._run_history_endpoint_url.replace("/history/v1.0", "/content/v2.0") + url = endpoint + "/snapshots/sas" + payload = { + "snapshotOrAssetId": snapshot_id, + } + + response = await httpx_client.post(url, headers=headers, json=payload) + if response.status_code == 200: + return self._parse_snapshot_response(response.json()) + elif response.status_code == 404: + raise UserErrorException(f"Snapshot {snapshot_id!r} not found.") + else: + raise RunOperationError( + f"Failed to get snapshot {snapshot_id!r} from content service. " + f"Code: {response.status_code}, text: {response.text}" + ) + + async def _get_asset_path(self, client: httpx.AsyncClient, asset_id): + """Get the asset path from asset id.""" + logger.debug("Getting asset path from asset id with calling to data service.") + headers = self.run_ops._get_headers() + endpoint = self.run_ops._run_history_endpoint_url.replace("/history", "/data") + url = endpoint + "/dataversion/getByAssetId" + payload = { + "value": asset_id, + } + + response = await client.post(url, headers=headers, json=payload) + response_data = response.json() + data_path = response_data["dataVersion"]["dataUri"].split("/paths/")[-1] + if self._use_flow_outputs: + data_path = data_path.replace("flow_outputs", "flow_artifacts") + return data_path + + def _parse_snapshot_response(self, response: dict): + """Parse the snapshot response.""" + urls = [] + if response["absoluteUrl"]: + urls.append(response["absoluteUrl"]) + for value in response["children"].values(): + urls += self._parse_snapshot_response(value) + + return urls + + async def _download_run_logs(self): + """Download the run logs.""" + logger.debug("Downloading run logs.") + logs = self.run_ops._get_log(self.run) + + with open(self.output_folder / self.LOCAL_LOGS_FILE_NAME, "w", encoding=DEFAULT_ENCODING) as f: + f.write(logs) + logger.debug("Downloaded run logs.") diff --git a/src/promptflow/promptflow/azure/operations/_run_operations.py b/src/promptflow/promptflow/azure/operations/_run_operations.py index fe64c3c845d..d7d56f8b03a 100644 --- a/src/promptflow/promptflow/azure/operations/_run_operations.py +++ b/src/promptflow/promptflow/azure/operations/_run_operations.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import asyncio import concurrent import copy import hashlib @@ -8,10 +9,12 @@ import logging import os import re +import shutil import sys import time from concurrent.futures import ThreadPoolExecutor from functools import cached_property +from pathlib import Path from typing import Any, Dict, List, Optional, Union import requests @@ -33,6 +36,8 @@ LOGGER_NAME, MAX_RUN_LIST_RESULTS, MAX_SHOW_DETAILS_RESULTS, + PROMPT_FLOW_DIR_NAME, + PROMPT_FLOW_RUNS_DIR_NAME, REGISTRY_URI_PREFIX, VIS_PORTAL_URL_TMPL, AzureRunTypes, @@ -45,6 +50,7 @@ from promptflow._sdk._telemetry import ActivityType, WorkspaceTelemetryMixin, monitor_operation from promptflow._sdk._utils import in_jupyter_notebook, incremental_print, is_remote_uri, print_red_error from promptflow._sdk.entities import Run +from promptflow._utils.async_utils import async_run_allowing_running_loop from promptflow._utils.flow_utils import get_flow_lineage_id from promptflow._utils.logger_utils import LoggerFactory from promptflow.azure._constants._flow import ( @@ -59,6 +65,7 @@ from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller from promptflow.azure._utils.gerneral import get_user_alias_from_credential from promptflow.azure.operations._flow_operations import FlowOperations +from promptflow.exceptions import UserErrorException RUNNING_STATUSES = RunStatus.get_running_statuses() @@ -109,7 +116,7 @@ def __init__( self._credential = credential self._flow_operations = flow_operations self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config) - self._workspace_default_datastore = self._datastore_operations.get_default().name + self._workspace_default_datastore = self._datastore_operations.get_default() @property def _data_operations(self): @@ -761,7 +768,7 @@ def _get_data_type(_data): self._operation_scope, self._datastore_operations, test_data, - datastore_name=self._workspace_default_datastore, + datastore_name=self._workspace_default_datastore.name, show_progress=self._show_progress, ) if data_type == AssetTypes.URI_FOLDER and test_data and not test_data.endswith("/"): @@ -994,3 +1001,52 @@ def _resolve_flow_definition_resource_id(self, run: Run): workspace_id = self._workspace._workspace_id location = self._workspace.location return f"azureml://locations/{location}/workspaces/{workspace_id}/flows/{run._flow_name}" + + @monitor_operation(activity_name="pfazure.runs.download", activity_type=ActivityType.PUBLICAPI) + def download( + self, run: Union[str, Run], output: Optional[Union[str, Path]] = None, overwrite: Optional[bool] = False + ) -> str: + """Download the data of a run, including input, output, snapshot and other run information. + + :param run: The run name or run object + :type run: Union[str, ~promptflow.entities.Run] + :param output: The output directory. Default to be default to be "~/.promptflow/.runs" folder. + :type output: Optional[str] + :param overwrite: Whether to overwrite the existing run folder. Default to be False. + :type overwrite: Optional[bool] + :return: The run directory path + :rtype: str + """ + import platform + + from promptflow.azure.operations._async_run_downloader import AsyncRunDownloader + + run = Run._validate_and_return_run_name(run) + if output is None: + # default to be "~/.promptflow/.runs" folder + output_directory = Path.home() / PROMPT_FLOW_DIR_NAME / PROMPT_FLOW_RUNS_DIR_NAME + else: + output_directory = Path(output) + + run_folder = output_directory / run + if run_folder.exists(): + if overwrite is True: + logger.warning("Removing existing run folder %r.", run_folder.resolve().as_posix()) + shutil.rmtree(run_folder) + else: + raise UserErrorException( + f"Run folder {run_folder.resolve().as_posix()!r} already exists, please specify a new output path " + f"or set the overwrite flag to be true." + ) + run_folder.mkdir(parents=True) + + run_downloader = AsyncRunDownloader(run=run, run_ops=self, output_folder=run_folder) + if platform.system().lower() == "windows": + # Reference: https://stackoverflow.com/questions/45600579/asyncio-event-loop-is-closed-when-getting-loop + # On Windows seems to be a problem with EventLoopPolicy, use this snippet to work around it + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + async_run_allowing_running_loop(run_downloader.download) + result_path = run_folder.resolve().as_posix() + logger.info(f"Successfully downloaded run {run!r} to {result_path!r}.") + return result_path diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py index 63f55fa5098..9f2e7c92326 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py +++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py @@ -796,6 +796,23 @@ def test_vnext_workspace_base_url(self): ) assert service_caller.caller._client._base_url == "https://promptflow.azure-api.net/" + @pytest.mark.skipif(condition=not is_live(), reason="need to fix recording") + def test_download_run(self, pf): + from promptflow.azure.operations._async_run_downloader import AsyncRunDownloader + + run = "c619f648-c809-4545-9f94-f67b0a680706" + + expected_files = [ + AsyncRunDownloader.LOCAL_LOGS_FILE_NAME, + AsyncRunDownloader.LOCAL_METRICS_FILE_NAME, + f"{AsyncRunDownloader.LOCAL_SNAPSHOT_FOLDER}/flow.dag.yaml", + ] + + with TemporaryDirectory() as tmp_dir: + pf.runs.download(run=run, output=tmp_dir) + for file in expected_files: + assert Path(tmp_dir, run, file).exists() + def test_request_id_when_making_http_requests(self, pf, runtime: str, randstr: Callable[[str], str]): from azure.core.exceptions import HttpResponseError diff --git a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py index 7a496d8fb0a..dcf8761c62e 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py @@ -320,3 +320,20 @@ def check_workspace_info(*args, **kwargs): "--include-archived", *operation_scope_args, ) + + def test_run_download(self, mocker: MockFixture, operation_scope_args): + from promptflow.azure.operations._run_operations import RunOperations + + mocked = mocker.patch.object(RunOperations, "download") + mocked.return_value = "fake_output_run_dir" + run_pf_command( + "run", + "download", + "--name", + "test_run", + "--output", + "fake_output_dir", + "--overwrite", + *operation_scope_args, + ) + mocked.assert_called_once()