Skip to content

Commit

Permalink
[tool] Tool supports input_settings (#929)
Browse files Browse the repository at this point in the history
# Description
Using tool function to configure input settings
- is_multi_select
- allow_manual_entry
- enabled_by
- enabled_by_value
- dynamic_list

```python
from promptflow import InputSettings, DynamicList, tool

@def my_list_func(prefix: str = "", size: int = 10, **kwargs):
    pass


dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"prefix": "input_prefix"})
input_settings = {
    "input_text": InputSettings(
        dynamic_list=dynamic_list_setting,
        allow_manual_entry=True,
        is_multi_select=True
    )
}

@tool(
    name="My Tool with Dynamic List Input",
    description="This is my tool with dynamic list input",
    input_settings=input_settings
)
def my_tool(input_text: list, input_prefix: str) -> str:
    return f"Hello {input_prefix} {','.join(input_text)}"
```
Generated tool meta:
```json
{
    "name": "My Tool with Dynamic List Input",
    "type": "python",
    "inputs": {
        "input_text": {
            "type": [
                "list"
            ],
            "is_multi_select": true,
            "allow_manual_entry": true,
            "dynamic_list": {
                "func_path": "cli_tool_package.tool_with_dynamic_list_input.my_list_func",
                "func_kwargs": [
                    {
                        "name": "prefix",
                        "type": [
                            "string"
                        ],
                        "reference": "${inputs.input_prefix}",
                        "optional": true,
                        "default": ""
                    },
                    {
                        "name": "size",
                        "type": [
                            "int"
                        ],
                        "optional": true,
                        "default": 10
                    }
                ]
            }
        },
        "input_prefix": {
            "type": [
                "string"
            ]
        }
}
```

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Nov 8, 2023
1 parent 4f43297 commit be75ef0
Show file tree
Hide file tree
Showing 12 changed files with 630 additions and 44 deletions.
84 changes: 83 additions & 1 deletion src/promptflow/promptflow/_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import functools
import inspect
import importlib
import logging
from abc import ABC
from enum import Enum
from typing import Callable, Optional
from typing import Callable, Optional, List, Dict, Union, get_args, get_origin
from dataclasses import dataclass, InitVar, field

module_logger = logging.getLogger(__name__)

Expand All @@ -18,6 +20,7 @@ class ToolType(str, Enum):
PYTHON = "python"
PROMPT = "prompt"
_ACTION = "action"
CUSTOM_LLM = "custom_llm"


class ToolInvoker(ABC):
Expand Down Expand Up @@ -45,6 +48,7 @@ def tool(
name: str = None,
description: str = None,
type: str = None,
input_settings=None,
**kwargs,
) -> Callable:
"""Decorator for tool functions. The decorated function will be registered as a tool and can be used in a flow.
Expand All @@ -55,11 +59,15 @@ def tool(
:type description: str
:param type: The tool type.
:type type: str
:param input_settings: Dict of input setting.
:type input_settings: Dict[str, promptflow.entities.InputSetting]
:return: The decorated function.
:rtype: Callable
"""

def tool_decorator(func: Callable) -> Callable:
from promptflow.exceptions import UserErrorException

@functools.wraps(func)
def new_f(*args, **kwargs):
tool_invoker = ToolInvoker.active_instance()
Expand All @@ -68,13 +76,18 @@ def new_f(*args, **kwargs):
return func(*args, **kwargs)
return tool_invoker.invoke_tool(func, *args, **kwargs)

if type is not None and type not in [k.value for k in ToolType]:
raise UserErrorException(f"Tool type {type} is not supported yet.")

new_f.__original_function = func
func.__wrapped_function = new_f
new_f.__tool = None # This will be set when generating the tool definition.
new_f.__name = name
new_f.__description = description
new_f.__type = type
new_f.__input_settings = input_settings
new_f.__extra_info = kwargs

return new_f

# enable use decorator without "()" if all arguments are default values
Expand Down Expand Up @@ -127,3 +140,72 @@ def get_required_initialize_inputs(cls):
if k != "self" and v.default is inspect.Parameter.empty
}
return cls._required_initialize_inputs


@dataclass
class DynamicList:

function: InitVar[Union[str, Callable]]
"""The dynamic list function."""

input_mapping: InitVar[Dict] = None
"""The mapping between dynamic list function inputs and tool inputs."""

func_path: str = field(init=False)
func_kwargs: List = field(init=False)

def __post_init__(self, function, input_mapping):
from promptflow.exceptions import UserErrorException
from promptflow.contracts.tool import ValueType

# Validate function exist
if isinstance(function, str):
func = importlib.import_module(tool["module"])
func_path = function
elif isinstance(function, Callable):
func = function
func_path = f"{function.__module__}.{function.__name__}"
else:
raise UserErrorException(
"Function has invalid type, please provide callable or function name for function.")
self.func_path = func_path
self._func_obj = func
self._input_mapping = input_mapping or {}

