diff --git a/src/promptflow-tools/promptflow/tools/common.py b/src/promptflow-tools/promptflow/tools/common.py index 1737c114328..4fd95a08a6e 100644 --- a/src/promptflow-tools/promptflow/tools/common.py +++ b/src/promptflow-tools/promptflow/tools/common.py @@ -16,7 +16,6 @@ from promptflow.exceptions import SystemErrorException, UserErrorException from promptflow.tools.exception import ( ToolValidationError, - ChatAPIAssistantRoleInvalidFormat, ChatAPIFunctionRoleInvalidFormat, ChatAPIToolRoleInvalidFormat, ChatAPIInvalidFunctions, @@ -242,38 +241,21 @@ def try_parse_tool_call_id_and_content(role_prompt): def try_parse_tool_calls(role_prompt): # customer can add ## in front of tool_calls for markdown highlight. # and we still support tool_calls without ## prefix for backward compatibility. - pattern = r"\n*#{0,2}\s*tool_calls:\n*\s*(\[.*?\])" + pattern = r"\n*#{0,2}\s*tool_calls\s*:\s*\n+\s*(\[.*?\])" match = re.search(pattern, role_prompt, re.DOTALL) if match: - return match.group(1) + try: + parsed_array = eval(match.group(1)) + return parsed_array + except Exception: + None return None -def is_tools_chunk(last_message): +def is_tool_chunk(last_message): return last_message and "role" in last_message and last_message["role"] == "tool" and "content" not in last_message -def is_assistant_tool_calls_chunk(last_message, chunk): - return last_message and "role" in last_message and last_message["role"] == "assistant" and "tool_calls" in chunk - - -def parse_tool_calls_for_assistant(last_message, chunk): - parsed_result = try_parse_tool_calls(chunk) - error_msg = "Failed to parse assistant role prompt with tool_calls. Please make sure the prompt follows the format:" - " 'tool_calls:\\n[{ id: tool_call_id, type: tool_type, function: {name: function_name, arguments: function_args }]'" - "See more details in https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages" - - if parsed_result is None: - raise ChatAPIAssistantRoleInvalidFormat(message=error_msg) - else: - parsed_array = None - try: - parsed_array = eval(parsed_result) - last_message["tool_calls"] = parsed_array - except Exception: - raise ChatAPIAssistantRoleInvalidFormat(message=error_msg) - - def parse_tools(last_message, chunk, hash2images, image_detail): parsed_result = try_parse_tool_call_id_and_content(chunk) if parsed_result is None: @@ -311,13 +293,15 @@ def parse_chat( for chunk in chunks: last_message = chat_list[-1] if len(chat_list) > 0 else None - if is_tools_chunk(last_message): + if is_tool_chunk(last_message): parse_tools(last_message, chunk, hash2images, image_detail) continue - if is_assistant_tool_calls_chunk(last_message, chunk): - parse_tool_calls_for_assistant(last_message, chunk) - continue + if last_message and "role" in last_message and last_message["role"] == "assistant": + parsed_result = try_parse_tool_calls(chunk) + if parsed_result is not None: + last_message["tool_calls"] = parsed_result + continue if ( last_message diff --git a/src/promptflow-tools/promptflow/tools/exception.py b/src/promptflow-tools/promptflow/tools/exception.py index c3e00fad8da..199fa909398 100644 --- a/src/promptflow-tools/promptflow/tools/exception.py +++ b/src/promptflow-tools/promptflow/tools/exception.py @@ -140,11 +140,6 @@ class ChatAPIToolRoleInvalidFormat(ToolValidationError): pass -class ChatAPIAssistantRoleInvalidFormat(ToolValidationError): - """Base exception raised when failed to validate chat api assistant role format.""" - pass - - class ChatAPIInvalidFunctions(ToolValidationError): """Base exception raised when failed to validate functions when call chat api.""" pass diff --git a/src/promptflow-tools/tests/test_common.py b/src/promptflow-tools/tests/test_common.py index 97373babdaa..7fd1a9b2171 100644 --- a/src/promptflow-tools/tests/test_common.py +++ b/src/promptflow-tools/tests/test_common.py @@ -6,12 +6,11 @@ from promptflow.tools.common import ChatAPIInvalidFunctions, validate_functions, process_function_call, \ parse_chat, find_referenced_image_set, preprocess_template_string, convert_to_chat_list, ChatInputList, \ ParseConnectionError, _parse_resource_id, list_deployment_connections, normalize_connection_config, \ - parse_tool_calls_for_assistant, validate_tools, process_tool_choice, init_azure_openai_client, \ + validate_tools, process_tool_choice, init_azure_openai_client, try_parse_tool_calls, \ Escaper, PromptResult, render_jinja_template, build_messages from promptflow.tools.exception import ( ListDeploymentsError, ChatAPIInvalidTools, - ChatAPIAssistantRoleInvalidFormat, ChatAPIToolRoleInvalidFormat, ) @@ -227,10 +226,6 @@ def test_parse_chat_with_name_in_role_prompt(self, chat_str, expected_result): @pytest.mark.parametrize( "chat_str, error_message, exception_type", [(""" - # assistant: - ## tool_calls: - """, "Failed to parse assistant role prompt with tool_calls", ChatAPIAssistantRoleInvalidFormat), - (""" # tool: ## tool_call_id: """, "Failed to parse tool role prompt.", ChatAPIToolRoleInvalidFormat,)]) @@ -243,6 +238,23 @@ def test_try_parse_chat_with_tools(self, example_prompt_template_with_tool, pars actual_result = parse_chat(example_prompt_template_with_tool) assert actual_result == parsed_chat_with_tools + @pytest.mark.parametrize( + "role_prompt, expected_result", + [("## tool_calls:\n[]", []), + ("## tool_calls:\r\n[]", []), + ("## tool_calls: \n[]", []), + ("## tool_calls :\r\n[]", []), + ("tool_calls:\r\n[]", []), + ("some text", None), + ("tool_calls:\r\n[", None), + ("tool_calls:\r\n[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]", + [{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]), + ("tool_calls:\n[{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}]", + [{'id': 'tool_call_id', 'type': 'function', 'function': {'name': 'func1', 'arguments': ''}}])]) + def test_try_parse_tool_calls(self, role_prompt, expected_result): + actual = try_parse_tool_calls(role_prompt) + assert actual == expected_result + @pytest.mark.parametrize( "chat_str, expected_result", [ @@ -257,70 +269,6 @@ def test_parse_tool_call_id_and_content(self, chat_str, expected_result): actual_result = parse_chat(chat_str) assert actual_result == expected_result - @pytest.mark.parametrize("chunk, error_msg, success", [ - (""" - ## tool_calls: - """, "Failed to parse assistant role prompt with tool_calls", False), - (""" - ## tool_calls: - tool_calls_str - """, "Failed to parse assistant role prompt with tool_calls", False), - (""" - ## tool_calls: - [{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}] - """, "", True), - (""" - ## tool_calls: - - [{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}] - """, "", True), - (""" - ## tool_calls:[{"id": "tool_call_id", "type": "function", "function": {"name": "func1", "arguments": ""}}] - """, "", True), - (""" - ## tool_calls: - [{ - "id": "tool_call_id", - "type": "function", - "function": {"name": "func1", "arguments": ""} - }] - """, "", True), - (""" - ## tool_calls: - [{ - "id": "tool_call_id", "type": "function", - "function": {"name": "func1", "arguments": ""} - }] - """, "", True), - # portal may add extra \r to new line character. - (""" - ## tool_calls:\r - [{ - "id": "tool_call_id", "type": "function", - "function": {"name": "func1", "arguments": ""} - }] - """, "", True), - ]) - def test_parse_tool_calls_for_assistant(self, chunk: str, error_msg: str, success: bool): - last_message = {'role': 'assistant'} - if success: - expected_res = [ - { - "id": "tool_call_id", - "type": "function", - "function": { - "name": "func1", - "arguments": "", - }, - } - ] - parse_tool_calls_for_assistant(last_message, chunk) - assert last_message["tool_calls"] == expected_res - else: - with pytest.raises(ChatAPIAssistantRoleInvalidFormat) as exc_info: - parse_tool_calls_for_assistant(last_message, chunk) - assert error_msg in exc_info.value.message - @pytest.mark.parametrize( "kwargs, expected_result", [