diff --git a/.cspell.json b/.cspell.json
index 13d6b357859..615002e1d61 100644
--- a/.cspell.json
+++ b/.cspell.json
@@ -126,6 +126,7 @@
"tcsetattr",
"pysqlite",
"AADSTS700082",
+ "levelno",
"Mobius"
],
"allowCompoundWords": true
diff --git a/README.md b/README.md
index 203caf93ebe..8a07452b501 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# Prompt flow
[![Python package](https://img.shields.io/pypi/v/promptflow)](https://pypi.org/project/promptflow/)
-[![Python](https://img.shields.io/pypi/pyversions/promptflow.svg?maxAge=2592000)](https://pypi.python.org/pypi/promptflow/)
+[![Python](https://img.shields.io/pypi/pyversions/promptflow.svg?maxAge=2592000)](https://pypi.python.org/pypi/promptflow/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/promptflow)](https://pypi.org/project/promptflow/)
[![CLI](https://img.shields.io/badge/CLI-reference-blue)](https://microsoft.github.io/promptflow/reference/pf-command-reference.html)
[![vsc extension](https://img.shields.io/visual-studio-marketplace/i/prompt-flow.prompt-flow?logo=Visual%20Studio&label=Extension%20)](https://marketplace.visualstudio.com/items?itemName=prompt-flow.prompt-flow)
@@ -85,9 +85,9 @@ Prompt Flow is a tool designed to **build high quality LLM apps**, the developme
### Develop your own LLM apps
-#### VS Code Extension
+#### VS Code Extension
-We also offer a VS Code extension (a flow designer) for an interactive flow development experience with UI.
+We also offer a VS Code extension (a flow designer) for an interactive flow development experience with UI.
@@ -139,6 +139,25 @@ For more information see the
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com)
with any additional questions or comments.
+## Data Collection
+
+The software may collect information about you and your use of the software and
+send it to Microsoft if configured to enable telemetry.
+Microsoft may use this information to provide services and improve our products and services.
+You may turn on the telemetry as described in the repository.
+There are also some features in the software that may enable you and Microsoft
+to collect data from users of your applications. If you use these features, you
+must comply with applicable law, including providing appropriate notices to
+users of your applications together with a copy of Microsoft's privacy
+statement. Our privacy statement is located at
+https://go.microsoft.com/fwlink/?LinkID=824704. You can learn more about data
+collection and use in the help documentation and our privacy statement. Your
+use of the software operates as your consent to these practices.
+
+### Telemetry Configuration
+
+Telemetry collection is off by default. To opt in, please run `pf config set cli.telemetry_enabled=true` to turn it on.
+
## License
Copyright (c) Microsoft Corporation. All rights reserved.
diff --git a/scripts/tool/utils/tool_utils.py b/scripts/tool/utils/tool_utils.py
index 999b44d0d3e..1afdcdf6cf9 100644
--- a/scripts/tool/utils/tool_utils.py
+++ b/scripts/tool/utils/tool_utils.py
@@ -30,6 +30,7 @@ def resolve_annotation(anno) -> Union[str, list]:
def param_to_definition(param, value_type) -> (InputDefinition, bool):
default_value = param.default
enum = None
+ custom_type = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
@@ -39,17 +40,32 @@ def param_to_definition(param, value_type) -> (InputDefinition, bool):
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
- typ = [value_type.__name__]
+ if ConnectionType.is_custom_strong_type(value_type):
+ typ = ["CustomConnection"]
+ custom_type = [value_type.__name__]
+ else:
+ typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
- typ = [t.__name__ for t in value_type]
+ custom_connection_added = False
+ typ = []
+ custom_type = []
+ for t in value_type:
+ if ConnectionType.is_custom_strong_type(t):
+ if not custom_connection_added:
+ custom_connection_added = True
+ typ.append("CustomConnection")
+ custom_type.append(t.__name__)
+ else:
+ typ.append(t.__name__)
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
- return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection
+ return InputDefinition(type=typ, default=value_to_str(default_value),
+ description=None, enum=enum, custom_type=custom_type), is_connection
def function_to_interface(f: Callable, tool_type, initialize_inputs=None) -> tuple:
diff --git a/src/promptflow/CHANGELOG.md b/src/promptflow/CHANGELOG.md
index 9ab9b87bab3..30cede90d26 100644
--- a/src/promptflow/CHANGELOG.md
+++ b/src/promptflow/CHANGELOG.md
@@ -7,6 +7,7 @@
- **pf flow validate**: support validate flow
- **pf config set**: support set user-level promptflow config.
- Support workspace connection provider, usage: `pf config set connection.provider=azureml:/subscriptions//resourceGroups//providers/Microsoft.MachineLearningServices/workspaces/`
+- **Telemetry**: enable telemetry and won't collect by default, use `pf config set cli.telemetry_enabled=true` to opt in.
### Bugs Fixed
- [Flow build] Fix flow build file name and environment variable name when connection name contains space.
@@ -14,12 +15,14 @@
- Read/write log file with encoding specified.
- Avoid inconsistent error message when executor exits abnormally.
- Align inputs & outputs row number in case partial completed run will break `pfazure run show-details`.
+- Fix bug that failed to parse portal url for run data when the form is an asset id.
### Improvements
- [Executor][Internal] Improve error message with more details and actionable information.
- [SDK/CLI] `pf/pfazure run show-details`:
- Add `--max-results` option to control the number of results to display.
- Add `--all-results` option to display all results.
+- Add validation for azure `PFClient` constructor in case wrong parameter is passed.
## 0.1.0b6 (2023.09.15)
diff --git a/src/promptflow/promptflow/_constants.py b/src/promptflow/promptflow/_constants.py
index 3aaca4ba1e5..faaae875e27 100644
--- a/src/promptflow/promptflow/_constants.py
+++ b/src/promptflow/promptflow/_constants.py
@@ -12,3 +12,4 @@
SERPAPI_API_KEY = "serpapi-api-key"
CONTENT_SAFETY_API_KEY = "content-safety-api-key"
ERROR_RESPONSE_COMPONENT_NAME = "promptflow"
+EXTENSION_UA = "prompt-flow-extension"
diff --git a/src/promptflow/promptflow/_core/_errors.py b/src/promptflow/promptflow/_core/_errors.py
index d191b5c2473..f92b1d360f4 100644
--- a/src/promptflow/promptflow/_core/_errors.py
+++ b/src/promptflow/promptflow/_core/_errors.py
@@ -24,12 +24,15 @@ class PackageToolNotFoundError(ValidationException):
pass
-class LoadToolError(ValidationException):
+class MissingRequiredInputs(ValidationException):
pass
-class MissingRequiredInputs(LoadToolError):
- pass
+class ToolLoadError(UserErrorException):
+ """Exception raised when tool load failed."""
+
+ def __init__(self, module: str = None, **kwargs):
+ super().__init__(target=ErrorTarget.TOOL, module=module, **kwargs)
class ToolExecutionError(UserErrorException):
diff --git a/src/promptflow/promptflow/_core/tools_manager.py b/src/promptflow/promptflow/_core/tools_manager.py
index 1062622811b..a8a2abd50a2 100644
--- a/src/promptflow/promptflow/_core/tools_manager.py
+++ b/src/promptflow/promptflow/_core/tools_manager.py
@@ -14,7 +14,7 @@
import pkg_resources
import yaml
-from promptflow._core._errors import MissingRequiredInputs, NotSupported, PackageToolNotFoundError
+from promptflow._core._errors import MissingRequiredInputs, NotSupported, PackageToolNotFoundError, ToolLoadError
from promptflow._core.tool_meta_generator import (
_parse_tool_from_function,
collect_tool_function_in_module,
@@ -213,8 +213,18 @@ def _load_package_tool(tool_name, module_name, class_name, method_name, node_inp
message=f"Required inputs {list(missing_inputs)} are not provided for tool '{tool_name}'.",
target=ErrorTarget.EXECUTOR,
)
+ try:
+ api = getattr(provider_class(**init_inputs_values), method_name)
+ except Exception as ex:
+ error_type_and_message = f"({ex.__class__.__name__}) {ex}"
+ raise ToolLoadError(
+ module=module_name,
+ message_format="Failed to load package tool '{tool_name}': {error_type_and_message}",
+ tool_name=tool_name,
+ error_type_and_message=error_type_and_message,
+ ) from ex
# Return the init_inputs to update node inputs in the afterward steps
- return getattr(provider_class(**init_inputs_values), method_name), init_inputs
+ return api, init_inputs
@staticmethod
def load_tool_by_api_name(api_name: str) -> Tool:
diff --git a/src/promptflow/promptflow/_sdk/_configuration.py b/src/promptflow/promptflow/_sdk/_configuration.py
index 921393976f0..92489272fd8 100644
--- a/src/promptflow/promptflow/_sdk/_configuration.py
+++ b/src/promptflow/promptflow/_sdk/_configuration.py
@@ -13,7 +13,7 @@
from promptflow._sdk._constants import LOGGER_NAME, ConnectionProvider
from promptflow._sdk._logger_factory import LoggerFactory
-from promptflow._sdk._utils import dump_yaml, load_yaml
+from promptflow._sdk._utils import call_from_extension, dump_yaml, load_yaml
from promptflow.exceptions import ErrorTarget, ValidationException
logger = LoggerFactory.get_logger(name=LOGGER_NAME, verbosity=logging.WARNING)
@@ -26,7 +26,8 @@ class ConfigFileNotFound(ValidationException):
class Configuration(object):
CONFIG_PATH = Path.home() / ".promptflow" / "pf.yaml"
- COLLECT_TELEMETRY = "cli.collect_telemetry"
+ COLLECT_TELEMETRY = "cli.telemetry_enabled"
+ EXTENSION_COLLECT_TELEMETRY = "extension.telemetry_enabled"
INSTALLATION_ID = "cli.installation_id"
CONNECTION_PROVIDER = "connection.provider"
_instance = None
@@ -41,6 +42,10 @@ def __init__(self):
if not self._config:
self._config = {}
+ @property
+ def config(self):
+ return self._config
+
@classmethod
def get_instance(cls):
"""Use this to get instance to avoid multiple copies of same global config."""
@@ -142,6 +147,8 @@ def get_connection_provider(self) -> Optional[str]:
def get_telemetry_consent(self) -> Optional[bool]:
"""Get the current telemetry consent value. Return None if not configured."""
+ if call_from_extension():
+ return self.get_config(key=self.EXTENSION_COLLECT_TELEMETRY)
return self.get_config(key=self.COLLECT_TELEMETRY)
def set_telemetry_consent(self, value):
diff --git a/src/promptflow/promptflow/_sdk/_constants.py b/src/promptflow/promptflow/_sdk/_constants.py
index beabc3d3b75..ad529a6fbf4 100644
--- a/src/promptflow/promptflow/_sdk/_constants.py
+++ b/src/promptflow/promptflow/_sdk/_constants.py
@@ -61,14 +61,20 @@ class CustomStrongTypeConnectionConfigs:
PREFIX = "promptflow.connection."
TYPE = "custom_type"
MODULE = "module"
+ PACKAGE = "package"
+ PACKAGE_VERSION = "package_version"
PROMPTFLOW_TYPE_KEY = PREFIX + TYPE
PROMPTFLOW_MODULE_KEY = PREFIX + MODULE
+ PROMPTFLOW_PACKAGE_KEY = PREFIX + PACKAGE
+ PROMPTFLOW_PACKAGE_VERSION_KEY = PREFIX + PACKAGE_VERSION
@staticmethod
def is_custom_key(key):
return key not in [
CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY,
CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY,
+ CustomStrongTypeConnectionConfigs.PROMPTFLOW_PACKAGE_KEY,
+ CustomStrongTypeConnectionConfigs.PROMPTFLOW_PACKAGE_VERSION_KEY,
]
@@ -291,3 +297,7 @@ class ConnectionProvider(str, Enum):
LOCAL_SERVICE_PORT = 5000
BULK_RUN_LINE_ERRORS = "BulkRunLineErrors"
+
+RUN_MACRO = "${run}"
+VARIANT_ID_MACRO = "${variant_id}"
+TIMESTAMP_MACRO = "${timestamp}"
diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py
index 7795c560e69..2b5e960765b 100644
--- a/src/promptflow/promptflow/_sdk/_utils.py
+++ b/src/promptflow/promptflow/_sdk/_utils.py
@@ -3,12 +3,14 @@
# ---------------------------------------------------------
import collections
+import hashlib
import json
import logging
import multiprocessing
import os
import re
import shutil
+import sys
import tempfile
import zipfile
from contextlib import contextmanager
@@ -27,12 +29,14 @@
from marshmallow import ValidationError
import promptflow
+from promptflow._constants import EXTENSION_UA
from promptflow._core.tool_meta_generator import generate_tool_meta_dict_by_file
from promptflow._sdk._constants import (
DAG_FILE_NAME,
DEFAULT_ENCODING,
FLOW_TOOLS_JSON,
FLOW_TOOLS_JSON_GEN_TIMEOUT,
+ HOME_PROMPT_FLOW_DIR,
KEYRING_ENCRYPTION_KEY_NAME,
KEYRING_ENCRYPTION_LOCK_PATH,
KEYRING_SYSTEM,
@@ -695,14 +699,33 @@ def process_node(_node, _node_path):
return flow_tools
-def setup_user_agent_to_operation_context(user_agent):
+def update_user_agent_from_env_var():
+ """Update user agent from env var to OperationContext"""
from promptflow._core.operation_context import OperationContext
if "USER_AGENT" in os.environ:
# Append vscode or other user agent from env
OperationContext.get_instance().append_user_agent(os.environ["USER_AGENT"])
+
+
+def setup_user_agent_to_operation_context(user_agent):
+ """Setup user agent to OperationContext"""
+ from promptflow._core.operation_context import OperationContext
+
+ update_user_agent_from_env_var()
# Append user agent
- OperationContext.get_instance().append_user_agent(user_agent)
+ context = OperationContext.get_instance()
+ context.append_user_agent(user_agent)
+ return context.get_user_agent()
+
+
+def call_from_extension() -> bool:
+ """Return true if current request is from extension."""
+ from promptflow._core.operation_context import OperationContext
+
+ update_user_agent_from_env_var()
+ context = OperationContext().get_instance()
+ return EXTENSION_UA in context.get_user_agent()
def generate_random_string(length: int = 6) -> str:
@@ -750,3 +773,36 @@ def get_local_connections_from_executable(executable):
# ignore when connection not found since it can be configured with env var.
raise Exception(f"Connection {n!r} required for flow {executable.name!r} is not found.")
return result
+
+
+def _generate_connections_dir():
+ # Get Python executable path
+ python_path = sys.executable
+
+ # Hash the Python executable path
+ hash_object = hashlib.sha1(python_path.encode())
+ hex_dig = hash_object.hexdigest()
+
+ # Generate the connections system path using the hash
+ connections_dir = (HOME_PROMPT_FLOW_DIR / "envs" / hex_dig / "connections").resolve()
+ return connections_dir
+
+
+# This function is used by extension to generate the connection files every time collect tools.
+def refresh_connections_dir(connection_spec_files, connection_template_yamls):
+ connections_dir = _generate_connections_dir()
+ if os.path.isdir(connections_dir):
+ shutil.rmtree(connections_dir)
+ os.makedirs(connections_dir)
+
+ if connection_spec_files and connection_template_yamls:
+ for connection_name, content in connection_spec_files.items():
+ file_name = connection_name + ".spec.json"
+ with open(connections_dir / file_name, "w") as f:
+ json.dump(content, f, indent=2)
+
+ for connection_name, content in connection_template_yamls.items():
+ yaml_data = yaml.safe_load(content)
+ file_name = connection_name + ".template.yaml"
+ with open(connections_dir / file_name, "w") as f:
+ yaml.dump(yaml_data, f, sort_keys=False)
diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py
index 015becfb2dc..987987367ea 100644
--- a/src/promptflow/promptflow/_sdk/entities/_connection.py
+++ b/src/promptflow/promptflow/_sdk/entities/_connection.py
@@ -669,6 +669,33 @@ def __init__(
super().__init__(configs=configs, secrets=secrets, **kwargs)
self.module = kwargs.get("module", self.__class__.__module__)
self.custom_type = custom_type or self.__class__.__name__
+ self.package = kwargs.get(CustomStrongTypeConnectionConfigs.PACKAGE, None)
+ self.package_version = kwargs.get(CustomStrongTypeConnectionConfigs.PACKAGE_VERSION, None)
+
+ def __getattribute__(self, item):
+ # Note: The reason to overwrite __getattribute__ instead of __getattr__ is as follows:
+ # Custom strong type connection is written this way:
+ # class MyCustomConnection(CustomStrongTypeConnection):
+ # api_key: Secret
+ # api_base: str = "This is a default value"
+ # api_base has a default value, my_custom_connection_instance.api_base would not trigger __getattr__.
+ # The default value will be returned directly instead of the real value in configs.
+ annotations = getattr(super().__getattribute__("__class__"), "__annotations__", {})
+ if item in annotations:
+ if annotations[item] == Secret:
+ return self.secrets[item]
+ else:
+ return self.configs[item]
+ return super().__getattribute__(item)
+
+ def __setattr__(self, key, value):
+ annotations = getattr(super().__getattribute__("__class__"), "__annotations__", {})
+ if key in annotations:
+ if annotations[key] == Secret:
+ self.secrets[key] = value
+ else:
+ self.configs[key] = value
+ return super().__setattr__(key, value)
def _to_orm_object(self) -> ORMConnection:
custom_connection = self._convert_to_custom()
@@ -678,6 +705,11 @@ def _convert_to_custom(self):
# update configs
self.configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY: self.custom_type})
self.configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY: self.module})
+ if self.package and self.package_version:
+ self.configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_PACKAGE_KEY: self.package})
+ self.configs.update(
+ {CustomStrongTypeConnectionConfigs.PROMPTFLOW_PACKAGE_VERSION_KEY: self.package_version}
+ )
custom_connection = CustomConnection(configs=self.configs, secrets=self.secrets, **self.kwargs)
return custom_connection
diff --git a/src/promptflow/promptflow/_sdk/entities/_run.py b/src/promptflow/promptflow/_sdk/entities/_run.py
index 77f2ce5b9b7..3b385e44d86 100644
--- a/src/promptflow/promptflow/_sdk/entities/_run.py
+++ b/src/promptflow/promptflow/_sdk/entities/_run.py
@@ -14,6 +14,9 @@
from promptflow._sdk._constants import (
BASE_PATH_CONTEXT_KEY,
PARAMS_OVERRIDE_KEY,
+ RUN_MACRO,
+ TIMESTAMP_MACRO,
+ VARIANT_ID_MACRO,
AzureRunTypes,
FlowRunProperties,
RestRunTypes,
@@ -396,21 +399,23 @@ def _get_default_display_name(self) -> str:
def _format_display_name(self) -> str:
"""
- Format display name.
- For run without upstream run (variant run)
- the pattern is {client_run_display_name}-{variant_id}-{timestamp}
- For run with upstream run (variant run)
- the pattern is {upstream_run_display_name}-{client_run_display_name}-{timestamp}
+ Format display name. Replace macros in display name with actual values.
+ The following macros are supported: ${variant_id}, ${run}, ${timestamp}
+
+ For example,
+ if the display name is "run-${variant_id}-${timestamp}"
+ it will be formatted to "run-variant_1-20210901123456"
"""
display_name = self._get_default_display_name()
time_stamp = datetime.datetime.now().strftime("%Y%m%d%H%M")
if self.run:
- display_name = f"{self.run.display_name}-{display_name}-{time_stamp}"
- else:
- variant = self.variant
- variant = parse_variant(variant)[1] if variant else "default"
- display_name = f"{display_name}-{variant}-{time_stamp}"
+ display_name = display_name.replace(RUN_MACRO, self._validate_and_return_run_name(self.run))
+ display_name = display_name.replace(TIMESTAMP_MACRO, time_stamp)
+ variant = self.variant
+ variant = parse_variant(variant)[1] if variant else "default"
+ display_name = display_name.replace(VARIANT_ID_MACRO, variant)
+
return display_name
def _get_flow_dir(self) -> Path:
diff --git a/src/promptflow/promptflow/_sdk/operations/_connection_operations.py b/src/promptflow/promptflow/_sdk/operations/_connection_operations.py
index 05a60b57298..37413760668 100644
--- a/src/promptflow/promptflow/_sdk/operations/_connection_operations.py
+++ b/src/promptflow/promptflow/_sdk/operations/_connection_operations.py
@@ -8,11 +8,17 @@
from promptflow._sdk._orm import Connection as ORMConnection
from promptflow._sdk._utils import safe_parse_object_list
from promptflow._sdk.entities._connection import _Connection
+from promptflow._telemetry.activity import ActivityType, monitor_operation
+from promptflow._telemetry.telemetry import TelemetryMixin
-class ConnectionOperations:
+class ConnectionOperations(TelemetryMixin):
"""ConnectionOperations."""
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ @monitor_operation(activity_name="pf.connections.list", activity_type=ActivityType.PUBLICAPI)
def list(
self,
max_results: int = MAX_LIST_CLI_RESULTS,
@@ -34,6 +40,7 @@ def list(
message_generator=lambda x: f"Failed to load connection {x.connectionName}, skipped.",
)
+ @monitor_operation(activity_name="pf.connections.get", activity_type=ActivityType.PUBLICAPI)
def get(self, name: str, **kwargs) -> _Connection:
"""Get a connection entity.
@@ -51,6 +58,7 @@ def get(self, name: str, **kwargs) -> _Connection:
return _Connection._from_orm_object_with_secrets(orm_connection)
return _Connection._from_orm_object(orm_connection)
+ @monitor_operation(activity_name="pf.connections.delete", activity_type=ActivityType.PUBLICAPI)
def delete(self, name: str) -> None:
"""Delete a connection entity.
@@ -59,6 +67,7 @@ def delete(self, name: str) -> None:
"""
ORMConnection.delete(name)
+ @monitor_operation(activity_name="pf.connections.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, connection: _Connection, **kwargs):
"""Create or update a connection.
diff --git a/src/promptflow/promptflow/_sdk/operations/_local_azure_connection_operations.py b/src/promptflow/promptflow/_sdk/operations/_local_azure_connection_operations.py
index bcc8aee91e1..b2acd066c43 100644
--- a/src/promptflow/promptflow/_sdk/operations/_local_azure_connection_operations.py
+++ b/src/promptflow/promptflow/_sdk/operations/_local_azure_connection_operations.py
@@ -8,6 +8,7 @@
from promptflow._sdk._constants import AZURE_WORKSPACE_REGEX_FORMAT, LOGGER_NAME, MAX_LIST_CLI_RESULTS
from promptflow._sdk._logger_factory import LoggerFactory
from promptflow._sdk.entities._connection import _Connection
+from promptflow._telemetry.activity import ActivityType, monitor_operation
logger = LoggerFactory.get_logger(name=LOGGER_NAME, verbosity=logging.WARNING)
@@ -30,7 +31,7 @@ def __init__(self, connection_provider):
@classmethod
def _extract_workspace(cls, connection_provider):
match = re.match(AZURE_WORKSPACE_REGEX_FORMAT, connection_provider)
- if not match:
+ if not match or len(match.groups()) != 5:
raise ValueError(
"Malformed connection provider string, expected azureml:/subscriptions//"
"resourceGroups//providers/Microsoft.MachineLearningServices/"
@@ -41,6 +42,7 @@ def _extract_workspace(cls, connection_provider):
workspace_name = match.group(5)
return subscription_id, resource_group, workspace_name
+ @monitor_operation(activity_name="pf.connections.azure.list", activity_type=ActivityType.PUBLICAPI)
def list(
self,
max_results: int = MAX_LIST_CLI_RESULTS,
@@ -57,6 +59,7 @@ def list(
)
return self._pfazure_client._connections.list()
+ @monitor_operation(activity_name="pf.connections.azure.get", activity_type=ActivityType.PUBLICAPI)
def get(self, name: str, **kwargs) -> _Connection:
"""Get a connection entity.
@@ -70,6 +73,7 @@ def get(self, name: str, **kwargs) -> _Connection:
return self._pfazure_client._arm_connections.get(name)
return self._pfazure_client._connections.get(name)
+ @monitor_operation(activity_name="pf.connections.azure.delete", activity_type=ActivityType.PUBLICAPI)
def delete(self, name: str) -> None:
"""Delete a connection entity.
@@ -81,6 +85,7 @@ def delete(self, name: str) -> None:
"please manage it in workspace portal, az ml cli or AzureML SDK."
)
+ @monitor_operation(activity_name="pf.connections.azure.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, connection: _Connection, **kwargs):
"""Create or update a connection.
diff --git a/src/promptflow/promptflow/_sdk/operations/_run_operations.py b/src/promptflow/promptflow/_sdk/operations/_run_operations.py
index 070d147a016..80369528eed 100644
--- a/src/promptflow/promptflow/_sdk/operations/_run_operations.py
+++ b/src/promptflow/promptflow/_sdk/operations/_run_operations.py
@@ -24,6 +24,8 @@
from promptflow._sdk._visualize_functions import dump_html, generate_html_string
from promptflow._sdk.entities import Run
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
+from promptflow._telemetry.activity import ActivityType, monitor_operation
+from promptflow._telemetry.telemetry import TelemetryMixin
from promptflow.contracts._run_management import RunMetadata, RunVisualization
RUNNING_STATUSES = RunStatus.get_running_statuses()
@@ -31,12 +33,13 @@
logger = logging.getLogger(LOGGER_NAME)
-class RunOperations:
+class RunOperations(TelemetryMixin):
"""RunOperations."""
- def __init__(self):
- pass
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ @monitor_operation(activity_name="pf.runs.list", activity_type=ActivityType.PUBLICAPI)
def list(
self,
max_results: Optional[int] = MAX_RUN_LIST_RESULTS,
@@ -59,6 +62,7 @@ def list(
message_generator=lambda x: f"Error parsing run {x.name!r}, skipped.",
)
+ @monitor_operation(activity_name="pf.runs.get", activity_type=ActivityType.PUBLICAPI)
def get(self, name: str) -> Run:
"""Get a run entity.
@@ -73,6 +77,7 @@ def get(self, name: str) -> Run:
except RunNotFoundError as e:
raise e
+ @monitor_operation(activity_name="pf.runs.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, run: Run, **kwargs) -> Run:
"""Create or update a run.
@@ -104,6 +109,7 @@ def _print_run_summary(self, run: Run) -> None:
f'Output path: "{run._output_path}"\n'
)
+ @monitor_operation(activity_name="pf.runs.stream", activity_type=ActivityType.PUBLICAPI)
def stream(self, name: Union[str, Run]) -> Run:
"""Stream run logs to the console.
@@ -137,6 +143,7 @@ def stream(self, name: Union[str, Run]) -> Run:
print(error_message)
return run
+ @monitor_operation(activity_name="pf.runs.archive", activity_type=ActivityType.PUBLICAPI)
def archive(self, name: Union[str, Run]) -> Run:
"""Archive a run.
@@ -149,6 +156,7 @@ def archive(self, name: Union[str, Run]) -> Run:
ORMRun.get(name).archive()
return self.get(name)
+ @monitor_operation(activity_name="pf.runs.restore", activity_type=ActivityType.PUBLICAPI)
def restore(self, name: Union[str, Run]) -> Run:
"""Restore a run.
@@ -161,6 +169,7 @@ def restore(self, name: Union[str, Run]) -> Run:
ORMRun.get(name).restore()
return self.get(name)
+ @monitor_operation(activity_name="pf.runs.update", activity_type=ActivityType.PUBLICAPI)
def update(
self,
name: Union[str, Run],
@@ -184,6 +193,7 @@ def update(
ORMRun.get(name).update(display_name=display_name, description=description, tags=tags, **kwargs)
return self.get(name)
+ @monitor_operation(activity_name="pf.runs.get_details", activity_type=ActivityType.PUBLICAPI)
def get_details(
self, name: Union[str, Run], max_results: int = MAX_SHOW_DETAILS_RESULTS, all_results: bool = False
) -> pd.DataFrame:
@@ -229,6 +239,7 @@ def get_details(
df = pd.DataFrame(data).head(max_results).reindex(columns=columns)
return df
+ @monitor_operation(activity_name="pf.runs.get_metrics", activity_type=ActivityType.PUBLICAPI)
def get_metrics(self, name: Union[str, Run]) -> Dict[str, Any]:
"""Get run metrics.
@@ -270,6 +281,7 @@ def _visualize(self, runs: List[Run], html_path: Optional[str] = None) -> None:
# if html_path is specified, not open it in webbrowser(as it comes from VSC)
dump_html(html_string, html_path=html_path, open_html=html_path is None)
+ @monitor_operation(activity_name="pf.runs.visualize", activity_type=ActivityType.PUBLICAPI)
def visualize(self, runs: Union[str, Run, List[str], List[Run]], **kwargs) -> None:
"""Visualize run(s).
diff --git a/src/promptflow/promptflow/_sdk/schemas/_connection.py b/src/promptflow/promptflow/_sdk/schemas/_connection.py
index cc868d87293..e81c8563737 100644
--- a/src/promptflow/promptflow/_sdk/schemas/_connection.py
+++ b/src/promptflow/promptflow/_sdk/schemas/_connection.py
@@ -122,6 +122,8 @@ class CustomStrongTypeConnectionSchema(CustomConnectionSchema):
name = fields.Str(attribute="name")
module = fields.Str(required=True)
custom_type = fields.Str(required=True)
+ package = fields.Str(required=True)
+ package_version = fields.Str(required=True)
# TODO: validate configs and secrets
@validates("configs")
diff --git a/src/promptflow/promptflow/_telemetry/__init__.py b/src/promptflow/promptflow/_telemetry/__init__.py
new file mode 100644
index 00000000000..d540fd20468
--- /dev/null
+++ b/src/promptflow/promptflow/_telemetry/__init__.py
@@ -0,0 +1,3 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
diff --git a/src/promptflow/promptflow/_telemetry/activity.py b/src/promptflow/promptflow/_telemetry/activity.py
new file mode 100644
index 00000000000..5672387fe76
--- /dev/null
+++ b/src/promptflow/promptflow/_telemetry/activity.py
@@ -0,0 +1,142 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import contextlib
+import functools
+import uuid
+from datetime import datetime
+
+from promptflow._telemetry.telemetry import TelemetryMixin
+
+
+class ActivityType(object):
+ """The type of activity (code) monitored.
+
+ The default type is "PublicAPI".
+ """
+
+ PUBLICAPI = "PublicApi" # incoming public API call (default)
+ INTERNALCALL = "InternalCall" # internal (function) call
+ CLIENTPROXY = "ClientProxy" # an outgoing service API call
+
+
+class ActivityCompletionStatus(object):
+ """The activity (code) completion status, success, or failure."""
+
+ SUCCESS = "Success"
+ FAILURE = "Failure"
+
+
+@contextlib.contextmanager
+def log_activity(
+ logger,
+ activity_name,
+ activity_type=ActivityType.INTERNALCALL,
+ custom_dimensions=None,
+):
+ """Log an activity.
+
+ An activity is a logical block of code that consumers want to monitor.
+ To monitor, wrap the logical block of code with the ``log_activity()`` method. As an alternative, you can
+ also use the ``@monitor_with_activity`` decorator.
+
+ :param logger: The logger adapter.
+ :type logger: logging.LoggerAdapter
+ :param activity_name: The name of the activity. The name should be unique per the wrapped logical code block.
+ :type activity_name: str
+ :param activity_type: One of PUBLICAPI, INTERNALCALL, or CLIENTPROXY which represent an incoming API call,
+ an internal (function) call, or an outgoing API call. If not specified, INTERNALCALL is used.
+ :type activity_type: str
+ :param custom_dimensions: The custom properties of the activity.
+ :type custom_dimensions: dict
+ :return: None
+ """
+ activity_info = {
+ # TODO(2699383): use same request id with service caller
+ "request_id": str(uuid.uuid4()),
+ "activity_name": activity_name,
+ "activity_type": activity_type,
+ }
+ custom_dimensions = custom_dimensions or {}
+ activity_info.update(custom_dimensions)
+
+ start_time = datetime.utcnow()
+ completion_status = ActivityCompletionStatus.SUCCESS
+
+ message = f"{activity_name}.start"
+ logger.info(message, extra={"custom_dimensions": activity_info})
+ exception = None
+
+ try:
+ yield logger
+ except BaseException as e: # pylint: disable=broad-except
+ exception = e
+ completion_status = ActivityCompletionStatus.FAILURE
+ finally:
+ try:
+ end_time = datetime.utcnow()
+ duration_ms = round((end_time - start_time).total_seconds() * 1000, 2)
+
+ activity_info["completion_status"] = completion_status
+ activity_info["duration_ms"] = duration_ms
+ message = f"{activity_name}.complete"
+ if exception:
+ logger.error(message, extra={"custom_dimensions": activity_info})
+ else:
+ logger.info(message, extra={"custom_dimensions": activity_info})
+ except Exception: # pylint: disable=broad-except
+ # skip if logger failed to log
+ pass # pylint: disable=lost-exception
+ # raise the exception to align with the behavior of the with statement
+ if exception:
+ raise exception
+
+
+def extract_telemetry_info(self):
+ """Extract pf telemetry info from given telemetry mix-in instance."""
+ result = {}
+ try:
+ if isinstance(self, TelemetryMixin):
+ return self._get_telemetry_values()
+ except Exception:
+ pass
+ return result
+
+
+def monitor_operation(
+ activity_name,
+ activity_type=ActivityType.INTERNALCALL,
+ custom_dimensions=None,
+):
+ """A wrapper for monitoring an activity in operations class.
+
+ To monitor, use the ``@monitor_operation`` decorator.
+ Note: this decorator should only use in operations class methods.
+
+ :param activity_name: The name of the activity. The name should be unique per the wrapped logical code block.
+ :type activity_name: str
+ :param activity_type: One of PUBLICAPI, INTERNALCALL, or CLIENTPROXY which represent an incoming API call,
+ an internal (function) call, or an outgoing API call. If not specified, INTERNALCALL is used.
+ :type activity_type: str
+ :param custom_dimensions: The custom properties of the activity.
+ :type custom_dimensions: dict
+ :return:
+ """
+ if not custom_dimensions:
+ custom_dimensions = {}
+
+ def monitor(f):
+ @functools.wraps(f)
+ def wrapper(self, *args, **kwargs):
+ from promptflow._telemetry.telemetry import get_telemetry_logger
+
+ logger = get_telemetry_logger()
+
+ custom_dimensions.update(extract_telemetry_info(self))
+
+ with log_activity(logger, activity_name, activity_type, custom_dimensions):
+ return f(self, *args, **kwargs)
+
+ return wrapper
+
+ return monitor
diff --git a/src/promptflow/promptflow/_telemetry/logging_handler.py b/src/promptflow/promptflow/_telemetry/logging_handler.py
new file mode 100644
index 00000000000..094a8a95578
--- /dev/null
+++ b/src/promptflow/promptflow/_telemetry/logging_handler.py
@@ -0,0 +1,84 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+import platform
+
+from opencensus.ext.azure.log_exporter import AzureEventHandler
+
+from promptflow._cli._user_agent import USER_AGENT
+from promptflow._sdk._configuration import Configuration
+
+# TODO: replace with prod app insights
+INSTRUMENTATION_KEY = "b4ff2b60-2f72-4a5f-b7a6-571318b50ab2"
+
+
+# cspell:ignore overriden
+def get_appinsights_log_handler():
+ """
+ Enable the OpenCensus logging handler for specified logger and instrumentation key to send info to AppInsights.
+ """
+ from promptflow._sdk._utils import setup_user_agent_to_operation_context
+ from promptflow._telemetry.telemetry import is_telemetry_enabled
+
+ try:
+ # TODO: use different instrumentation key for Europe
+ instrumentation_key = INSTRUMENTATION_KEY
+ config = Configuration.get_instance()
+ user_agent = setup_user_agent_to_operation_context(USER_AGENT)
+ custom_properties = {
+ "python_version": platform.python_version(),
+ "user_agent": user_agent,
+ "installation_id": config.get_or_set_installation_id(),
+ }
+
+ handler = PromptFlowSDKLogHandler(
+ connection_string=f"InstrumentationKey={instrumentation_key}",
+ custom_properties=custom_properties,
+ enable_telemetry=is_telemetry_enabled(),
+ )
+ return handler
+ except Exception: # pylint: disable=broad-except
+ # ignore any exceptions, telemetry collection errors shouldn't block an operation
+ return logging.NullHandler()
+
+
+# cspell:ignore AzureMLSDKLogHandler
+class PromptFlowSDKLogHandler(AzureEventHandler):
+ """Customized AzureLogHandler for PromptFlow SDK"""
+
+ def __init__(self, custom_properties, enable_telemetry, **kwargs):
+ super().__init__(**kwargs)
+
+ self._is_telemetry_enabled = enable_telemetry
+ self._custom_dimensions = custom_properties
+
+ def emit(self, record):
+ # skip logging if telemetry is disabled
+ if not self._is_telemetry_enabled:
+ return
+
+ try:
+ self._queue.put(record, block=False)
+
+ # log the record immediately if it is an error
+ if record.exc_info and not all(item is None for item in record.exc_info):
+ self._queue.flush()
+ except Exception: # pylint: disable=broad-except
+ # ignore any exceptions, telemetry collection errors shouldn't block an operation
+ return
+
+ def log_record_to_envelope(self, record):
+ # skip logging if telemetry is disabled
+ if not self._is_telemetry_enabled:
+ return
+ custom_dimensions = {
+ "level": record.levelname,
+ }
+ custom_dimensions.update(self._custom_dimensions)
+ if hasattr(record, "custom_dimensions") and isinstance(record.custom_dimensions, dict):
+ record.custom_dimensions.update(custom_dimensions)
+ else:
+ record.custom_dimensions = custom_dimensions
+
+ return super().log_record_to_envelope(record=record)
diff --git a/src/promptflow/promptflow/_telemetry/telemetry.py b/src/promptflow/promptflow/_telemetry/telemetry.py
new file mode 100644
index 00000000000..b568e1066ae
--- /dev/null
+++ b/src/promptflow/promptflow/_telemetry/telemetry.py
@@ -0,0 +1,51 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import logging
+import os
+
+from promptflow._sdk._configuration import Configuration
+from promptflow._telemetry.logging_handler import get_appinsights_log_handler
+
+TELEMETRY_ENABLED = "TELEMETRY_ENABLED"
+PROMPTFLOW_LOGGER_NAMESPACE = "promptflow._telemetry"
+
+
+class TelemetryMixin(object):
+ def __init__(self, **kwargs):
+ # Need to call init for potential parent, otherwise it won't be initialized.
+ super().__init__(**kwargs)
+
+ def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument
+ """Return the telemetry values of object.
+
+ :return: The telemetry values
+ :rtype: Dict
+ """
+ return {}
+
+
+def is_telemetry_enabled():
+ """Check if telemetry is enabled. User can enable telemetry by
+ 1. setting environment variable TELEMETRY_ENABLED to true.
+ 2. running `pf config set cli.collect_telemetry=true` command.
+ If None of the above is set, telemetry is disabled by default.
+ """
+ telemetry_enabled = os.getenv(TELEMETRY_ENABLED)
+ if telemetry_enabled is not None:
+ return str(telemetry_enabled).lower() == "true"
+ config = Configuration.get_instance()
+ telemetry_consent = config.get_telemetry_consent()
+ if telemetry_consent is not None:
+ return telemetry_consent
+ return False
+
+
+def get_telemetry_logger():
+ current_logger = logging.getLogger(PROMPTFLOW_LOGGER_NAMESPACE)
+ # avoid telemetry log appearing in higher level loggers
+ current_logger.propagate = False
+ current_logger.setLevel(logging.INFO)
+ handler = get_appinsights_log_handler()
+ current_logger.addHandler(handler)
+ return current_logger
diff --git a/src/promptflow/promptflow/_utils/exception_utils.py b/src/promptflow/promptflow/_utils/exception_utils.py
index 36da4c8de06..e696ba60df6 100644
--- a/src/promptflow/promptflow/_utils/exception_utils.py
+++ b/src/promptflow/promptflow/_utils/exception_utils.py
@@ -198,55 +198,47 @@ def build_debug_info(self, ex: Exception):
"innerException": inner_exception,
}
- def to_dict(self, *, include_debug_info=False):
- """Return a dict representation of the exception.
+ @property
+ def error_codes(self):
+ """The hierarchy of the error codes.
- This dict specification corresponds to the specification of the Microsoft API Guidelines:
+ We follow the "Microsoft REST API Guidelines" to define error codes in a hierarchy style.
+ See the below link for details:
https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#7102-error-condition-responses
- Note that this dict representation the "error" field in the response body of the API.
- The whole error response is then populated in another place outside of this class.
+ This method returns the error codes in a list. It will be converted into a nested json format by
+ error_code_recursed.
"""
- if isinstance(self._ex, JsonSerializedPromptflowException):
- return self._ex.to_dict(include_debug_info=include_debug_info)
-
- # Otherwise, return general dict representation of the exception.
- result = {
- "code": infer_error_code_from_class(SystemErrorException),
- "message": str(self._ex),
- "messageFormat": "",
- "messageParameters": {},
- "innerError": {
- "code": self._ex.__class__.__name__,
- "innerError": None,
- },
- }
- if include_debug_info:
- result["debugInfo"] = self.debug_info
-
- return result
+ return [infer_error_code_from_class(SystemErrorException), self._ex.__class__.__name__]
-
-class PromptflowExceptionPresenter(ExceptionPresenter):
@property
def error_code_recursed(self):
"""Returns a dict of the error codes for this exception.
It is populated in a recursive manner, using the source from `error_codes` property.
- i.e. For ToolExcutionError which inherits from UserErrorException,
+ i.e. For PromptflowException, such as ToolExcutionError which inherits from UserErrorException,
The result would be:
{
- "code": "UserErrorException",
+ "code": "UserError",
"innerError": {
"code": "ToolExecutionError",
"innerError": None,
},
}
+ For other exception types, such as ValueError, the result would be:
+
+ {
+ "code": "SystemError",
+ "innerError": {
+ "code": "ValueError",
+ "innerError": None,
+ },
+ }
"""
current_error = None
- reversed_error_codes = reversed(self._ex.error_codes) if self._ex.error_codes else []
+ reversed_error_codes = reversed(self.error_codes) if self.error_codes else []
for code in reversed_error_codes:
current_error = {
"code": code,
@@ -255,6 +247,53 @@ def error_code_recursed(self):
return current_error
+ def to_dict(self, *, include_debug_info=False):
+ """Return a dict representation of the exception.
+
+ This dict specification corresponds to the specification of the Microsoft API Guidelines:
+ https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#7102-error-condition-responses
+
+ Note that this dict represents the "error" field in the response body of the API.
+ The whole error response is then populated in another place outside of this class.
+ """
+ if isinstance(self._ex, JsonSerializedPromptflowException):
+ return self._ex.to_dict(include_debug_info=include_debug_info)
+
+ # Otherwise, return general dict representation of the exception.
+ result = {"message": str(self._ex), "messageFormat": "", "messageParameters": {}}
+ result.update(self.error_code_recursed)
+
+ if include_debug_info:
+ result["debugInfo"] = self.debug_info
+
+ return result
+
+
+class PromptflowExceptionPresenter(ExceptionPresenter):
+ @property
+ def error_codes(self):
+ """The hierarchy of the error codes.
+
+ We follow the "Microsoft REST API Guidelines" to define error codes in a hierarchy style.
+ See the below link for details:
+ https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#7102-error-condition-responses
+
+ For subclass of PromptflowException, use the ex.error_codes directly.
+
+ For PromptflowException (not a subclass), the ex.error_code is None.
+ The result should be:
+ ["SystemError", {inner_exception type name if exist}]
+ """
+ if self._ex.error_codes:
+ return self._ex.error_codes
+
+ # For PromptflowException (not a subclass), the ex.error_code is None.
+ # Handle this case specifically.
+ error_codes = [infer_error_code_from_class(SystemErrorException)]
+ if self._ex.inner_exception:
+ error_codes.append(infer_error_code_from_class(self._ex.inner_exception.__class__))
+ return error_codes
+
def to_dict(self, *, include_debug_info=False):
result = {
"message": self._ex.message,
@@ -263,20 +302,7 @@ def to_dict(self, *, include_debug_info=False):
"referenceCode": self._ex.reference_code,
}
- if self.error_code_recursed:
- result.update(self.error_code_recursed)
- else:
- # For PromptflowException (not a subclass), the error_code_recursed is None.
- # Handle this case specifically.
- result["code"] = infer_error_code_from_class(SystemErrorException)
- if self._ex.inner_exception:
- # Set the type of inner_exception as the inner error
- result["innerError"] = {
- "code": infer_error_code_from_class(self._ex.inner_exception.__class__),
- "innerError": None,
- }
- else:
- result["innerError"] = None
+ result.update(self.error_code_recursed)
if self._ex.additional_info:
result["additionalInfo"] = [{"type": k, "info": v} for k, v in self._ex.additional_info.items()]
if include_debug_info:
diff --git a/src/promptflow/promptflow/azure/_pf_client.py b/src/promptflow/promptflow/azure/_pf_client.py
index 70c8ca9c805..f82d8a2eac9 100644
--- a/src/promptflow/promptflow/azure/_pf_client.py
+++ b/src/promptflow/promptflow/azure/_pf_client.py
@@ -11,6 +11,7 @@
from pandas import DataFrame
from promptflow._sdk._constants import MAX_SHOW_DETAILS_RESULTS
+from promptflow._sdk._errors import RunOperationParameterError
from promptflow._sdk._user_agent import USER_AGENT
from promptflow._sdk.entities import Run
from promptflow.azure._load_functions import load_flow
@@ -48,6 +49,7 @@ def __init__(
workspace_name: Optional[str] = None,
**kwargs,
):
+ self._validate_config_information(subscription_id, resource_group_name, workspace_name, kwargs)
self._add_user_agent(kwargs)
self._ml_client = kwargs.pop("ml_client", None) or MLClient(
credential=credential,
@@ -94,6 +96,23 @@ def __init__(
**kwargs,
)
+ @staticmethod
+ def _validate_config_information(subscription_id, resource_group_name, workspace_name, kwargs):
+ """Validate the config information in case wrong parameter name is passed into the constructor."""
+ sub_name, wrong_sub_name = "subscription_id", "subscription"
+ rg_name, wrong_rg_name = "resource_group_name", "resource_group"
+ ws_name, wrong_ws_name = "workspace_name", "workspace"
+
+ error_message = (
+ "You have passed in the wrong parameter name to initialize the PFClient, please use {0!r} instead of {1!r}."
+ )
+ if not subscription_id and kwargs.get(wrong_sub_name, None) is not None:
+ raise RunOperationParameterError(error_message.format(sub_name, wrong_sub_name))
+ if not resource_group_name and kwargs.get(wrong_rg_name, None) is not None:
+ raise RunOperationParameterError(error_message.format(rg_name, wrong_rg_name))
+ if not workspace_name and kwargs.get(wrong_ws_name, None) is not None:
+ raise RunOperationParameterError(error_message.format(ws_name, wrong_ws_name))
+
@property
def ml_client(self):
"""Return a client to interact with Azure ML services."""
diff --git a/src/promptflow/promptflow/azure/_restclient/flow_service_caller.py b/src/promptflow/promptflow/azure/_restclient/flow_service_caller.py
index 02432db24b4..1893c538f22 100644
--- a/src/promptflow/promptflow/azure/_restclient/flow_service_caller.py
+++ b/src/promptflow/promptflow/azure/_restclient/flow_service_caller.py
@@ -13,6 +13,7 @@
from azure.core.exceptions import HttpResponseError, ResourceExistsError
from azure.core.pipeline.policies import RetryPolicy
+from promptflow._telemetry.telemetry import TelemetryMixin
from promptflow.azure._constants._flow import AUTOMATIC_RUNTIME
from promptflow.azure._restclient.flow import AzureMachineLearningDesignerServiceClient
from promptflow.exceptions import ValidationException, UserErrorException, PromptflowException
@@ -27,16 +28,6 @@ def __init__(self, message, **kwargs):
super().__init__(message, **kwargs)
-class TelemetryMixin(object):
-
- def __init__(self):
- # Need to call init for potential parent, otherwise it won't be initialized.
- super().__init__()
-
- def _get_telemetry_values(self, *args, **kwargs):
- return {}
-
-
class RequestTelemetryMixin(TelemetryMixin):
def __init__(self):
diff --git a/src/promptflow/promptflow/azure/operations/_run_operations.py b/src/promptflow/promptflow/azure/operations/_run_operations.py
index e0b74345aca..a3f678751b5 100644
--- a/src/promptflow/promptflow/azure/operations/_run_operations.py
+++ b/src/promptflow/promptflow/azure/operations/_run_operations.py
@@ -43,6 +43,8 @@
from promptflow._sdk._logger_factory import LoggerFactory
from promptflow._sdk._utils import in_jupyter_notebook, incremental_print
from promptflow._sdk.entities import Run
+from promptflow._telemetry.activity import ActivityType, monitor_operation
+from promptflow._telemetry.telemetry import TelemetryMixin
from promptflow._utils.flow_utils import get_flow_lineage_id
from promptflow.azure._constants._flow import (
AUTOMATIC_RUNTIME,
@@ -69,7 +71,7 @@ def __init__(self, message):
super().__init__(message)
-class RunOperations(_ScopeDependentOperations):
+class RunOperations(_ScopeDependentOperations, TelemetryMixin):
"""RunOperations that can manage runs.
You should not instantiate this class directly. Instead, you should
@@ -124,6 +126,18 @@ def _run_history_endpoint_url(self):
endpoint = self._service_caller._service_endpoint
return endpoint + "history/v1.0" + self._common_azure_url_pattern
+ def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument
+ """Return the telemetry values of run operations.
+
+ :return: The telemetry values
+ :rtype: Dict
+ """
+ return {
+ "subscription_id": self._operation_scope.subscription_id,
+ "resource_group_name": self._operation_scope.resource_group_name,
+ "workspace_name": self._operation_scope.workspace_name,
+ }
+
def _get_run_portal_url(self, run_id: str):
"""Get the portal url for the run."""
url = f"https://ml.azure.com/prompts/flow/bulkrun/run/{run_id}/details?wsid={self._common_azure_url_pattern}"
@@ -135,18 +149,17 @@ def _get_input_portal_url_from_input_uri(self, input_uri):
if not input_uri:
return None
if input_uri.startswith("azureml://"):
- # input uri is a datastore path
- match = self._DATASTORE_PATH_PATTERN.match(input_uri)
- if not match or len(match.groups()) != 2:
+ res = self._get_portal_url_from_asset_id(input_uri)
+ if res is None:
+ res = self._get_portal_url_from_datastore_path(input_uri)
+ if res is None:
+ error_msg = (
+ f"Failed to get portal url: {input_uri!r} is not a valid azureml asset id or datastore path."
+ )
logger.warning(error_msg)
- return None
- datastore, path = match.groups()
- return (
- f"https://ml.azure.com/data/datastore/{datastore}/edit?wsid={self._common_azure_url_pattern}"
- f"&activeFilePath={path}#browseTab"
- )
+ return res
elif input_uri.startswith("azureml:/"):
- # when input uri is an asset id, leverage the asset id pattern to get the portal url
+ # some asset id could start with "azureml:/"
return self._get_portal_url_from_asset_id(input_uri)
elif input_uri.startswith("azureml:"):
# named asset id
@@ -156,14 +169,33 @@ def _get_input_portal_url_from_input_uri(self, input_uri):
logger.warning(error_msg)
return None
- def _get_portal_url_from_asset_id(self, output_uri):
- """Get the portal url for the data output."""
- error_msg = f"Failed to get portal url: {output_uri!r} is not a valid azureml asset id."
- if not output_uri:
+ def _get_portal_url_from_datastore_path(self, datastore_path, log_warning=False):
+ """Get the portal url from the datastore path."""
+ error_msg = (
+ f"Failed to get portal url: Datastore path {datastore_path!r} is not a valid azureml datastore path."
+ )
+ if not datastore_path:
return None
- match = self._ASSET_ID_PATTERN.match(output_uri)
+ match = self._DATASTORE_PATH_PATTERN.match(datastore_path)
if not match or len(match.groups()) != 2:
- logger.warning(error_msg)
+ if log_warning:
+ logger.warning(error_msg)
+ return None
+ datastore, path = match.groups()
+ return (
+ f"https://ml.azure.com/data/datastore/{datastore}/edit?wsid={self._common_azure_url_pattern}"
+ f"&activeFilePath={path}#browseTab"
+ )
+
+ def _get_portal_url_from_asset_id(self, asset_id, log_warning=False):
+ """Get the portal url from asset id."""
+ error_msg = f"Failed to get portal url: {asset_id!r} is not a valid azureml asset id."
+ if not asset_id:
+ return None
+ match = self._ASSET_ID_PATTERN.match(asset_id)
+ if not match or len(match.groups()) != 2:
+ if log_warning:
+ logger.warning(error_msg)
return None
name, version = match.groups()
return f"https://ml.azure.com/data/{name}/{version}/details?wsid={self._common_azure_url_pattern}"
@@ -176,6 +208,7 @@ def _get_headers(self):
}
return custom_header
+ @monitor_operation(activity_name="pfazure.runs.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, run: Run, **kwargs) -> Run:
"""Create or update a run.
@@ -201,6 +234,7 @@ def create_or_update(self, run: Run, **kwargs) -> Run:
self.stream(run=run.name)
return self.get(run=run.name)
+ @monitor_operation(activity_name="pfazure.runs.list", activity_type=ActivityType.PUBLICAPI)
def list(
self, max_results: int = MAX_RUN_LIST_RESULTS, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs
) -> List[Run]:
@@ -270,6 +304,7 @@ def list(
refined_runs.append(Run._from_index_service_entity(run))
return refined_runs
+ @monitor_operation(activity_name="pfazure.runs.get_metrics", activity_type=ActivityType.PUBLICAPI)
def get_metrics(self, run: Union[str, Run], **kwargs) -> dict:
"""Get the metrics from the run.
@@ -283,6 +318,7 @@ def get_metrics(self, run: Union[str, Run], **kwargs) -> dict:
metrics = self._get_metrics_from_metric_service(run)
return metrics
+ @monitor_operation(activity_name="pfazure.runs.get_details", activity_type=ActivityType.PUBLICAPI)
def get_details(
self, run: Union[str, Run], max_results: int = MAX_SHOW_DETAILS_RESULTS, all_results: bool = False, **kwargs
) -> DataFrame:
@@ -419,6 +455,7 @@ def _is_system_metric(metric: str) -> bool:
or metric.endswith(".is_completed")
)
+ @monitor_operation(activity_name="pfazure.runs.get", activity_type=ActivityType.PUBLICAPI)
def get(self, run: str, **kwargs) -> Run:
"""Get a run.
@@ -479,7 +516,7 @@ def _refine_run_data_from_run_history(self, run_data: dict) -> dict:
# get portal urls
run_data[RunDataKeys.DATA_PORTAL_URL] = self._get_input_portal_url_from_input_uri(input_data)
run_data[RunDataKeys.INPUT_RUN_PORTAL_URL] = self._get_run_portal_url(run_id=input_run_id)
- run_data[RunDataKeys.OUTPUT_PORTAL_URL] = self._get_portal_url_from_asset_id(output_data)
+ run_data[RunDataKeys.OUTPUT_PORTAL_URL] = self._get_portal_url_from_asset_id(output_data, log_warning=True)
return run_data
def _get_run_from_index_service(self, flow_run_id, **kwargs):
@@ -512,6 +549,7 @@ def _get_run_from_index_service(self, flow_run_id, **kwargs):
f"Failed to get run metrics from service. Code: {response.status_code}, text: {response.text}"
)
+ @monitor_operation(activity_name="pfazure.runs.archive", activity_type=ActivityType.PUBLICAPI)
def archive(self, run: str) -> Run:
"""Archive a run.
@@ -522,6 +560,7 @@ def archive(self, run: str) -> Run:
"""
pass
+ @monitor_operation(activity_name="pfazure.runs.restore", activity_type=ActivityType.PUBLICAPI)
def restore(self, run: str) -> Run:
"""Restore a run.
@@ -541,6 +580,7 @@ def _get_log(self, flow_run_id: str) -> str:
headers=self._get_headers(),
)
+ @monitor_operation(activity_name="pfazure.runs.stream", activity_type=ActivityType.PUBLICAPI)
def stream(self, run: Union[str, Run]) -> Run:
"""Stream the logs of a run."""
run = self.get(run=run)
@@ -676,6 +716,7 @@ def _get_inputs_outputs_from_child_runs(self, runs: List[Dict[str, Any]]):
outputs[LINE_NUMBER].append(index)
return inputs, outputs
+ @monitor_operation(activity_name="pfazure.runs.visualize", activity_type=ActivityType.PUBLICAPI)
def visualize(self, runs: Union[str, Run, List[str], List[Run]], **kwargs) -> None:
"""Visualize run(s) using Azure AI portal.
diff --git a/src/promptflow/promptflow/contracts/tool.py b/src/promptflow/promptflow/contracts/tool.py
index 30decfea8fb..edb9fde065c 100644
--- a/src/promptflow/promptflow/contracts/tool.py
+++ b/src/promptflow/promptflow/contracts/tool.py
@@ -216,6 +216,7 @@ class InputDefinition:
default: str = None
description: str = None
enum: List[str] = None
+ custom_type: List[str] = None
def serialize(self) -> dict:
"""Serialize input definition to dict.
@@ -234,6 +235,8 @@ def serialize(self) -> dict:
data["description"] = self.description
if self.enum:
data["enum"] = self.enum
+ if self.custom_type:
+ data["custom_type"] = self.custom_type
return data
@staticmethod
diff --git a/src/promptflow/promptflow/executor/_errors.py b/src/promptflow/promptflow/executor/_errors.py
index 9048638a51d..85a79e925b8 100644
--- a/src/promptflow/promptflow/executor/_errors.py
+++ b/src/promptflow/promptflow/executor/_errors.py
@@ -2,7 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
-from promptflow.exceptions import ErrorTarget, SystemErrorException, UserErrorException, ValidationException
+from promptflow._utils.exception_utils import ExceptionPresenter, infer_error_code_from_class
+from promptflow.exceptions import (
+ ErrorTarget,
+ PromptflowException,
+ SystemErrorException,
+ UserErrorException,
+ ValidationException,
+)
class InvalidCustomLLMTool(ValidationException):
@@ -180,3 +187,56 @@ def __init__(self, line_number, timeout):
super().__init__(
message=f"Line {line_number} execution timeout for exceeding {timeout} seconds", target=ErrorTarget.EXECUTOR
)
+
+
+class ResolveToolError(PromptflowException):
+ """Exception raised when tool load failed.
+
+ It is used to append the name of the failed node to the error message to improve the user experience.
+ It simply wraps the error thrown by the Resolve Tool phase.
+ It has the same additional_info and error_codes as inner error.
+ """
+
+ def __init__(self, *, node_name: str, target: ErrorTarget = ErrorTarget.EXECUTOR, module: str = None):
+ self._node_name = node_name
+ super().__init__(target=target, module=module)
+
+ @property
+ def message_format(self):
+ if self.inner_exception:
+ return "Tool load failed in '{node_name}': {error_type_and_message}"
+ else:
+ return "Tool load failed in '{node_name}'."
+
+ @property
+ def message_parameters(self):
+ error_type_and_message = None
+ if self.inner_exception:
+ error_type_and_message = f"({self.inner_exception.__class__.__name__}) {self.inner_exception}"
+
+ return {
+ "node_name": self._node_name,
+ "error_type_and_message": error_type_and_message,
+ }
+
+ @property
+ def additional_info(self):
+ """Get additional info from innererror when the innererror is PromptflowException"""
+ if isinstance(self.inner_exception, PromptflowException):
+ return self.inner_exception.additional_info
+ return None
+
+ @property
+ def error_codes(self):
+ """The hierarchy of the error codes.
+
+ We follow the "Microsoft REST API Guidelines" to define error codes in a hierarchy style.
+ See the below link for details:
+ https://github.com/microsoft/api-guidelines/blob/vNext/Guidelines.md#7102-error-condition-responses
+
+ Due to ResolveToolError has no classification of its own.
+ Its error_codes respect the inner_error.
+ """
+ if self.inner_exception:
+ return ExceptionPresenter.create(self.inner_exception).error_codes
+ return [infer_error_code_from_class(SystemErrorException), self.__class__.__name__]
diff --git a/src/promptflow/promptflow/executor/_line_execution_process_pool.py b/src/promptflow/promptflow/executor/_line_execution_process_pool.py
index 8bb4803a0ed..f668d8944f2 100644
--- a/src/promptflow/promptflow/executor/_line_execution_process_pool.py
+++ b/src/promptflow/promptflow/executor/_line_execution_process_pool.py
@@ -1,5 +1,6 @@
import contextvars
import math
+import time
import multiprocessing
import os
import queue
@@ -41,6 +42,79 @@ def persist_flow_run(self, run_info: FlowRunInfo):
self.queue.put(run_info)
+class HealthyEnsuredProcess:
+ def __init__(self, executor_creation_func):
+ self.process = None
+ self.input_queue = None
+ self.output_queue = None
+ self.is_ready = False
+ self._executor_creation_func = executor_creation_func
+
+ def start_new(self):
+ input_queue = Queue()
+ output_queue = Queue()
+ self.input_queue = input_queue
+ self.output_queue = output_queue
+
+ # Put a start message and wait the subprocess be ready.
+ # Test if the subprocess can receive the message.
+ input_queue.put("start")
+
+ current_log_context = LogContext.get_current()
+ process = Process(
+ target=_process_wrapper,
+ args=(
+ self._executor_creation_func,
+ input_queue,
+ output_queue,
+ current_log_context.get_initializer() if current_log_context else None,
+ OperationContext.get_instance().get_context_dict(),
+ ),
+ # Set the process as a daemon process to automatically terminated and release system resources
+ # when the main process exits.
+ daemon=True
+ )
+
+ self.process = process
+ process.start()
+
+ try:
+ # Wait for subprocess send a ready message.
+ ready_msg = output_queue.get(timeout=30)
+ logger.info(f"Process {process.pid} get ready_msg: {ready_msg}")
+ self.is_ready = True
+ except queue.Empty:
+ logger.info(f"Process {process.pid} did not send ready message, exit.")
+ self.end()
+ self.start_new()
+
+ def end(self):
+ # When process failed to start and the task_queue is empty.
+ # The process will no longer re-created, and the process is None.
+ if self.process is None:
+ return
+ if self.process.is_alive():
+ self.process.kill()
+
+ def put(self, args):
+ self.input_queue.put(args)
+
+ def get(self):
+ return self.output_queue.get(timeout=1)
+
+ def format_current_process(self, line_number: int, is_completed=False):
+ process_name = self.process.name if self.process else None
+ process_pid = self.process.pid if self.process else None
+ if is_completed:
+ logger.info(
+ f"Process name: {process_name}, Process id: {process_pid}, Line number: {line_number} completed.")
+ else:
+ logger.info(
+ f"Process name: {process_name}, Process id: {process_pid}, Line number: {line_number} start execution.")
+
+ return f"Process name({process_name})-Process id({process_pid})"
+
+
class LineExecutionProcessPool:
def __init__(
self,
@@ -93,8 +167,10 @@ def __enter__(self):
if not self._use_fork:
available_max_worker_count = get_available_max_worker_count()
self._n_process = min(self._worker_count, self._nlines, available_max_worker_count)
+ bulk_logger.info(f"Not using fork, process count: {self._n_process}")
else:
self._n_process = min(self._worker_count, self._nlines)
+ bulk_logger.info(f"Using fork, process count: {self._n_process}")
pool = ThreadPool(self._n_process, initializer=set_context, initargs=(contextvars.copy_context(),))
self._pool = pool
@@ -105,41 +181,23 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._pool.close()
self._pool.join()
- def _new_process(self):
- input_queue = Queue()
- output_queue = Queue()
- current_log_context = LogContext.get_current()
- process = Process(
- target=_process_wrapper,
- args=(
- self._executor_creation_func,
- input_queue,
- output_queue,
- current_log_context.get_initializer() if current_log_context else None,
- OperationContext.get_instance().get_context_dict(),
- ),
- )
- process.start()
- return process, input_queue, output_queue
-
- def end_process(self, process):
- if process.is_alive():
- process.kill()
-
def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, result_list):
- process, input_queue, output_queue = self._new_process()
+ healthy_ensured_process = HealthyEnsuredProcess(self._executor_creation_func)
+ healthy_ensured_process.start_new()
+
while True:
try:
+ while not healthy_ensured_process.is_ready and not task_queue.empty():
+ time.sleep(1)
args = task_queue.get(timeout=1)
except queue.Empty:
logger.info(f"Process {idx} queue empty, exit.")
- self.end_process(process)
+ healthy_ensured_process.end()
return
- input_queue.put(args)
+ healthy_ensured_process.put(args)
inputs, line_number, run_id = args[:3]
-
- self._processing_idx[line_number] = process.name
+ self._processing_idx[line_number] = healthy_ensured_process.format_current_process(line_number)
start_time = datetime.now()
completed = False
@@ -148,7 +206,7 @@ def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, re
try:
# Responsible for checking the output queue messages and
# processing them within a specified timeout period.
- message = output_queue.get(timeout=1)
+ message = healthy_ensured_process.get()
if isinstance(message, LineResult):
completed = True
result_list.append(message)
@@ -160,7 +218,7 @@ def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, re
except queue.Empty:
continue
- self._completed_idx[line_number] = process.name
+ self._completed_idx[line_number] = healthy_ensured_process.format_current_process(line_number, True)
# Handling the timeout of a line execution process.
if not completed:
logger.warning(f"Line {line_number} timeout after {timeout_time} seconds.")
@@ -169,8 +227,10 @@ def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, re
inputs, run_id, line_number, self._flow_id, start_time, ex
)
result_list.append(result)
- self.end_process(process)
- process, input_queue, output_queue = self._new_process()
+ self._completed_idx[line_number] = healthy_ensured_process.format_current_process(line_number, True)
+ healthy_ensured_process.end()
+ healthy_ensured_process.start_new()
+
self._processing_idx.pop(line_number)
log_progress(
logger=bulk_logger,
@@ -180,6 +240,7 @@ def _timeout_process_wrapper(self, task_queue: Queue, idx: int, timeout_time, re
)
def _generate_line_result_for_exception(self, inputs, run_id, line_number, flow_id, start_time, ex) -> LineResult:
+ logger.error(f"Line {line_number}, Process {os.getpid()} failed with exception: {ex}")
run_info = FlowRunInfo(
run_id=f"{run_id}_{line_number}",
status=Status.Failed,
@@ -280,9 +341,12 @@ def _exec_line(
line_result.output = {}
return line_result
except Exception as e:
+ logger.error(f"Line {index}, Process {os.getpid()} failed with exception: {e}")
if executor._run_tracker.flow_run_list:
+ logger.info(f"Line {index}, Process {os.getpid()} have been added to flow run list.")
run_info = executor._run_tracker.flow_run_list[0]
else:
+ logger.info(f"Line {index}, Process {os.getpid()} have not been added to flow run list.")
run_info = executor._run_tracker.end_run(f"{run_id}_{index}", ex=e)
output_queue.put(run_info)
result = LineResult(
@@ -301,6 +365,7 @@ def _process_wrapper(
log_context_initialization_func,
operation_contexts_dict: dict,
):
+ logger.info(f"Process {os.getpid()} started.")
OperationContext.get_instance().update(operation_contexts_dict) # Update the operation context for the new process.
if log_context_initialization_func:
with log_context_initialization_func():
@@ -326,9 +391,18 @@ def create_executor_fork(*, flow_executor: FlowExecutor, storage: AbstractRunSto
def exec_line_for_queue(executor_creation_func, input_queue: Queue, output_queue: Queue):
run_storage = QueueRunStorage(output_queue)
executor: FlowExecutor = executor_creation_func(storage=run_storage)
+
+ # Wait for the start signal message
+ start_msg = input_queue.get()
+ logger.info(f"Process {os.getpid()} received start signal message: {start_msg}")
+
+ # Send a ready signal message
+ output_queue.put("ready")
+ logger.info(f"Process {os.getpid()} sent ready signal message.")
+
while True:
try:
- args = input_queue.get(1)
+ args = input_queue.get(timeout=1)
inputs, line_number, run_id, variant_id, validate_inputs = args[:5]
result = _exec_line(
executor=executor,
diff --git a/src/promptflow/promptflow/executor/_tool_resolver.py b/src/promptflow/promptflow/executor/_tool_resolver.py
index 91c34f94d7b..62ce9db3f84 100644
--- a/src/promptflow/promptflow/executor/_tool_resolver.py
+++ b/src/promptflow/promptflow/executor/_tool_resolver.py
@@ -16,13 +16,14 @@
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
from promptflow.contracts.types import PromptTemplate
-from promptflow.exceptions import ErrorTarget, UserErrorException
+from promptflow.exceptions import ErrorTarget, PromptflowException, UserErrorException
from promptflow.executor._errors import (
ConnectionNotFound,
InvalidConnectionType,
InvalidCustomLLMTool,
InvalidSource,
NodeInputValidationError,
+ ResolveToolError,
ValueTypeUnresolved,
)
@@ -98,26 +99,33 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool):
return updated_node
def resolve_tool_by_node(self, node: Node, convert_input_types=True) -> ResolvedTool:
- if node.source is None:
- raise UserErrorException(f"Node {node.name} does not have source defined.")
-
- if node.type is ToolType.PYTHON:
- if node.source.type == ToolSourceType.Package:
- return self._resolve_package_node(node, convert_input_types=convert_input_types)
- elif node.source.type == ToolSourceType.Code:
- return self._resolve_script_node(node, convert_input_types=convert_input_types)
- raise NotImplementedError(f"Tool source type {node.source.type} for python tool is not supported yet.")
- elif node.type is ToolType.PROMPT:
- return self._resolve_prompt_node(node)
- elif node.type is ToolType.LLM:
- return self._resolve_llm_node(node, convert_input_types=convert_input_types)
- elif node.type is ToolType.CUSTOM_LLM:
- if node.source.type == ToolSourceType.PackageWithPrompt:
- resolved_tool = self._resolve_package_node(node, convert_input_types=convert_input_types)
- return self._integrate_prompt_in_package_node(node, resolved_tool)
- raise NotImplementedError(f"Tool source type {node.source.type} for custom_llm tool is not supported yet.")
- else:
- raise NotImplementedError(f"Tool type {node.type} is not supported yet.")
+ try:
+ if node.source is None:
+ raise UserErrorException(f"Node {node.name} does not have source defined.")
+
+ if node.type is ToolType.PYTHON:
+ if node.source.type == ToolSourceType.Package:
+ return self._resolve_package_node(node, convert_input_types=convert_input_types)
+ elif node.source.type == ToolSourceType.Code:
+ return self._resolve_script_node(node, convert_input_types=convert_input_types)
+ raise NotImplementedError(f"Tool source type {node.source.type} for python tool is not supported yet.")
+ elif node.type is ToolType.PROMPT:
+ return self._resolve_prompt_node(node)
+ elif node.type is ToolType.LLM:
+ return self._resolve_llm_node(node, convert_input_types=convert_input_types)
+ elif node.type is ToolType.CUSTOM_LLM:
+ if node.source.type == ToolSourceType.PackageWithPrompt:
+ resolved_tool = self._resolve_package_node(node, convert_input_types=convert_input_types)
+ return self._integrate_prompt_in_package_node(node, resolved_tool)
+ raise NotImplementedError(
+ f"Tool source type {node.source.type} for custom_llm tool is not supported yet."
+ )
+ else:
+ raise NotImplementedError(f"Tool type {node.type} is not supported yet.")
+ except Exception as e:
+ if isinstance(e, PromptflowException) and e.target != ErrorTarget.UNKNOWN:
+ raise ResolveToolError(node_name=node.name, target=e.target, module=e.module) from e
+ raise ResolveToolError(node_name=node.name) from e
def _load_source_content(self, node: Node) -> str:
source = node.source
diff --git a/src/promptflow/tests/conftest.py b/src/promptflow/tests/conftest.py
index 2e07dd77c5c..cceb1ed95d2 100644
--- a/src/promptflow/tests/conftest.py
+++ b/src/promptflow/tests/conftest.py
@@ -107,13 +107,15 @@ def prepare_symbolic_flow() -> str:
return target_folder
-@pytest.fixture
-def is_custom_tool_pkg_installed() -> bool:
+@pytest.fixture(scope="session")
+def install_custom_tool_pkg():
+ # Leave the pkg installed since multiple tests rely on it and the tests may run concurrently
try:
import my_tool_package # noqa: F401
- pkg_installed = True
except ImportError:
- pkg_installed = False
+ import subprocess
+ import sys
- return pkg_installed
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "test-custom-tools==0.0.1"])
+ yield
diff --git a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py
index 5c2874bf73a..2ea0cbdae00 100644
--- a/src/promptflow/tests/executor/e2etests/test_executor_happypath.py
+++ b/src/promptflow/tests/executor/e2etests/test_executor_happypath.py
@@ -9,7 +9,7 @@
from promptflow.contracts.run_info import Status
from promptflow.exceptions import UserErrorException
from promptflow.executor import FlowExecutor
-from promptflow.executor._errors import ConnectionNotFound, InputTypeError
+from promptflow.executor._errors import ConnectionNotFound, InputTypeError, ResolveToolError
from promptflow.executor.flow_executor import BulkResult, LineResult
from promptflow.storage import AbstractRunStorage
@@ -265,13 +265,14 @@ def test_executor_node_overrides(self, dev_connections):
assert type(e.value).__name__ == "WrappedOpenAIError"
assert "The API deployment for this resource does not exist." in str(e.value)
- with pytest.raises(ConnectionNotFound) as e:
+ with pytest.raises(ResolveToolError) as e:
executor = FlowExecutor.create(
get_yaml_file(SAMPLE_FLOW),
dev_connections,
node_override={"classify_with_llm.connection": "dummy_connection"},
raise_ex=True,
)
+ assert isinstance(e.value.inner_exception, ConnectionNotFound)
assert "Connection 'dummy_connection' not found" in str(e.value)
@pytest.mark.parametrize(
diff --git a/src/promptflow/tests/executor/e2etests/test_executor_validation.py b/src/promptflow/tests/executor/e2etests/test_executor_validation.py
index cedb11342e2..f2b3f9eb082 100644
--- a/src/promptflow/tests/executor/e2etests/test_executor_validation.py
+++ b/src/promptflow/tests/executor/e2etests/test_executor_validation.py
@@ -24,6 +24,7 @@
NodeInputValidationError,
NodeReferenceNotFound,
OutputReferenceNotFound,
+ ResolveToolError,
SingleNodeValidationError,
)
from promptflow.executor.flow_executor import BulkResult
@@ -35,12 +36,13 @@
@pytest.mark.e2etest
class TestValidation:
@pytest.mark.parametrize(
- "flow_folder, yml_file, error_class, error_msg",
+ "flow_folder, yml_file, error_class, inner_class, error_msg",
[
(
"nodes_names_duplicated",
"flow.dag.yaml",
DuplicateNodeName,
+ None,
(
"Invalid node definitions found in the flow graph. Node with name 'stringify_num' appears more "
"than once in the node definitions in your flow, which is not allowed. To "
@@ -51,16 +53,19 @@ class TestValidation:
(
"source_file_missing",
"flow.dag.jinja.yaml",
+ ResolveToolError,
InvalidSource,
(
- "Node source path 'summarize_text_content__variant_1.jinja2' is invalid on "
- "node 'summarize_text_content'."
+ "Tool load failed in 'summarize_text_content': (InvalidSource) "
+ "Node source path 'summarize_text_content__variant_1.jinja2' is invalid on node "
+ "'summarize_text_content'."
),
),
(
"node_reference_not_found",
"flow.dag.yaml",
NodeReferenceNotFound,
+ None,
(
"Invalid node definitions found in the flow graph. Node 'divide_num_2' references a non-existent "
"node 'divide_num_3' in your flow. Please review your flow to ensure that the "
@@ -71,6 +76,7 @@ class TestValidation:
"node_circular_dependency",
"flow.dag.yaml",
NodeCircularDependency,
+ None,
(
"Invalid node definitions found in the flow graph. Node circular dependency has been detected "
"among the nodes in your flow. Kindly review the reference relationships for "
@@ -82,6 +88,7 @@ class TestValidation:
"flow_input_reference_invalid",
"flow.dag.yaml",
InputReferenceNotFound,
+ None,
(
"Invalid node definitions found in the flow graph. Node 'divide_num' references flow input 'num_1' "
"which is not defined in your flow. To resolve this issue, please review your "
@@ -93,6 +100,7 @@ class TestValidation:
"flow_output_reference_invalid",
"flow.dag.yaml",
EmptyOutputReference,
+ None,
(
"The output 'content' for flow is incorrect. The reference is not specified for the output "
"'content' in the flow. To rectify this, ensure that you accurately specify "
@@ -103,6 +111,7 @@ class TestValidation:
"outputs_reference_not_valid",
"flow.dag.yaml",
OutputReferenceNotFound,
+ None,
(
"The output 'content' for flow is incorrect. The output 'content' references non-existent "
"node 'another_stringify_num' in your flow. To resolve this issue, please "
@@ -114,6 +123,7 @@ class TestValidation:
"outputs_with_invalid_flow_inputs_ref",
"flow.dag.yaml",
OutputReferenceNotFound,
+ None,
(
"The output 'num' for flow is incorrect. The output 'num' references non-existent flow "
"input 'num11' in your flow. Please carefully review your flow and correct "
@@ -123,21 +133,25 @@ class TestValidation:
],
)
def test_executor_create_failure_type_and_message(
- self, flow_folder, yml_file, error_class, error_msg, dev_connections
+ self, flow_folder, yml_file, error_class, inner_class, error_msg, dev_connections
):
with pytest.raises(error_class) as exc_info:
FlowExecutor.create(get_yaml_file(flow_folder, WRONG_FLOW_ROOT, yml_file), dev_connections)
+ if isinstance(exc_info.value, ResolveToolError):
+ assert isinstance(exc_info.value.inner_exception, inner_class)
assert error_msg == exc_info.value.message
@pytest.mark.parametrize(
- "flow_folder, yml_file, error_class",
+ "flow_folder, yml_file, error_class, inner_class",
[
- ("source_file_missing", "flow.dag.python.yaml", PythonParsingError),
+ ("source_file_missing", "flow.dag.python.yaml", ResolveToolError, PythonParsingError),
],
)
- def test_executor_create_failure_type(self, flow_folder, yml_file, error_class, dev_connections):
- with pytest.raises(error_class):
+ def test_executor_create_failure_type(self, flow_folder, yml_file, error_class, inner_class, dev_connections):
+ with pytest.raises(error_class) as e:
FlowExecutor.create(get_yaml_file(flow_folder, WRONG_FLOW_ROOT, yml_file), dev_connections)
+ if isinstance(e.value, ResolveToolError):
+ assert isinstance(e.value.inner_exception, inner_class)
@pytest.mark.parametrize(
"ordered_flow_folder, unordered_flow_folder",
@@ -153,18 +167,20 @@ def test_node_topology_in_order(self, ordered_flow_folder, unordered_flow_folder
assert node1.name == node2.name
@pytest.mark.parametrize(
- "flow_folder, error_class",
+ "flow_folder, error_class, inner_class",
[
- ("invalid_connection", ConnectionNotFound),
- ("tool_type_missing", NotImplementedError),
- ("wrong_module", FailedToImportModule),
- ("wrong_api", APINotFound),
- ("wrong_provider", APINotFound),
+ ("invalid_connection", ResolveToolError, ConnectionNotFound),
+ ("tool_type_missing", ResolveToolError, NotImplementedError),
+ ("wrong_module", FailedToImportModule, None),
+ ("wrong_api", ResolveToolError, APINotFound),
+ ("wrong_provider", ResolveToolError, APINotFound),
],
)
- def test_invalid_flow_dag(self, flow_folder, error_class, dev_connections):
- with pytest.raises(error_class):
+ def test_invalid_flow_dag(self, flow_folder, error_class, inner_class, dev_connections):
+ with pytest.raises(error_class) as e:
FlowExecutor.create(get_yaml_file(flow_folder, WRONG_FLOW_ROOT), dev_connections)
+ if isinstance(e.value, ResolveToolError):
+ assert isinstance(e.value.inner_exception, inner_class)
@pytest.mark.parametrize(
"flow_folder, line_input, error_class",
@@ -342,8 +358,9 @@ def test_single_node_input_type_invalid(
],
)
def test_flow_run_with_duplicated_inputs(self, flow_folder, msg, dev_connections):
- with pytest.raises(NodeInputValidationError, match=msg):
+ with pytest.raises(ResolveToolError, match=msg) as e:
FlowExecutor.create(get_yaml_file(flow_folder, FLOW_ROOT), dev_connections)
+ assert isinstance(e.value.inner_exception, NodeInputValidationError)
@pytest.mark.parametrize(
"flow_folder, batch_input, raise_on_line_failure, error_class",
diff --git a/src/promptflow/tests/executor/e2etests/test_package_tool.py b/src/promptflow/tests/executor/e2etests/test_package_tool.py
index 73491cf5e50..72ae0372a44 100644
--- a/src/promptflow/tests/executor/e2etests/test_package_tool.py
+++ b/src/promptflow/tests/executor/e2etests/test_package_tool.py
@@ -4,10 +4,10 @@
import pytest
-from promptflow._core._errors import PackageToolNotFoundError
+from promptflow._core._errors import PackageToolNotFoundError, ToolLoadError
from promptflow.contracts.run_info import Status
from promptflow.executor import FlowExecutor
-from promptflow.executor._errors import NodeInputValidationError
+from promptflow.executor._errors import NodeInputValidationError, ResolveToolError
from promptflow.executor.flow_executor import LineResult
from ..utils import WRONG_FLOW_ROOT, get_flow_package_tool_definition, get_flow_sample_inputs, get_yaml_file
@@ -83,15 +83,18 @@ def test_custom_llm_tool_with_duplicated_inputs(self, dev_connections, mocker):
"Invalid inputs {'api'} in prompt template of node custom_llm_tool_with_duplicated_inputs. "
"These inputs are duplicated with the inputs of custom llm tool."
)
- with pytest.raises(NodeInputValidationError, match=msg):
+ with pytest.raises(ResolveToolError, match=msg) as e:
FlowExecutor.create(get_yaml_file(flow_folder), dev_connections)
+ assert isinstance(e.value.inner_exception, NodeInputValidationError)
@pytest.mark.parametrize(
- "flow_folder, error_class, error_message",
+ "flow_folder, error_class, inner_class, error_message",
[
(
"wrong_tool_in_package_tools",
+ ResolveToolError,
PackageToolNotFoundError,
+ "Tool load failed in 'search_by_text': (PackageToolNotFoundError) "
"Package tool 'promptflow.tools.serpapi.SerpAPI.search_11' is not found in the current environment. "
"All available package tools are: "
"['promptflow.tools.azure_content_safety.AzureContentSafety.analyze_text', "
@@ -99,7 +102,9 @@ def test_custom_llm_tool_with_duplicated_inputs(self, dev_connections, mocker):
),
(
"wrong_package_in_package_tools",
+ ResolveToolError,
PackageToolNotFoundError,
+ "Tool load failed in 'search_by_text': (PackageToolNotFoundError) "
"Package tool 'promptflow.tools.serpapi11.SerpAPI.search' is not found in the current environment. "
"All available package tools are: "
"['promptflow.tools.azure_content_safety.AzureContentSafety.analyze_text', "
@@ -107,7 +112,7 @@ def test_custom_llm_tool_with_duplicated_inputs(self, dev_connections, mocker):
),
],
)
- def test_package_tool_execution(self, flow_folder, error_class, error_message, dev_connections):
+ def test_package_tool_execution(self, flow_folder, error_class, inner_class, error_message, dev_connections):
def mock_collect_package_tools(keys=None):
return {
"promptflow.tools.azure_content_safety.AzureContentSafety.analyze_text": None,
@@ -117,4 +122,25 @@ def mock_collect_package_tools(keys=None):
with patch(PACKAGE_TOOL_ENTRY, side_effect=mock_collect_package_tools):
with pytest.raises(error_class) as exce_info:
FlowExecutor.create(get_yaml_file(flow_folder, WRONG_FLOW_ROOT), dev_connections)
+ if isinstance(exce_info.value, ResolveToolError):
+ assert isinstance(exce_info.value.inner_exception, inner_class)
assert error_message == exce_info.value.message
+
+ @pytest.mark.parametrize(
+ "flow_folder, error_message",
+ [
+ (
+ "tool_with_init_error",
+ "Tool load failed in 'tool_with_init_error': "
+ "(ToolLoadError) Failed to load package tool 'Tool with init error': (Exception) Tool load error.",
+ )
+ ],
+ )
+ def test_package_tool_load_error(self, flow_folder, error_message, dev_connections, mocker):
+ flow_folder = PACKAGE_TOOL_BASE / flow_folder
+ package_tool_definition = get_flow_package_tool_definition(flow_folder)
+ with mocker.patch(PACKAGE_TOOL_ENTRY, return_value=package_tool_definition):
+ with pytest.raises(ResolveToolError) as exce_info:
+ FlowExecutor.create(get_yaml_file(flow_folder), dev_connections)
+ assert isinstance(exce_info.value.inner_exception, ToolLoadError)
+ assert exce_info.value.message == error_message
diff --git a/src/promptflow/tests/executor/package_tools/tool_with_init_error.py b/src/promptflow/tests/executor/package_tools/tool_with_init_error.py
new file mode 100644
index 00000000000..d3e1f0b34dc
--- /dev/null
+++ b/src/promptflow/tests/executor/package_tools/tool_with_init_error.py
@@ -0,0 +1,10 @@
+from promptflow import ToolProvider, tool
+
+
+class TestLoadErrorTool(ToolProvider):
+ def __init__(self):
+ raise Exception("Tool load error.")
+
+ @tool
+ def tool(self, name: str):
+ return name
diff --git a/src/promptflow/tests/executor/package_tools/tool_with_init_error/flow.dag.yaml b/src/promptflow/tests/executor/package_tools/tool_with_init_error/flow.dag.yaml
new file mode 100644
index 00000000000..2c77f95e339
--- /dev/null
+++ b/src/promptflow/tests/executor/package_tools/tool_with_init_error/flow.dag.yaml
@@ -0,0 +1,10 @@
+inputs: {}
+outputs: {}
+nodes:
+- name: tool_with_init_error
+ type: python
+ source:
+ type: package
+ tool: tool_with_init_error
+ inputs:
+ name: test_name
diff --git a/src/promptflow/tests/executor/package_tools/tool_with_init_error/package_tool_definition.json b/src/promptflow/tests/executor/package_tools/tool_with_init_error/package_tool_definition.json
new file mode 100644
index 00000000000..ed47b820f80
--- /dev/null
+++ b/src/promptflow/tests/executor/package_tools/tool_with_init_error/package_tool_definition.json
@@ -0,0 +1,12 @@
+{
+ "tool_with_init_error": {
+ "class_name": "TestLoadErrorTool",
+ "function": "tool",
+ "inputs": {
+ "name": {"type": ["string"]}
+ },
+ "module": "tool_with_init_error",
+ "name": "Tool with init error",
+ "type": "python"
+ }
+}
diff --git a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py
index 11627b45f62..eb43c3bcc98 100644
--- a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py
+++ b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py
@@ -136,35 +136,42 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro
gen_tool_by_source("fake_name", tool_source, tool_type, working_dir),
assert str(ex.value) == error_message
- @pytest.mark.skip("test package not installed")
- def test_collect_package_tools_and_connections(self):
+ @pytest.mark.skip("TODO: need to fix random pacakge not found error")
+ def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
+ # Need to reload pkg_resources to get the latest installed packages
+ import importlib
+
+ import pkg_resources
+
+ importlib.reload(pkg_resources)
+
keys = ["my_tool_package.tools.my_tool_2.MyTool.my_tool"]
tools, specs, templates = collect_package_tools_and_connections(keys)
assert len(tools) == 1
assert specs == {
- "my_tool_package.connections.MySecondConnection": {
+ "my_tool_package.connections.MyFirstConnection": {
"connectionCategory": "CustomKeys",
"flowValueType": "CustomConnection",
- "connectionType": "MySecondConnection",
- "ConnectionTypeDisplayName": "MySecondConnection",
+ "connectionType": "MyFirstConnection",
+ "ConnectionTypeDisplayName": "MyFirstConnection",
"configSpecs": [
{"name": "api_key", "displayName": "Api Key", "configValueType": "Secret", "isOptional": False},
{"name": "api_base", "displayName": "Api Base", "configValueType": "str", "isOptional": True},
],
"module": "my_tool_package.connections",
- "package": "my-tools-package-with-cstc",
- "package_version": "0.0.6",
+ "package": "test-custom-tools",
+ "package_version": "0.0.1",
}
}
expected_template = {
"name": "",
"type": "custom",
- "custom_type": "MySecondConnection",
+ "custom_type": "MyFirstConnection",
"module": "my_tool_package.connections",
- "package": "my-tools-package-with-cstc",
- "package_version": "0.0.6",
+ "package": "test-custom-tools",
+ "package_version": "0.0.1",
"configs": {"api_base": ""},
"secrets": {"api_key": ""},
}
- loaded_yaml = yaml.safe_load(templates["my_tool_package.connections.MySecondConnection"])
+ loaded_yaml = yaml.safe_load(templates["my_tool_package.connections.MyFirstConnection"])
assert loaded_yaml == expected_template
diff --git a/src/promptflow/tests/executor/unittests/_utils/test_exception_utils.py b/src/promptflow/tests/executor/unittests/_utils/test_exception_utils.py
index 593dd87525d..08515b9d698 100644
--- a/src/promptflow/tests/executor/unittests/_utils/test_exception_utils.py
+++ b/src/promptflow/tests/executor/unittests/_utils/test_exception_utils.py
@@ -229,6 +229,22 @@ def test_to_dict_for_tool_execution_error(self):
},
}
+ @pytest.mark.parametrize(
+ "raise_exception_func, error_class, expected_error_codes",
+ [
+ (raise_general_exception, CustomizedException, ["SystemError", "CustomizedException"]),
+ (raise_tool_execution_error, ToolExecutionError, ["UserError", "ToolExecutionError"]),
+ (raise_promptflow_exception, PromptflowException, ["SystemError", "ZeroDivisionError"]),
+ (raise_promptflow_exception_without_inner_exception, PromptflowException, ["SystemError"]),
+ ],
+ )
+ def test_error_codes(self, raise_exception_func, error_class, expected_error_codes):
+ with pytest.raises(error_class) as e:
+ raise_exception_func()
+
+ presenter = ExceptionPresenter.create(e.value)
+ assert presenter.error_codes == expected_error_codes
+
@pytest.mark.unittest
class TestErrorResponse:
diff --git a/src/promptflow/tests/executor/unittests/executor/test_errors.py b/src/promptflow/tests/executor/unittests/executor/test_errors.py
new file mode 100644
index 00000000000..c87e581ee23
--- /dev/null
+++ b/src/promptflow/tests/executor/unittests/executor/test_errors.py
@@ -0,0 +1,63 @@
+import pytest
+
+from promptflow._core.tool_meta_generator import PythonLoadError
+from promptflow.exceptions import ErrorTarget
+from promptflow.executor._errors import ResolveToolError
+
+
+def code_with_bug():
+ 1 / 0
+
+
+def raise_resolve_tool_error(func, target=None, module=None):
+ try:
+ func()
+ except Exception as e:
+ if target:
+ raise ResolveToolError(node_name="MyTool", target=target, module=module) from e
+ raise ResolveToolError(node_name="MyTool") from e
+
+
+def raise_python_load_error():
+ try:
+ code_with_bug()
+ except Exception as e:
+ raise PythonLoadError(message="Test PythonLoadError.") from e
+
+
+def test_resolve_tool_error():
+ with pytest.raises(ResolveToolError) as e:
+ raise_resolve_tool_error(raise_python_load_error, ErrorTarget.TOOL, "__pf_main__")
+
+ exception = e.value
+ inner_exception = exception.inner_exception
+
+ assert isinstance(inner_exception, PythonLoadError)
+ assert exception.message == "Tool load failed in 'MyTool': (PythonLoadError) Test PythonLoadError."
+ assert exception.additional_info == inner_exception.additional_info
+ assert exception.error_codes == ["UserError", "ToolValidationError", "PythonParsingError", "PythonLoadError"]
+ assert exception.reference_code == "Tool/__pf_main__"
+
+
+def test_resolve_tool_error_with_none_inner():
+ with pytest.raises(ResolveToolError) as e:
+ raise ResolveToolError(node_name="MyTool")
+
+ exception = e.value
+ assert exception.inner_exception is None
+ assert exception.message == "Tool load failed in 'MyTool'."
+ assert exception.additional_info is None
+ assert exception.error_codes == ["SystemError", "ResolveToolError"]
+ assert exception.reference_code == "Executor"
+
+
+def test_resolve_tool_error_with_no_PromptflowException_inner():
+ with pytest.raises(ResolveToolError) as e:
+ raise_resolve_tool_error(code_with_bug)
+
+ exception = e.value
+ assert isinstance(exception.inner_exception, ZeroDivisionError)
+ assert exception.message == "Tool load failed in 'MyTool': (ZeroDivisionError) division by zero"
+ assert exception.additional_info is None
+ assert exception.error_codes == ["SystemError", "ZeroDivisionError"]
+ assert exception.reference_code == "Executor"
diff --git a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py
index 98a94301008..b300410ed35 100644
--- a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py
+++ b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py
@@ -13,9 +13,10 @@
InvalidConnectionType,
InvalidSource,
NodeInputValidationError,
+ ResolveToolError,
ValueTypeUnresolved,
)
-from promptflow.executor._tool_resolver import ToolResolver, ResolvedTool
+from promptflow.executor._tool_resolver import ResolvedTool, ToolResolver
TEST_ROOT = Path(__file__).parent.parent.parent
REQUESTS_PATH = TEST_ROOT / "test_configs/executor_api_requests"
@@ -86,39 +87,48 @@ def test_resolve_tool_by_node_with_invalid_type(self, resolver, mocker):
node = mocker.Mock(name="node", tool=None, inputs={})
node.source = mocker.Mock(type=None)
- with pytest.raises(NotImplementedError) as exec_info:
+ with pytest.raises(ResolveToolError) as exec_info:
resolver.resolve_tool_by_node(node)
- assert "Tool type" in exec_info.value.args[0]
+
+ assert isinstance(exec_info.value.inner_exception, NotImplementedError)
+ assert "Tool type" in exec_info.value.message
def test_resolve_tool_by_node_with_invalid_source_type(self, resolver, mocker):
node = mocker.Mock(name="node", tool=None, inputs={})
node.type = ToolType.PYTHON
node.source = mocker.Mock(type=None)
- with pytest.raises(NotImplementedError) as exec_info:
+ with pytest.raises(ResolveToolError) as exec_info:
resolver.resolve_tool_by_node(node)
- assert "Tool source type" in exec_info.value.args[0]
+
+ assert isinstance(exec_info.value.inner_exception, NotImplementedError)
+ assert "Tool source type" in exec_info.value.message
node.type = ToolType.CUSTOM_LLM
node.source = mocker.Mock(type=None)
- with pytest.raises(NotImplementedError) as exec_info:
+ with pytest.raises(ResolveToolError) as exec_info:
resolver.resolve_tool_by_node(node)
- assert "Tool source type" in exec_info.value.args[0]
+
+ assert isinstance(exec_info.value.inner_exception, NotImplementedError)
+ assert "Tool source type" in exec_info.value.message
def test_resolve_tool_by_node_with_no_source(self, resolver, mocker):
node = mocker.Mock(name="node", tool=None, inputs={})
node.source = None
- with pytest.raises(UserErrorException):
+ with pytest.raises(ResolveToolError) as ex:
resolver.resolve_tool_by_node(node)
+ assert isinstance(ex.value.inner_exception, UserErrorException)
def test_resolve_tool_by_node_with_no_source_path(self, resolver, mocker):
node = mocker.Mock(name="node", tool=None, inputs={})
node.type = ToolType.PROMPT
node.source = mocker.Mock(type=ToolSourceType.Package, path=None)
- with pytest.raises(InvalidSource) as exec_info:
+ with pytest.raises(ResolveToolError) as exec_info:
resolver.resolve_tool_by_node(node)
+
+ assert isinstance(exec_info.value.inner_exception, InvalidSource)
assert "Node source path" in exec_info.value.message
def test_resolve_tool_by_node_with_duplicated_inputs(self, resolver, mocker):
@@ -126,9 +136,11 @@ def test_resolve_tool_by_node_with_duplicated_inputs(self, resolver, mocker):
node.type = ToolType.PROMPT
mocker.patch.object(resolver, "_load_source_content", return_value="{{template}}")
- with pytest.raises(NodeInputValidationError) as exec_info:
+ with pytest.raises(ResolveToolError) as exec_info:
resolver.resolve_tool_by_node(node)
- assert "These inputs are duplicated" in exec_info.value.args[0]
+
+ assert isinstance(exec_info.value.inner_exception, NodeInputValidationError)
+ assert "These inputs are duplicated" in exec_info.value.message
def test_ensure_node_inputs_type(self):
# Case 1: conn_name not in connections, should raise conn_name not found error
@@ -262,7 +274,7 @@ def mock_llm_api_func(prompt: PromptTemplate, **kwargs):
mocker.patch(
"promptflow._core.tools_manager.BuiltinsManager._load_package_tool",
- return_value=(mock_llm_api_func, {"conn": AzureOpenAIConnection})
+ return_value=(mock_llm_api_func, {"conn": AzureOpenAIConnection}),
)
connections = {"conn_name": {"type": "AzureOpenAIConnection", "value": {"api_key": "mock", "api_base": "mock"}}}
@@ -327,7 +339,7 @@ def mock_package_func(prompt: PromptTemplate, **kwargs):
mocker.patch(
"promptflow._core.tools_manager.BuiltinsManager._load_package_tool",
- return_value=(mock_package_func, {"conn": AzureOpenAIConnection})
+ return_value=(mock_package_func, {"conn": AzureOpenAIConnection}),
)
connections = {"conn_name": {"type": "AzureOpenAIConnection", "value": {"api_key": "mock", "api_base": "mock"}}}
@@ -357,7 +369,11 @@ def mock_package_func(prompt: PromptTemplate, **kwargs):
return render_template_jinja2(prompt, **kwargs)
tool_resolver = ToolResolver(working_dir=None, connections={})
- mocker.patch.object(tool_resolver, "_load_source_content", return_value="{{text}}",)
+ mocker.patch.object(
+ tool_resolver,
+ "_load_source_content",
+ return_value="{{text}}",
+ )
tool = Tool(name="mock", type=ToolType.CUSTOM_LLM, inputs={"prompt": InputDefinition(type=["PromptTemplate"])})
node = Node(
diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py
index c552b9d7297..847e12e4cdc 100644
--- a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py
+++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py
@@ -56,6 +56,24 @@ def test_basic_flow_run_bulk_without_env(self, pf, runtime) -> None:
run = pf.runs.get(run=name)
assert isinstance(run, Run)
+ @pytest.mark.skip("Custom tool pkg and promptprompt pkg with CustomStrongTypeConnection not installed on runtime.")
+ def test_basic_flow_run_with_custom_strong_type_connection(self, pf, runtime) -> None:
+ name = str(uuid.uuid4())
+ run_pf_command(
+ "run",
+ "create",
+ "--flow",
+ f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
+ "--data",
+ f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
+ "--name",
+ name,
+ pf=pf,
+ runtime=runtime,
+ )
+ run = pf.runs.get(run=name)
+ assert isinstance(run, Run)
+
def test_run_with_remote_data(self, pf, runtime, remote_web_classification_data, temp_output_dir: str):
# run with arm id
name = str(uuid.uuid4())
diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py
index 944336549c4..971da06f499 100644
--- a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py
+++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_run_operations.py
@@ -11,7 +11,7 @@
import pytest
from promptflow._sdk._constants import RunStatus
-from promptflow._sdk._errors import InvalidRunError, RunNotFoundError
+from promptflow._sdk._errors import InvalidRunError, RunNotFoundError, RunOperationParameterError
from promptflow._sdk._load_functions import load_run
from promptflow._sdk.entities import Run
from promptflow._utils.flow_utils import get_flow_lineage_id
@@ -609,6 +609,48 @@ def fake_submit(*args, **kwargs):
# request id should be included in FlowRequestException
assert f"request id: {remote_client.runs._service_caller._request_id}" in str(e.value)
+ def test_input_output_portal_url_parser(self, remote_client):
+ runs_op = remote_client.runs
+
+ # test input with datastore path
+ input_datastore_path = (
+ "azureml://datastores/workspaceblobstore/paths/LocalUpload/312cca2af474e5f895013392b6b38f45/data.jsonl"
+ )
+ expected_input_portal_url = (
+ f"https://ml.azure.com/data/datastore/workspaceblobstore/edit?wsid={runs_op._common_azure_url_pattern}"
+ f"&activeFilePath=LocalUpload/312cca2af474e5f895013392b6b38f45/data.jsonl#browseTab"
+ )
+ assert runs_op._get_input_portal_url_from_input_uri(input_datastore_path) == expected_input_portal_url
+
+ # test input with asset id
+ input_asset_id = (
+ "azureml://locations/eastus/workspaces/f40fcfba-ed15-4c0c-a522-6798d8d89094/data/hod-qa-sample/versions/1"
+ )
+ expected_input_portal_url = (
+ f"https://ml.azure.com/data/hod-qa-sample/1/details?wsid={runs_op._common_azure_url_pattern}"
+ )
+ assert runs_op._get_input_portal_url_from_input_uri(input_asset_id) == expected_input_portal_url
+
+ # test output with asset id
+ output_asset_id = (
+ "azureml://locations/eastus/workspaces/f40fcfba-ed15-4c0c-a522-6798d8d89094/data"
+ "/azureml_d360affb-c01f-460f-beca-db9a8b88b625_output_data_flow_outputs/versions/1"
+ )
+ expected_output_portal_url = (
+ "https://ml.azure.com/data/azureml_d360affb-c01f-460f-beca-db9a8b88b625_output_data_flow_outputs/1/details"
+ f"?wsid={runs_op._common_azure_url_pattern}"
+ )
+ assert runs_op._get_portal_url_from_asset_id(output_asset_id) == expected_output_portal_url
+
+ def test_wrong_client_parameters(self):
+ # test wrong client parameters
+ with pytest.raises(RunOperationParameterError, match="You have passed in the wrong parameter name"):
+ PFClient(
+ subscription_id="fake_subscription_id",
+ resource_group="fake_resource_group",
+ workspace_name="fake_workspace_name",
+ )
+
def test_get_detail_against_partial_fail_run(self, remote_client, pf, runtime) -> None:
run = pf.run(
flow=f"{FLOWS_DIR}/partial_fail",
diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_telemetry.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_telemetry.py
new file mode 100644
index 00000000000..b55c643a2d5
--- /dev/null
+++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_telemetry.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import contextlib
+import os
+from unittest.mock import patch
+
+import pytest
+
+from promptflow._sdk._configuration import Configuration
+from promptflow._sdk._utils import call_from_extension
+from promptflow._telemetry.logging_handler import PromptFlowSDKLogHandler, get_appinsights_log_handler
+
+
+@contextlib.contextmanager
+def environment_variable_overwrite(key, val):
+ if key in os.environ.keys():
+ backup_value = os.environ[key]
+ else:
+ backup_value = None
+ os.environ[key] = val
+
+ try:
+ yield
+ finally:
+ if backup_value:
+ os.environ[key] = backup_value
+ else:
+ os.environ.pop(key)
+
+
+@contextlib.contextmanager
+def cli_consent_config_overwrite(val):
+ config = Configuration.get_instance()
+ original_consent = config.get_telemetry_consent()
+ config.set_telemetry_consent(val)
+ try:
+ yield
+ finally:
+ if original_consent:
+ config.set_telemetry_consent(original_consent)
+ else:
+ config.set_telemetry_consent(False)
+
+
+@pytest.mark.e2etest
+class TestTelemetry:
+ def test_logging_handler(self):
+ # override environment variable
+ with environment_variable_overwrite("TELEMETRY_ENABLED", "true"):
+ handler = get_appinsights_log_handler()
+ assert isinstance(handler, PromptFlowSDKLogHandler)
+ assert handler._is_telemetry_enabled is True
+
+ with environment_variable_overwrite("TELEMETRY_ENABLED", "false"):
+ handler = get_appinsights_log_handler()
+ assert isinstance(handler, PromptFlowSDKLogHandler)
+ assert handler._is_telemetry_enabled is False
+
+ def test_call_from_extension(self):
+ assert call_from_extension() is False
+ with environment_variable_overwrite("USER_AGENT", "prompt-flow-extension/1.0.0"):
+ assert call_from_extension() is True
+
+ def test_custom_event(self, pf):
+ from opencensus.ext.azure.log_exporter import AzureEventHandler
+
+ def log_event(*args, **kwargs):
+ record = kwargs.get("record", None)
+ assert record.custom_dimensions is not None
+ assert isinstance(record.custom_dimensions, dict)
+ assert record.custom_dimensions.keys() == {
+ "request_id",
+ "activity_name",
+ "activity_type",
+ "subscription_id",
+ "resource_group_name",
+ "workspace_name",
+ "completion_status",
+ "duration_ms",
+ "level",
+ "python_version",
+ "user_agent",
+ "installation_id",
+ }
+ assert record.msg.startswith("pfazure.runs.get")
+
+ with patch.object(AzureEventHandler, "log_record_to_envelope") as mock_log:
+ mock_log.side_effect = log_event
+ with cli_consent_config_overwrite(True):
+ try:
+ pf.runs.get("not_exist")
+ except Exception:
+ pass
diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
index 14f314b1cbb..1bcd46c835c 100644
--- a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
+++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
@@ -45,7 +45,7 @@ def run_pf_command(*args, cwd=None):
os.chdir(origin_cwd)
-@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection")
+@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection", "install_custom_tool_pkg")
@pytest.mark.cli_test
@pytest.mark.e2etest
class TestCli:
@@ -1016,6 +1016,8 @@ def get_node_settings(_flow_dag_path: Path):
# "api_base": "This is my first connection.",
# "promptflow.connection.custom_type": "MyFirstConnection",
# "promptflow.connection.module": "my_tool_package.connections",
+ # "promptflow.connection.package": "test-custom-tools",
+ # "promptflow.connection.package_version": "0.0.1",
# },
# "secrets": {"api_key": SCRUBBED_VALUE},
# },
@@ -1023,7 +1025,9 @@ def get_node_settings(_flow_dag_path: Path):
# ),
],
)
- def test_connection_create_update(self, file_name, expected, update_item, capfd, local_client):
+ def test_connection_create_update(
+ self, install_custom_tool_pkg, file_name, expected, update_item, capfd, local_client
+ ):
name = f"Connection_{str(uuid.uuid4())[:4]}"
run_pf_command("connection", "create", "--file", f"{CONNECTIONS_DIR}/{file_name}", "--name", f"{name}")
out, err = capfd.readouterr()
diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py
index ae6cd96f26a..9f3670947b3 100644
--- a/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py
+++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py
@@ -4,7 +4,6 @@
import pydash
import pytest
-from promptflow._core.tools_manager import register_connections
from promptflow._sdk._constants import SCRUBBED_VALUE, CustomStrongTypeConnectionConfigs
from promptflow._sdk._pf_client import PFClient
from promptflow._sdk.entities import CustomStrongTypeConnection
@@ -16,8 +15,6 @@ class MyCustomConnection(CustomStrongTypeConnection):
api_base: str
-register_connections([MyCustomConnection])
-
_client = PFClient()
TEST_ROOT = Path(__file__).parent.parent.parent
@@ -124,14 +121,31 @@ def test_connection_get_and_update(self):
_client.connections.create_or_update(result)
assert "secrets ['api_key'] value invalid, please fill them" in str(e.value)
- @pytest.mark.skip("test package not installed")
+ def test_connection_get_and_update_with_key(self):
+ # Test api key not updated
+ name = f"Connection_{str(uuid.uuid4())[:4]}"
+ conn = MyCustomConnection(name=name, secrets={"api_key": "test"}, configs={"api_base": "test"})
+ assert conn.api_base == "test"
+ assert conn.configs["api_base"] == "test"
+
+ result = _client.connections.create_or_update(conn)
+ converted_conn = result._convert_to_custom_strong_type()
+
+ assert converted_conn.api_base == "test"
+ converted_conn.api_base = "test2"
+ assert converted_conn.api_base == "test2"
+ assert converted_conn.configs["api_base"] == "test2"
+
+ @pytest.mark.skip("TODO: need to fix random pacakge not found error")
@pytest.mark.parametrize(
"file_name, expected_updated_item, expected_secret_item",
[
("custom_strong_type_connection.yaml", ("api_base", "new_value"), ("api_key", "")),
],
)
- def test_upsert_connection_from_file(self, file_name, expected_updated_item, expected_secret_item):
+ def test_upsert_connection_from_file(
+ self, install_custom_tool_pkg, file_name, expected_updated_item, expected_secret_item
+ ):
from promptflow._cli._pf._connection import _upsert_connection_from_file
name = f"Connection_{str(uuid.uuid4())[:4]}"
diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
index 42ec34e9140..bb38a8e0eaf 100644
--- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
+++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
@@ -50,7 +50,7 @@ def create_run_against_run(client, run: Run) -> Run:
)
-@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection")
+@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection", "install_custom_tool_pkg")
@pytest.mark.sdk_test
@pytest.mark.e2etest
class TestFlowRun:
@@ -247,17 +247,15 @@ def test_custom_connection_overwrite(self, local_client, local_custom_connection
)
assert "Connection with name new_connection not found" in str(e.value)
- def test_custom_strong_type_connection_basic_flow(self, local_client, pf, is_custom_tool_pkg_installed):
- if is_custom_tool_pkg_installed:
- result = pf.run(
- flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
- data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
- connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}},
- )
- run = local_client.runs.get(name=result.name)
- assert run.status == "Completed"
- else:
- pytest.skip("Custom tool package 'my_tools_package_with_cstc' not installed.")
+ @pytest.mark.skip("TODO: need to fix random pacakge not found error")
+ def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg, local_client, pf):
+ result = pf.run(
+ flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
+ data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
+ connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}},
+ )
+ run = local_client.runs.get(name=result.name)
+ assert run.status == "Completed"
def test_run_with_connection_overwrite_non_exist(self, local_client, local_aoai_connection, pf):
# overwrite non_exist connection
@@ -379,10 +377,11 @@ def test_create_run_with_tags(self, pf):
environment_variables={"API_BASE": "${azure_open_ai_connection.api_base}"},
)
assert run.name == name
- assert f"{display_name}-default-" in run.display_name
+ assert "test_run_with_tags" == run.display_name
assert run.tags == tags
def test_run_display_name(self, pf):
+ # use folder name if not specify display_name
run = pf.runs.create_or_update(
run=Run(
flow=Path(f"{FLOWS_DIR}/print_env_var"),
@@ -390,7 +389,9 @@ def test_run_display_name(self, pf):
environment_variables={"API_BASE": "${azure_open_ai_connection.api_base}"},
)
)
- assert "print_env_var-default-" in run.display_name
+ assert run.display_name == "print_env_var"
+
+ # will respect if specified in run
base_run = pf.runs.create_or_update(
run=Run(
flow=Path(f"{FLOWS_DIR}/print_env_var"),
@@ -399,18 +400,29 @@ def test_run_display_name(self, pf):
display_name="my_run",
)
)
- assert "my_run-default-" in base_run.display_name
+ assert base_run.display_name == "my_run"
run = pf.runs.create_or_update(
run=Run(
flow=Path(f"{FLOWS_DIR}/print_env_var"),
data=f"{DATAS_DIR}/env_var_names.jsonl",
environment_variables={"API_BASE": "${azure_open_ai_connection.api_base}"},
- display_name="my_run",
+ display_name="my_run_${variant_id}_${run}",
+ run=base_run,
+ )
+ )
+ assert run.display_name == f"my_run_default_{base_run.name}"
+
+ run = pf.runs.create_or_update(
+ run=Run(
+ flow=Path(f"{FLOWS_DIR}/print_env_var"),
+ data=f"{DATAS_DIR}/env_var_names.jsonl",
+ environment_variables={"API_BASE": "${azure_open_ai_connection.api_base}"},
+ display_name="my_run_${timestamp}",
run=base_run,
)
)
- assert f"{base_run.display_name}-my_run-" in run.display_name
+ assert "${timestamp}" not in run.display_name
def test_run_dump(self, azure_open_ai_connection: AzureOpenAIConnection, pf: PFClient) -> None:
data_path = f"{DATAS_DIR}/webClassification3.jsonl"
diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
index 3b19ef0c3ed..1c73eddc8c9 100644
--- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
+++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
@@ -19,7 +19,7 @@
_client = PFClient()
-@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection")
+@pytest.mark.usefixtures("use_secrets_config_file", "setup_local_connection", "install_custom_tool_pkg")
@pytest.mark.sdk_test
@pytest.mark.e2etest
class TestFlowTest:
@@ -33,15 +33,18 @@ def test_pf_test_flow(self):
result = _client.test(flow=f"{FLOWS_DIR}/web_classification")
assert all([key in FLOW_RESULT_KEYS for key in result])
- def test_pf_test_flow_with_custom_strong_type_connection(self, is_custom_tool_pkg_installed):
- if is_custom_tool_pkg_installed:
- inputs = {"text": "Hello World!"}
- flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute()
+ @pytest.mark.skip("TODO: need to fix random pacakge not found error")
+ def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_tool_pkg):
+ inputs = {"text": "Hello World!"}
+ flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute()
- result = _client.test(flow=flow_path, inputs=inputs)
- assert result == {"out": "connection_value is MyFirstConnection: True"}
- else:
- pytest.skip("Custom tool package 'my_tools_package_with_cstc' not installed.")
+ # Test that connection would be custom strong type in flow
+ result = _client.test(flow=flow_path, inputs=inputs)
+ assert result == {"out": "connection_value is MyFirstConnection: True"}
+
+ # Test that connection
+ result = _client.test(flow=flow_path, inputs=inputs, node="My_Second_Tool_usi3")
+ assert result == "Hello World!This is my first custom connection."
def test_pf_test_with_streaming_output(self):
flow_path = Path(f"{FLOWS_DIR}/chat_flow_with_stream_output")
diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_config.py b/src/promptflow/tests/sdk_cli_test/unittests/test_config.py
index 18992c6e18d..229dcb471f4 100644
--- a/src/promptflow/tests/sdk_cli_test/unittests/test_config.py
+++ b/src/promptflow/tests/sdk_cli_test/unittests/test_config.py
@@ -21,20 +21,13 @@ class TestConfig:
def test_set_config(self, config):
config.set_config("a.b.c.test_key", "test_value")
assert config.get_config("a.b.c.test_key") == "test_value"
- assert config._config == {"a": {"b": {"c": {"test_key": "test_value"}}}}
+ # global config may contain other keys
+ assert config.config["a"] == {"b": {"c": {"test_key": "test_value"}}}
def test_get_config(self, config):
config.set_config("test_key", "test_value")
assert config.get_config("test_key") == "test_value"
- def test_get_telemetry_consent(self, config):
- config.set_telemetry_consent(True)
- assert config.get_telemetry_consent() is True
-
- def test_set_telemetry_consent(self, config):
- config.set_telemetry_consent(True)
- assert config.get_telemetry_consent() is True
-
def test_get_or_set_installation_id(self, config):
user_id = config.get_or_set_installation_id()
assert user_id is not None
diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_pf_client.py b/src/promptflow/tests/sdk_cli_test/unittests/test_pf_client.py
index ad299726092..fb22e9594f4 100644
--- a/src/promptflow/tests/sdk_cli_test/unittests/test_pf_client.py
+++ b/src/promptflow/tests/sdk_cli_test/unittests/test_pf_client.py
@@ -63,3 +63,9 @@ def test_local_azure_connection_extract_workspace(self):
with pytest.raises(ValueError) as e:
LocalAzureConnectionOperations._extract_workspace("azureml:xx")
assert "Malformed connection provider string" in str(e.value)
+
+ with pytest.raises(ValueError) as e:
+ LocalAzureConnectionOperations._extract_workspace(
+ "azureml:/subscriptions/123/resourceGroups/456/providers/Microsoft.MachineLearningServices/workspaces/"
+ )
+ assert "Malformed connection provider string" in str(e.value)
diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py
index 47654bf50e4..35e6f33a0c6 100644
--- a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py
+++ b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py
@@ -5,21 +5,25 @@
import argparse
import os
import shutil
+import sys
import tempfile
from pathlib import Path
+from unittest.mock import patch
import mock
import pandas as pd
import pytest
+from promptflow._cli._params import AppendToDictAction
from promptflow._cli._utils import (
_build_sorted_column_widths_tuple_list,
_calculate_column_widths,
list_of_dict_to_nested_dict,
)
-from promptflow._cli._params import AppendToDictAction
+from promptflow._sdk._constants import HOME_PROMPT_FLOW_DIR
from promptflow._sdk._errors import GenerateFlowToolsJsonError
from promptflow._sdk._utils import (
+ _generate_connections_dir,
decrypt_secret_value,
encrypt_secret_value,
generate_flow_tools_json,
@@ -114,6 +118,19 @@ def test_generate_flow_tools_json_expecting_fail(self) -> None:
flow_tools_json = generate_flow_tools_json(flow_path, dump=False, raise_error=False)
assert len(flow_tools_json["code"]) == 0
+ @pytest.mark.parametrize(
+ "python_path, env_hash",
+ [
+ ("D:\\Tools\\Anaconda3\\envs\\pf\\python.exe", ("a9620c3cdb7ccf3ec9f4005e5b19c12d1e1fef80")),
+ ("/Users/fake_user/anaconda3/envs/pf/bin/python3.10", ("e3f33eadd9be376014eb75a688930930ca83c056")),
+ ],
+ )
+ def test_generate_connections_dir(self, python_path, env_hash):
+ expected_result = (HOME_PROMPT_FLOW_DIR / "envs" / env_hash / "connections").resolve()
+ with patch.object(sys, "executable", python_path):
+ result = _generate_connections_dir()
+ assert result == expected_result
+
@pytest.mark.unittest
class TestCLIUtils:
@@ -128,7 +145,7 @@ def test_list_of_dict_to_nested_dict(self):
def test_append_to_dict_action(self):
parser = argparse.ArgumentParser(prog="test_dict_action")
parser.add_argument("--dict", action=AppendToDictAction, nargs="+")
- args = ["--dict", "key1=val1", "\'key2=val2\'", "\"key3=val3\"", "key4=\'val4\'", "key5=\"val5'"]
+ args = ["--dict", "key1=val1", "'key2=val2'", '"key3=val3"', "key4='val4'", "key5=\"val5'"]
args = parser.parse_args(args)
expect_dict = {
"key1": "val1",
diff --git a/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml b/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml
index e1f94b8b12d..276419d4c73 100644
--- a/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml
+++ b/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml
@@ -2,6 +2,8 @@ name: my_custom_strong_type_connection
type: custom
custom_type: MyFirstConnection
module: my_tool_package.connections
+package: test-custom-tools
+package_version: 0.0.1
configs:
api_base: "This is my first connection."
secrets: # must-have
diff --git a/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml b/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml
index 0449a3f0f3e..6974d7cff10 100644
--- a/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml
+++ b/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml
@@ -2,6 +2,8 @@ name: my_custom_strong_type_connection
type: custom
custom_type: MyFirstConnection
module: my_tool_package.connections
+package: test-custom-tools
+package_version: 0.0.1
configs:
api_base: "new_value"
secrets: # must-have
diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml
index 5e7a519c6fb..c71127c50bc 100644
--- a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml
+++ b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml
@@ -1,16 +1,29 @@
+id: template_standard_flow
+name: Template Standard Flow
+environment:
+ python_requirements_txt: requirements.txt
inputs:
text:
type: string
+ default: Hello!
outputs:
out:
type: string
reference: ${My_First_Tool_00f8.output}
nodes:
+- name: My_Second_Tool_usi3
+ type: python
+ source:
+ type: package
+ tool: my_tool_package.tools.my_tool_2.MyTool.my_tool
+ inputs:
+ connection: custom_strong_type_connection
+ input_text: ${inputs.text}
- name: My_First_Tool_00f8
type: python
source:
type: package
tool: my_tool_package.tools.my_tool_1.my_tool
inputs:
- connection: my_custom_strong_type_connection
- input_text: ${inputs.text}
+ connection: custom_strong_type_connection
+ input_text: ${My_Second_Tool_usi3.output}
diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/requirements.txt b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/requirements.txt
new file mode 100644
index 00000000000..3aa3a8efd89
--- /dev/null
+++ b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/requirements.txt
@@ -0,0 +1 @@
+test-custom-tools==0.0.1
\ No newline at end of file