# Get function input info
self.func_kwargs = []
inputs = inspect.signature(self._func_obj).parameters
for name, value in inputs.items():
if value.kind != value.VAR_KEYWORD and value.kind != value.VAR_POSITIONAL:
input_info = {"name": name}
if value.annotation is not inspect.Parameter.empty:
if get_origin(value.annotation):
input_info["type"] = [annotation.__name__ for annotation in get_args(value.annotation)]
else:
input_info["type"] = [ValueType.from_type(value.annotation)]
if name in self._input_mapping:
input_info["reference"] = f"${{inputs.{self._input_mapping[name]}}}"
input_info["optional"] = value.default is not inspect.Parameter.empty
if input_info["optional"]:
input_info["default"] = value.default
self.func_kwargs.append(input_info)


@dataclass
class InputSetting:
"""Settings of the tool input"""

is_multi_select: bool = None
"""Allow user to select multiple values."""

allow_manual_entry: bool = None
"""Allow user to enter input value manually."""

enabled_by: str = None
"""The input field which must be an enum type, that controls the visibility of the dependent input field."""

enabled_by_value: List = None
"""Defines the accepted enum values from the enabled_by field that will make this dependent input field visible."""

dynamic_list: DynamicList = None
"""Settings of dynamic list function."""
20 changes: 14 additions & 6 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,21 @@ def collect_tool_methods_with_init_inputs_in_module(m):
return tools


