diff --git a/src/promptflow/promptflow/_core/tool.py b/src/promptflow/promptflow/_core/tool.py index f77521d0287..76fe800a998 100644 --- a/src/promptflow/promptflow/_core/tool.py +++ b/src/promptflow/promptflow/_core/tool.py @@ -4,10 +4,12 @@ import functools import inspect +import importlib import logging from abc import ABC from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, List, Dict, Union, get_args, get_origin +from dataclasses import dataclass, InitVar, field module_logger = logging.getLogger(__name__) @@ -18,6 +20,7 @@ class ToolType(str, Enum): PYTHON = "python" PROMPT = "prompt" _ACTION = "action" + CUSTOM_LLM = "custom_llm" class ToolInvoker(ABC): @@ -45,6 +48,7 @@ def tool( name: str = None, description: str = None, type: str = None, + input_settings=None, **kwargs, ) -> Callable: """Decorator for tool functions. The decorated function will be registered as a tool and can be used in a flow. @@ -55,11 +59,15 @@ def tool( :type description: str :param type: The tool type. :type type: str + :param input_settings: Dict of input setting. + :type input_settings: Dict[str, promptflow.entities.InputSetting] :return: The decorated function. :rtype: Callable """ def tool_decorator(func: Callable) -> Callable: + from promptflow.exceptions import UserErrorException + @functools.wraps(func) def new_f(*args, **kwargs): tool_invoker = ToolInvoker.active_instance() @@ -68,13 +76,18 @@ def new_f(*args, **kwargs): return func(*args, **kwargs) return tool_invoker.invoke_tool(func, *args, **kwargs) + if type is not None and type not in [k.value for k in ToolType]: + raise UserErrorException(f"Tool type {type} is not supported yet.") + new_f.__original_function = func func.__wrapped_function = new_f new_f.__tool = None # This will be set when generating the tool definition. new_f.__name = name new_f.__description = description new_f.__type = type + new_f.__input_settings = input_settings new_f.__extra_info = kwargs + return new_f # enable use decorator without "()" if all arguments are default values @@ -127,3 +140,72 @@ def get_required_initialize_inputs(cls): if k != "self" and v.default is inspect.Parameter.empty } return cls._required_initialize_inputs + + +@dataclass +class DynamicList: + + function: InitVar[Union[str, Callable]] + """The dynamic list function.""" + + input_mapping: InitVar[Dict] = None + """The mapping between dynamic list function inputs and tool inputs.""" + + func_path: str = field(init=False) + func_kwargs: List = field(init=False) + + def __post_init__(self, function, input_mapping): + from promptflow.exceptions import UserErrorException + from promptflow.contracts.tool import ValueType + + # Validate function exist + if isinstance(function, str): + func = importlib.import_module(tool["module"]) + func_path = function + elif isinstance(function, Callable): + func = function + func_path = f"{function.__module__}.{function.__name__}" + else: + raise UserErrorException( + "Function has invalid type, please provide callable or function name for function.") + self.func_path = func_path + self._func_obj = func + self._input_mapping = input_mapping or {} + + # Get function input info + self.func_kwargs = [] + inputs = inspect.signature(self._func_obj).parameters + for name, value in inputs.items(): + if value.kind != value.VAR_KEYWORD and value.kind != value.VAR_POSITIONAL: + input_info = {"name": name} + if value.annotation is not inspect.Parameter.empty: + if get_origin(value.annotation): + input_info["type"] = [annotation.__name__ for annotation in get_args(value.annotation)] + else: + input_info["type"] = [ValueType.from_type(value.annotation)] + if name in self._input_mapping: + input_info["reference"] = f"${{inputs.{self._input_mapping[name]}}}" + input_info["optional"] = value.default is not inspect.Parameter.empty + if input_info["optional"]: + input_info["default"] = value.default + self.func_kwargs.append(input_info) + + +@dataclass +class InputSetting: + """Settings of the tool input""" + + is_multi_select: bool = None + """Allow user to select multiple values.""" + + allow_manual_entry: bool = None + """Allow user to enter input value manually.""" + + enabled_by: str = None + """The input field which must be an enum type, that controls the visibility of the dependent input field.""" + + enabled_by_value: List = None + """Defines the accepted enum values from the enabled_by field that will make this dependent input field visible.""" + + dynamic_list: DynamicList = None + """Settings of dynamic list function.""" diff --git a/src/promptflow/promptflow/_core/tool_meta_generator.py b/src/promptflow/promptflow/_core/tool_meta_generator.py index d8f4ae599b2..d751397b932 100644 --- a/src/promptflow/promptflow/_core/tool_meta_generator.py +++ b/src/promptflow/promptflow/_core/tool_meta_generator.py @@ -130,13 +130,21 @@ def collect_tool_methods_with_init_inputs_in_module(m): return tools -def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=False): +def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=False, skip_prompt_template=False): + try: + tool_type = getattr(f, "__type") or ToolType.PYTHON + except Exception as e: + raise e + tool_name = getattr(f, "__name") + description = getattr(f, "__description") if hasattr(f, "__tool") and isinstance(f.__tool, Tool): return f.__tool if hasattr(f, "__original_function"): f = f.__original_function try: - inputs, _, _ = function_to_interface(f, initialize_inputs, gen_custom_type_conn=gen_custom_type_conn) + inputs, _, _ = function_to_interface( + f, initialize_inputs=initialize_inputs, gen_custom_type_conn=gen_custom_type_conn, + skip_prompt_template=skip_prompt_template) except Exception as e: error_type_and_message = f"({e.__class__.__name__}) {e}" raise BadFunctionInterface( @@ -149,10 +157,10 @@ def _parse_tool_from_function(f, initialize_inputs=None, gen_custom_type_conn=Fa class_name = f.__qualname__.replace(f".{f.__name__}", "") # Construct the Tool structure return Tool( - name=f.__qualname__, - description=inspect.getdoc(f), + name=tool_name or f.__qualname__, + description=description or inspect.getdoc(f), inputs=inputs, - type=ToolType.PYTHON, + type=tool_type, class_name=class_name, function=f.__name__, module=f.__module__, @@ -239,7 +247,7 @@ def collect_tool_function_in_module(m): def generate_python_tool(name, content, source=None): m = load_python_module(content, source) f, initialize_inputs = collect_tool_function_in_module(m) - tool = _parse_tool_from_function(f, initialize_inputs) + tool = _parse_tool_from_function(f, initialize_inputs=initialize_inputs) tool.module = None if name is not None: tool.name = name diff --git a/src/promptflow/promptflow/_sdk/operations/_tool_operations.py b/src/promptflow/promptflow/_sdk/operations/_tool_operations.py index ec971546534..47fd6df87f0 100644 --- a/src/promptflow/promptflow/_sdk/operations/_tool_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_tool_operations.py @@ -9,12 +9,10 @@ from pathlib import Path from typing import Union -from promptflow._core.tool_meta_generator import is_tool +from promptflow._core.tool_meta_generator import _parse_tool_from_function, asdict_without_none, is_tool from promptflow._core.tools_manager import collect_package_tools from promptflow._utils.multimedia_utils import convert_multimedia_data_to_base64 -from promptflow._utils.tool_utils import function_to_interface from promptflow.contracts.multimedia import Image -from promptflow.contracts.tool import Tool, ToolType from promptflow.exceptions import UserErrorException @@ -48,7 +46,7 @@ def _collect_tool_functions_in_module(tool_module): @staticmethod def _collect_tool_class_methods_in_module(tool_module): - from promptflow import ToolProvider + from promptflow._core.tool import ToolProvider tools = [] for _, obj in inspect.getmembers(tool_module): @@ -59,34 +57,35 @@ def _collect_tool_class_methods_in_module(tool_module): tools.append((method, initialize_inputs)) return tools - @staticmethod - def _parse_tool_from_function(f, initialize_inputs=None): - tool_type = getattr(f, "__type") or ToolType.PYTHON - tool_name = getattr(f, "__name") - description = getattr(f, "__description") - extra_info = getattr(f, "__extra_info") - if getattr(f, "__tool", None) and isinstance(f.__tool, Tool): - return getattr(f, "__tool") - if hasattr(f, "__original_function"): - f = getattr(f, "__original_function") - try: - inputs, _, _ = function_to_interface(f, initialize_inputs=initialize_inputs) - except Exception as e: - raise UserErrorException(f"Failed to parse interface for tool {f.__name__}, reason: {e}") from e - class_name = None - if "." in f.__qualname__: - class_name = f.__qualname__.replace(f".{f.__name__}", "") - # Construct the Tool structure - tool = Tool( - name=tool_name or f.__qualname__, - description=description or inspect.getdoc(f), - inputs=inputs, - type=tool_type, - class_name=class_name, - function=f.__name__, - module=f.__module__, - ) - return tool, extra_info + def _validate_input_settings(self, tool_inputs, input_settings): + for input_name, settings in input_settings.items(): + if input_name not in tool_inputs: + raise UserErrorException(f"Cannot find {input_name} in tool inputs.") + if settings.enabled_by and settings.enabled_by not in tool_inputs: + raise UserErrorException( + f"Cannot find the input \"{settings.enabled_by}\" for the enabled_by of {input_name}.") + if settings.dynamic_list: + dynamic_func_inputs = inspect.signature(settings.dynamic_list._func_obj).parameters + has_kwargs = any([param.kind == param.VAR_KEYWORD for param in dynamic_func_inputs.values()]) + required_inputs = [k for k, v in dynamic_func_inputs.items() if + v.default is inspect.Parameter.empty and v.kind != v.VAR_KEYWORD] + if settings.dynamic_list._input_mapping: + # Validate input mapping in dynamic_list + for func_input, reference_input in settings.dynamic_list._input_mapping.items(): + # Check invalid input name of dynamic list function + if not has_kwargs and func_input not in dynamic_func_inputs: + raise UserErrorException( + f"Cannot find {func_input} in the inputs of " + f"dynamic_list func {settings.dynamic_list.func_path}" + ) + # Check invalid input name of tool + if reference_input not in tool_inputs: + raise UserErrorException(f"Cannot find {reference_input} in the tool inputs.") + if func_input in required_inputs: + required_inputs.remove(func_input) + # Check required input of dynamic_list function + if len(required_inputs) != 0: + raise UserErrorException(f"Missing required input(s) of dynamic_list function: {required_inputs}") def _serialize_tool(self, tool_func, initialize_inputs=None): """ @@ -99,7 +98,9 @@ def _serialize_tool(self, tool_func, initialize_inputs=None): :return: package tool name, serialized tool :rtype: str, Dict[str, str] """ - tool, extra_info = self._parse_tool_from_function(tool_func, initialize_inputs) + tool = _parse_tool_from_function(tool_func, initialize_inputs=initialize_inputs, + gen_custom_type_conn=True, skip_prompt_template=True) + extra_info = getattr(tool_func, "__extra_info") tool_name = ( f"{tool.module}.{tool.class_name}.{tool.function}" if tool.class_name is not None @@ -112,6 +113,14 @@ def _serialize_tool(self, tool_func, initialize_inputs=None): raise UserErrorException(f"Cannot find the icon path {extra_info['icon']}.") extra_info["icon"] = self._serialize_image_data(extra_info["icon"]) construct_tool.update(extra_info) + + # Update tool input settings + input_settings = getattr(tool_func, "__input_settings") + if input_settings: + tool_inputs = construct_tool.get("inputs", {}) + self._validate_input_settings(tool_inputs, input_settings) + for input_name, settings in input_settings.items(): + tool_inputs[input_name].update(asdict_without_none(settings)) return tool_name, construct_tool @staticmethod @@ -134,8 +143,8 @@ def _serialize_image_data(image_path): return image_url def list( - self, - flow: Union[str, PathLike] = None, + self, + flow: Union[str, PathLike] = None, ): """ List all package tools in the environment and code tools in the flow. diff --git a/src/promptflow/promptflow/_utils/tool_utils.py b/src/promptflow/promptflow/_utils/tool_utils.py index c1807d78f5d..b1b1a716944 100644 --- a/src/promptflow/promptflow/_utils/tool_utils.py +++ b/src/promptflow/promptflow/_utils/tool_utils.py @@ -107,7 +107,8 @@ def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition, ) -def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple: +def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False, + skip_prompt_template=False) -> tuple: sign = inspect.signature(f) all_inputs = {} input_defs = {} @@ -126,6 +127,11 @@ def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_c ) # Resolve inputs to definitions. for k, v in all_inputs.items(): + # Get value type from annotation + value_type = resolve_annotation(v.annotation) + if skip_prompt_template and value_type is PromptTemplate: + # custom llm tool has prompt template as input, skip it + continue input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn) input_defs[k] = input_def if is_connection: diff --git a/src/promptflow/promptflow/entities/__init__.py b/src/promptflow/promptflow/entities/__init__.py index 8ea3e488928..e6568ef9505 100644 --- a/src/promptflow/promptflow/entities/__init__.py +++ b/src/promptflow/promptflow/entities/__init__.py @@ -17,6 +17,7 @@ FormRecognizerConnection, ) from promptflow._sdk.entities._run import Run +from promptflow._core.tool import InputSetting, DynamicList from promptflow._sdk.entities._flow import FlowContext __all__ = [ @@ -33,6 +34,10 @@ # region Run "Run", # endregion + # region Tool + "InputSetting", + "DynamicList", + # endregion # region Flow "FlowContext", # endregion diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_tool.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_tool.py new file mode 100644 index 00000000000..1141edd1593 --- /dev/null +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_tool.py @@ -0,0 +1,316 @@ +import importlib.util +from pathlib import Path + +import pytest +from promptflow._core.tool import tool +from promptflow.entities import DynamicList, InputSetting +from promptflow._sdk._pf_client import PFClient +from promptflow.exceptions import UserErrorException + +PROMOTFLOW_ROOT = Path(__file__) / "../../../.." +TEST_ROOT = Path(__file__).parent.parent.parent +TOOL_ROOT = TEST_ROOT / "test_configs/tools" + +_client = PFClient() + + +@pytest.mark.e2etest +class TestCli: + def get_tool_meta(self, tool_path): + module_name = f"test_tool.{Path(tool_path).stem}" + + # Load the module from the file path + spec = importlib.util.spec_from_file_location(module_name, tool_path) + module = importlib.util.module_from_spec(spec) + + # Load the module's code + spec.loader.exec_module(module) + return _client._tools.generate_tool_meta(module) + + def test_python_tool_meta(self): + tool_path = TOOL_ROOT / "python_tool.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + "test_tool.python_tool.PythonTool.python_tool": { + "class_name": "PythonTool", + "function": "python_tool", + "inputs": { + "connection": {"type": ["AzureOpenAIConnection"]}, + "input1": {"type": ["string"]} + }, + "module": "test_tool.python_tool", + "name": "PythonTool.python_tool", + "type": "python", + }, + "test_tool.python_tool.my_python_tool": { + "function": "my_python_tool", + "inputs": { + "input1": {"type": ["string"]} + }, + "module": "test_tool.python_tool", + "name": "python_tool", + "type": "python", + }, + "test_tool.python_tool.my_python_tool_without_name": { + "function": "my_python_tool_without_name", + "inputs": { + "input1": {"type": ["string"]} + }, + "module": "test_tool.python_tool", + "name": "my_python_tool_without_name", + "type": "python", + }, + } + assert tool_meta == expect_tool_meta + + def test_llm_tool_meta(self): + tool_path = TOOL_ROOT / "custom_llm_tool.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + 'test_tool.custom_llm_tool.my_tool': { + 'name': 'My Custom LLM Tool', + 'type': 'custom_llm', + 'inputs': { + 'connection': {'type': ['CustomConnection']} + }, + 'description': 'This is a tool to demonstrate the custom_llm tool type', + 'module': 'test_tool.custom_llm_tool', + 'function': 'my_tool' + }, + 'test_tool.custom_llm_tool.TestCustomLLMTool.tool_func': { + 'name': 'My Custom LLM Tool', + 'type': 'custom_llm', + 'inputs': { + 'connection': {'type': ['AzureOpenAIConnection']}, + 'api': {'type': ['string']} + }, + 'description': 'This is a tool to demonstrate the custom_llm tool type', + 'module': 'test_tool.custom_llm_tool', + 'class_name': 'TestCustomLLMTool', + 'function': 'tool_func' + } + } + assert tool_meta == expect_tool_meta + + def test_invalid_tool_type(self): + with pytest.raises(UserErrorException) as exception: + @tool(name="invalid_tool_type", type="invalid_type") + def invalid_tool_type(): + pass + + assert exception.value.message == "Tool type invalid_type is not supported yet." + + def test_tool_with_custom_connection(self): + tool_path = TOOL_ROOT / "tool_with_custom_connection.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + "test_tool.tool_with_custom_connection.MyTool.my_tool": { + "name": "My Second Tool", + "type": "python", + "inputs": { + "connection": {"type": ["CustomConnection"]}, + "input_text": {"type": ["string"]} + }, + "description": "This is my second tool", + "module": "test_tool.tool_with_custom_connection", + "class_name": "MyTool", + "function": "my_tool", + } + } + assert tool_meta == expect_tool_meta + + tool_path = TOOL_ROOT / "tool_with_custom_strong_type_connection.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + "test_tool.tool_with_custom_strong_type_connection.my_tool": { + "name": "Tool With Custom Strong Type Connection", + "type": "python", + "inputs": { + "connection": {"type": ["CustomConnection"], "custom_type": ["MyCustomConnection"]}, + "input_text": {"type": ["string"]}, + }, + "description": "This is my tool with custom strong type connection.", + "module": "test_tool.tool_with_custom_strong_type_connection", + "function": "my_tool", + } + } + assert tool_meta == expect_tool_meta + + def test_tool_with_input_settings(self): + tool_path = TOOL_ROOT / "tool_with_dynamic_list_input.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + "test_tool.tool_with_dynamic_list_input.my_tool": { + "name": "My Tool with Dynamic List Input", + "type": "python", + "inputs": { + "input_text": { + "type": ["list"], + "is_multi_select": True, + "allow_manual_entry": True, + "dynamic_list": { + "func_path": "test_tool.tool_with_dynamic_list_input.my_list_func", + "func_kwargs": [ + { + "name": "prefix", + "type": ["string"], + "reference": "${inputs.input_prefix}", + "optional": True, + "default": "", + }, + { + "name": "size", + "type": ["int"], + "optional": True, + "default": 10 + }, + ], + }, + }, + "input_prefix": {"type": ["string"]}, + }, + "description": "This is my tool with dynamic list input", + "module": "test_tool.tool_with_dynamic_list_input", + "function": "my_tool", + } + } + assert tool_meta == expect_tool_meta + + tool_path = TOOL_ROOT / "tool_with_enabled_by_value.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + "test_tool.tool_with_enabled_by_value.my_tool": { + "name": "My Tool with Enabled By Value", + "type": "python", + "inputs": { + "user_type": { + "type": [ + "string" + ], + "enum": [ + "student", + "teacher" + ] + }, + "student_id": { + "type": [ + "string" + ], + "enabled_by": "user_type", + "enabled_by_value": [ + "student" + ] + }, + "teacher_id": { + "type": [ + "string" + ], + "enabled_by": "user_type", + "enabled_by_value": [ + "teacher" + ] + } + }, + "description": "This is my tool with enabled by value", + "module": "test_tool.tool_with_enabled_by_value", + "function": "my_tool" + } + } + assert tool_meta == expect_tool_meta + + def test_dynamic_list_with_invalid_reference(self): + def my_list_func(prefix: str, size: int = 10): + pass + + # value in reference doesn't exist in tool inputs + invalid_dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"prefix": "invalid_input"}) + input_settings = { + "input_text": InputSetting( + dynamic_list=invalid_dynamic_list_setting, + allow_manual_entry=True, + is_multi_select=True + ) + } + + @tool( + name="My Tool with Dynamic List Input", + description="This is my tool with dynamic list input", + input_settings=input_settings + ) + def my_tool(input_text: list, input_prefix: str) -> str: + return f"Hello {input_prefix} {','.join(input_text)}" + + with pytest.raises(UserErrorException) as exception: + _client._tools._serialize_tool(my_tool) + assert "Cannot find invalid_input in the tool inputs." in exception.value.message + + # invalid dynamic func input + invalid_dynamic_list_setting = DynamicList( + function=my_list_func, input_mapping={"invalid_input": "input_prefix"}) + input_settings = { + "input_text": InputSetting( + dynamic_list=invalid_dynamic_list_setting, + allow_manual_entry=True, + is_multi_select=True + ) + } + + @tool( + name="My Tool with Dynamic List Input", + description="This is my tool with dynamic list input", + input_settings=input_settings + ) + def my_tool(input_text: list, input_prefix: str) -> str: + return f"Hello {input_prefix} {','.join(input_text)}" + + with pytest.raises(UserErrorException) as exception: + _client._tools._serialize_tool(my_tool) + assert "Cannot find invalid_input in the inputs of dynamic_list func" in exception.value.message + + # check required inputs of dynamic list func + invalid_dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"size": "input_prefix"}) + input_settings = { + "input_text": InputSetting(dynamic_list=invalid_dynamic_list_setting, ) + } + + @tool( + name="My Tool with Dynamic List Input", + description="This is my tool with dynamic list input", + input_settings=input_settings + ) + def my_tool(input_text: list, input_prefix: str) -> str: + return f"Hello {input_prefix} {','.join(input_text)}" + + with pytest.raises(UserErrorException) as exception: + _client._tools._serialize_tool(my_tool) + assert "Missing required input(s) of dynamic_list function: ['prefix']" in exception.value.message + + def test_enabled_by_with_invalid_input(self): + # value in enabled_by_value doesn't exist in tool inputs + input1_settings = InputSetting(enabled_by="invalid_input") + + @tool(name="enabled_by_with_invalid_input", input_settings={"input1": input1_settings}) + def enabled_by_with_invalid_input(input1: str, input2: str): + pass + + with pytest.raises(UserErrorException) as exception: + _client._tools._serialize_tool(enabled_by_with_invalid_input) + assert "Cannot find the input \"invalid_input\"" in exception.value.message + + def test_tool_with_file_path_input(self): + tool_path = TOOL_ROOT / "tool_with_file_path_input.py" + tool_meta = self.get_tool_meta(tool_path) + expect_tool_meta = { + 'test_tool.tool_with_file_path_input.my_tool': { + 'name': 'Tool with FilePath Input', + 'type': 'python', + 'inputs': { + 'input_file': {'type': ['file_path']}, + 'input_text': {'type': ['string']} + }, + 'description': 'This is a tool to demonstrate the usage of FilePath input', + 'module': 'test_tool.tool_with_file_path_input', + 'function': 'my_tool' + } + } + assert expect_tool_meta == tool_meta diff --git a/src/promptflow/tests/test_configs/tools/custom_llm_tool.py b/src/promptflow/tests/test_configs/tools/custom_llm_tool.py index 379e3081809..b82180a304f 100644 --- a/src/promptflow/tests/test_configs/tools/custom_llm_tool.py +++ b/src/promptflow/tests/test_configs/tools/custom_llm_tool.py @@ -1,3 +1,6 @@ +from jinja2 import Template +from promptflow.connections import CustomConnection + from promptflow import ToolProvider, tool from promptflow.connections import AzureOpenAIConnection from promptflow.contracts.types import PromptTemplate @@ -8,6 +11,23 @@ def __init__(self, connection: AzureOpenAIConnection): super().__init__() self.connection = connection - @tool(name="custom_llm_tool", type="custom_llm") + @tool( + name="My Custom LLM Tool", + type="custom_llm", + description="This is a tool to demonstrate the custom_llm tool type", + ) def tool_func(self, api: str, template: PromptTemplate, **kwargs): pass + + +@tool( + name="My Custom LLM Tool", + type="custom_llm", + description="This is a tool to demonstrate the custom_llm tool type", +) +def my_tool(connection: CustomConnection, prompt: PromptTemplate, **kwargs) -> str: + # Replace with your tool code, customise your own code to handle and use the prompt here. + # Usually connection contains configs to connect to an API. + # Not all tools need a connection. You can remove it if you don't need it. + rendered_prompt = Template(prompt, trim_blocks=True, keep_trailing_newline=True).render(**kwargs) + return rendered_prompt diff --git a/src/promptflow/tests/test_configs/tools/tool_with_custom_connection.py b/src/promptflow/tests/test_configs/tools/tool_with_custom_connection.py new file mode 100644 index 00000000000..b02a045e986 --- /dev/null +++ b/src/promptflow/tests/test_configs/tools/tool_with_custom_connection.py @@ -0,0 +1,20 @@ +from promptflow._core.tool import ToolProvider, tool +from promptflow.connections import CustomConnection + + +class MyTool(ToolProvider): + """ + Doc reference : + """ + + def __init__(self, connection: CustomConnection): + super().__init__() + self.connection = connection + + @tool(name="My Second Tool", description="This is my second tool") + def my_tool(self, input_text: str) -> str: + # Replace with your tool code. + # Usually connection contains configs to connect to an API. + # Use CustomConnection is a dict. You can use it like: connection.api_key, connection.api_base + # Not all tools need a connection. You can remove it if you don't need it. + return "Hello " + input_text diff --git a/src/promptflow/tests/test_configs/tools/tool_with_custom_strong_type_connection.py b/src/promptflow/tests/test_configs/tools/tool_with_custom_strong_type_connection.py new file mode 100644 index 00000000000..c60aaf79e84 --- /dev/null +++ b/src/promptflow/tests/test_configs/tools/tool_with_custom_strong_type_connection.py @@ -0,0 +1,22 @@ +from promptflow._core.tool import tool +from promptflow.connections import CustomStrongTypeConnection +from promptflow.contracts.types import Secret + + +class MyCustomConnection(CustomStrongTypeConnection): + """My custom strong type connection. + + :param api_key: The api key get from "https://xxx.com". + :type api_key: Secret + :param api_base: The api base. + :type api_base: String + """ + api_key: Secret + api_base: str = "This is a fake api base." + + +@tool(name="Tool With Custom Strong Type Connection", description="This is my tool with custom strong type connection.") +def my_tool(connection: MyCustomConnection, input_text: str) -> str: + # Replace with your tool code. + # Use custom strong type connection like: connection.api_key, connection.api_base + return "Hello " + input_text diff --git a/src/promptflow/tests/test_configs/tools/tool_with_dynamic_list_input.py b/src/promptflow/tests/test_configs/tools/tool_with_dynamic_list_input.py new file mode 100644 index 00000000000..6920f26927b --- /dev/null +++ b/src/promptflow/tests/test_configs/tools/tool_with_dynamic_list_input.py @@ -0,0 +1,51 @@ +from promptflow._core.tool import tool +from promptflow.entities import InputSetting, DynamicList +from typing import List, Union, Dict + + +def my_list_func(prefix: str = "", size: int = 10, **kwargs) -> List[Dict[str, Union[str, int, float, list, Dict]]]: + """This is a dummy function to generate a list of items. + + :param prefix: prefix to add to each item. + :param size: number of items to generate. + :param kwargs: other parameters. + :return: a list of items. Each item is a dict with the following keys: + - value: for backend use. Required. + - display_value: for UI display. Optional. + - hyperlink: external link. Optional. + - description: information icon tip. Optional. + """ + import random + + words = ["apple", "banana", "cherry", "date", "elderberry", "fig", "grape", "honeydew", "kiwi", "lemon"] + result = [] + for i in range(size): + random_word = f"{random.choice(words)}{i}" + cur_item = { + "value": random_word, + "display_value": f"{prefix}_{random_word}", + "hyperlink": f'https://www.google.com/search?q={random_word}', + "description": f"this is {i} item", + } + result.append(cur_item) + + return result + + +dynamic_list_setting = DynamicList(function=my_list_func, input_mapping={"prefix": "input_prefix"}) +input_settings = { + "input_text": InputSetting( + dynamic_list=dynamic_list_setting, + allow_manual_entry=True, + is_multi_select=True + ) +} + + +@tool( + name="My Tool with Dynamic List Input", + description="This is my tool with dynamic list input", + input_settings=input_settings +) +def my_tool(input_text: list, input_prefix: str) -> str: + return f"Hello {input_prefix} {','.join(input_text)}" diff --git a/src/promptflow/tests/test_configs/tools/tool_with_enabled_by_value.py b/src/promptflow/tests/test_configs/tools/tool_with_enabled_by_value.py new file mode 100644 index 00000000000..c74f242446d --- /dev/null +++ b/src/promptflow/tests/test_configs/tools/tool_with_enabled_by_value.py @@ -0,0 +1,35 @@ +from enum import Enum + +from promptflow.entities import InputSetting +from promptflow import tool + + +class UserType(str, Enum): + STUDENT = "student" + TEACHER = "teacher" + + +@tool( + name="My Tool with Enabled By Value", + description="This is my tool with enabled by value", + input_settings={ + "teacher_id": InputSetting(enabled_by="user_type", enabled_by_value=[UserType.TEACHER]), + "student_id": InputSetting(enabled_by="user_type", enabled_by_value=[UserType.STUDENT]), + } +) +def my_tool(user_type: UserType, student_id: str = "", teacher_id: str = "") -> str: + """This is a dummy function to support enabled by feature. + + :param user_type: user type, student or teacher. + :param student_id: student id. + :param teacher_id: teacher id. + :return: id of the user. + If user_type is student, return student_id. + If user_type is teacher, return teacher_id. + """ + if user_type == UserType.STUDENT: + return student_id + elif user_type == UserType.TEACHER: + return teacher_id + else: + raise Exception("Invalid user.") diff --git a/src/promptflow/tests/test_configs/tools/tool_with_file_path_input.py b/src/promptflow/tests/test_configs/tools/tool_with_file_path_input.py new file mode 100644 index 00000000000..538ae74a75a --- /dev/null +++ b/src/promptflow/tests/test_configs/tools/tool_with_file_path_input.py @@ -0,0 +1,12 @@ +import importlib +from pathlib import Path +from promptflow._core.tool import tool +from promptflow.contracts.types import FilePath + + +@tool(name="Tool with FilePath Input", description="This is a tool to demonstrate the usage of FilePath input") +def my_tool(input_file: FilePath, input_text: str) -> str: + # customise your own code to handle and use the input_file here + new_module = importlib.import_module(Path(input_file).stem) + + return new_module.hello(input_text)