From fa6dc4a9ae02d3f14c34125baf3a02130baf2bab Mon Sep 17 00:00:00 2001 From: Honglin Date: Wed, 6 Dec 2023 18:29:04 +0800 Subject: [PATCH] [SDK/CLI] Add pfazure download run feature. (#1378) # Description ![image](https://github.com/microsoft/promptflow/assets/2572521/1f1f2cdf-6b2f-4706-aa42-114a0249afdd) ![image](https://github.com/microsoft/promptflow/assets/2572521/b016e614-b8c9-4404-ae4f-fbc8659c0a2c) # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow/CHANGELOG.md | 3 + src/promptflow/promptflow/_cli/_params.py | 64 +---- .../promptflow/_cli/_pf_azure/_run.py | 39 ++- src/promptflow/promptflow/_sdk/_constants.py | 1 + src/promptflow/promptflow/_sdk/_errors.py | 6 + .../azure/operations/_async_run_downloader.py | 252 ++++++++++++++++++ .../azure/operations/_run_operations.py | 60 ++++- .../e2etests/test_run_operations.py | 17 ++ .../sdk_cli_azure_test/unittests/test_cli.py | 17 ++ 9 files changed, 405 insertions(+), 54 deletions(-) create mode 100644 src/promptflow/promptflow/azure/operations/_async_run_downloader.py diff --git a/src/promptflow/CHANGELOG.md b/src/promptflow/CHANGELOG.md index cd49623225d..38dfa0a6d1d 100644 --- a/src/promptflow/CHANGELOG.md +++ b/src/promptflow/CHANGELOG.md @@ -2,6 +2,9 @@ ## 1.2.0 (upcoming) +### Features Added +- [SDK/CLI] Support `pfazure run download` to download run data from Azure AI. + ### Bugs Fixed - [SDK/CLI] Removing telemetry warning when running commands. diff --git a/src/promptflow/promptflow/_cli/_params.py b/src/promptflow/promptflow/_cli/_params.py index df17434c5fe..3aea229b36b 100644 --- a/src/promptflow/promptflow/_cli/_params.py +++ b/src/promptflow/promptflow/_cli/_params.py @@ -56,20 +56,20 @@ def add_param_ua(parser): def add_param_flow_display_name(parser): - parser.add_argument("--flow", type=str, required=True, help="the flow name to create.") + parser.add_argument("--flow", type=str, required=True, help="The flow name to create.") def add_param_entry(parser): - parser.add_argument("--entry", type=str, help="the entry file.") + parser.add_argument("--entry", type=str, help="The entry file.") def add_param_function(parser): - parser.add_argument("--function", type=str, help="the function name in entry file.") + parser.add_argument("--function", type=str, help="The function name in entry file.") def add_param_prompt_template(parser): parser.add_argument( - "--prompt-template", action=AppendToDictAction, help="the prompt template parameter and assignment.", nargs="+" + "--prompt-template", action=AppendToDictAction, help="The prompt template parameter and assignment.", nargs="+" ) @@ -146,58 +146,30 @@ def add_param_inputs(parser): ) -def add_param_input(parser): - parser.add_argument( - "--input", type=str, required=True, help="the input file path. Note that we accept jsonl file only for now." - ) - - def add_param_env(parser): parser.add_argument( "--env", type=str, default=None, - help="the dotenv file path containing the environment variables to be used in the flow.", + help="The dotenv file path containing the environment variables to be used in the flow.", ) def add_param_output(parser): - parser.add_argument("--output", type=str, default="outputs", help="the output directory to store the results.") - - -def add_param_flow(parser): - parser.add_argument("--flow", type=str, required=True, help="the evaluation flow to be used.") - - -def add_param_source(parser): - parser.add_argument("--source", type=str, required=True, help="The flow or run source to be used.") - - -def add_param_bulk_run_output(parser): - parser.add_argument("--bulk-run-output", type=str, help="the output directory of the bulk run.") - - -def add_param_eval_output(parser): - parser.add_argument("--eval-output", type=str, help="the output file path of the evaluation result.") - - -def add_param_column_mapping(parser): parser.add_argument( - "--column-mapping", type=str, required=True, help="the column mapping to be used in the evaluation." + "-o", + "--output", + type=str, + help="The output directory to store the results. Default to be current working directory if not specified.", ) -def add_param_runtime(parser): - parser.add_argument( - "--runtime", - type=str, - default="local", - help="Name of your runtime in Azure ML workspace, will run in cloud when runtime is not none.", - ) +def add_param_overwrite(parser): + parser.add_argument("--overwrite", action="store_true", help="Overwrite the existing results.") -def add_param_connection(parser): - parser.add_argument("--connection", type=str, help="Name of your connection in Azure ML workspace.") +def add_param_source(parser): + parser.add_argument("--source", type=str, required=True, help="The flow or run source to be used.") def add_param_run_name(parser): @@ -208,16 +180,6 @@ def add_param_connection_name(parser): parser.add_argument("-n", "--name", type=str, help="Name of the connection to create.") -def add_param_variants(parser): - parser.add_argument( - "--variants", - type=str, - nargs="+", - help="the variant run ids to be used in the evaluation. Note that we only support one variant for now.", - default=[], - ) - - def add_param_max_results(parser): parser.add_argument( # noqa: E731 "-r", diff --git a/src/promptflow/promptflow/_cli/_pf_azure/_run.py b/src/promptflow/promptflow/_cli/_pf_azure/_run.py index ff2a7c701e4..6eb10af7be9 100644 --- a/src/promptflow/promptflow/_cli/_pf_azure/_run.py +++ b/src/promptflow/promptflow/_cli/_pf_azure/_run.py @@ -13,7 +13,9 @@ add_param_debug, add_param_include_archived, add_param_max_results, + add_param_output, add_param_output_format, + add_param_overwrite, add_param_run_name, add_param_set, add_param_verbose, @@ -52,6 +54,7 @@ def add_parser_run(subparsers): add_parser_run_archive(subparsers) add_parser_run_restore(subparsers) add_parser_run_update(subparsers) + add_parser_run_download(subparsers) run_parser.set_defaults(action="run") @@ -334,7 +337,7 @@ def add_parser_run_update(subparsers): Example: # Update a run metadata: -pf run update --name --set display_name="" description="" tags.key="" +pfazure run update --name --set display_name="" description="" tags.key="" """ add_params = [ _set_workspace_argument_for_subparsers, @@ -353,6 +356,32 @@ def add_parser_run_update(subparsers): ) +def add_parser_run_download(subparsers): + """Add run download parser to the pfazure subparsers.""" + epilog = """ +Example: + +# Download a run data to local: +pfazure run download --name --output +""" + add_params = [ + add_param_run_name, + add_param_output, + add_param_overwrite, + _set_workspace_argument_for_subparsers, + ] + base_params + + activate_action( + name="download", + description="A CLI tool to download a run.", + epilog=epilog, + add_params=add_params, + subparsers=subparsers, + help_message="Download a run.", + action_param_name="sub_action", + ) + + def dispatch_run_commands(args: argparse.Namespace): if args.sub_action == "create": pf = _get_azure_pf_client(args.subscription, args.resource_group, args.workspace_name, debug=args.debug) @@ -403,6 +432,8 @@ def dispatch_run_commands(args: argparse.Namespace): restore_run(args.subscription, args.resource_group, args.workspace_name, args.name) elif args.sub_action == "update": update_run(args.subscription, args.resource_group, args.workspace_name, args.name, params=args.params_override) + elif args.sub_action == "download": + download_run(args) @exception_handler("List runs") @@ -530,3 +561,9 @@ def update_run( pf = _get_azure_pf_client(subscription_id, resource_group, workspace_name) run = pf.runs.update(run=run_name, display_name=display_name, description=description, tags=tags) print(json.dumps(run._to_dict(), indent=4)) + + +@exception_handler("Download run") +def download_run(args: argparse.Namespace): + pf = _get_azure_pf_client(args.subscription, args.resource_group, args.workspace_name, debug=args.debug) + pf.runs.download(run=args.name, output=args.output, overwrite=args.overwrite) diff --git a/src/promptflow/promptflow/_sdk/_constants.py b/src/promptflow/promptflow/_sdk/_constants.py index 2ef182e6d43..399251669ec 100644 --- a/src/promptflow/promptflow/_sdk/_constants.py +++ b/src/promptflow/promptflow/_sdk/_constants.py @@ -18,6 +18,7 @@ FLOW_TOOLS_JSON = "flow.tools.json" FLOW_TOOLS_JSON_GEN_TIMEOUT = 60 PROMPT_FLOW_DIR_NAME = ".promptflow" +PROMPT_FLOW_RUNS_DIR_NAME = ".runs" HOME_PROMPT_FLOW_DIR = (Path.home() / PROMPT_FLOW_DIR_NAME).resolve() SERVICE_CONFIG_FILE = "pf.yaml" PF_SERVICE_PORT_FILE = "pfs.port" diff --git a/src/promptflow/promptflow/_sdk/_errors.py b/src/promptflow/promptflow/_sdk/_errors.py index c850b567f77..9e3e686c2eb 100644 --- a/src/promptflow/promptflow/_sdk/_errors.py +++ b/src/promptflow/promptflow/_sdk/_errors.py @@ -92,6 +92,12 @@ class RunOperationParameterError(PromptflowException): pass +class RunOperationError(PromptflowException): + """Exception raised when run operation failed.""" + + pass + + class FlowOperationError(PromptflowException): """Exception raised when flow operation failed.""" diff --git a/src/promptflow/promptflow/azure/operations/_async_run_downloader.py b/src/promptflow/promptflow/azure/operations/_async_run_downloader.py new file mode 100644 index 00000000000..75176e109e1 --- /dev/null +++ b/src/promptflow/promptflow/azure/operations/_async_run_downloader.py @@ -0,0 +1,252 @@ +import asyncio +import json +import logging +from pathlib import Path +from typing import Optional, Union + +import httpx +from azure.storage.blob.aio import BlobServiceClient + +from promptflow._sdk._constants import DEFAULT_ENCODING, LOGGER_NAME +from promptflow._sdk._errors import RunNotFoundError, RunOperationError +from promptflow._utils.logger_utils import LoggerFactory +from promptflow.exceptions import UserErrorException + +logger = LoggerFactory.get_logger(name=LOGGER_NAME, verbosity=logging.WARNING) + + +class AsyncRunDownloader: + """Download run results from the service asynchronously. + + :param run: The run id. + :type run: str + :param run_ops: The run operations. + :type run_ops: ~promptflow.azure.operations.RunOperations + :param output_folder: The output folder to save the run results. + :type output_folder: Union[Path, str] + """ + + LOCAL_SNAPSHOT_FOLDER = "snapshot" + LOCAL_METRICS_FILE_NAME = "metrics.json" + LOCAL_LOGS_FILE_NAME = "logs.txt" + + IGNORED_PATTERN = ["__pycache__"] + + def __init__(self, run: str, run_ops: "RunOperations", output_folder: Union[str, Path]) -> None: + self.run = run + self.run_ops = run_ops + self.datastore = run_ops._workspace_default_datastore + self.output_folder = Path(output_folder) + self.blob_service_client = self._init_blob_service_client() + self._use_flow_outputs = False # old runtime does not write debug_info output asset, use flow_outputs instead + + def _init_blob_service_client(self): + logger.debug("Initializing blob service client.") + account_url = f"{self.datastore.account_name}.blob.{self.datastore.endpoint}" + return BlobServiceClient(account_url=account_url, credential=self.run_ops._credential) + + async def download(self) -> str: + """Download the run results asynchronously.""" + try: + # pass verify=False to client to disable SSL verification. + # Source: https://github.com/encode/httpx/issues/1331 + async with httpx.AsyncClient(verify=False) as client: + + async_tasks = [ + # put async functions in tasks to run in coroutines + self._download_artifacts_and_snapshot(client), + ] + sync_tasks = [ + # below functions are actually synchronous functions in order to reuse code, + # the execution time of these functions should be shorter than the above async functions + # so it won't increase the total execution time. + # the reason we still put them in the tasks is, on one hand the code is more consistent and + # we can use asyncio.gather() to wait for all tasks to finish, on the other hand, we can + # also evaluate below functions to be shorter than the async functions with the help of logs + self._download_run_metrics(), + self._download_run_logs(), + ] + tasks = async_tasks + sync_tasks + await asyncio.gather(*tasks) + except Exception as e: + raise RunOperationError(f"Failed to download run {self.run!r}. Error: {e}") from e + + return self.output_folder.resolve().as_posix() + + async def _download_artifacts_and_snapshot(self, httpx_client: httpx.AsyncClient): + run_data = await self._get_run_data_from_run_history(httpx_client) + + logger.debug("Parsing run data from run history to get necessary information.") + # extract necessary information from run data + snapshot_id = run_data["runMetadata"]["properties"]["azureml.promptflow.snapshot_id"] + output_data = run_data["runMetadata"]["outputs"].get("debug_info", None) + if output_data is None: + logger.warning( + "Downloading run '%s' but the 'debug_info' output assets is not available, " + "maybe because the job ran on old version runtime, trying to get `flow_outputs` output asset instead.", + self.run, + ) + self._use_flow_outputs = True + output_data = run_data["runMetadata"]["outputs"].get("flow_outputs", None) + output_asset_id = output_data["assetId"] + + async with self.blob_service_client: + container_name = self.datastore.container_name + logger.debug("Getting container client (%s) from workspace default datastore.", container_name) + container_client = self.blob_service_client.get_container_client(container_name) + + async with container_client: + tasks = [ + self._download_flow_artifacts(httpx_client, container_client, output_asset_id), + self._download_snapshot(httpx_client, container_client, snapshot_id), + ] + await asyncio.gather(*tasks) + + async def _get_run_data_from_run_history(self, client: httpx.AsyncClient): + """Get the run data from the run history.""" + logger.debug("Getting run data from run history.") + headers = self.run_ops._get_headers() + url = self.run_ops._run_history_endpoint_url + "/rundata" + + payload = { + "runId": self.run, + "selectRunMetadata": True, + "selectRunDefinition": True, + "selectJobSpecification": True, + } + + response = await client.post(url, headers=headers, json=payload) + if response.status_code == 200: + return response.json() + elif response.status_code == 404: + raise RunNotFoundError(f"Run {self.run!r} not found.") + else: + raise RunOperationError( + f"Failed to get run from service. Code: {response.status_code}, text: {response.text}" + ) + + async def _download_run_metrics( + self, + ): + """Download the run metrics.""" + logger.debug("Downloading run metrics.") + metrics = self.run_ops.get_metrics(self.run) + with open(self.output_folder / self.LOCAL_METRICS_FILE_NAME, "w", encoding=DEFAULT_ENCODING) as f: + json.dump(metrics, f, ensure_ascii=False) + logger.debug("Downloaded run metrics.") + + async def _download_flow_artifacts(self, httpx_client: httpx.AsyncClient, container_client, output_data): + """Download the output data.""" + asset_path = await self._get_asset_path(httpx_client, output_data) + await self._download_blob_folder_from_asset_path(container_client, asset_path) + + async def _download_blob_folder_from_asset_path( + self, container_client, asset_path: str, local_folder: Optional[Path] = None + ): + """Download the blob data from the data path.""" + logger.debug("Downloading all blobs from data path prefix '%s'", asset_path) + if local_folder is None: + local_folder = self.output_folder + + tasks = [] + async for blob in container_client.list_blobs(name_starts_with=asset_path): + blob_client = container_client.get_blob_client(blob.name) + relative_path = Path(blob.name).relative_to(asset_path) + local_path = local_folder / relative_path + tasks.append(self._download_single_blob(blob_client, local_path)) + await asyncio.gather(*tasks) + + async def _download_single_blob(self, blob_client, local_path: Optional[Path] = None): + """Download a single blob.""" + if local_path is None: + local_path = Path(self.output_folder / blob_client.blob_name) + elif local_path.exists(): + raise UserErrorException(f"Local file {local_path.resolve().as_posix()!r} already exists.") + + # ignore some files + for item in self.IGNORED_PATTERN: + if item in blob_client.blob_name: + logger.warning( + "Ignoring file '%s' because it matches the ignored pattern '%s'", local_path.as_posix(), item + ) + return None + + logger.debug("Downloading blob '%s' to local path '%s'", blob_client.blob_name, local_path.resolve().as_posix()) + local_path.parent.mkdir(parents=True, exist_ok=True) + async with blob_client: + with open(local_path, "wb") as f: + stream = await blob_client.download_blob() + data = await stream.readall() + # TODO: File IO may block the event loop, consider using thread pool. e.g. to_thread() method + f.write(data) + return local_path + + async def _download_snapshot(self, httpx_client: httpx.AsyncClient, container_client, snapshot_id): + """Download the flow snapshot.""" + snapshot_urls = await self._get_flow_snapshot_urls(httpx_client, snapshot_id) + + logger.debug("Downloading all snapshot blobs from snapshot urls.") + tasks = [] + for url in snapshot_urls: + blob_name = url.split(self.datastore.container_name)[-1].lstrip("/") + blob_client = container_client.get_blob_client(blob_name) + relative_path = url.split(self.run)[-1].lstrip("/") + local_path = Path(self.output_folder / self.LOCAL_SNAPSHOT_FOLDER / relative_path) + tasks.append(self._download_single_blob(blob_client, local_path)) + await asyncio.gather(*tasks) + + async def _get_flow_snapshot_urls(self, httpx_client: httpx.AsyncClient, snapshot_id): + logger.debug("Getting flow snapshot blob urls from snapshot id with calling to content service.") + headers = self.run_ops._get_headers() + endpoint = self.run_ops._run_history_endpoint_url.replace("/history/v1.0", "/content/v2.0") + url = endpoint + "/snapshots/sas" + payload = { + "snapshotOrAssetId": snapshot_id, + } + + response = await httpx_client.post(url, headers=headers, json=payload) + if response.status_code == 200: + return self._parse_snapshot_response(response.json()) + elif response.status_code == 404: + raise UserErrorException(f"Snapshot {snapshot_id!r} not found.") + else: + raise RunOperationError( + f"Failed to get snapshot {snapshot_id!r} from content service. " + f"Code: {response.status_code}, text: {response.text}" + ) + + async def _get_asset_path(self, client: httpx.AsyncClient, asset_id): + """Get the asset path from asset id.""" + logger.debug("Getting asset path from asset id with calling to data service.") + headers = self.run_ops._get_headers() + endpoint = self.run_ops._run_history_endpoint_url.replace("/history", "/data") + url = endpoint + "/dataversion/getByAssetId" + payload = { + "value": asset_id, + } + + response = await client.post(url, headers=headers, json=payload) + response_data = response.json() + data_path = response_data["dataVersion"]["dataUri"].split("/paths/")[-1] + if self._use_flow_outputs: + data_path = data_path.replace("flow_outputs", "flow_artifacts") + return data_path + + def _parse_snapshot_response(self, response: dict): + """Parse the snapshot response.""" + urls = [] + if response["absoluteUrl"]: + urls.append(response["absoluteUrl"]) + for value in response["children"].values(): + urls += self._parse_snapshot_response(value) + + return urls + + async def _download_run_logs(self): + """Download the run logs.""" + logger.debug("Downloading run logs.") + logs = self.run_ops._get_log(self.run) + + with open(self.output_folder / self.LOCAL_LOGS_FILE_NAME, "w", encoding=DEFAULT_ENCODING) as f: + f.write(logs) + logger.debug("Downloaded run logs.") diff --git a/src/promptflow/promptflow/azure/operations/_run_operations.py b/src/promptflow/promptflow/azure/operations/_run_operations.py index fe64c3c845d..d7d56f8b03a 100644 --- a/src/promptflow/promptflow/azure/operations/_run_operations.py +++ b/src/promptflow/promptflow/azure/operations/_run_operations.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import asyncio import concurrent import copy import hashlib @@ -8,10 +9,12 @@ import logging import os import re +import shutil import sys import time from concurrent.futures import ThreadPoolExecutor from functools import cached_property +from pathlib import Path from typing import Any, Dict, List, Optional, Union import requests @@ -33,6 +36,8 @@ LOGGER_NAME, MAX_RUN_LIST_RESULTS, MAX_SHOW_DETAILS_RESULTS, + PROMPT_FLOW_DIR_NAME, + PROMPT_FLOW_RUNS_DIR_NAME, REGISTRY_URI_PREFIX, VIS_PORTAL_URL_TMPL, AzureRunTypes, @@ -45,6 +50,7 @@ from promptflow._sdk._telemetry import ActivityType, WorkspaceTelemetryMixin, monitor_operation from promptflow._sdk._utils import in_jupyter_notebook, incremental_print, is_remote_uri, print_red_error from promptflow._sdk.entities import Run +from promptflow._utils.async_utils import async_run_allowing_running_loop from promptflow._utils.flow_utils import get_flow_lineage_id from promptflow._utils.logger_utils import LoggerFactory from promptflow.azure._constants._flow import ( @@ -59,6 +65,7 @@ from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller from promptflow.azure._utils.gerneral import get_user_alias_from_credential from promptflow.azure.operations._flow_operations import FlowOperations +from promptflow.exceptions import UserErrorException RUNNING_STATUSES = RunStatus.get_running_statuses() @@ -109,7 +116,7 @@ def __init__( self._credential = credential self._flow_operations = flow_operations self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config) - self._workspace_default_datastore = self._datastore_operations.get_default().name + self._workspace_default_datastore = self._datastore_operations.get_default() @property def _data_operations(self): @@ -761,7 +768,7 @@ def _get_data_type(_data): self._operation_scope, self._datastore_operations, test_data, - datastore_name=self._workspace_default_datastore, + datastore_name=self._workspace_default_datastore.name, show_progress=self._show_progress, ) if data_type == AssetTypes.URI_FOLDER and test_data and not test_data.endswith("/"): @@ -994,3 +1001,52 @@ def _resolve_flow_definition_resource_id(self, run: Run): workspace_id = self._workspace._workspace_id location = self._workspace.location return f"azureml://locations/{location}/workspaces/{workspace_id}/flows/{run._flow_name}" + + @monitor_operation(activity_name="pfazure.runs.download", activity_type=ActivityType.PUBLICAPI) + def download( + self, run: Union[str, Run], output: Optional[Union[str, Path]] = None, overwrite: Optional[bool] = False + ) -> str: + """Download the data of a run, including input, output, snapshot and other run information. + + :param run: The run name or run object + :type run: Union[str, ~promptflow.entities.Run] + :param output: The output directory. Default to be default to be "~/.promptflow/.runs" folder. + :type output: Optional[str] + :param overwrite: Whether to overwrite the existing run folder. Default to be False. + :type overwrite: Optional[bool] + :return: The run directory path + :rtype: str + """ + import platform + + from promptflow.azure.operations._async_run_downloader import AsyncRunDownloader + + run = Run._validate_and_return_run_name(run) + if output is None: + # default to be "~/.promptflow/.runs" folder + output_directory = Path.home() / PROMPT_FLOW_DIR_NAME / PROMPT_FLOW_RUNS_DIR_NAME + else: + output_directory = Path(output) + + run_folder = output_directory / run + if run_folder.exists(): + if overwrite is True: + logger.warning("Removing existing run folder %r.", run_folder.resolve().as_posix()) + shutil.rmtree(run_folder) + else: + raise UserErrorException( + f"Run folder {run_folder.resolve().as_posix()!r} already exists, please specify a new output path " + f"or set the overwrite flag to be true." + ) + run_folder.mkdir(parents=True) + + run_downloader = AsyncRunDownloader(run=run, run_ops=self, output_folder=run_folder) + if platform.system().lower() == "windows": + # Reference: https://stackoverflow.com/questions/45600579/asyncio-event-loop-is-closed-when-getting-loop + # On Windows seems to be a problem with EventLoopPolicy, use this snippet to work around it + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + async_run_allowing_running_loop(run_downloader.download) + result_path = run_folder.resolve().as_posix() + logger.info(f"Successfully downloaded run {run!r} to {result_path!r}.") + return result_path diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py index 63f55fa5098..9f2e7c92326 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py +++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py @@ -796,6 +796,23 @@ def test_vnext_workspace_base_url(self): ) assert service_caller.caller._client._base_url == "https://promptflow.azure-api.net/" + @pytest.mark.skipif(condition=not is_live(), reason="need to fix recording") + def test_download_run(self, pf): + from promptflow.azure.operations._async_run_downloader import AsyncRunDownloader + + run = "c619f648-c809-4545-9f94-f67b0a680706" + + expected_files = [ + AsyncRunDownloader.LOCAL_LOGS_FILE_NAME, + AsyncRunDownloader.LOCAL_METRICS_FILE_NAME, + f"{AsyncRunDownloader.LOCAL_SNAPSHOT_FOLDER}/flow.dag.yaml", + ] + + with TemporaryDirectory() as tmp_dir: + pf.runs.download(run=run, output=tmp_dir) + for file in expected_files: + assert Path(tmp_dir, run, file).exists() + def test_request_id_when_making_http_requests(self, pf, runtime: str, randstr: Callable[[str], str]): from azure.core.exceptions import HttpResponseError diff --git a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py index 7a496d8fb0a..dcf8761c62e 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_cli.py @@ -320,3 +320,20 @@ def check_workspace_info(*args, **kwargs): "--include-archived", *operation_scope_args, ) + + def test_run_download(self, mocker: MockFixture, operation_scope_args): + from promptflow.azure.operations._run_operations import RunOperations + + mocked = mocker.patch.object(RunOperations, "download") + mocked.return_value = "fake_output_run_dir" + run_pf_command( + "run", + "download", + "--name", + "test_run", + "--output", + "fake_output_dir", + "--overwrite", + *operation_scope_args, + ) + mocked.assert_called_once()