Skip to content

Commit

Permalink
Support custom strong type connection (#319)
Browse files Browse the repository at this point in the history
# Description

**This PR does the following things:**
1. Provide generate connection spec and template method when gen tool
meta (for extension)
2. Connection CRUD:
    a. create through file
    b. create through PFClient
3. Submit local flow can succeed with custom strong type connection
4. Convert connection in flow back to strong type connection when
executing user tool scripts in tool resolver


How user writes their own custom strong type connection:
```python
from promptflow.connections import CustomStrongTypeConnection
from promptflow.contracts.types import Secret

class MyCustomConnection(CustomStrongTypeConnection):
    api_key: Secret
    api_base: str
```
Connection spec example:
```
{
    'my_tool_package.connections.MyFirstConnection': {
        'connectionCategory': 'CustomKeys',
        'flowValueType': 'CustomConnection',
        'connectionType': 'MyFirstConnection',
        'ConnectionTypeDisplayName': 'MyFirstConnection',
        'configSpecs': [{
                'name': 'api_key',
                'displayName': 'Api Key',
                'configValueType': 'Secret',
                'isOptional': True
            }, {
                'name': 'api_base',
                'displayName': 'Api Base',
                'configValueType': 'str',
                'isOptional': True
            }
        ],
        'module': 'my_tool_package.connections',
        'package': 'my-tools-package-with-cstc',
        'package_version': '0.0.2'
    }
}
```
Connection template example:
```
{
    "name": "<connection-name>",
    "type": "custom",
    "custom_type": "MyFirstConnection",
    "module": "my_tool_package.connections",
    "package": "my-tools-package-with-cstc",
    "package_version": "0.0.2",
    "configs": {
        "api_base": "<api-base>"
    },
    "secrets": {
        "api_key": "<api-key>"
    },
}
```
Create connection through file command:
```
pf connection create -f <path-to-my-first-connection>
```
CustomConfigs of created connection in DB:
```
{
    "api_base": {
        "configValueType": "String",
        "value": "This is my first custom connection."
    },
    "promptflow.custom.connection.custom_type": {
        "configValueType": "String",
        "value": "MyFirstConnection"
    },
    "promptflow.custom.connection.module": {
        "configValueType": "String",
        "value": "my_tool_package.connections"
    },
    "api_key": {
        "configValueType": "Secret",
        "value": "XXX" // real key endswith "=="
    }
}

```
Create connection through PFClient:
```
client = PFClient()
conn = MyFirstConnection(name=name, secrets={"api_key": "test"}, configs={"api_base": "test"})
client.connections.create_or_update(conn)
```

**Things to refine:**
1. Functions marked with TODO
2. Exception handling and error message refine

**Work items yet to do:**
1. [Task
2687123](https://msdata.visualstudio.com/Vienna/_workitems/edit/2687123):
Promptflow Cli support generate custom strong type connection template
2. [Task
2691873](https://msdata.visualstudio.com/Vienna/_workitems/edit/2691873):
Gen custom tool yaml with extended contract
3. [Task
2691944](https://msdata.visualstudio.com/Vienna/_workitems/edit/2691944):
Enable schema check for custom strong type connection
4. [Task
2692760](https://msdata.visualstudio.com/Vienna/_workitems/edit/2692760):
Support to use custom strong type connection in script tool
5. [Task
2692758](https://msdata.visualstudio.com/Vienna/_workitems/edit/2692758):
When gen connection template, also move the comments of connection class
over

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: yalu4 <[email protected]>
  • Loading branch information
16oeahr and yalu4 authored Sep 21, 2023
1 parent 54a29fb commit 9b9db37
Show file tree
Hide file tree
Showing 23 changed files with 650 additions and 16 deletions.
13 changes: 13 additions & 0 deletions src/promptflow/dev-connections.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@
"key1"
]
},
"custom_strong_type_connection": {
"type": "CustomConnection",
"value": {
"api_key": "<your-key>",
"api_base": "This is my first custom connection.",
"promptflow.connection.custom_type": "MyFirstConnection",
"promptflow.connection.module": "my_tool_package.connections"
},
"module": "promptflow.connections",
"secret_keys": [
"api_key"
]
},
"open_ai_connection": {
"type": "OpenAIConnection",
"value": {
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_cli/_pf/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial

from promptflow._cli._params import add_param_set, logging_params
from promptflow._cli._utils import activate_action, confirm, exception_handler, print_yellow_warning, get_secret_input
from promptflow._cli._utils import activate_action, confirm, exception_handler, get_secret_input, print_yellow_warning
from promptflow._sdk._constants import LOGGER_NAME
from promptflow._sdk._load_functions import load_connection
from promptflow._sdk._pf_client import PFClient
Expand Down
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/_core/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Dict, List

from promptflow._constants import CONNECTION_NAME_PROPERTY, CONNECTION_SECRET_KEYS, PROMPTFLOW_CONNECTIONS
from promptflow._sdk._constants import CustomStrongTypeConnectionConfigs
from promptflow._utils.utils import try_import
from promptflow.contracts.tool import ConnectionType
from promptflow.contracts.types import Secret
Expand Down Expand Up @@ -55,6 +56,8 @@ def _build_connections(cls, _dict: Dict[str, dict]):
secrets = {k: v for k, v in value.items() if k in secret_keys}
configs = {k: v for k, v in value.items() if k not in secrets}
connection_value = connection_class(configs=configs, secrets=secrets)
if CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY in configs:
connection_value.custom_type = configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
else:
"""
Note: Ignore non exists keys of connection class,
Expand Down
55 changes: 54 additions & 1 deletion src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import importlib
import importlib.util
import inspect
import logging
import traceback
from functools import partial
Expand All @@ -21,9 +22,13 @@
generate_python_tool,
load_python_module_from_file,
)
from promptflow._utils.connection_utils import (
generate_custom_strong_type_connection_spec,
generate_custom_strong_type_connection_template,
)
from promptflow._utils.tool_utils import function_to_tool_definition, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType
from promptflow.contracts.tool import Tool, ToolType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType
from promptflow.exceptions import ErrorTarget, SystemErrorException, UserErrorException, ValidationException

module_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +72,54 @@ def collect_package_tools(keys: Optional[List[str]] = None) -> dict:
return all_package_tools


def collect_package_tools_and_connections(keys: Optional[List[str]] = None) -> dict:
"""Collect all tools and custom strong type connections from all installed packages."""
all_package_tools = {}
all_package_connection_specs = {}
all_package_connection_templates = {}
if keys is not None:
keys = set(keys)
for entry_point in pkg_resources.iter_entry_points(group=PACKAGE_TOOLS_ENTRY):
try:
list_tool_func = entry_point.resolve()
package_tools = list_tool_func()
for identifier, tool in package_tools.items():
# Only load required tools to avoid unnecessary loading when keys is provided
if isinstance(keys, set) and identifier not in keys:
continue
m = tool["module"]
module = importlib.import_module(m) # Import the module to make sure it is valid
tool["package"] = entry_point.dist.project_name
tool["package_version"] = entry_point.dist.version
all_package_tools[identifier] = tool

# Get custom strong type connection definition
custom_strong_type_connections_classes = [
obj
for name, obj in inspect.getmembers(module)
if inspect.isclass(obj) and ConnectionType.is_custom_strong_type(obj)
]

if custom_strong_type_connections_classes:
for cls in custom_strong_type_connections_classes:
identifier = f"{cls.__module__}.{cls.__name__}"
connection_spec = generate_custom_strong_type_connection_spec(
cls, entry_point.dist.project_name, entry_point.dist.version
)
all_package_connection_specs[identifier] = connection_spec
all_package_connection_templates[identifier] = generate_custom_strong_type_connection_template(
cls, connection_spec, entry_point.dist.project_name, entry_point.dist.version
)
except Exception as e:
msg = (
f"Failed to load tools from package {entry_point.dist.project_name}: {e},"
+ f" traceback: {traceback.format_exc()}"
)
module_logger.warning(msg)

return all_package_tools, all_package_connection_specs, all_package_connection_templates


def gen_tool_by_source(name, source: ToolSource, tool_type: ToolType, working_dir: Path) -> Tool:
if source.type == ToolSourceType.Package:
package_tools = collect_package_tools()
Expand Down
17 changes: 17 additions & 0 deletions src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
RUN_INFO_CREATED_ON_INDEX_NAME = "idx_run_info_created_on"
CONNECTION_TABLE_NAME = "connection"
BASE_PATH_CONTEXT_KEY = "base_path"
SCHEMA_KEYS_CONTEXT_CONFIG_KEY = "schema_configs_keys"
SCHEMA_KEYS_CONTEXT_SECRET_KEY = "schema_secrets_keys"
PARAMS_OVERRIDE_KEY = "params_override"
FILE_PREFIX = "file:"
KEYRING_SYSTEM = "promptflow"
Expand All @@ -51,6 +53,21 @@
LOCAL_STORAGE_BATCH_SIZE = 1


class CustomStrongTypeConnectionConfigs:
PREFIX = "promptflow.connection."
TYPE = "custom_type"
MODULE = "module"
PROMPTFLOW_TYPE_KEY = PREFIX + TYPE
PROMPTFLOW_MODULE_KEY = PREFIX + MODULE

@staticmethod
def is_custom_key(key):
return key not in [
CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY,
CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY,
]


class RunTypes:
BATCH = "batch"
EVALUATION = "evaluation"
Expand Down
7 changes: 6 additions & 1 deletion src/promptflow/promptflow/_sdk/_load_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def load_common(
cls, type_str = cls._resolve_cls_and_type(data=yaml_dict, params_override=params_override)

try:
return cls._load(data=yaml_dict, yaml_path=relative_origin, params_override=params_override, **kwargs)
return cls._load(
data=yaml_dict,
yaml_path=relative_origin,
params_override=params_override,
**kwargs,
)
except Exception as e:
raise Exception(f"Load entity error: {e}") from e

Expand Down
2 changes: 2 additions & 0 deletions src/promptflow/promptflow/_sdk/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
QdrantConnection,
WeaviateConnection,
FormRecognizerConnection,
CustomStrongTypeConnection,
)
from ._run import Run
from ._validation import ValidationResult
Expand All @@ -25,6 +26,7 @@
"AzureOpenAIConnection",
"OpenAIConnection",
"CustomConnection",
"CustomStrongTypeConnection",
"CognitiveSearchConnection",
"SerpConnection",
"QdrantConnection",
Expand Down
Loading

0 comments on commit 9b9db37

Please sign in to comment.