From 9f2bc059b2cff0a318e0812d30415afbd4edf48c Mon Sep 17 00:00:00 2001 From: Zhengfei Wang <38847871+zhengfeiwang@users.noreply.github.com> Date: Thu, 25 Apr 2024 17:46:08 +0800 Subject: [PATCH 1/6] [trace][refactor] Move tracing related functions to a better place (#2990) # Description **Move tracing related functions** - `promptflow._sdk._tracing`: tracing related function import place, with unit tests covered and guarded - for `promptflow-tracing`: `start_trace_with_devkit`, `setup_exporter_to_pfs` - for OTLP collector, runtime and others: `process_otlp_trace_request` - parse span from Protocol Buffer - `promptflow._sdk._utils.tracing`: utilities for tracing Remove previous tracing utilities file `_tracing_utils.py`. **Pass function that gets credential** For `process_otlp_trace_request` usage, user should pass the function that how to get credential, instead of the credential itself. However, as the environment may not have Azure extension, so we cannot directly pass `AzureCliCredential` in outside; so add a default logic inside the function. # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [x] Pull request includes test coverage for the included changes. --- .../promptflow/azure/_storage/blob/client.py | 17 +- .../azure/_storage/cosmosdb/client.py | 14 +- .../unittests/test_run_operations.py | 4 +- .../_sdk/_service/apis/collector.py | 37 ++- .../promptflow/_sdk/_tracing.py | 51 ++-- .../promptflow/_sdk/_tracing_utils.py | 145 --------- .../promptflow/_sdk/_utils/general_utils.py | 28 -- .../promptflow/_sdk/_utils/tracing.py | 285 ++++++++++++++++++ .../_sdk/operations/_trace_operations.py | 118 +------- .../sdk_cli_test/e2etests/test_flow_run.py | 5 +- .../sdk_cli_test/unittests/test_trace.py | 9 +- 11 files changed, 363 insertions(+), 350 deletions(-) delete mode 100644 src/promptflow-devkit/promptflow/_sdk/_tracing_utils.py create mode 100644 src/promptflow-devkit/promptflow/_sdk/_utils/tracing.py diff --git a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py index 6f2085229d1..75f9e2bbb4c 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py @@ -2,7 +2,7 @@ import logging import threading import traceback -from typing import Optional, Tuple +from typing import Callable, Tuple from azure.ai.ml import MLClient from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata @@ -25,17 +25,10 @@ def get_datastore_container_client( subscription_id: str, resource_group_name: str, workspace_name: str, - credential: Optional[object] = None, + get_credential: Callable, ) -> Tuple[ContainerClient, str]: try: - if credential is None: - # in cloud scenario, runtime will pass in credential - # so this is local to cloud only code, happens in prompt flow service - # which should rely on Azure CLI credential only - from azure.identity import AzureCliCredential - - credential = AzureCliCredential() - + credential = get_credential() datastore_definition, datastore_credential = _get_default_datastore( subscription_id, resource_group_name, workspace_name, credential ) @@ -68,7 +61,7 @@ def get_datastore_container_client( def _get_default_datastore( - subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object] + subscription_id: str, resource_group_name: str, workspace_name: str, credential ) -> Tuple[Datastore, str]: datastore_key = _get_datastore_client_key(subscription_id, resource_group_name, workspace_name) @@ -103,7 +96,7 @@ def _get_datastore_client_key(subscription_id: str, resource_group_name: str, wo def _get_aml_default_datastore( - subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object] + subscription_id: str, resource_group_name: str, workspace_name: str, credential ) -> Tuple[Datastore, str]: ml_client = MLClient( diff --git a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py index 6e013ad7cfc..01a741da654 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py @@ -5,7 +5,7 @@ import ast import datetime import threading -from typing import Optional +from typing import Callable client_map = {} _thread_lock = threading.Lock() @@ -18,7 +18,7 @@ def get_client( subscription_id: str, resource_group_name: str, workspace_name: str, - credential: Optional[object] = None, + get_credential: Callable, ): client_key = _get_db_client_key(container_name, subscription_id, resource_group_name, workspace_name) container_client = _get_client_from_map(client_key) @@ -28,13 +28,7 @@ def get_client( with container_lock: container_client = _get_client_from_map(client_key) if container_client is None: - if credential is None: - # in cloud scenario, runtime will pass in credential - # so this is local to cloud only code, happens in prompt flow service - # which should rely on Azure CLI credential only - from azure.identity import AzureCliCredential - - credential = AzureCliCredential() + credential = get_credential() token = _get_resource_token( container_name, subscription_id, resource_group_name, workspace_name, credential ) @@ -77,7 +71,7 @@ def _get_resource_token( subscription_id: str, resource_group_name: str, workspace_name: str, - credential: Optional[object], + credential, ) -> object: from promptflow.azure import PFClient diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_run_operations.py b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_run_operations.py index f85ace022af..a334bc296d9 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_run_operations.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/unittests/test_run_operations.py @@ -9,7 +9,7 @@ from sdk_cli_azure_test.conftest import DATAS_DIR, EAGER_FLOWS_DIR, FLOWS_DIR from promptflow._sdk._errors import RunOperationParameterError, UploadUserError, UserAuthenticationError -from promptflow._sdk._utils import parse_otel_span_status_code +from promptflow._sdk._utils.tracing import _parse_otel_span_status_code from promptflow._sdk.entities import Run from promptflow._sdk.operations._run_operations import RunOperations from promptflow._utils.async_utils import async_run_allowing_running_loop @@ -88,7 +88,7 @@ def test_flex_flow_with_imported_func(self, pf: PFClient): # TODO(3017093): won't support this for now with pytest.raises(UserErrorException) as e: pf.run( - flow=parse_otel_span_status_code, + flow=_parse_otel_span_status_code, data=f"{DATAS_DIR}/simple_eager_flow_data.jsonl", # set code folder to avoid snapshot too big code=f"{EAGER_FLOWS_DIR}/multiple_entries", diff --git a/src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py b/src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py index 766eef4ffcd..8820a11a168 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py +++ b/src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py @@ -13,7 +13,8 @@ from flask import request from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest -from promptflow._sdk._tracing import process_otlp_trace_request +from promptflow._sdk._errors import MissingAzurePackage +from promptflow._sdk._tracing import _is_azure_ext_installed, process_otlp_trace_request def trace_collector( @@ -41,13 +42,33 @@ def trace_collector( if "application/x-protobuf" in content_type: trace_request = ExportTraceServiceRequest() trace_request.ParseFromString(request.data) - process_otlp_trace_request( - trace_request=trace_request, - get_created_by_info_with_cache=get_created_by_info_with_cache, - logger=logger, - cloud_trace_only=cloud_trace_only, - credential=credential, - ) + # this function will be called in some old runtime versions + # where runtime will pass either credential object, or the function to get credential + # as we need to be compatible with this, need to handle both cases + if credential is not None: + # local prompt flow service will not pass credential, so this is runtime scenario + get_credential = credential if callable(credential) else lambda: credential # noqa: F841 + process_otlp_trace_request( + trace_request=trace_request, + get_created_by_info_with_cache=get_created_by_info_with_cache, + logger=logger, + get_credential=get_credential, + cloud_trace_only=cloud_trace_only, + ) + else: + # if `promptflow-azure` is not installed, pass an exception class to the function + get_credential = MissingAzurePackage + if _is_azure_ext_installed(): + from azure.identity import AzureCliCredential + + get_credential = AzureCliCredential + process_otlp_trace_request( + trace_request=trace_request, + get_created_by_info_with_cache=get_created_by_info_with_cache, + logger=logger, + get_credential=get_credential, + cloud_trace_only=cloud_trace_only, + ) return "Traces received", 200 # JSON protobuf encoding diff --git a/src/promptflow-devkit/promptflow/_sdk/_tracing.py b/src/promptflow-devkit/promptflow/_sdk/_tracing.py index c8516c61d86..e167adc6c39 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_tracing.py +++ b/src/promptflow-devkit/promptflow/_sdk/_tracing.py @@ -51,12 +51,8 @@ is_port_in_use, is_run_from_built_binary, ) -from promptflow._sdk._tracing_utils import get_workspace_kind -from promptflow._sdk._utils import ( - add_executable_script_to_env_path, - extract_workspace_triad_from_trace_provider, - parse_kv_from_pb_attribute, -) +from promptflow._sdk._utils import add_executable_script_to_env_path, extract_workspace_triad_from_trace_provider +from promptflow._sdk._utils.tracing import get_workspace_kind, parse_kv_from_pb_attribute, parse_protobuf_span from promptflow._utils.logger_utils import get_cli_sdk_logger from promptflow._utils.thread_utils import ThreadWithContextVars from promptflow.tracing._integrations._openai_injector import inject_openai_api @@ -559,8 +555,8 @@ def process_otlp_trace_request( trace_request: ExportTraceServiceRequest, get_created_by_info_with_cache: typing.Callable, logger: logging.Logger, + get_credential: typing.Callable, cloud_trace_only: bool = False, - credential: typing.Optional[object] = None, ): """Process ExportTraceServiceRequest and write data to local/remote storage. @@ -572,13 +568,12 @@ def process_otlp_trace_request( :type get_created_by_info_with_cache: Callable :param logger: The logger object used for logging. :type logger: logging.Logger + :param get_credential: A function that gets credential for Cosmos DB operation. + :type get_credential: Callable :param cloud_trace_only: If True, only write trace to cosmosdb and skip local trace. Default is False. :type cloud_trace_only: bool - :param credential: The credential object used to authenticate with cosmosdb. Default is None. - :type credential: Optional[object] """ from promptflow._sdk.entities._trace import Span - from promptflow._sdk.operations._trace_operations import TraceOperations all_spans = [] for resource_span in trace_request.resource_spans: @@ -596,7 +591,7 @@ def process_otlp_trace_request( for scope_span in resource_span.scope_spans: for span in scope_span.spans: # TODO: persist with batch - span: Span = TraceOperations._parse_protobuf_span(span, resource=resource, logger=logger) + span: Span = parse_protobuf_span(span, resource=resource, logger=logger) if not cloud_trace_only: all_spans.append(copy.deepcopy(span)) span._persist() @@ -606,12 +601,14 @@ def process_otlp_trace_request( if cloud_trace_only: # If we only trace to cloud, we should make sure the data writing is success before return. - _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache, logger, credential, is_cloud_trace=True) + _try_write_trace_to_cosmosdb( + all_spans, get_created_by_info_with_cache, logger, get_credential, is_cloud_trace=True + ) else: # Create a new thread to write trace to cosmosdb to avoid blocking the main thread ThreadWithContextVars( target=_try_write_trace_to_cosmosdb, - args=(all_spans, get_created_by_info_with_cache, logger, credential, False), + args=(all_spans, get_created_by_info_with_cache, logger, get_credential, False), ).start() return @@ -621,7 +618,7 @@ def _try_write_trace_to_cosmosdb( all_spans: typing.List, get_created_by_info_with_cache: typing.Callable, logger: logging.Logger, - credential: typing.Optional[object] = None, + get_credential: typing.Callable, is_cloud_trace: bool = False, ): if not all_spans: @@ -649,19 +646,31 @@ def _try_write_trace_to_cosmosdb( # So, we load clients in parallel for warm up. span_client_thread = ThreadWithContextVars( target=get_client, - args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential), + args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential), ) span_client_thread.start() collection_client_thread = ThreadWithContextVars( target=get_client, - args=(CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential), + args=( + CosmosDBContainerName.COLLECTION, + subscription_id, + resource_group_name, + workspace_name, + get_credential, + ), ) collection_client_thread.start() line_summary_client_thread = ThreadWithContextVars( target=get_client, - args=(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name, credential), + args=( + CosmosDBContainerName.LINE_SUMMARY, + subscription_id, + resource_group_name, + workspace_name, + get_credential, + ), ) line_summary_client_thread.start() @@ -677,7 +686,7 @@ def _try_write_trace_to_cosmosdb( subscription_id=subscription_id, resource_group_name=resource_group_name, workspace_name=workspace_name, - credential=credential, + get_credential=get_credential, ) span_client_thread.join() @@ -687,7 +696,7 @@ def _try_write_trace_to_cosmosdb( created_by = get_created_by_info_with_cache() collection_client = get_client( - CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential + CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, get_credential ) collection_db = CollectionCosmosDB(first_span, is_cloud_trace, created_by) @@ -701,7 +710,7 @@ def _try_write_trace_to_cosmosdb( for span in all_spans: try: span_client = get_client( - CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential + CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, get_credential ) result = SpanCosmosDB(span, collection_id, created_by).persist( span_client, blob_container_client, blob_base_uri @@ -713,7 +722,7 @@ def _try_write_trace_to_cosmosdb( subscription_id, resource_group_name, workspace_name, - credential, + get_credential, ) Summary(span, collection_id, created_by, logger).persist(line_summary_client) except Exception as e: diff --git a/src/promptflow-devkit/promptflow/_sdk/_tracing_utils.py b/src/promptflow-devkit/promptflow/_sdk/_tracing_utils.py deleted file mode 100644 index b56848e652d..00000000000 --- a/src/promptflow-devkit/promptflow/_sdk/_tracing_utils.py +++ /dev/null @@ -1,145 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- - -import datetime -import json -import logging -import typing -from dataclasses import dataclass -from pathlib import Path - -from promptflow._sdk._constants import HOME_PROMPT_FLOW_DIR, AzureMLWorkspaceTriad -from promptflow._sdk._utils import json_load -from promptflow._utils.logger_utils import get_cli_sdk_logger -from promptflow.core._errors import MissingRequiredPackage - -_logger = get_cli_sdk_logger() - - -# SCENARIO: local to cloud -# distinguish Azure ML workspace and AI project -@dataclass -class WorkspaceKindLocalCache: - subscription_id: str - resource_group_name: str - workspace_name: str - kind: typing.Optional[str] = None - timestamp: typing.Optional[datetime.datetime] = None - - SUBSCRIPTION_ID = "subscription_id" - RESOURCE_GROUP_NAME = "resource_group_name" - WORKSPACE_NAME = "workspace_name" - KIND = "kind" - TIMESTAMP = "timestamp" - # class-related constants - PF_DIR_TRACING = "tracing" - WORKSPACE_KIND_LOCAL_CACHE_EXPIRE_DAYS = 1 - - def __post_init__(self): - if self.is_cache_exists: - cache = json_load(self.cache_path) - self.kind = cache[self.KIND] - self.timestamp = datetime.datetime.fromisoformat(cache[self.TIMESTAMP]) - - @property - def cache_path(self) -> Path: - tracing_dir = HOME_PROMPT_FLOW_DIR / self.PF_DIR_TRACING - if not tracing_dir.exists(): - tracing_dir.mkdir(parents=True) - filename = f"{self.subscription_id}_{self.resource_group_name}_{self.workspace_name}.json" - return (tracing_dir / filename).resolve() - - @property - def is_cache_exists(self) -> bool: - return self.cache_path.is_file() - - @property - def is_expired(self) -> bool: - if not self.is_cache_exists: - return True - time_delta = datetime.datetime.now() - self.timestamp - return time_delta.days > self.WORKSPACE_KIND_LOCAL_CACHE_EXPIRE_DAYS - - def get_kind(self) -> str: - if not self.is_cache_exists or self.is_expired: - _logger.debug(f"refreshing local cache for resource {self.workspace_name}...") - self._refresh() - _logger.debug(f"local cache kind for resource {self.workspace_name}: {self.kind}") - return self.kind - - def _refresh(self) -> None: - self.kind = self._get_workspace_kind_from_azure() - self.timestamp = datetime.datetime.now() - cache = { - self.SUBSCRIPTION_ID: self.subscription_id, - self.RESOURCE_GROUP_NAME: self.resource_group_name, - self.WORKSPACE_NAME: self.workspace_name, - self.KIND: self.kind, - self.TIMESTAMP: self.timestamp.isoformat(), - } - with open(self.cache_path, "w") as f: - f.write(json.dumps(cache)) - - def _get_workspace_kind_from_azure(self) -> str: - try: - from azure.ai.ml import MLClient - - from promptflow.azure._cli._utils import get_credentials_for_cli - except ImportError: - error_message = "Please install 'promptflow-azure' to use Azure related tracing features." - raise MissingRequiredPackage(message=error_message) - - _logger.debug("trying to get workspace from Azure...") - ml_client = MLClient( - credential=get_credentials_for_cli(), - subscription_id=self.subscription_id, - resource_group_name=self.resource_group_name, - workspace_name=self.workspace_name, - ) - ws = ml_client.workspaces.get(name=self.workspace_name) - return ws._kind - - -def get_workspace_kind(ws_triad: AzureMLWorkspaceTriad) -> str: - """Get workspace kind. - - Note that we will cache this result locally with timestamp, so that we don't - really need to request every time, but need to check timestamp. - """ - return WorkspaceKindLocalCache( - subscription_id=ws_triad.subscription_id, - resource_group_name=ws_triad.resource_group_name, - workspace_name=ws_triad.workspace_name, - ).get_kind() - - -# SCENARIO: local trace UI search experience -# append condition(s) to user specified query -def append_conditions( - expression: str, - collection: typing.Optional[str] = None, - runs: typing.Optional[typing.Union[str, typing.List[str]]] = None, - session_id: typing.Optional[str] = None, - logger: typing.Optional[logging.Logger] = None, -) -> str: - if logger is None: - logger = _logger - logger.debug("received original search expression: %s", expression) - if collection is not None: - logger.debug("received search parameter collection: %s", collection) - expression += f" and collection == '{collection}'" - if runs is not None: - logger.debug("received search parameter runs: %s", runs) - if isinstance(runs, str): - expression += f" and run == '{runs}'" - elif len(runs) == 1: - expression += f" and run == '{runs[0]}'" - else: - runs_expr = " or ".join([f"run == '{run}'" for run in runs]) - expression += f" and ({runs_expr})" - if session_id is not None: - logger.debug("received search parameter session_id: %s", session_id) - expression += f" and session_id == '{session_id}'" - logger.debug("final search expression: %s", expression) - return expression diff --git a/src/promptflow-devkit/promptflow/_sdk/_utils/general_utils.py b/src/promptflow-devkit/promptflow/_sdk/_utils/general_utils.py index d332a63fa66..afb33415baf 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_utils/general_utils.py +++ b/src/promptflow-devkit/promptflow/_sdk/_utils/general_utils.py @@ -909,34 +909,6 @@ def convert_time_unix_nano_to_timestamp(time_unix_nano: str) -> datetime.datetim return datetime.datetime.utcfromtimestamp(seconds) -def parse_kv_from_pb_attribute(attribute: Dict) -> Tuple[str, str]: - attr_key = attribute["key"] - # suppose all values are flattened here - # so simply regard the first value as the attribute value - attr_value = list(attribute["value"].values())[0] - return attr_key, attr_value - - -def flatten_pb_attributes(attributes: List[Dict]) -> Dict: - flattened_attributes = {} - for attribute in attributes: - attr_key, attr_value = parse_kv_from_pb_attribute(attribute) - flattened_attributes[attr_key] = attr_value - return flattened_attributes - - -def parse_otel_span_status_code(value: int) -> str: - # map int value to string - # https://github.com/open-telemetry/opentelemetry-specification/blob/v1.22.0/specification/trace/api.md#set-status - # https://github.com/open-telemetry/opentelemetry-python/blob/v1.22.0/opentelemetry-api/src/opentelemetry/trace/status.py#L22-L32 - if value == 0: - return "Unset" - elif value == 1: - return "Ok" - else: - return "Error" - - def extract_workspace_triad_from_trace_provider(trace_provider: str) -> AzureMLWorkspaceTriad: match = re.match(AZURE_WORKSPACE_REGEX_FORMAT, trace_provider) if not match or len(match.groups()) != 5: diff --git a/src/promptflow-devkit/promptflow/_sdk/_utils/tracing.py b/src/promptflow-devkit/promptflow/_sdk/_utils/tracing.py new file mode 100644 index 00000000000..2bbf6058988 --- /dev/null +++ b/src/promptflow-devkit/promptflow/_sdk/_utils/tracing.py @@ -0,0 +1,285 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import datetime +import json +import logging +import typing +from dataclasses import dataclass +from pathlib import Path + +from google.protobuf.json_format import MessageToJson +from opentelemetry.proto.trace.v1.trace_pb2 import Span as PBSpan +from opentelemetry.trace.span import format_span_id as otel_format_span_id +from opentelemetry.trace.span import format_trace_id as otel_format_trace_id + +from promptflow._constants import ( + SpanContextFieldName, + SpanEventFieldName, + SpanFieldName, + SpanLinkFieldName, + SpanStatusFieldName, +) +from promptflow._sdk._constants import HOME_PROMPT_FLOW_DIR, AzureMLWorkspaceTriad +from promptflow._sdk._utils import convert_time_unix_nano_to_timestamp, json_load +from promptflow._sdk.entities._trace import Span +from promptflow._utils.logger_utils import get_cli_sdk_logger +from promptflow.core._errors import MissingRequiredPackage + +_logger = get_cli_sdk_logger() + + +# SCENARIO: OTLP trace collector +# prompt flow service, runtime parse OTLP trace +def format_span_id(span_id: bytes) -> str: + """Format span id to hex string. + Note that we need to add 0x since it is how opentelemetry-sdk does. + Reference: https://github.com/open-telemetry/opentelemetry-python/blob/ + 642f8dd18eea2737b4f8cd2f6f4d08a7e569c4b2/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L505 + """ + return f"0x{otel_format_span_id(int.from_bytes(span_id, byteorder='big', signed=False))}" + + +def format_trace_id(trace_id: bytes) -> str: + """Format trace_id id to hex string. + Note that we need to add 0x since it is how opentelemetry-sdk does. + Reference: https://github.com/open-telemetry/opentelemetry-python/blob/ + 642f8dd18eea2737b4f8cd2f6f4d08a7e569c4b2/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L505 + """ + return f"0x{otel_format_trace_id(int.from_bytes(trace_id, byteorder='big', signed=False))}" + + +def parse_kv_from_pb_attribute(attribute: typing.Dict) -> typing.Tuple[str, str]: + attr_key = attribute["key"] + # suppose all values are flattened here + # so simply regard the first value as the attribute value + attr_value = list(attribute["value"].values())[0] + return attr_key, attr_value + + +def _flatten_pb_attributes(attributes: typing.List[typing.Dict]) -> typing.Dict: + flattened_attributes = {} + for attribute in attributes: + attr_key, attr_value = parse_kv_from_pb_attribute(attribute) + flattened_attributes[attr_key] = attr_value + return flattened_attributes + + +def _parse_otel_span_status_code(value: int) -> str: + # map int value to string + # https://github.com/open-telemetry/opentelemetry-specification/blob/v1.22.0/specification/trace/api.md#set-status + # https://github.com/open-telemetry/opentelemetry-python/blob/v1.22.0/opentelemetry-api/src/opentelemetry/trace/status.py#L22-L32 + if value == 0: + return "Unset" + elif value == 1: + return "Ok" + else: + return "Error" + + +def parse_protobuf_events(obj: typing.List[PBSpan.Event], logger: logging.Logger) -> typing.List[typing.Dict]: + events = [] + if len(obj) == 0: + logger.debug("No events found in span") + return events + for pb_event in obj: + event_dict: dict = json.loads(MessageToJson(pb_event)) + logger.debug("Received event: %s", json.dumps(event_dict)) + event = { + SpanEventFieldName.NAME: pb_event.name, + # .isoformat() here to make this dumpable to JSON + SpanEventFieldName.TIMESTAMP: convert_time_unix_nano_to_timestamp(pb_event.time_unix_nano).isoformat(), + SpanEventFieldName.ATTRIBUTES: _flatten_pb_attributes( + event_dict.get(SpanEventFieldName.ATTRIBUTES, dict()) + ), + } + events.append(event) + return events + + +def parse_protobuf_links(obj: typing.List[PBSpan.Link], logger: logging.Logger) -> typing.List[typing.Dict]: + links = [] + if len(obj) == 0: + logger.debug("No links found in span") + return links + for pb_link in obj: + link_dict: dict = json.loads(MessageToJson(pb_link)) + logger.debug("Received link: %s", json.dumps(link_dict)) + link = { + SpanLinkFieldName.CONTEXT: { + SpanContextFieldName.TRACE_ID: format_trace_id(pb_link.trace_id), + SpanContextFieldName.SPAN_ID: format_span_id(pb_link.span_id), + SpanContextFieldName.TRACE_STATE: pb_link.trace_state, + }, + SpanLinkFieldName.ATTRIBUTES: _flatten_pb_attributes(link_dict.get(SpanLinkFieldName.ATTRIBUTES, dict())), + } + links.append(link) + return links + + +def parse_protobuf_span(span: PBSpan, resource: typing.Dict, logger: logging.Logger) -> Span: + # Open Telemetry does not provide official way to parse Protocol Buffer Span object + # so we need to parse it manually relying on `MessageToJson` + # reference: https://github.com/open-telemetry/opentelemetry-python/issues/3700#issuecomment-2010704554 + span_dict: dict = json.loads(MessageToJson(span)) + logger.debug("Received span: %s, resource: %s", json.dumps(span_dict), json.dumps(resource)) + span_id = format_span_id(span.span_id) + trace_id = format_trace_id(span.trace_id) + parent_id = format_span_id(span.parent_span_id) if span.parent_span_id else None + # we have observed in some scenarios, there is not `attributes` field + attributes = _flatten_pb_attributes(span_dict.get(SpanFieldName.ATTRIBUTES, dict())) + logger.debug("Parsed attributes: %s", json.dumps(attributes)) + links = parse_protobuf_links(span.links, logger) + events = parse_protobuf_events(span.events, logger) + + return Span( + trace_id=trace_id, + span_id=span_id, + name=span.name, + context={ + SpanContextFieldName.TRACE_ID: trace_id, + SpanContextFieldName.SPAN_ID: span_id, + SpanContextFieldName.TRACE_STATE: span.trace_state, + }, + kind=span.kind, + parent_id=parent_id if parent_id else None, + start_time=convert_time_unix_nano_to_timestamp(span.start_time_unix_nano), + end_time=convert_time_unix_nano_to_timestamp(span.end_time_unix_nano), + status={ + SpanStatusFieldName.STATUS_CODE: _parse_otel_span_status_code(span.status.code), + SpanStatusFieldName.DESCRIPTION: span.status.message, + }, + attributes=attributes, + links=links, + events=events, + resource=resource, + ) + + +# SCENARIO: local to cloud +# distinguish Azure ML workspace and AI project +@dataclass +class WorkspaceKindLocalCache: + subscription_id: str + resource_group_name: str + workspace_name: str + kind: typing.Optional[str] = None + timestamp: typing.Optional[datetime.datetime] = None + + SUBSCRIPTION_ID = "subscription_id" + RESOURCE_GROUP_NAME = "resource_group_name" + WORKSPACE_NAME = "workspace_name" + KIND = "kind" + TIMESTAMP = "timestamp" + # class-related constants + PF_DIR_TRACING = "tracing" + WORKSPACE_KIND_LOCAL_CACHE_EXPIRE_DAYS = 1 + + def __post_init__(self): + if self.is_cache_exists: + cache = json_load(self.cache_path) + self.kind = cache[self.KIND] + self.timestamp = datetime.datetime.fromisoformat(cache[self.TIMESTAMP]) + + @property + def cache_path(self) -> Path: + tracing_dir = HOME_PROMPT_FLOW_DIR / self.PF_DIR_TRACING + if not tracing_dir.exists(): + tracing_dir.mkdir(parents=True) + filename = f"{self.subscription_id}_{self.resource_group_name}_{self.workspace_name}.json" + return (tracing_dir / filename).resolve() + + @property + def is_cache_exists(self) -> bool: + return self.cache_path.is_file() + + @property + def is_expired(self) -> bool: + if not self.is_cache_exists: + return True + time_delta = datetime.datetime.now() - self.timestamp + return time_delta.days > self.WORKSPACE_KIND_LOCAL_CACHE_EXPIRE_DAYS + + def get_kind(self) -> str: + if not self.is_cache_exists or self.is_expired: + _logger.debug(f"refreshing local cache for resource {self.workspace_name}...") + self._refresh() + _logger.debug(f"local cache kind for resource {self.workspace_name}: {self.kind}") + return self.kind + + def _refresh(self) -> None: + self.kind = self._get_workspace_kind_from_azure() + self.timestamp = datetime.datetime.now() + cache = { + self.SUBSCRIPTION_ID: self.subscription_id, + self.RESOURCE_GROUP_NAME: self.resource_group_name, + self.WORKSPACE_NAME: self.workspace_name, + self.KIND: self.kind, + self.TIMESTAMP: self.timestamp.isoformat(), + } + with open(self.cache_path, "w") as f: + f.write(json.dumps(cache)) + + def _get_workspace_kind_from_azure(self) -> str: + try: + from azure.ai.ml import MLClient + + from promptflow.azure._cli._utils import get_credentials_for_cli + except ImportError: + error_message = "Please install 'promptflow-azure' to use Azure related tracing features." + raise MissingRequiredPackage(message=error_message) + + _logger.debug("trying to get workspace from Azure...") + ml_client = MLClient( + credential=get_credentials_for_cli(), + subscription_id=self.subscription_id, + resource_group_name=self.resource_group_name, + workspace_name=self.workspace_name, + ) + ws = ml_client.workspaces.get(name=self.workspace_name) + return ws._kind + + +def get_workspace_kind(ws_triad: AzureMLWorkspaceTriad) -> str: + """Get workspace kind. + + Note that we will cache this result locally with timestamp, so that we don't + really need to request every time, but need to check timestamp. + """ + return WorkspaceKindLocalCache( + subscription_id=ws_triad.subscription_id, + resource_group_name=ws_triad.resource_group_name, + workspace_name=ws_triad.workspace_name, + ).get_kind() + + +# SCENARIO: local trace UI search experience +# append condition(s) to user specified query +def append_conditions( + expression: str, + collection: typing.Optional[str] = None, + runs: typing.Optional[typing.Union[str, typing.List[str]]] = None, + session_id: typing.Optional[str] = None, + logger: typing.Optional[logging.Logger] = None, +) -> str: + if logger is None: + logger = _logger + logger.debug("received original search expression: %s", expression) + if collection is not None: + logger.debug("received search parameter collection: %s", collection) + expression += f" and collection == '{collection}'" + if runs is not None: + logger.debug("received search parameter runs: %s", runs) + if isinstance(runs, str): + expression += f" and run == '{runs}'" + elif len(runs) == 1: + expression += f" and run == '{runs[0]}'" + else: + runs_expr = " or ".join([f"run == '{run}'" for run in runs]) + expression += f" and ({runs_expr})" + if session_id is not None: + logger.debug("received search parameter session_id: %s", session_id) + expression += f" and session_id == '{session_id}'" + logger.debug("final search expression: %s", expression) + return expression diff --git a/src/promptflow-devkit/promptflow/_sdk/operations/_trace_operations.py b/src/promptflow-devkit/promptflow/_sdk/operations/_trace_operations.py index 2bafffbaaa6..228ec836d3e 100644 --- a/src/promptflow-devkit/promptflow/_sdk/operations/_trace_operations.py +++ b/src/promptflow-devkit/promptflow/_sdk/operations/_trace_operations.py @@ -3,21 +3,8 @@ # --------------------------------------------------------- import datetime -import json -import logging import typing -from google.protobuf.json_format import MessageToJson -from opentelemetry.proto.trace.v1.trace_pb2 import Span as PBSpan -from opentelemetry.trace.span import format_span_id, format_trace_id - -from promptflow._constants import ( - SpanContextFieldName, - SpanEventFieldName, - SpanFieldName, - SpanLinkFieldName, - SpanStatusFieldName, -) from promptflow._sdk._constants import TRACE_DEFAULT_COLLECTION, TRACE_LIST_DEFAULT_LIMIT from promptflow._sdk._orm.retry import sqlite_retry from promptflow._sdk._orm.session import trace_mgmt_db_session @@ -25,12 +12,7 @@ from promptflow._sdk._orm.trace import LineRun as ORMLineRun from promptflow._sdk._orm.trace import Span as ORMSpan from promptflow._sdk._telemetry import ActivityType, monitor_operation -from promptflow._sdk._tracing_utils import append_conditions -from promptflow._sdk._utils import ( - convert_time_unix_nano_to_timestamp, - flatten_pb_attributes, - parse_otel_span_status_code, -) +from promptflow._sdk._utils.tracing import append_conditions from promptflow._sdk.entities._trace import Event, LineRun, Span from promptflow._utils.logger_utils import get_cli_sdk_logger from promptflow.exceptions import UserErrorException @@ -40,104 +22,6 @@ class TraceOperations: def __init__(self): self._logger = get_cli_sdk_logger() - def _parse_protobuf_events(obj: typing.List[PBSpan.Event], logger: logging.Logger) -> typing.List[typing.Dict]: - events = [] - if len(obj) == 0: - logger.debug("No events found in span") - return events - for pb_event in obj: - event_dict: dict = json.loads(MessageToJson(pb_event)) - logger.debug("Received event: %s", json.dumps(event_dict)) - event = { - SpanEventFieldName.NAME: pb_event.name, - # .isoformat() here to make this dumpable to JSON - SpanEventFieldName.TIMESTAMP: convert_time_unix_nano_to_timestamp(pb_event.time_unix_nano).isoformat(), - SpanEventFieldName.ATTRIBUTES: flatten_pb_attributes( - event_dict.get(SpanEventFieldName.ATTRIBUTES, dict()) - ), - } - events.append(event) - return events - - @staticmethod - def _parse_protobuf_links(obj: typing.List[PBSpan.Link], logger: logging.Logger) -> typing.List[typing.Dict]: - links = [] - if len(obj) == 0: - logger.debug("No links found in span") - return links - for pb_link in obj: - link_dict: dict = json.loads(MessageToJson(pb_link)) - logger.debug("Received link: %s", json.dumps(link_dict)) - link = { - SpanLinkFieldName.CONTEXT: { - SpanContextFieldName.TRACE_ID: TraceOperations.format_trace_id(pb_link.trace_id), - SpanContextFieldName.SPAN_ID: TraceOperations.format_span_id(pb_link.span_id), - SpanContextFieldName.TRACE_STATE: pb_link.trace_state, - }, - SpanLinkFieldName.ATTRIBUTES: flatten_pb_attributes( - link_dict.get(SpanLinkFieldName.ATTRIBUTES, dict()) - ), - } - links.append(link) - return links - - @staticmethod - def format_span_id(span_id: bytes) -> str: - """Format span id to hex string. - Note that we need to add 0x since it is how opentelemetry-sdk does. - Reference: https://github.com/open-telemetry/opentelemetry-python/blob/ - 642f8dd18eea2737b4f8cd2f6f4d08a7e569c4b2/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L505 - """ - return f"0x{format_span_id(int.from_bytes(span_id, byteorder='big', signed=False))}" - - @staticmethod - def format_trace_id(trace_id: bytes) -> str: - """Format trace_id id to hex string. - Note that we need to add 0x since it is how opentelemetry-sdk does. - Reference: https://github.com/open-telemetry/opentelemetry-python/blob/ - 642f8dd18eea2737b4f8cd2f6f4d08a7e569c4b2/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L505 - """ - return f"0x{format_trace_id(int.from_bytes(trace_id, byteorder='big', signed=False))}" - - @staticmethod - def _parse_protobuf_span(span: PBSpan, resource: typing.Dict, logger: logging.Logger) -> Span: - # Open Telemetry does not provide official way to parse Protocol Buffer Span object - # so we need to parse it manually relying on `MessageToJson` - # reference: https://github.com/open-telemetry/opentelemetry-python/issues/3700#issuecomment-2010704554 - span_dict: dict = json.loads(MessageToJson(span)) - logger.debug("Received span: %s, resource: %s", json.dumps(span_dict), json.dumps(resource)) - span_id = TraceOperations.format_span_id(span.span_id) - trace_id = TraceOperations.format_trace_id(span.trace_id) - parent_id = TraceOperations.format_span_id(span.parent_span_id) if span.parent_span_id else None - # we have observed in some scenarios, there is not `attributes` field - attributes = flatten_pb_attributes(span_dict.get(SpanFieldName.ATTRIBUTES, dict())) - logger.debug("Parsed attributes: %s", json.dumps(attributes)) - links = TraceOperations._parse_protobuf_links(span.links, logger) - events = TraceOperations._parse_protobuf_events(span.events, logger) - - return Span( - trace_id=trace_id, - span_id=span_id, - name=span.name, - context={ - SpanContextFieldName.TRACE_ID: trace_id, - SpanContextFieldName.SPAN_ID: span_id, - SpanContextFieldName.TRACE_STATE: span.trace_state, - }, - kind=span.kind, - parent_id=parent_id if parent_id else None, - start_time=convert_time_unix_nano_to_timestamp(span.start_time_unix_nano), - end_time=convert_time_unix_nano_to_timestamp(span.end_time_unix_nano), - status={ - SpanStatusFieldName.STATUS_CODE: parse_otel_span_status_code(span.status.code), - SpanStatusFieldName.DESCRIPTION: span.status.message, - }, - attributes=attributes, - links=links, - events=events, - resource=resource, - ) - def get_event(self, event_id: str) -> typing.Dict: return Event.get(event_id=event_id) diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py index be0e213f4a4..53c77bc7311 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py @@ -33,7 +33,8 @@ from promptflow._sdk._load_functions import load_flow, load_run from promptflow._sdk._orchestrator.utils import SubmitterHelper from promptflow._sdk._run_functions import create_yaml_run -from promptflow._sdk._utils import _get_additional_includes, parse_otel_span_status_code +from promptflow._sdk._utils import _get_additional_includes +from promptflow._sdk._utils.tracing import _parse_otel_span_status_code from promptflow._sdk.entities import Run from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations from promptflow._utils.context_utils import _change_working_dir, inject_sys_path @@ -1409,7 +1410,7 @@ def test_flex_flow_with_local_imported_func(self, pf): def test_flex_flow_with_imported_func(self, pf): # run eager flow against a function from module run = pf.run( - flow=parse_otel_span_status_code, + flow=_parse_otel_span_status_code, data=f"{DATAS_DIR}/simple_eager_flow_data.jsonl", # set code folder to avoid snapshot too big code=f"{EAGER_FLOWS_DIR}/multiple_entries", diff --git a/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_trace.py b/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_trace.py index 0f8c6577e45..0ab655fdcfe 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_trace.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/unittests/test_trace.py @@ -33,8 +33,7 @@ ContextAttributeKey, ) from promptflow._sdk._tracing import start_trace_with_devkit -from promptflow._sdk._tracing_utils import WorkspaceKindLocalCache, append_conditions -from promptflow._sdk.operations._trace_operations import TraceOperations +from promptflow._sdk._utils.tracing import WorkspaceKindLocalCache, append_conditions, parse_protobuf_span from promptflow.client import PFClient from promptflow.exceptions import UserErrorException from promptflow.tracing._operation_context import OperationContext @@ -150,7 +149,7 @@ def test_trace_without_attributes_collection(self, mock_resource: Dict) -> None: pb_span.parent_span_id = base64.b64decode("C+++WS+OuxI=") pb_span.kind = PBSpan.SpanKind.SPAN_KIND_INTERNAL # below line should execute successfully - span = TraceOperations._parse_protobuf_span(pb_span, resource=mock_resource, logger=logging.getLogger(__name__)) + span = parse_protobuf_span(pb_span, resource=mock_resource, logger=logging.getLogger(__name__)) # as the above span do not have any attributes, so the parsed span should not have any attributes assert isinstance(span.attributes, dict) assert len(span.attributes) == 0 @@ -265,7 +264,7 @@ def test_no_cache(self): # mock `WorkspaceKindLocalCache._get_workspace_kind_from_azure` mock_kind = str(uuid.uuid4()) with patch( - "promptflow._sdk._tracing_utils.WorkspaceKindLocalCache._get_workspace_kind_from_azure" + "promptflow._sdk._utils.tracing.WorkspaceKindLocalCache._get_workspace_kind_from_azure" ) as mock_get_kind: mock_get_kind.return_value = mock_kind assert ws_local_cache.get_kind() == mock_kind @@ -306,7 +305,7 @@ def test_expired_cache(self): # mock `WorkspaceKindLocalCache._get_workspace_kind_from_azure` kind = str(uuid.uuid4()) with patch( - "promptflow._sdk._tracing_utils.WorkspaceKindLocalCache._get_workspace_kind_from_azure" + "promptflow._sdk._utils.tracing.WorkspaceKindLocalCache._get_workspace_kind_from_azure" ) as mock_get_kind: mock_get_kind.return_value = kind assert ws_local_cache.get_kind() == kind From 600d0bfee9f6a3236ca0fd26aa0e68b6b551b6a4 Mon Sep 17 00:00:00 2001 From: zhen Date: Thu, 25 Apr 2024 18:07:36 +0800 Subject: [PATCH 2/6] [changelog] Add prompty to changelog (#3003) # Description Please add an informative description that covers that changes made by the pull request and link all relevant issues. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- src/promptflow-core/CHANGELOG.md | 5 +++++ src/promptflow-devkit/CHANGELOG.md | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/promptflow-core/CHANGELOG.md b/src/promptflow-core/CHANGELOG.md index 5d5c9c645f4..b5762a04501 100644 --- a/src/promptflow-core/CHANGELOG.md +++ b/src/promptflow-core/CHANGELOG.md @@ -1,6 +1,11 @@ # promptflow-core package ## v1.10.0 (Upcoming) + +### Features Added +- Add prompty feature to simplify the development of prompt templates for customers, reach [here](https://microsoft.github.io/promptflow/how-to-guides/develop-a-prompty/index.html) for more details. + +### Others - Add fastapi serving engine support. ## v1.9.0 (2024.04.17) diff --git a/src/promptflow-devkit/CHANGELOG.md b/src/promptflow-devkit/CHANGELOG.md index 465076bc662..1f2f0da1275 100644 --- a/src/promptflow-devkit/CHANGELOG.md +++ b/src/promptflow-devkit/CHANGELOG.md @@ -7,6 +7,8 @@ - The `pf config set ` support set the folder where the config is saved by `--path config_folder` parameter, and the config will take effect when **os.getcwd** is a subdirectory of the specified folder. - Local serving container support using fastapi engine and tuning worker/thread num via environment variables, reach [here](https://microsoft.github.io/promptflow/how-to-guides/deploy-a-flow/deploy-using-docker.html) for more details. +- Prompty supports to flow test and batch run, reach [here](https://microsoft.github.io/promptflow/how-to-guides/develop-a-prompty/index.html#testing-prompty) for more details. + ## v1.9.0 (2024.04.17) From 6399905a0f4073e222d5b3c4a9c6707be649681f Mon Sep 17 00:00:00 2001 From: Honglin Date: Thu, 25 Apr 2024 18:39:23 +0800 Subject: [PATCH 3/6] [SDK/CLI] Write instance_results.jsonl path in run properties (#3014) # Description Please add an informative description that covers that changes made by the pull request and link all relevant issues. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../e2etests/test_run_upload.py | 27 +++++-------------- .../promptflow/_sdk/_constants.py | 2 +- .../promptflow/_sdk/entities/_run.py | 6 ++++- 3 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/promptflow-azure/tests/sdk_cli_azure_test/e2etests/test_run_upload.py b/src/promptflow-azure/tests/sdk_cli_azure_test/e2etests/test_run_upload.py index 699517ea992..74600c144a9 100644 --- a/src/promptflow-azure/tests/sdk_cli_azure_test/e2etests/test_run_upload.py +++ b/src/promptflow-azure/tests/sdk_cli_azure_test/e2etests/test_run_upload.py @@ -51,6 +51,7 @@ def check_local_to_cloud_run(pf: PFClient, run: Run, check_run_details_in_cloud: assert cloud_run.properties["azureml.promptflow.local_to_cloud"] == "true" assert cloud_run.properties["azureml.promptflow.snapshot_id"] assert cloud_run.properties[Local2CloudProperties.TOTAL_TOKENS] + assert cloud_run.properties[Local2CloudProperties.EVAL_ARTIFACTS] # if no description or tags, skip the check, since one could be {} but the other is None if run.description: @@ -74,12 +75,12 @@ def check_local_to_cloud_run(pf: PFClient, run: Run, check_run_details_in_cloud: "mock_set_headers_with_user_aml_token", "single_worker_thread_pool", "vcr_recording", + "mock_isinstance_for_mock_datastore", + "mock_get_azure_pf_client", + "mock_trace_destination_to_cloud", ) class TestFlowRunUpload: @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_run( self, pf: PFClient, @@ -103,9 +104,6 @@ def test_upload_run( Local2CloudTestHelper.check_local_to_cloud_run(pf, run, check_run_details_in_cloud=True) @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_flex_flow_run_with_yaml(self, pf: PFClient, randstr: Callable[[str], str]): name = randstr("flex_run_name_with_yaml_for_upload") local_pf = Local2CloudTestHelper.get_local_pf(name) @@ -125,9 +123,6 @@ def test_upload_flex_flow_run_with_yaml(self, pf: PFClient, randstr: Callable[[s Local2CloudTestHelper.check_local_to_cloud_run(pf, run) @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_flex_flow_run_without_yaml(self, pf: PFClient, randstr: Callable[[str], str]): name = randstr("flex_run_name_without_yaml_for_upload") local_pf = Local2CloudTestHelper.get_local_pf(name) @@ -148,9 +143,6 @@ def test_upload_flex_flow_run_without_yaml(self, pf: PFClient, randstr: Callable Local2CloudTestHelper.check_local_to_cloud_run(pf, run) @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_prompty_run(self, pf: PFClient, randstr: Callable[[str], str]): # currently prompty run is skipped for upload, this test should be finished without error name = randstr("prompty_run_name_for_upload") @@ -167,9 +159,6 @@ def test_upload_prompty_run(self, pf: PFClient, randstr: Callable[[str], str]): Local2CloudTestHelper.check_local_to_cloud_run(pf, run) @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_run_with_customized_run_properties(self, pf: PFClient, randstr: Callable[[str], str]): name = randstr("batch_run_name_for_upload_with_customized_properties") local_pf = Local2CloudTestHelper.get_local_pf(name) @@ -200,9 +189,6 @@ def test_upload_run_with_customized_run_properties(self, pf: PFClient, randstr: assert cloud_run.properties[Local2CloudUserProperties.EVAL_ARTIFACTS] == eval_artifacts @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures( - "mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client", "mock_trace_destination_to_cloud" - ) def test_upload_eval_run(self, pf: PFClient, randstr: Callable[[str], str]): main_run_name = randstr("main_run_name_for_test_upload_eval_run") local_pf = Local2CloudTestHelper.get_local_pf(main_run_name) @@ -216,8 +202,8 @@ def test_upload_eval_run(self, pf: PFClient, randstr: Callable[[str], str]): # run an evaluation run eval_run_name = randstr("eval_run_name_for_test_upload_eval_run") - local_lpf = Local2CloudTestHelper.get_local_pf(eval_run_name) - eval_run = local_lpf.run( + local_pf = Local2CloudTestHelper.get_local_pf(eval_run_name) + eval_run = local_pf.run( flow=f"{FLOWS_DIR}/simple_hello_world", data=f"{DATAS_DIR}/webClassification3.jsonl", run=main_run_name, @@ -229,7 +215,6 @@ def test_upload_eval_run(self, pf: PFClient, randstr: Callable[[str], str]): assert eval_run.properties["azureml.promptflow.variant_run_id"] == main_run_name @pytest.mark.skipif(condition=not pytest.is_live, reason="Bug - 3089145 Replay failed for test 'test_upload_run'") - @pytest.mark.usefixtures("mock_isinstance_for_mock_datastore", "mock_get_azure_pf_client") def test_upload_flex_flow_run_with_global_azureml(self, pf: PFClient, randstr: Callable[[str], str]): with patch("promptflow._sdk._configuration.Configuration.get_config", return_value="azureml"): name = randstr("flex_run_name_with_global_azureml_for_upload") diff --git a/src/promptflow-devkit/promptflow/_sdk/_constants.py b/src/promptflow-devkit/promptflow/_sdk/_constants.py index f4dc2ec4b24..a90ac6df468 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_constants.py +++ b/src/promptflow-devkit/promptflow/_sdk/_constants.py @@ -483,13 +483,13 @@ class Local2CloudProperties: """Run properties that server needs when uploading local run to cloud.""" TOTAL_TOKENS = "azureml.promptflow.total_tokens" + EVAL_ARTIFACTS = "_azureml.evaluate_artifacts" class Local2CloudUserProperties: """Run properties that user can specify when uploading local run to cloud.""" RUN_TYPE = "runType" - EVAL_ARTIFACTS = "_azureml.evaluate_artifacts" @staticmethod def get_all_values(): diff --git a/src/promptflow-devkit/promptflow/_sdk/entities/_run.py b/src/promptflow-devkit/promptflow/_sdk/entities/_run.py index bb91d25d3ca..ef593597fd4 100644 --- a/src/promptflow-devkit/promptflow/_sdk/entities/_run.py +++ b/src/promptflow-devkit/promptflow/_sdk/entities/_run.py @@ -711,7 +711,11 @@ def _to_rest_object_for_local_to_cloud(self, local_to_cloud_info: dict, variant_ # extract properties that needs to be passed to the request total_tokens = self.properties[FlowRunProperties.SYSTEM_METRICS].get("total_tokens", 0) - properties = {Local2CloudProperties.TOTAL_TOKENS: total_tokens} + properties = { + Local2CloudProperties.TOTAL_TOKENS: total_tokens, + # add instance_results.jsonl path to run properties, which is required by UI feature. + Local2CloudProperties.EVAL_ARTIFACTS: '[{"path": "instance_results.jsonl", "type": "table"}]', + } for property_key in Local2CloudUserProperties.get_all_values(): value = self.properties.get(property_key, None) if value is not None: From b4c9f37999d05d833806d59de26bb263b8a8b220 Mon Sep 17 00:00:00 2001 From: chjinche <49483542+chjinche@users.noreply.github.com> Date: Thu, 25 Apr 2024 19:08:32 +0800 Subject: [PATCH 4/6] [Bugfix] Support parsing chat prompt if role property has \r around colon (#3007) # Description ## Issue: UX may add extra '\r' to user input, which may throw confusing error to user because user does not write '\r' explicitly. - user input ![image](https://github.com/microsoft/promptflow/assets/49483542/727be9b3-a8d5-42fc-ab98-592816f85f91) - ux adding extra '\r' ![image](https://github.com/microsoft/promptflow/assets/49483542/2628a0d5-2cec-4bd1-a4ef-8149f05db51d) - confusing error ![image](https://github.com/microsoft/promptflow/assets/49483542/27bab56a-e318-47ae-b377-842061e4928d) ## Solution: Handling '\r' around role property colon. The same way as '\r' around role colon. # All Promptflow Contribution checklist: - [X] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [X] Title of the pull request is clear and informative. - [X] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [X] Pull request includes test coverage for the included changes. --- .../promptflow/tools/common.py | 4 +-- src/promptflow-tools/tests/test_common.py | 27 ++++++++++++++++++- .../tests/test_handle_openai_error.py | 1 + 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/promptflow-tools/promptflow/tools/common.py b/src/promptflow-tools/promptflow/tools/common.py index f957f55f6f6..1737c114328 100644 --- a/src/promptflow-tools/promptflow/tools/common.py +++ b/src/promptflow-tools/promptflow/tools/common.py @@ -222,7 +222,7 @@ def validate_tools(tools): def try_parse_name_and_content(role_prompt): # customer can add ## in front of name/content for markdown highlight. # and we still support name/content without ## prefix for backward compatibility. - pattern = r"\n*#{0,2}\s*name:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)" + pattern = r"\n*#{0,2}\s*name\s*:\s*\n+\s*(\S+)\s*\n*#{0,2}\s*content\s*:\s*\n?(.*)" match = re.search(pattern, role_prompt, re.DOTALL) if match: return match.group(1), match.group(2) @@ -232,7 +232,7 @@ def try_parse_name_and_content(role_prompt): def try_parse_tool_call_id_and_content(role_prompt): # customer can add ## in front of tool_call_id/content for markdown highlight. # and we still support tool_call_id/content without ## prefix for backward compatibility. - pattern = r"\n*#{0,2}\s*tool_call_id:\n+\s*(\S+)\s*\n*#{0,2}\s*content:\n?(.*)" + pattern = r"\n*#{0,2}\s*tool_call_id\s*:\s*\n+\s*(\S+)\s*\n*#{0,2}\s*content\s*:\s*\n?(.*)" match = re.search(pattern, role_prompt, re.DOTALL) if match: return match.group(1), match.group(2) diff --git a/src/promptflow-tools/tests/test_common.py b/src/promptflow-tools/tests/test_common.py index 78f30790b0c..97373babdaa 100644 --- a/src/promptflow-tools/tests/test_common.py +++ b/src/promptflow-tools/tests/test_common.py @@ -214,7 +214,10 @@ def test_success_parse_role_prompt(self, chat_str, images, image_detail, expecte ("\nsystem:\nname:\n\n content:\nfirst", [ {'role': 'system', 'content': 'name:\n\n content:\nfirst'}]), ("\nsystem:\nname:\n\n", [ - {'role': 'system', 'content': 'name:'}]) + {'role': 'system', 'content': 'name:'}]), + # portal may add extra \r to new line character. + ("function:\r\nname:\r\n AI\ncontent :\r\nfirst", [ + {'role': 'function', 'name': 'AI', 'content': 'first'}]), ], ) def test_parse_chat_with_name_in_role_prompt(self, chat_str, expected_result): @@ -240,6 +243,20 @@ def test_try_parse_chat_with_tools(self, example_prompt_template_with_tool, pars actual_result = parse_chat(example_prompt_template_with_tool) assert actual_result == parsed_chat_with_tools + @pytest.mark.parametrize( + "chat_str, expected_result", + [ + ("\n#tool:\n## tool_call_id:\nid \n content:\nfirst\n\n#user:\nsecond", [ + {'role': 'tool', 'tool_call_id': 'id', 'content': 'first'}, {'role': 'user', 'content': 'second'}]), + # portal may add extra \r to new line character. + ("\ntool:\ntool_call_id :\r\nid\n content:\r\n", [ + {'role': 'tool', 'tool_call_id': 'id', 'content': ''}]), + ], + ) + def test_parse_tool_call_id_and_content(self, chat_str, expected_result): + actual_result = parse_chat(chat_str) + assert actual_result == expected_result + @pytest.mark.parametrize("chunk, error_msg, success", [ (""" ## tool_calls: @@ -275,6 +292,14 @@ def test_try_parse_chat_with_tools(self, example_prompt_template_with_tool, pars "function": {"name": "func1", "arguments": ""} }] """, "", True), + # portal may add extra \r to new line character. + (""" + ## tool_calls:\r + [{ + "id": "tool_call_id", "type": "function", + "function": {"name": "func1", "arguments": ""} + }] + """, "", True), ]) def test_parse_tool_calls_for_assistant(self, chunk: str, error_msg: str, success: bool): last_message = {'role': 'assistant'} diff --git a/src/promptflow-tools/tests/test_handle_openai_error.py b/src/promptflow-tools/tests/test_handle_openai_error.py index 686a2c7a8a2..cfad9281161 100644 --- a/src/promptflow-tools/tests/test_handle_openai_error.py +++ b/src/promptflow-tools/tests/test_handle_openai_error.py @@ -261,6 +261,7 @@ def test_input_invalid_function_role_prompt(self, azure_open_ai_connection): ) assert "'name' is required if role is function," in exc_info.value.message + @pytest.mark.skip(reason="Skip temporarily because there is something issue with test AOAI resource response.") def test_completion_with_chat_model(self, azure_open_ai_connection): with pytest.raises(UserErrorException) as exc_info: completion(connection=azure_open_ai_connection, prompt="hello", deployment_name="gpt-35-turbo") From 40c054d56c1ecf3dde841f8f047f49d2f385a66a Mon Sep 17 00:00:00 2001 From: Xingzhi Zhang <37076709+elliotzh@users.noreply.github.com> Date: Thu, 25 Apr 2024 19:32:41 +0800 Subject: [PATCH 5/6] hotfix: extra string at the beginning of yaml in snapshot (#3006) # Description Fix a bug that !!omap tag will generated at the beginning of `flow.flex.yaml` in snapshot and break local to cloud # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [x] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [x] Pull request includes test coverage for the included changes. --- .cspell.json | 5 ++- .../promptflow/_utils/flow_utils.py | 5 ++- .../promptflow/_utils/utils.py | 45 +++++++++++++++++++ .../sdk_cli_test/e2etests/test_flow_run.py | 2 + 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/.cspell.json b/.cspell.json index 1d97d3c469d..b13c4312bc2 100644 --- a/.cspell.json +++ b/.cspell.json @@ -217,10 +217,11 @@ "dcid", "piezo", "Piezo", - "cmpop" + "cmpop", + "omap" ], "flagWords": [ "Prompt Flow" ], "allowCompoundWords": true -} \ No newline at end of file +} diff --git a/src/promptflow-core/promptflow/_utils/flow_utils.py b/src/promptflow-core/promptflow/_utils/flow_utils.py index 0962a675de8..58fb7839edc 100644 --- a/src/promptflow-core/promptflow/_utils/flow_utils.py +++ b/src/promptflow-core/promptflow/_utils/flow_utils.py @@ -21,7 +21,7 @@ ) from promptflow._core._errors import MetaFileNotFound, MetaFileReadError from promptflow._utils.logger_utils import LoggerFactory -from promptflow._utils.utils import strip_quotation +from promptflow._utils.utils import convert_ordered_dict_to_dict, strip_quotation from promptflow._utils.yaml_utils import dump_yaml, load_yaml from promptflow.contracts.flow import Flow as ExecutableFlow from promptflow.exceptions import ErrorTarget, UserErrorException, ValidationException @@ -157,7 +157,8 @@ def dump_flow_dag(flow_dag: dict, flow_path: Path): flow_dir, flow_filename = resolve_flow_path(flow_path, check_flow_exist=False) flow_path = flow_dir / flow_filename with open(flow_path, "w", encoding=DEFAULT_ENCODING) as f: - dump_yaml(flow_dag, f) + # directly dumping ordered dict will bring !!omap tag in yaml + dump_yaml(convert_ordered_dict_to_dict(flow_dag, remove_empty=False), f) return flow_path diff --git a/src/promptflow-core/promptflow/_utils/utils.py b/src/promptflow-core/promptflow/_utils/utils.py index 7af01b61774..26f52e3fabd 100644 --- a/src/promptflow-core/promptflow/_utils/utils.py +++ b/src/promptflow-core/promptflow/_utils/utils.py @@ -434,3 +434,48 @@ def strip_quotation(value): return value[1:-1] else: return value + + +def is_empty_target(obj: Optional[Dict]) -> bool: + """Determines if it's empty target + + :param obj: The object to check + :type obj: Optional[Dict] + :return: True if obj is None or an empty Dict + :rtype: bool + """ + return ( + obj is None + # some objs have overloaded "==" and will cause error. e.g CommandComponent obj + or (isinstance(obj, dict) and len(obj) == 0) + ) + + +def convert_ordered_dict_to_dict(target_object: Union[Dict, List], remove_empty: bool = True) -> Union[Dict, List]: + """Convert ordered dict to dict. Remove keys with None value. + This is a workaround for rest request must be in dict instead of + ordered dict. + + :param target_object: The object to convert + :type target_object: Union[Dict, List] + :param remove_empty: Whether to omit values that are None or empty dictionaries. Defaults to True. + :type remove_empty: bool + :return: Converted ordered dict with removed None values + :rtype: Union[Dict, List] + """ + # OrderedDict can appear nested in a list + if isinstance(target_object, list): + new_list = [] + for item in target_object: + item = convert_ordered_dict_to_dict(item, remove_empty=remove_empty) + if not is_empty_target(item) or not remove_empty: + new_list.append(item) + return new_list + if isinstance(target_object, dict): + new_dict = {} + for key, value in target_object.items(): + value = convert_ordered_dict_to_dict(value, remove_empty=remove_empty) + if not is_empty_target(value) or not remove_empty: + new_dict[key] = value + return new_dict + return target_object diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py index 53c77bc7311..3539d113d0f 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_run.py @@ -1376,6 +1376,8 @@ def test_flex_flow_run( yaml_dict = load_yaml(local_storage._dag_path) assert yaml_dict == expected_snapshot_yaml + assert not local_storage._dag_path.read_text().startswith("!!omap") + # actual result will be entry2:my_flow2 details = pf.get_details(run.name) # convert DataFrame to dict From 935176fa65b9959deeda44d0f3fffa0900503652 Mon Sep 17 00:00:00 2001 From: Xingzhi Zhang <37076709+elliotzh@users.noreply.github.com> Date: Thu, 25 Apr 2024 20:07:11 +0800 Subject: [PATCH 6/6] fix: csharp executor proxy ci failure (#3018) # Description fix 2 test on Windows: promptflow-devkit/tests/sdk_cli_test/e2etests/test_csharp_sdk.py - test_destroy_with_terminates_gracefully - test_destroy_with_force_kill # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../tests/sdk_cli_test/e2etests/test_csharp_sdk.py | 8 ++++++-- .../unittests/batch/test_csharp_executor_proxy.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_csharp_sdk.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_csharp_sdk.py index 0d66d21bd2e..0261997a2df 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_csharp_sdk.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_csharp_sdk.py @@ -31,7 +31,11 @@ class TestCSharpSdk: "language": {"default": "chinese", "type": "string"}, "topic": {"default": "ocean", "type": "string"}, }, - "outputs": {"output": {"type": "object"}}, + "outputs": { + "Answer": {"type": "string"}, + "AnswerLength": {"type": "int"}, + "PoemLanguage": {"type": "string"}, + }, }, id="function_mode_basic", ), @@ -39,7 +43,7 @@ class TestCSharpSdk: { "init": {"connection": {"type": "AzureOpenAIConnection"}, "name": {"type": "string"}}, "inputs": {"question": {"default": "What is Promptflow?", "type": "string"}}, - "outputs": {"output": {"type": "object"}}, + "outputs": {"output": {"type": "string"}}, }, id="class_init_flex_flow", ), diff --git a/src/promptflow/tests/executor/unittests/batch/test_csharp_executor_proxy.py b/src/promptflow/tests/executor/unittests/batch/test_csharp_executor_proxy.py index dcc22683270..0e36f198ca0 100644 --- a/src/promptflow/tests/executor/unittests/batch/test_csharp_executor_proxy.py +++ b/src/promptflow/tests/executor/unittests/batch/test_csharp_executor_proxy.py @@ -1,4 +1,6 @@ import json +import platform +import signal import socket import subprocess from pathlib import Path @@ -62,7 +64,10 @@ async def test_destroy_with_terminates_gracefully(self): await executor_proxy.destroy() mock_process.poll.assert_called_once() - mock_process.terminate.assert_called_once() + if platform.system() != "Windows": + mock_process.terminate.assert_called_once() + else: + mock_process.send_signal.assert_called_once_with(signal.CTRL_BREAK_EVENT) mock_process.wait.assert_called_once_with(timeout=5) mock_process.kill.assert_not_called() @@ -77,7 +82,10 @@ async def test_destroy_with_force_kill(self): await executor_proxy.destroy() mock_process.poll.assert_called_once() - mock_process.terminate.assert_called_once() + if platform.system() != "Windows": + mock_process.terminate.assert_called_once() + else: + mock_process.send_signal.assert_called_once_with(signal.CTRL_BREAK_EVENT) mock_process.wait.assert_called_once_with(timeout=5) mock_process.kill.assert_called_once()