diff --git a/src/promptflow/dev-connections.json.example b/src/promptflow/dev-connections.json.example index 1c12f7a74ef..ecb5b239847 100644 --- a/src/promptflow/dev-connections.json.example +++ b/src/promptflow/dev-connections.json.example @@ -63,6 +63,19 @@ "key1" ] }, + "custom_strong_type_connection": { + "type": "CustomConnection", + "value": { + "api_key": "", + "api_base": "This is my first custom connection.", + "promptflow.connection.custom_type": "MyFirstConnection", + "promptflow.connection.module": "my_tool_package.connections" + }, + "module": "promptflow.connections", + "secret_keys": [ + "api_key" + ] + }, "open_ai_connection": { "type": "OpenAIConnection", "value": { diff --git a/src/promptflow/promptflow/_cli/_pf/_connection.py b/src/promptflow/promptflow/_cli/_pf/_connection.py index 8988dac2a09..d596fe3abae 100644 --- a/src/promptflow/promptflow/_cli/_pf/_connection.py +++ b/src/promptflow/promptflow/_cli/_pf/_connection.py @@ -8,7 +8,7 @@ from functools import partial from promptflow._cli._params import add_param_set, logging_params -from promptflow._cli._utils import activate_action, confirm, exception_handler, print_yellow_warning, get_secret_input +from promptflow._cli._utils import activate_action, confirm, exception_handler, get_secret_input, print_yellow_warning from promptflow._sdk._constants import LOGGER_NAME from promptflow._sdk._load_functions import load_connection from promptflow._sdk._pf_client import PFClient diff --git a/src/promptflow/promptflow/_core/connection_manager.py b/src/promptflow/promptflow/_core/connection_manager.py index 3bf672e5dfa..8d2bcf12ff7 100644 --- a/src/promptflow/promptflow/_core/connection_manager.py +++ b/src/promptflow/promptflow/_core/connection_manager.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List from promptflow._constants import CONNECTION_NAME_PROPERTY, CONNECTION_SECRET_KEYS, PROMPTFLOW_CONNECTIONS +from promptflow._sdk._constants import CustomStrongTypeConnectionConfigs from promptflow._utils.utils import try_import from promptflow.contracts.tool import ConnectionType from promptflow.contracts.types import Secret @@ -55,6 +56,8 @@ def _build_connections(cls, _dict: Dict[str, dict]): secrets = {k: v for k, v in value.items() if k in secret_keys} configs = {k: v for k, v in value.items() if k not in secrets} connection_value = connection_class(configs=configs, secrets=secrets) + if CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY in configs: + connection_value.custom_type = configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY] else: """ Note: Ignore non exists keys of connection class, diff --git a/src/promptflow/promptflow/_core/tools_manager.py b/src/promptflow/promptflow/_core/tools_manager.py index 72f65bcacd7..1062622811b 100644 --- a/src/promptflow/promptflow/_core/tools_manager.py +++ b/src/promptflow/promptflow/_core/tools_manager.py @@ -4,6 +4,7 @@ import importlib import importlib.util +import inspect import logging import traceback from functools import partial @@ -21,9 +22,13 @@ generate_python_tool, load_python_module_from_file, ) +from promptflow._utils.connection_utils import ( + generate_custom_strong_type_connection_spec, + generate_custom_strong_type_connection_template, +) from promptflow._utils.tool_utils import function_to_tool_definition, get_prompt_param_name_from_func from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType -from promptflow.contracts.tool import Tool, ToolType +from promptflow.contracts.tool import ConnectionType, Tool, ToolType from promptflow.exceptions import ErrorTarget, SystemErrorException, UserErrorException, ValidationException module_logger = logging.getLogger(__name__) @@ -67,6 +72,54 @@ def collect_package_tools(keys: Optional[List[str]] = None) -> dict: return all_package_tools +def collect_package_tools_and_connections(keys: Optional[List[str]] = None) -> dict: + """Collect all tools and custom strong type connections from all installed packages.""" + all_package_tools = {} + all_package_connection_specs = {} + all_package_connection_templates = {} + if keys is not None: + keys = set(keys) + for entry_point in pkg_resources.iter_entry_points(group=PACKAGE_TOOLS_ENTRY): + try: + list_tool_func = entry_point.resolve() + package_tools = list_tool_func() + for identifier, tool in package_tools.items(): + # Only load required tools to avoid unnecessary loading when keys is provided + if isinstance(keys, set) and identifier not in keys: + continue + m = tool["module"] + module = importlib.import_module(m) # Import the module to make sure it is valid + tool["package"] = entry_point.dist.project_name + tool["package_version"] = entry_point.dist.version + all_package_tools[identifier] = tool + + # Get custom strong type connection definition + custom_strong_type_connections_classes = [ + obj + for name, obj in inspect.getmembers(module) + if inspect.isclass(obj) and ConnectionType.is_custom_strong_type(obj) + ] + + if custom_strong_type_connections_classes: + for cls in custom_strong_type_connections_classes: + identifier = f"{cls.__module__}.{cls.__name__}" + connection_spec = generate_custom_strong_type_connection_spec( + cls, entry_point.dist.project_name, entry_point.dist.version + ) + all_package_connection_specs[identifier] = connection_spec + all_package_connection_templates[identifier] = generate_custom_strong_type_connection_template( + cls, connection_spec, entry_point.dist.project_name, entry_point.dist.version + ) + except Exception as e: + msg = ( + f"Failed to load tools from package {entry_point.dist.project_name}: {e}," + + f" traceback: {traceback.format_exc()}" + ) + module_logger.warning(msg) + + return all_package_tools, all_package_connection_specs, all_package_connection_templates + + def gen_tool_by_source(name, source: ToolSource, tool_type: ToolType, working_dir: Path) -> Tool: if source.type == ToolSourceType.Package: package_tools = collect_package_tools() diff --git a/src/promptflow/promptflow/_sdk/_constants.py b/src/promptflow/promptflow/_sdk/_constants.py index 6cd83b5a5f2..18d47b61c61 100644 --- a/src/promptflow/promptflow/_sdk/_constants.py +++ b/src/promptflow/promptflow/_sdk/_constants.py @@ -30,6 +30,8 @@ RUN_INFO_CREATED_ON_INDEX_NAME = "idx_run_info_created_on" CONNECTION_TABLE_NAME = "connection" BASE_PATH_CONTEXT_KEY = "base_path" +SCHEMA_KEYS_CONTEXT_CONFIG_KEY = "schema_configs_keys" +SCHEMA_KEYS_CONTEXT_SECRET_KEY = "schema_secrets_keys" PARAMS_OVERRIDE_KEY = "params_override" FILE_PREFIX = "file:" KEYRING_SYSTEM = "promptflow" @@ -51,6 +53,21 @@ LOCAL_STORAGE_BATCH_SIZE = 1 +class CustomStrongTypeConnectionConfigs: + PREFIX = "promptflow.connection." + TYPE = "custom_type" + MODULE = "module" + PROMPTFLOW_TYPE_KEY = PREFIX + TYPE + PROMPTFLOW_MODULE_KEY = PREFIX + MODULE + + @staticmethod + def is_custom_key(key): + return key not in [ + CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY, + CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY, + ] + + class RunTypes: BATCH = "batch" EVALUATION = "evaluation" diff --git a/src/promptflow/promptflow/_sdk/_load_functions.py b/src/promptflow/promptflow/_sdk/_load_functions.py index 413542f53e1..dde327ddd77 100644 --- a/src/promptflow/promptflow/_sdk/_load_functions.py +++ b/src/promptflow/promptflow/_sdk/_load_functions.py @@ -50,7 +50,12 @@ def load_common( cls, type_str = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override) try: - return cls._load(data=yaml_dict, yaml_path=relative_origin, params_override=params_override, **kwargs) + return cls._load( + data=yaml_dict, + yaml_path=relative_origin, + params_override=params_override, + **kwargs, + ) except Exception as e: raise Exception(f"Load entity error: {e}") from e diff --git a/src/promptflow/promptflow/_sdk/entities/__init__.py b/src/promptflow/promptflow/_sdk/entities/__init__.py index 541bc062baa..6a90a464a0f 100644 --- a/src/promptflow/promptflow/_sdk/entities/__init__.py +++ b/src/promptflow/promptflow/_sdk/entities/__init__.py @@ -15,6 +15,7 @@ QdrantConnection, WeaviateConnection, FormRecognizerConnection, + CustomStrongTypeConnection, ) from ._run import Run from ._validation import ValidationResult @@ -25,6 +26,7 @@ "AzureOpenAIConnection", "OpenAIConnection", "CustomConnection", + "CustomStrongTypeConnection", "CognitiveSearchConnection", "SerpConnection", "QdrantConnection", diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index 2a5f67b5374..b47375843aa 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import abc +import importlib import json from os import PathLike from pathlib import Path @@ -10,11 +11,14 @@ from promptflow._sdk._constants import ( BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, + SCHEMA_KEYS_CONTEXT_CONFIG_KEY, + SCHEMA_KEYS_CONTEXT_SECRET_KEY, SCRUBBED_VALUE, SCRUBBED_VALUE_NO_CHANGE, SCRUBBED_VALUE_USER_INPUT, ConfigValueType, ConnectionType, + CustomStrongTypeConnectionConfigs, ) from promptflow._sdk._errors import UnsecureConnectionError from promptflow._sdk._logger_factory import LoggerFactory @@ -33,14 +37,17 @@ AzureOpenAIConnectionSchema, CognitiveSearchConnectionSchema, CustomConnectionSchema, + CustomStrongTypeConnectionSchema, FormRecognizerConnectionSchema, OpenAIConnectionSchema, QdrantConnectionSchema, SerpConnectionSchema, WeaviateConnectionSchema, ) +from promptflow.contracts.types import Secret logger = LoggerFactory.get_logger(name=__name__) +PROMPTFLOW_CONNECTIONS = "promptflow.connections" class _Connection(YAMLTranslatableMixin): @@ -618,6 +625,115 @@ def _get_schema_cls(cls): return FormRecognizerConnectionSchema +class CustomStrongTypeConnection(_Connection): + """Custom strong type connection. + + .. note:: + + This connection type should not be used directly. Below is an example of how to use CustomStrongTypeConnection: + + .. code-block:: python + + class MyCustomConnection(CustomStrongTypeConnection): + api_key: Secret + api_base: str + + :param configs: The configs kv pairs. + :type configs: Dict[str, str] + :param secrets: The secrets kv pairs. + :type secrets: Dict[str, str] + :param name: Connection name + :type name: str + """ + + def __init__( + self, + secrets: Dict[str, str], + configs: Dict[str, str] = None, + **kwargs, + ): + # There are two cases to init a Custom strong type connection: + # 1. The connection is created through SDK PFClient, custom_type and custom_module are not in the kwargs. + # 2. The connection is loaded from template file, custom_type and custom_module are in the kwargs. + custom_type = kwargs.get(CustomStrongTypeConnectionConfigs.TYPE, None) + custom_module = kwargs.get(CustomStrongTypeConnectionConfigs.MODULE, None) + if custom_type: + configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY: custom_type}) + if custom_module: + configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY: custom_module}) + self.kwargs = kwargs + super().__init__(configs=configs, secrets=secrets, **kwargs) + self.module = kwargs.get("module", self.__class__.__module__) + self.custom_type = custom_type or self.__class__.__name__ + + def _to_orm_object(self) -> ORMConnection: + custom_connection = self._convert_to_custom() + return custom_connection._to_orm_object() + + def _convert_to_custom(self): + # update configs + self.configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY: self.custom_type}) + self.configs.update({CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY: self.module}) + + custom_connection = CustomConnection(configs=self.configs, secrets=self.secrets, **self.kwargs) + return custom_connection + + @classmethod + def _get_custom_keys(cls, data): + # The data could be either from yaml or from DB. + # If from yaml, 'custom_type' and 'module' are outside the configs of data. + # If from DB, 'custom_type' and 'module' are within the configs of data. + if not data.get(CustomStrongTypeConnectionConfigs.TYPE) or not data.get( + CustomStrongTypeConnectionConfigs.MODULE + ): + if ( + not data["configs"][CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY] + or not data["configs"][CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY] + ): + raise ValueError("custom_type and module are required for custom strong type connections.") + else: + m = data["configs"][CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY] + custom_cls = data["configs"][CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY] + else: + m = data[CustomStrongTypeConnectionConfigs.MODULE] + custom_cls = data[CustomStrongTypeConnectionConfigs.TYPE] + + try: + module = importlib.import_module(m) + cls = getattr(module, custom_cls) + except ImportError: + raise ValueError( + f"Can't find module {m} in current environment. Please check the module is correctly configured." + ) + except AttributeError: + raise ValueError( + f"Can't find class {custom_cls} in module {m}. Please check the custom_type is correctly configured." + ) + + schema_configs = {} + schema_secrets = {} + + for k, v in cls.__annotations__.items(): + if v == Secret: + schema_secrets[k] = v + else: + schema_configs[k] = v + + return schema_configs, schema_secrets + + @classmethod + def _get_schema_cls(cls): + return CustomStrongTypeConnectionSchema + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = None, **kwargs): + schema_config_keys, schema_secret_keys = cls._get_custom_keys(data) + context[SCHEMA_KEYS_CONTEXT_CONFIG_KEY] = schema_config_keys + context[SCHEMA_KEYS_CONTEXT_SECRET_KEY] = schema_secret_keys + + return (super()._load_from_dict(data, context, additional_message, **kwargs))._convert_to_custom() + + class CustomConnection(_Connection): """Custom connection. @@ -631,13 +747,34 @@ class CustomConnection(_Connection): TYPE = ConnectionType.CUSTOM - def __init__(self, secrets: Dict[str, str], configs: Dict[str, str] = None, **kwargs): + def __init__( + self, + secrets: Dict[str, str], + configs: Dict[str, str] = None, + **kwargs, + ): super().__init__(secrets=secrets, configs=configs, **kwargs) @classmethod def _get_schema_cls(cls): return CustomConnectionSchema + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = None, **kwargs): + # If context has params_override, it means the data would be updated by overridden values. + # Provide CustomStrongTypeConnectionSchema if 'custom_type' in params_override, else CustomConnectionSchema. + # For example: + # If a user updates an existing connection by re-upserting a connection file, + # the 'data' from DB is CustomConnection, + # but 'params_override' would actually contain custom strong type connection data. + is_custom_strong_type = data.get(CustomStrongTypeConnectionConfigs.TYPE) or any( + CustomStrongTypeConnectionConfigs.TYPE in d for d in context.get(PARAMS_OVERRIDE_KEY, []) + ) + if is_custom_strong_type: + return CustomStrongTypeConnection._load_from_dict(data, context, additional_message, **kwargs) + + return super()._load_from_dict(data, context, additional_message, **kwargs) + def __getattr__(self, item): # Note: This is added for compatibility with promptflow.connections custom connection usage. if item == "secrets": @@ -676,6 +813,7 @@ def _to_orm_object(self): custom_configs.update( {k: {"configValueType": ConfigValueType.SECRET.value, "value": v} for k, v in encrypted_secrets.items()} ) + return ORMConnection( connectionName=self.name, connectionType=self.type.value, @@ -698,7 +836,11 @@ def _from_orm_object_with_secrets(cls, orm_object: ORMConnection): secrets = {} unsecure_connection = False + custom_type = None for k, v in json.loads(orm_object.customConfigs).items(): + if k == CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY: + custom_type = v["value"] + continue if not v["configValueType"] == ConfigValueType.SECRET.value: continue try: @@ -717,6 +859,7 @@ def _from_orm_object_with_secrets(cls, orm_object: ORMConnection): name=orm_object.connectionName, configs=configs, secrets=secrets, + custom_type=custom_type, expiry_time=orm_object.expiryTime, created_date=orm_object.createdDate, last_modified_date=orm_object.lastModifiedDate, @@ -748,6 +891,23 @@ def _from_mt_rest_object(cls, mt_rest_obj): last_modified_date=mt_rest_obj.last_modified_date, ) + def _is_custom_strong_type(self): + return ( + CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY in self.configs + and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY] + ) + + def _convert_to_custom_strong_type(self): + module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY) + custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY) + import importlib + + module = importlib.import_module(module_name) + custom_defined_connection_class = getattr(module, custom_type_class_name) + connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets) + + return connection_instance + _supported_types = { v.TYPE.value: v diff --git a/src/promptflow/promptflow/_sdk/schemas/_connection.py b/src/promptflow/promptflow/_sdk/schemas/_connection.py index 2d8608de91e..cc868d87293 100644 --- a/src/promptflow/promptflow/_sdk/schemas/_connection.py +++ b/src/promptflow/promptflow/_sdk/schemas/_connection.py @@ -3,9 +3,14 @@ # --------------------------------------------------------- import copy -from marshmallow import fields, pre_dump - -from promptflow._sdk._constants import ConnectionType +from marshmallow import ValidationError, fields, pre_dump, validates + +from promptflow._sdk._constants import ( + SCHEMA_KEYS_CONTEXT_CONFIG_KEY, + SCHEMA_KEYS_CONTEXT_SECRET_KEY, + ConnectionType, + CustomStrongTypeConnectionConfigs, +) from promptflow._sdk.schemas._base import YamlFileSchema from promptflow._sdk.schemas._fields import StringTransformedEnum from promptflow._utils.utils import camel_to_snake @@ -111,3 +116,26 @@ class CustomConnectionSchema(ConnectionSchema): configs = fields.Dict(keys=fields.Str(), values=fields.Str()) # Secrets is a must-have field for CustomConnection secrets = fields.Dict(keys=fields.Str(), values=fields.Str(), required=True) + + +class CustomStrongTypeConnectionSchema(CustomConnectionSchema): + name = fields.Str(attribute="name") + module = fields.Str(required=True) + custom_type = fields.Str(required=True) + + # TODO: validate configs and secrets + @validates("configs") + def validate_configs(self, value): + schema_config_keys = self.context.get(SCHEMA_KEYS_CONTEXT_CONFIG_KEY, None) + if schema_config_keys: + for key in value: + if CustomStrongTypeConnectionConfigs.is_custom_key(key) and key not in schema_config_keys: + raise ValidationError(f"Invalid config key {key}, please check the schema.") + + @validates("secrets") + def validate_secrets(self, value): + schema_secret_keys = self.context.get(SCHEMA_KEYS_CONTEXT_SECRET_KEY, None) + if schema_secret_keys: + for key in value: + if key not in schema_secret_keys: + raise ValidationError(f"Invalid secret key {key}, please check the schema.") diff --git a/src/promptflow/promptflow/_utils/connection_utils.py b/src/promptflow/promptflow/_utils/connection_utils.py new file mode 100644 index 00000000000..f9574a2830b --- /dev/null +++ b/src/promptflow/promptflow/_utils/connection_utils.py @@ -0,0 +1,68 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from jinja2 import Template + + +def generate_custom_strong_type_connection_spec(cls, package, package_version): + connection_spec = { + "connectionCategory": "CustomKeys", + "flowValueType": "CustomConnection", + "connectionType": cls.__name__, + "ConnectionTypeDisplayName": cls.__name__, + "configSpecs": [], + "module": cls.__module__, + "package": package, + "package_version": package_version, + } + + for k, typ in cls.__annotations__.items(): + spec = { + "name": k, + "displayName": k.replace("_", " ").title(), + "configValueType": typ.__name__, + } + if hasattr(cls, k): + spec["isOptional"] = getattr(cls, k, None) is not None + else: + spec["isOptional"] = False + connection_spec["configSpecs"].append(spec) + + return connection_spec + + +def generate_custom_strong_type_connection_template(cls, connection_spec, package, package_version): + connection_template_str = """ + name: + type: custom + custom_type: {{ custom_type }} + module: {{ module }} + package: {{ package }} + package_version: {{ package_version }} + configs: + {% for key, value in configs.items() %} + {{ key }}: "{{ value -}}"{% endfor %} + secrets: # must-have{% for key, value in secrets.items() %} + {{ key }}: "{{ value -}}"{% endfor %} + """ + + configs = {} + secrets = {} + connection_template = Template(connection_template_str) + for spec in connection_spec["configSpecs"]: + if spec["configValueType"] == "Secret": + secrets[spec["name"]] = "<" + spec["name"].replace("_", "-") + ">" + else: + configs[spec["name"]] = "<" + spec["name"].replace("_", "-") + ">" + + data = { + "custom_type": cls.__name__, + "module": cls.__module__, + "package": package, + "package_version": package_version, + "configs": configs, + "secrets": secrets, + } + + return connection_template.render(data) diff --git a/src/promptflow/promptflow/connections/__init__.py b/src/promptflow/promptflow/connections/__init__.py index 2dda16e7630..c2614192274 100644 --- a/src/promptflow/promptflow/connections/__init__.py +++ b/src/promptflow/promptflow/connections/__init__.py @@ -12,6 +12,7 @@ FormRecognizerConnection, OpenAIConnection, SerpConnection, + CustomStrongTypeConnection, ) from promptflow._sdk.entities._connection import _Connection from promptflow.contracts.types import Secret @@ -35,6 +36,7 @@ class BingConnection: "CognitiveSearchConnection", "FormRecognizerConnection", "CustomConnection", + "CustomStrongTypeConnection", ] register_connections( diff --git a/src/promptflow/promptflow/contracts/tool.py b/src/promptflow/promptflow/contracts/tool.py index 5fa15043843..30decfea8fb 100644 --- a/src/promptflow/promptflow/contracts/tool.py +++ b/src/promptflow/promptflow/contracts/tool.py @@ -173,7 +173,15 @@ def is_connection_value(val: Any) -> bool: from promptflow._core.tools_manager import connections val = type(val) if not isinstance(val, type) else val - return val in connections.values() + return val in connections.values() or ConnectionType.is_custom_strong_type(val) + + @staticmethod + def is_custom_strong_type(val): + """Check if the given value is a custom strong type connection.""" + + from promptflow._sdk.entities import CustomStrongTypeConnection + + return issubclass(val, CustomStrongTypeConnection) @staticmethod def serialize_conn(connection: Any) -> dict: diff --git a/src/promptflow/promptflow/executor/_tool_resolver.py b/src/promptflow/promptflow/executor/_tool_resolver.py index 8490f9161a6..91c34f94d7b 100644 --- a/src/promptflow/promptflow/executor/_tool_resolver.py +++ b/src/promptflow/promptflow/executor/_tool_resolver.py @@ -10,11 +10,8 @@ from typing import Callable, List, Optional from promptflow._core.connection_manager import ConnectionManager -from promptflow._core.tools_manager import ( - BuiltinsManager, - ToolLoader, - connection_type_to_api_mapping, -) +from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping +from promptflow._sdk.entities import CustomConnection from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType @@ -55,6 +52,10 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c connection_value = self._connection_manager.get(v.value) if not connection_value: raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.") + + if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type(): + return connection_value._convert_to_custom_strong_type() + # Check if type matched if not any(type(connection_value).__name__ == typ for typ in conn_types): msg = ( diff --git a/src/promptflow/tests/conftest.py b/src/promptflow/tests/conftest.py index 856dce8af11..2e07dd77c5c 100644 --- a/src/promptflow/tests/conftest.py +++ b/src/promptflow/tests/conftest.py @@ -105,3 +105,15 @@ def prepare_symbolic_flow() -> str: if not Path(file_name).exists(): os.symlink(source_folder / file_name, file_name) return target_folder + + +@pytest.fixture +def is_custom_tool_pkg_installed() -> bool: + try: + import my_tool_package # noqa: F401 + + pkg_installed = True + except ImportError: + pkg_installed = False + + return pkg_installed diff --git a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py index 0302cd3074e..11627b45f62 100644 --- a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py +++ b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py @@ -1,10 +1,17 @@ from pathlib import Path import pytest +import yaml from promptflow import tool from promptflow._core._errors import NotSupported, PackageToolNotFoundError -from promptflow._core.tools_manager import NodeSourcePathEmpty, ToolLoader, collect_package_tools, gen_tool_by_source +from promptflow._core.tools_manager import ( + NodeSourcePathEmpty, + ToolLoader, + collect_package_tools, + collect_package_tools_and_connections, + gen_tool_by_source, +) from promptflow.contracts.flow import Node, ToolSource, ToolSourceType from promptflow.contracts.tool import Tool, ToolType from promptflow.exceptions import UserErrorException @@ -128,3 +135,36 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro with pytest.raises(error_code) as ex: gen_tool_by_source("fake_name", tool_source, tool_type, working_dir), assert str(ex.value) == error_message + + @pytest.mark.skip("test package not installed") + def test_collect_package_tools_and_connections(self): + keys = ["my_tool_package.tools.my_tool_2.MyTool.my_tool"] + tools, specs, templates = collect_package_tools_and_connections(keys) + assert len(tools) == 1 + assert specs == { + "my_tool_package.connections.MySecondConnection": { + "connectionCategory": "CustomKeys", + "flowValueType": "CustomConnection", + "connectionType": "MySecondConnection", + "ConnectionTypeDisplayName": "MySecondConnection", + "configSpecs": [ + {"name": "api_key", "displayName": "Api Key", "configValueType": "Secret", "isOptional": False}, + {"name": "api_base", "displayName": "Api Base", "configValueType": "str", "isOptional": True}, + ], + "module": "my_tool_package.connections", + "package": "my-tools-package-with-cstc", + "package_version": "0.0.6", + } + } + expected_template = { + "name": "", + "type": "custom", + "custom_type": "MySecondConnection", + "module": "my_tool_package.connections", + "package": "my-tools-package-with-cstc", + "package_version": "0.0.6", + "configs": {"api_base": ""}, + "secrets": {"api_key": ""}, + } + loaded_yaml = yaml.safe_load(templates["my_tool_package.connections.MySecondConnection"]) + assert loaded_yaml == expected_template diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py index a241a745060..14f314b1cbb 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py @@ -1007,6 +1007,20 @@ def get_node_settings(_flow_dag_path: Path): }, ("configs.key1", "new_value"), ), + # ( + # "custom_strong_type_connection.yaml", + # { + # "module": "promptflow.connections", + # "type": "custom", + # "configs": { + # "api_base": "This is my first connection.", + # "promptflow.connection.custom_type": "MyFirstConnection", + # "promptflow.connection.module": "my_tool_package.connections", + # }, + # "secrets": {"api_key": SCRUBBED_VALUE}, + # }, + # ("configs.api_base", "new_value"), + # ), ], ) def test_connection_create_update(self, file_name, expected, update_item, capfd, local_client): diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py new file mode 100644 index 00000000000..ae6cd96f26a --- /dev/null +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py @@ -0,0 +1,153 @@ +import uuid +from pathlib import Path + +import pydash +import pytest + +from promptflow._core.tools_manager import register_connections +from promptflow._sdk._constants import SCRUBBED_VALUE, CustomStrongTypeConnectionConfigs +from promptflow._sdk._pf_client import PFClient +from promptflow._sdk.entities import CustomStrongTypeConnection +from promptflow.contracts.types import Secret + + +class MyCustomConnection(CustomStrongTypeConnection): + api_key: Secret + api_base: str + + +register_connections([MyCustomConnection]) + +_client = PFClient() + +TEST_ROOT = Path(__file__).parent.parent.parent +CONNECTION_ROOT = TEST_ROOT / "test_configs/connections" + + +@pytest.mark.cli_test +@pytest.mark.e2etest +class TestCustomStrongTypeConnection: + def test_connection_operations(self): + name = f"Connection_{str(uuid.uuid4())[:4]}" + conn = MyCustomConnection(name=name, secrets={"api_key": "test"}, configs={"api_base": "test"}) + # Create + _client.connections.create_or_update(conn) + # Get + result = _client.connections.get(name) + assert pydash.omit(result._to_dict(), ["created_date", "last_modified_date", "name"]) == { + "module": "promptflow.connections", + "type": "custom", + "configs": { + "api_base": "test", + "promptflow.connection.custom_type": "MyCustomConnection", + "promptflow.connection.module": "sdk_cli_test.e2etests.test_custom_strong_type_connection", + }, + "secrets": {"api_key": "******"}, + } + # Update + conn.configs["api_base"] = "test2" + result = _client.connections.create_or_update(conn) + assert pydash.omit(result._to_dict(), ["created_date", "last_modified_date", "name"]) == { + "module": "promptflow.connections", + "type": "custom", + "configs": { + "api_base": "test2", + "promptflow.connection.custom_type": "MyCustomConnection", + "promptflow.connection.module": "sdk_cli_test.e2etests.test_custom_strong_type_connection", + }, + "secrets": {"api_key": "******"}, + } + # List + result = _client.connections.list() + assert len(result) > 0 + # Delete + _client.connections.delete(name) + with pytest.raises(Exception) as e: + _client.connections.get(name) + assert "is not found." in str(e.value) + + def test_connection_update(self): + name = f"Connection_{str(uuid.uuid4())[:4]}" + conn = MyCustomConnection(name=name, secrets={"api_key": "test"}, configs={"api_base": "test"}) + # Create + _client.connections.create_or_update(conn) + # Get + custom_conn = _client.connections.get(name) + assert pydash.omit(custom_conn._to_dict(), ["created_date", "last_modified_date", "name"]) == { + "module": "promptflow.connections", + "type": "custom", + "configs": { + "api_base": "test", + "promptflow.connection.custom_type": "MyCustomConnection", + "promptflow.connection.module": "sdk_cli_test.e2etests.test_custom_strong_type_connection", + }, + "secrets": {"api_key": "******"}, + } + # Update + custom_conn.configs["api_base"] = "test2" + result = _client.connections.create_or_update(custom_conn) + assert pydash.omit(result._to_dict(), ["created_date", "last_modified_date", "name"]) == { + "module": "promptflow.connections", + "type": "custom", + "configs": { + "api_base": "test2", + "promptflow.connection.custom_type": "MyCustomConnection", + "promptflow.connection.module": "sdk_cli_test.e2etests.test_custom_strong_type_connection", + }, + "secrets": {"api_key": "******"}, + } + # List + result = _client.connections.list() + assert len(result) > 0 + # Delete + _client.connections.delete(name) + with pytest.raises(Exception) as e: + _client.connections.get(name) + assert "is not found." in str(e.value) + + def test_connection_get_and_update(self): + # Test api key not updated + name = f"Connection_{str(uuid.uuid4())[:4]}" + conn = MyCustomConnection(name=name, secrets={"api_key": "test"}, configs={"api_base": "test"}) + result = _client.connections.create_or_update(conn) + assert result.secrets["api_key"] == SCRUBBED_VALUE + # Update api_base only Assert no exception + result.configs["api_base"] = "test2" + result = _client.connections.create_or_update(result) + assert result._to_dict()["configs"]["api_base"] == "test2" + # Assert value not scrubbed + assert result._secrets["api_key"] == "test" + _client.connections.delete(name) + # Invalid update + with pytest.raises(Exception) as e: + result._secrets = {} + _client.connections.create_or_update(result) + assert "secrets ['api_key'] value invalid, please fill them" in str(e.value) + + @pytest.mark.skip("test package not installed") + @pytest.mark.parametrize( + "file_name, expected_updated_item, expected_secret_item", + [ + ("custom_strong_type_connection.yaml", ("api_base", "new_value"), ("api_key", "")), + ], + ) + def test_upsert_connection_from_file(self, file_name, expected_updated_item, expected_secret_item): + from promptflow._cli._pf._connection import _upsert_connection_from_file + + name = f"Connection_{str(uuid.uuid4())[:4]}" + result = _upsert_connection_from_file(file=CONNECTION_ROOT / file_name, params_override=[{"name": name}]) + assert result is not None + assert result.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY] == "my_tool_package.connections" + update_file_name = f"update_{file_name}" + result = _upsert_connection_from_file(file=CONNECTION_ROOT / update_file_name, params_override=[{"name": name}]) + # Test secrets not updated, and configs updated + assert ( + result.configs[expected_updated_item[0]] == expected_updated_item[1] + ), "Assert configs updated failed, expected: {}, actual: {}".format( + expected_updated_item[1], result.configs[expected_updated_item[0]] + ) + assert ( + result._secrets[expected_secret_item[0]] == expected_secret_item[1] + ), "Assert secrets not updated failed, expected: {}, actual: {}".format( + expected_secret_item[1], result._secrets[expected_secret_item[0]] + ) diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py index 9f907a79852..42ec34e9140 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py @@ -247,6 +247,18 @@ def test_custom_connection_overwrite(self, local_client, local_custom_connection ) assert "Connection with name new_connection not found" in str(e.value) + def test_custom_strong_type_connection_basic_flow(self, local_client, pf, is_custom_tool_pkg_installed): + if is_custom_tool_pkg_installed: + result = pf.run( + flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow", + data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl", + connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}}, + ) + run = local_client.runs.get(name=result.name) + assert run.status == "Completed" + else: + pytest.skip("Custom tool package 'my_tools_package_with_cstc' not installed.") + def test_run_with_connection_overwrite_non_exist(self, local_client, local_aoai_connection, pf): # overwrite non_exist connection with pytest.raises(Exception) as e: diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py index baec3b2331a..3b19ef0c3ed 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py @@ -1,11 +1,11 @@ +import logging from pathlib import Path from types import GeneratorType -import logging import pytest -from promptflow._sdk._pf_client import PFClient from promptflow._sdk._constants import LOGGER_NAME +from promptflow._sdk._pf_client import PFClient from promptflow.exceptions import UserErrorException PROMOTFLOW_ROOT = Path(__file__) / "../../../.." @@ -33,6 +33,16 @@ def test_pf_test_flow(self): result = _client.test(flow=f"{FLOWS_DIR}/web_classification") assert all([key in FLOW_RESULT_KEYS for key in result]) + def test_pf_test_flow_with_custom_strong_type_connection(self, is_custom_tool_pkg_installed): + if is_custom_tool_pkg_installed: + inputs = {"text": "Hello World!"} + flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute() + + result = _client.test(flow=flow_path, inputs=inputs) + assert result == {"out": "connection_value is MyFirstConnection: True"} + else: + pytest.skip("Custom tool package 'my_tools_package_with_cstc' not installed.") + def test_pf_test_with_streaming_output(self): flow_path = Path(f"{FLOWS_DIR}/chat_flow_with_stream_output") result = _client.test(flow=flow_path) diff --git a/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml b/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml new file mode 100644 index 00000000000..e1f94b8b12d --- /dev/null +++ b/src/promptflow/tests/test_configs/connections/custom_strong_type_connection.yaml @@ -0,0 +1,8 @@ +name: my_custom_strong_type_connection +type: custom +custom_type: MyFirstConnection +module: my_tool_package.connections +configs: + api_base: "This is my first connection." +secrets: # must-have + api_key: "" \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml b/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml new file mode 100644 index 00000000000..0449a3f0f3e --- /dev/null +++ b/src/promptflow/tests/test_configs/connections/update_custom_strong_type_connection.yaml @@ -0,0 +1,8 @@ +name: my_custom_strong_type_connection +type: custom +custom_type: MyFirstConnection +module: my_tool_package.connections +configs: + api_base: "new_value" +secrets: # must-have + api_key: "******" \ No newline at end of file diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/data.jsonl b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/data.jsonl new file mode 100644 index 00000000000..15e3aa54262 --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/data.jsonl @@ -0,0 +1 @@ +{"text": "Hello World!"} diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml new file mode 100644 index 00000000000..5e7a519c6fb --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml @@ -0,0 +1,16 @@ +inputs: + text: + type: string +outputs: + out: + type: string + reference: ${My_First_Tool_00f8.output} +nodes: +- name: My_First_Tool_00f8 + type: python + source: + type: package + tool: my_tool_package.tools.my_tool_1.my_tool + inputs: + connection: my_custom_strong_type_connection + input_text: ${inputs.text}