Skip to content

Commit

Permalink
Change function call type to object (#815)
Browse files Browse the repository at this point in the history
# Description

**Issue:**
The function_call can be of type 'str' or 'dict'. However, when
function_call is set to 'str' type in the char API, it becomes unsafe to
pass the parameter. Exceptions could occur when trying to convert a
'dict' input to a 'str' input in the process of passing this parameter.

**Solution:**
The type of function_call was changed from 'str' to 'object' to enhance
safety. The 'object' type is more versatile and can handle a wider range
of data types, thus reducing the risk of exceptions.

**Tests:** 
Additional unit and end-to-end tests were added to ensure the
functionality and safety of the changes. All tests were successfully
passed, indicating that the solution is effective and doesn't introduce
new issues.

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [x] Pull request includes test coverage for the included changes.

---------

Co-authored-by: cs_lucky <[email protected]>
  • Loading branch information
chenslucky and cs_lucky authored Oct 23, 2023
1 parent 6df5507 commit 4f439c8
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 24 deletions.
5 changes: 3 additions & 2 deletions src/promptflow-tools/promptflow/tools/aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def chat(
frequency_penalty: float = 0,
logit_bias: dict = {},
user: str = "",
function_call: str = None,
# function_call can be of type str or dict.
function_call: object = None,
functions: list = None,
**kwargs,
) -> [str, dict]:
Expand Down Expand Up @@ -218,7 +219,7 @@ def chat(
frequency_penalty: float = 0,
logit_bias: dict = {},
user: str = "",
function_call: str = None,
function_call: object = None,
functions: list = None,
**kwargs,
) -> str:
Expand Down
17 changes: 4 additions & 13 deletions src/promptflow-tools/promptflow/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,15 @@ def process_function_call(function_call):
common_tsg = f"Here is a valid example: {function_call_example}. See the guide at " \
"https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call " \
"or view sample 'How to call functions with chat models' in our gallery."
try:
param = json.loads(function_call)
except json.JSONDecodeError:
raise ChatAPIInvalidFunctions(
message=f"function_call parameter '{function_call}' is an invalid json. {common_tsg}")
except TypeError:
raise ChatAPIInvalidFunctions(
message=f"function_call parameter '{function_call}' must be str, bytes or bytearray"
f", but not {type(function_call)}. {common_tsg}"
)
param = function_call
if not isinstance(param, dict):
raise ChatAPIInvalidFunctions(
message=f"function_call parameter '{function_call}' must be a dict, but not {type(param)}. {common_tsg}"
message=f"function_call parameter '{param}' must be a dict, but not {type(function_call)}. {common_tsg}"
)
else:
if "name" not in param:
if "name" not in function_call:
raise ChatAPIInvalidFunctions(
message=f'function_call parameter {function_call} must contain "name" field. {common_tsg}'
message=f'function_call parameter {json.dumps(param)} must contain "name" field. {common_tsg}'
)
return param

Expand Down
5 changes: 3 additions & 2 deletions src/promptflow-tools/promptflow/tools/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def chat(
frequency_penalty: float = 0,
logit_bias: dict = {},
user: str = "",
function_call: str = None,
# function_call can be of type str or dict.
function_call: object = None,
functions: list = None,
**kwargs,
) -> [str, dict]:
Expand Down Expand Up @@ -211,7 +212,7 @@ def chat(
frequency_penalty: float = 0,
logit_bias: dict = {},
user: str = "",
function_call: str = None,
function_call: object = None,
functions: list = None,
**kwargs,
) -> [str, dict]:
Expand Down
11 changes: 9 additions & 2 deletions src/promptflow-tools/tests/test_aoai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,15 @@ def test_aoai_chat_api(self, azure_open_ai_connection, example_prompt_template,
)
assert "Product X".lower() in result.lower()

@pytest.mark.parametrize(
"function_call",
[
"auto",
{"name": "get_current_weather"},
],
)
def test_aoai_chat_with_function(
self, azure_open_ai_connection, example_prompt_template, chat_history, functions):
self, azure_open_ai_connection, example_prompt_template, chat_history, functions, function_call):
result = chat(
connection=azure_open_ai_connection,
prompt=example_prompt_template,
Expand All @@ -59,7 +66,7 @@ def test_aoai_chat_with_function(
user_input="What is the weather in Boston?",
chat_history=chat_history,
functions=functions,
function_call="auto"
function_call=function_call
)
assert "function_call" in result
assert result["function_call"]["name"] == "get_current_weather"
Expand Down
8 changes: 3 additions & 5 deletions src/promptflow-tools/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ def test_chat_api_invalid_functions(self, functions, error_message):
@pytest.mark.parametrize(
"function_call, error_message",
[
({"name": "get_current_weather"}, "must be str, bytes or bytearray"),
("{'name': 'get_current_weather'}", "is an invalid json"),
("get_current_weather", "is an invalid json"),
("123", "function_call parameter '123' must be a dict"),
('{"name1": "get_current_weather"}', 'function_call parameter {"name1": "get_current_weather"} must '
'contain "name" field'),
({"name1": "get_current_weather"},
'function_call parameter {"name1": "get_current_weather"} must '
'contain "name" field'),
],
)
def test_chat_api_invalid_function_call(self, function_call, error_message):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_executor_storage(self, dev_connections):
"prompt_tools",
"script_with___file__",
"connection_as_input",
"sample_flow_with_functions"
],
)
def test_executor_exec_bulk(self, flow_folder, dev_connections):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
id: use_functions_with_chat_models
name: Use Functions with Chat Models
inputs:
chat_history:
type: list
default:
- inputs:
question: What is the weather like in Boston?
outputs:
answer: '{"forecast":["sunny","windy"],"location":"Boston","temperature":"72","unit":"fahrenheit"}'
llm_output:
content: null
function_call:
name: get_current_weather
arguments: |-
{
"location": "Boston"
}
role: assistant
is_chat_input: false
question:
type: string
default: How about London next week?
is_chat_input: true
outputs:
answer:
type: string
reference: ${run_function.output}
is_chat_output: true
llm_output:
type: object
reference: ${use_functions_with_chat_models.output}
nodes:
- name: run_function
type: python
source:
type: code
path: run_function.py
inputs:
response_message: ${use_functions_with_chat_models.output}
use_variants: false
- name: use_functions_with_chat_models
type: llm
source:
type: code
path: use_functions_with_chat_models.jinja2
inputs:
deployment_name: gpt-35-turbo
temperature: 0.7
top_p: 1
stop: ""
max_tokens: 256
presence_penalty: 0
frequency_penalty: 0
logit_bias: ""
functions:
- name: get_current_weather
description: Get the current weather in a given location
parameters:
type: object
properties:
location:
type: string
description: The city and state, e.g. San Francisco, CA
unit:
type: string
enum:
- celsius
- fahrenheit
required:
- location
- name: get_n_day_weather_forecast
description: Get an N-day weather forecast
parameters:
type: object
properties:
location:
type: string
description: The city and state, e.g. San Francisco, CA
format:
type: string
enum:
- celsius
- fahrenheit
description: The temperature unit to use. Infer this from the users location.
num_days:
type: integer
description: The number of days to forecast
required:
- location
- format
- num_days
function_call:
name: get_current_weather
chat_history: ${inputs.chat_history}
question: ${inputs.question}
provider: AzureOpenAI
connection: azure_open_ai_connection
api: chat
module: promptflow.tools.aoai
use_variants: false
node_variants: {}
environment:
python_requirements_txt: requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from promptflow import tool
import json


