Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pf tool] Support tool init with advance info, such as icon, tags and category #838

Merged
merged 10 commits into from
Oct 24, 2023
10 changes: 10 additions & 0 deletions src/promptflow/promptflow/_cli/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def add_param_columns_mapping(parser):
)


def add_param_set_tool_extra_info(parser):
parser.add_argument(
"--set",
dest="extra_info",
action=AppendToDictAction,
help="Set extra information about the tool. Example: --set <key>=<value>.",
nargs="+",
)


def add_param_inputs(parser):
parser.add_argument(
"--inputs",
Expand Down
33 changes: 31 additions & 2 deletions src/promptflow/promptflow/_cli/_pf/_init_entry_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import shutil
from abc import ABC, abstractmethod
from ast import literal_eval
from enum import Enum
from pathlib import Path

Expand Down Expand Up @@ -303,16 +304,44 @@ def copy_extra_files(flow_path, extra_files):


class ToolPackageGenerator(BaseGenerator):
def __init__(self, tool_name):
def __init__(self, tool_name, icon=None, extra_info=None):
self.tool_name = tool_name
self._extra_info = extra_info
self.icon = icon

@property
def extra_info(self):
if self._extra_info:
extra_info = {}
for k, v in self._extra_info.items():
try:
extra_info[k] = literal_eval(v)
except Exception:
extra_info[k] = repr(v)
return extra_info
else:
return {}

@property
def tpl_file(self):
return TOOL_TEMPLATE_PATH / "tool.py.jinja2"

@property
def entry_template_keys(self):
return ["tool_name"]
return ["tool_name", "extra_info", "icon"]


class ManifestGenerator(BaseGenerator):
def __init__(self, package_name):
self.package_name = package_name

@property
def tpl_file(self):
return TOOL_TEMPLATE_PATH / "MANIFEST.in.jinja2"

@property
def entry_template_keys(self):
return ["package_name"]


class SetupGenerator(BaseGenerator):
Expand Down
31 changes: 24 additions & 7 deletions src/promptflow/promptflow/_cli/_pf/_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
# ---------------------------------------------------------

import argparse
import json
import logging
import re
import json
import shutil
from pathlib import Path

from promptflow._cli._params import logging_params
from promptflow._cli._params import add_param_set_tool_extra_info, logging_params
from promptflow._cli._pf._init_entry_generators import (
InitGenerator,
ManifestGenerator,
SetupGenerator,
ToolPackageGenerator,
ToolPackageUtilsGenerator,
ToolReadmeGenerator,
)
from promptflow._cli._utils import activate_action, exception_handler
from promptflow._cli._utils import activate_action, exception_handler, list_of_dict_to_dict
from promptflow._sdk._constants import LOGGER_NAME
from promptflow._sdk._pf_client import PFClient
from promptflow.exceptions import UserErrorException
Expand Down Expand Up @@ -44,6 +46,8 @@ def add_parser_init_tool(subparsers):

# Creating a package tool from scratch:
pf tool init --package package_tool --tool tool_name
# Creating a package tool with extra info:
pf tool init --package package_tool --tool tool_name --set icon=<icon-path> category=<category>
# Creating a python tool from scratch:
pf tool init --tool tool_name
""" # noqa: E501
Expand All @@ -56,6 +60,7 @@ def add_parser_init_tool(subparsers):
add_params = [
add_param_package,
add_param_tool,
add_param_set_tool_extra_info,
] + logging_params
return activate_action(
name="init",
Expand All @@ -78,9 +83,7 @@ def add_parser_list_tool(subparsers):
# List all package tool and code tool in the flow:
pf tool list --flow flow-path
""" # noqa: E501
add_param_flow = lambda parser: parser.add_argument( # noqa: E731
"--flow", type=str, help="the flow directory"
)
add_param_flow = lambda parser: parser.add_argument("--flow", type=str, help="the flow directory") # noqa: E731
add_params = [
add_param_flow,
] + logging_params
Expand Down Expand Up @@ -111,20 +114,34 @@ def init_tool(args):
if not re.match(pattern, args.tool):
raise UserErrorException(f"The tool name {args.tool} is a invalid identifier.")
print("Creating tool from scratch...")
extra_info = list_of_dict_to_dict(args.extra_info)
icon_path = extra_info.pop("icon", None)
if icon_path:
if not Path(icon_path).exists():
raise UserErrorException(f"Cannot find the icon path {icon_path}.")
if args.package:
package_path = Path(args.package)
package_name = package_path.stem
script_code_path = package_path / package_name
script_code_path.mkdir(parents=True, exist_ok=True)
if icon_path:
package_icon_path = package_path / "icon"
package_icon_path.mkdir(exist_ok=True)
dst = shutil.copy2(icon_path, package_icon_path)
icon_path = f'Path(__file__).parent.parent / "icon" / "{Path(dst).name}"'
# Generate package setup.py
SetupGenerator(package_name=package_name, tool_name=args.tool).generate_to_file(package_path / "setup.py")
# Generate manifest file
ManifestGenerator(package_name=package_name).generate_to_file(package_path / "MANIFEST.in")
# Generate utils.py to list meta data of tools.
ToolPackageUtilsGenerator(package_name=package_name).generate_to_file(script_code_path / "utils.py")
ToolReadmeGenerator(package_name=package_name, tool_name=args.tool).generate_to_file(package_path / "README.md")
else:
script_code_path = Path(".")
# Generate tool script
ToolPackageGenerator(tool_name=args.tool).generate_to_file(script_code_path / f"{args.tool}.py")
ToolPackageGenerator(tool_name=args.tool, icon=icon_path, extra_info=extra_info).generate_to_file(
script_code_path / f"{args.tool}.py"
)
InitGenerator().generate_to_file(script_code_path / "__init__.py")
print(f'Done. Created the tool "{args.tool}" in {script_code_path.resolve()}.')

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include {{ package_name }}/icons
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
{% if icon %}
from pathlib import Path

{% endif %}
from promptflow import tool
from promptflow.connections import CustomConnection


@tool(name="{{ tool_name }}", description="This is {{ tool_name }} tool")
@tool(
name="{{ tool_name }}",
description="This is {{ tool_name }} tool",
{% if icon %}
icon={{ icon }},
{% endif %}
{% for key, value in extra_info.items() %}
{{ key }}={{ value }},
{% endfor %}
)
def {{ tool_name }}(connection: CustomConnection, input_text: str) -> str:
# Replace with your tool code.
# Usually connection contains configs to connect to an API.
Expand Down
2 changes: 2 additions & 0 deletions src/promptflow/promptflow/_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def tool(
name: str = None,
description: str = None,
type: str = None,
**kwargs,
) -> Callable:
"""Decorator for tool functions. The decorated function will be registered as a tool and can be used in a flow.

Expand Down Expand Up @@ -73,6 +74,7 @@ def new_f(*args, **kwargs):
new_f.__name = name
new_f.__description = description
new_f.__type = type
new_f.__extra_info = kwargs
return new_f

# enable use decorator without "()" if all arguments are default values
Expand Down
69 changes: 59 additions & 10 deletions src/promptflow/promptflow/_sdk/operations/_tool_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import inspect
import io
import json
from dataclasses import asdict
from os import PathLike
from pathlib import Path
from typing import Union

from promptflow._core.tool_meta_generator import is_tool
from promptflow._core.tools_manager import collect_package_tools
from promptflow._utils.multimedia_utils import convert_multimedia_data_to_base64
from promptflow._utils.tool_utils import function_to_interface
from promptflow.contracts.multimedia import Image
from promptflow.contracts.tool import Tool, ToolType
from promptflow.exceptions import UserErrorException

Expand All @@ -20,15 +24,13 @@ class ToolOperations:
def generate_tool_meta(self, tool_module):
tool_functions = self._collect_tool_functions_in_module(tool_module)
tool_methods = self._collect_tool_class_methods_in_module(tool_module)
tools = [self._parse_tool_from_function(f) for f in tool_functions] + [
self._parse_tool_from_function(f, initialize_inputs) for (f, initialize_inputs) in tool_methods
]
construct_tools = {
f"{t.module}.{t.class_name}.{t.function}"
if t.class_name is not None
else f"{t.module}.{t.function}": asdict(t, dict_factory=lambda x: {k: v for (k, v) in x if v})
for t in tools
}
construct_tools = {}
for f in tool_functions:
tool_name, construct_tool = self._serialize_tool(f)
construct_tools[tool_name] = construct_tool
for (f, initialize_inputs) in tool_methods:
tool_name, construct_tool = self._serialize_tool(f, initialize_inputs)
construct_tools[tool_name] = construct_tool
# The generated dict cannot be dumped as yaml directly since yaml cannot handle string enum.
return json.loads(json.dumps(construct_tools))

Expand Down Expand Up @@ -62,6 +64,7 @@ def _parse_tool_from_function(f, initialize_inputs=None):
tool_type = getattr(f, "__type") or ToolType.PYTHON
tool_name = getattr(f, "__name")
description = getattr(f, "__description")
extra_info = getattr(f, "__extra_info")
if getattr(f, "__tool", None) and isinstance(f.__tool, Tool):
return getattr(f, "__tool")
if hasattr(f, "__original_function"):
Expand All @@ -74,7 +77,7 @@ def _parse_tool_from_function(f, initialize_inputs=None):
if "." in f.__qualname__:
class_name = f.__qualname__.replace(f".{f.__name__}", "")
# Construct the Tool structure
return Tool(
tool = Tool(
name=tool_name or f.__qualname__,
description=description or inspect.getdoc(f),
inputs=inputs,
Expand All @@ -83,6 +86,52 @@ def _parse_tool_from_function(f, initialize_inputs=None):
function=f.__name__,
module=f.__module__,
)
return tool, extra_info

def _serialize_tool(self, tool_func, initialize_inputs=None):
"""
Serialize tool obj to dict.

:param tool_func: Package tool function
:type tool_func: callable
:param initialize_inputs: Initialize inputs of package tool
:type initialize_inputs: Dict[str, obj]
:return: package tool name, serialized tool
:rtype: str, Dict[str, str]
"""
tool, extra_info = self._parse_tool_from_function(tool_func, initialize_inputs)
tool_name = (
f"{tool.module}.{tool.class_name}.{tool.function}"
if tool.class_name is not None
else f"{tool.module}.{tool.function}"
)
construct_tool = asdict(tool, dict_factory=lambda x: {k: v for (k, v) in x if v})
if extra_info:
if "icon" in extra_info:
if not Path(extra_info["icon"]).exists():
raise UserErrorException(f"Cannot find the icon path {extra_info['icon']}.")
extra_info["icon"] = self._serialize_image_data(extra_info["icon"])
construct_tool.update(extra_info)
return tool_name, construct_tool

@staticmethod
def _serialize_image_data(image_path):
"""Serialize image to base64."""
from PIL import Image as PIL_Image

with open(image_path, "rb") as image_file:
# Create a BytesIO object from the image file
image_data = io.BytesIO(image_file.read())

# Open the image and resize it
img = PIL_Image.open(image_data)
if img.size != (16, 16):
img = img.resize((16, 16), PIL_Image.Resampling.LANCZOS)
buffered = io.BytesIO()
img.save(buffered, format="PNG")
icon_image = Image(buffered.getvalue(), mime_type="image/png")
image_url = convert_multimedia_data_to_base64(icon_image, with_type=True)
return image_url

def list(
self,
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"opencensus-ext-azure<2.0.0", # configure opencensus to send telemetry to azure monitor
"ruamel.yaml>=0.17.35,<0.18.0", # used to generate connection templates with preserved comments
"pyarrow>=13.0.0,<14.0.0", # used to read parquet file with pandas.read_parquet
"pillow>=10.1.0,<11.0.0", # used to generate icon data URI for package tool
]

setup(
Expand Down
47 changes: 47 additions & 0 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,53 @@ def test_tool_init(self, capsys):
outerr = capsys.readouterr()
assert f"The tool name {invalid_tool_name} is a invalid identifier." in outerr.out

# Test init package tool with extra info
package_name = "tool_with_extra_info"
package_folder = Path(temp_dir) / package_name
icon_path = Path(DATAS_DIR) / "logo.jpg"
category = "test_category"
tags = {'tag1': 'value1', 'tag2': 'value2'}
run_pf_command(
"tool",
"init",
"--package",
package_name,
"--tool",
func_name,
"--set",
f"icon={icon_path.absolute()}",
f"category={category}",
f"tags={tags}",
cwd=temp_dir
)
spec = importlib.util.spec_from_file_location(
f"{package_name}.utils", package_folder / package_name / "utils.py")
utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils)

assert hasattr(utils, "list_package_tools")
tools_meta = utils.list_package_tools()
meta = tools_meta[f"{package_name}.{func_name}.{func_name}"]
assert meta["category"] == category
assert meta["tags"] == tags
assert meta["icon"].startswith("data:image")

# icon doesn't exist
with pytest.raises(SystemExit):
run_pf_command(
"tool",
"init",
"--package",
package_name,
"--tool",
func_name,
"--set",
"icon=invalid_icon_path",
cwd=temp_dir
)
outerr = capsys.readouterr()
assert "Cannot find the icon path" in outerr.out

def test_tool_list(self, capsys):
# List package tools in environment
run_pf_command("tool", "list")
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading