Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix unintended parsing tool calls #3001

Merged
merged 6 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 13 additions & 29 deletions src/promptflow-tools/promptflow/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from promptflow.exceptions import SystemErrorException, UserErrorException
from promptflow.tools.exception import (
ToolValidationError,
ChatAPIAssistantRoleInvalidFormat,
ChatAPIFunctionRoleInvalidFormat,
ChatAPIToolRoleInvalidFormat,
ChatAPIInvalidFunctions,
Expand Down Expand Up @@ -242,38 +241,21 @@ def try_parse_tool_call_id_and_content(role_prompt):
def try_parse_tool_calls(role_prompt):
# customer can add ## in front of tool_calls for markdown highlight.
# and we still support tool_calls without ## prefix for backward compatibility.
pattern = r"\n*#{0,2}\s*tool_calls:\n*\s*(\[.*?\])"
pattern = r"\n*#{0,2}\s*tool_calls\s*:\s*\n+\s*(\[.*?\])"
match = re.search(pattern, role_prompt, re.DOTALL)
if match:
return match.group(1)
try:
parsed_array = eval(match.group(1))
return parsed_array
except Exception:
None
return None


def is_tools_chunk(last_message):
def is_tool_chunk(last_message):
return last_message and "role" in last_message and last_message["role"] == "tool" and "content" not in last_message


def is_assistant_tool_calls_chunk(last_message, chunk):
return last_message and "role" in last_message and last_message["role"] == "assistant" and "tool_calls" in chunk


def parse_tool_calls_for_assistant(last_message, chunk):
parsed_result = try_parse_tool_calls(chunk)
error_msg = "Failed to parse assistant role prompt with tool_calls. Please make sure the prompt follows the format:"
" 'tool_calls:\\n[{ id: tool_call_id, type: tool_type, function: {name: function_name, arguments: function_args }]'"
"See more details in https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages"

if parsed_result is None:
raise ChatAPIAssistantRoleInvalidFormat(message=error_msg)
else:
parsed_array = None
try:
parsed_array = eval(parsed_result)
last_message["tool_calls"] = parsed_array
except Exception:
raise ChatAPIAssistantRoleInvalidFormat(message=error_msg)


def parse_tools(last_message, chunk, hash2images, image_detail):
parsed_result = try_parse_tool_call_id_and_content(chunk)
if parsed_result is None:
Expand Down Expand Up @@ -311,13 +293,15 @@ def parse_chat(

for chunk in chunks:
last_message = chat_list[-1] if len(chat_list) > 0 else None
if is_tools_chunk(last_message):
if is_tool_chunk(last_message):
parse_tools(last_message, chunk, hash2images, image_detail)
continue

if is_assistant_tool_calls_chunk(last_message, chunk):
parse_tool_calls_for_assistant(last_message, chunk)
continue
if last_message and "role" in last_message and last_message["role"] == "assistant":
parsed_result = try_parse_tool_calls(chunk)
if parsed_result is not None:
last_message["tool_calls"] = parsed_result
continue

if (
last_message
Expand Down
5 changes: 0 additions & 5 deletions src/promptflow-tools/promptflow/tools/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ class ChatAPIToolRoleInvalidFormat(ToolValidationError):
pass


class ChatAPIAssistantRoleInvalidFormat(ToolValidationError):
"""Base exception raised when failed to validate chat api assistant role format."""
pass


class ChatAPIInvalidFunctions(ToolValidationError):
"""Base exception raised when failed to validate functions when call chat api."""
pass
Expand Down
78 changes: 17 additions & 61 deletions src/promptflow-tools/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from promptflow.tools.common import ChatAPIInvalidFunctions, validate_functions, process_function_call, \
parse_chat, find_referenced_image_set, preprocess_template_string, convert_to_chat_list, ChatInputList, \
ParseConnectionError, _parse_resource_id, list_deployment_connections, normalize_connection_config, \
parse_tool_calls_for_assistant, validate_tools, process_tool_choice, init_azure_openai_client, \
validate_tools, process_tool_choice, init_azure_openai_client, try_parse_tool_calls, \
Escaper, PromptResult, render_jinja_template, build_messages
from promptflow.tools.exception import (
ListDeploymentsError,
ChatAPIInvalidTools,
ChatAPIAssistantRoleInvalidFormat,
ChatAPIToolRoleInvalidFormat,
)

Expand Down Expand Up @@ -224,10 +223,6 @@ def test_parse_chat_with_name_in_role_prompt(self, chat_str, expected_result):
@pytest.mark.parametrize(
"chat_str, error_message, exception_type",
[("""
# assistant:
## tool_calls:
""", "Failed to parse assistant role prompt with tool_calls", ChatAPIAssistantRoleInvalidFormat),
("""
# tool:
## tool_call_id:
""", "Failed to parse tool role prompt.", ChatAPIToolRoleInvalidFormat,)])
Expand All @@ -240,61 +235,22 @@ def test_try_parse_chat_with_tools(self, example_prompt_template_with_tool, pars
actual_result = parse_chat(example_prompt_template_with_tool)
assert actual_result == parsed_chat_with_tools

@pytest.mark.parametrize("chunk, error_msg, success", [
("""
## tool_calls:
""", "Failed to parse assistant role prompt with tool_calls", False),
("""
## tool_calls:
tool_calls_str
""", "Failed to parse assistant role prompt with tool_calls", False),
("""
## tool_calls:
[{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}]
""", "", True),
("""
## tool_calls:

[{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}]
""", "", True),
("""
## tool_calls:[{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}]
""", "", True),
("""
## tool_calls:
[{
"id": "tool_call_id",
"type": "function",
"function": {"name": "func1", "arguments": ""}
}]
""", "", True),
("""
## tool_calls:
[{
"id": "tool_call_id", "type": "function",
"function": {"name": "func1", "arguments": ""}
}]
""", "", True),
])
def test_parse_tool_calls_for_assistant(self, chunk: str, error_msg: str, success: bool):
last_message = {'role': 'assistant'}
if success:
expected_res = [
{
"id": "tool_call_id",
"type": "function",
"function": {
"name": "func1",
"arguments": "",
},
}
]
parse_tool_calls_for_assistant(last_message, chunk)
assert last_message["tool_calls"] == expected_res
else:
with pytest.raises(ChatAPIAssistantRoleInvalidFormat) as exc_info:
parse_tool_calls_for_assistant(last_message, chunk)
assert error_msg in exc_info.value.message
@pytest.mark.parametrize(
"role_prompt, expected_result",
[("## tool_calls:\n[]", []),
("## tool_calls:\r\n[]", []),
("## tool_calls: \n[]", []),
("## tool_calls :\r\n[]", []),
("tool_calls:\r\n[]", []),
("some text", None),
("tool_calls:\r\n[", None),
("tool_calls:\r\n[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]",
[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]),
("tool_calls:\n[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]",
[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}])])
def test_try_parse_tool_calls(self, role_prompt, expected_result):
actual = try_parse_tool_calls(role_prompt)
assert actual == expected_result

@pytest.mark.parametrize(
"kwargs, expected_result",
Expand Down
Loading