def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
weather_info = {
"location": location,
"temperature": "72",
"unit": unit,
"forecast": ["sunny", "windy"],
}
return weather_info


def get_n_day_weather_forecast(location, format, num_days):
"""Get next num_days weather in a given location"""
weather_info = {
"location": location,
"temperature": "60",
"format": format,
"forecast": ["rainy"],
"num_days": num_days,
}
return weather_info


@tool
def run_function(response_message: dict) -> str:
if "function_call" in response_message:
function_name = response_message["function_call"]["name"]
function_args = json.loads(response_message["function_call"]["arguments"])
print(function_args)
result = globals()[function_name](**function_args)
else:
print("No function call")
if isinstance(response_message, dict):
result = response_message["content"]
else:
result = response_message
return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
system:
Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.

{% for item in chat_history %}
user:
{{item.inputs.question}}

{% if 'function_call' in item.outputs.llm_output %}
assistant:
Function generation requested, function = {{item.outputs.llm_output.function_call.name}}, args = {{item.outputs.llm_output.function_call.arguments}}

function:
name:
{{item.outputs.llm_output.function_call.name}}
content:
{{item.outputs.answer}}

{% else %}
assistant:
{{item.outputs.llm_output}}}}

{% endif %}}

{% endfor %}

user:
{{question}}

0 comments on commit 4f439c8

Please sign in to comment.