def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=False):
def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=False, skip_prompt_template=False):
try:
tool_type = getattr(f, "__type") or ToolType.PYTHON
except Exception as e:
raise e
tool_name = getattr(f, "__name")
description = getattr(f, "__description")
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f, initialize_inputs, gen_custom_type_conn=gen_custom_type_conn)
inputs, _, _ = function_to_interface(
f, initialize_inputs=initialize_inputs, gen_custom_type_conn=gen_custom_type_conn,
skip_prompt_template=skip_prompt_template)
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
raise BadFunctionInterface(
Expand All @@ -149,10 +157,10 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa
class_name = f.__qualname__.replace(f".{f.__name__}", "")
# Construct the Tool structure
return Tool(
name=f.__qualname__,
description=inspect.getdoc(f),
name=tool_name or f.__qualname__,
description=description or inspect.getdoc(f),
inputs=inputs,
type=ToolType.PYTHON,
type=tool_type,
class_name=class_name,
function=f.__name__,
module=f.__module__,
Expand Down Expand Up @@ -239,7 +247,7 @@ def collect_tool_function_in_module(m):
def generate_python_tool(name, content, source=None):
m = load_python_module(content, source)
f, initialize_inputs = collect_tool_function_in_module(m)
tool = _parse_tool_from_function(f, initialize_inputs)
tool = _parse_tool_from_function(f, initialize_inputs=initialize_inputs)
tool.module = None
if name is not None:
tool.name = name
Expand Down
79 changes: 44 additions & 35 deletions src/promptflow/promptflow/_sdk/operations/_tool_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from pathlib import Path
from typing import Union

from promptflow._core.tool_meta_generator import is_tool
from promptflow._core.tool_meta_generator import _parse_tool_from_function, asdict_without_none, is_tool
from promptflow._core.tools_manager import collect_package_tools
from promptflow._utils.multimedia_utils import convert_multimedia_data_to_base64
from promptflow._utils.tool_utils import function_to_interface
from promptflow.contracts.multimedia import Image
from promptflow.contracts.tool import Tool, ToolType
from promptflow.exceptions import UserErrorException


Expand Down Expand Up @@ -48,7 +46,7 @@ def _collect_tool_functions_in_module(tool_module):

@staticmethod
def _collect_tool_class_methods_in_module(tool_module):
from promptflow import ToolProvider
from promptflow._core.tool import ToolProvider

tools = []
for _, obj in inspect.getmembers(tool_module):
Expand All @@ -59,34 +57,35 @@ def _collect_tool_class_methods_in_module(tool_module):
tools.append((method, initialize_inputs))
return tools

@staticmethod
def _parse_tool_from_function(f, initialize_inputs=None):
tool_type = getattr(f, "__type") or ToolType.PYTHON
tool_name = getattr(f, "__name")
description = getattr(f, "__description")
extra_info = getattr(f, "__extra_info")
if getattr(f, "__tool", None) and isinstance(f.__tool, Tool):
return getattr(f, "__tool")
if hasattr(f, "__original_function"):
f = getattr(f, "__original_function")
try:
inputs, _, _ = function_to_interface(f, initialize_inputs=initialize_inputs)
except Exception as e:
raise UserErrorException(f"Failed to parse interface for tool {f.__name__}, reason: {e}") from e
class_name = None
if "." in f.__qualname__:
class_name = f.__qualname__.replace(f".{f.__name__}", "")
# Construct the Tool structure
tool = Tool(
name=tool_name or f.__qualname__,
description=description or inspect.getdoc(f),
inputs=inputs,
type=tool_type,
class_name=class_name,
function=f.__name__,
module=f.__module__,
)
return tool, extra_info
def _validate_input_settings(self, tool_inputs, input_settings):
for input_name, settings in input_settings.items():
if input_name not in tool_inputs:
raise UserErrorException(f"Cannot find {input_name} in tool inputs.")
if settings.enabled_by and settings.enabled_by not in tool_inputs:
raise UserErrorException(
f"Cannot find the input \"{settings.enabled_by}\" for the enabled_by of {input_name}.")
if settings.dynamic_list:
dynamic_func_inputs = inspect.signature(settings.dynamic_list._func_obj).parameters
has_kwargs = any([param.kind == param.VAR_KEYWORD for param in dynamic_func_inputs.values()])
required_inputs = [k for k, v in dynamic_func_inputs.items() if
v.default is inspect.Parameter.empty and v.kind != v.VAR_KEYWORD]
if settings.dynamic_list._input_mapping:
# Validate input mapping in dynamic_list
for func_input, reference_input in settings.dynamic_list._input_mapping.items():
# Check invalid input name of dynamic list function
if not has_kwargs and func_input not in dynamic_func_inputs:
raise UserErrorException(
f"Cannot find {func_input} in the inputs of "
f"dynamic_list func {settings.dynamic_list.func_path}"
)
# Check invalid input name of tool
if reference_input not in tool_inputs:
raise UserErrorException(f"Cannot find {reference_input} in the tool inputs.")
if func_input in required_inputs:
required_inputs.remove(func_input)
# Check required input of dynamic_list function
if len(required_inputs) != 0:
raise UserErrorException(f"Missing required input(s) of dynamic_list function: {required_inputs}")

def _serialize_tool(self, tool_func, initialize_inputs=None):
"""
Expand All @@ -99,7 +98,9 @@ def _serialize_tool(self, tool_func, initialize_inputs=None):
:return: package tool name, serialized tool
:rtype: str, Dict[str, str]
"""
tool, extra_info = self._parse_tool_from_function(tool_func, initialize_inputs)
tool = _parse_tool_from_function(tool_func, initialize_inputs=initialize_inputs,
gen_custom_type_conn=True, skip_prompt_template=True)
extra_info = getattr(tool_func, "__extra_info")
tool_name = (
f"{tool.module}.{tool.class_name}.{tool.function}"
if tool.class_name is not None
Expand All @@ -112,6 +113,14 @@ def _serialize_tool(self, tool_func, initialize_inputs=None):
raise UserErrorException(f"Cannot find the icon path {extra_info['icon']}.")
extra_info["icon"] = self._serialize_image_data(extra_info["icon"])
construct_tool.update(extra_info)

# Update tool input settings
input_settings = getattr(tool_func, "__input_settings")
if input_settings:
tool_inputs = construct_tool.get("inputs", {})
self._validate_input_settings(tool_inputs, input_settings)
for input_name, settings in input_settings.items():
tool_inputs[input_name].update(asdict_without_none(settings))
return tool_name, construct_tool

@staticmethod
Expand All @@ -134,8 +143,8 @@ def _serialize_image_data(image_path):
return image_url

def list(
self,
flow: Union[str, PathLike] = None,
self,
flow: Union[str, PathLike] = None,
):
"""
List all package tools in the environment and code tools in the flow.
Expand Down
8 changes: 7 additions & 1 deletion src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition,
)


def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False,
skip_prompt_template=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -126,6 +127,11 @@ def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_c
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
# Get value type from annotation
value_type = resolve_annotation(v.annotation)
if skip_prompt_template and value_type is PromptTemplate:
# custom llm tool has prompt template as input, skip it
continue
input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn)
input_defs[k] = input_def
if is_connection:
Expand Down
5 changes: 5 additions & 0 deletions src/promptflow/promptflow/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FormRecognizerConnection,
)
from promptflow._sdk.entities._run import Run
from promptflow._core.tool import InputSetting, DynamicList
from promptflow._sdk.entities._flow import FlowContext

__all__ = [
Expand All @@ -33,6 +34,10 @@
# region Run
"Run",
# endregion
# region Tool
"InputSetting",
"DynamicList",
# endregion
# region Flow
"FlowContext",
# endregion
Expand Down
Loading

0 comments on commit be75ef0

Please sign in to comment.