Skip to content

Commit

Permalink
update function call code structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyun-Liang committed Jul 9, 2024
1 parent 071cedf commit 075b053
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 95 deletions.
16 changes: 9 additions & 7 deletions examples/quick_start/openai_example_func_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,25 @@ def get_current_weather(location: str, unit: str = "fahrenheit"):


@sgl.function
def multi_turn_question(s, question_1, functions=[]):
def question(s, question, tools=[]):
s += sgl.system("You are a helpful assistant.")
s += sgl.user(question_1)
s += sgl.func_call("func_call_1", tools=functions, tool_choice="auto")
s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
s += sgl.user(question)
s += sgl.assistant(
sgl.gen("answer_1", max_tokens=256, tools=tools, tool_choice="auto")
)


def single():
state = multi_turn_question.run(
question_1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
functions=[get_current_weather],
state = question.run(
question="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?",
tools=[get_current_weather],
)

for m in state.messages():
print(m["role"], ":", m["content"])

print("\n-- answer_1 --\n", state["answer_1"])
# TODO: do we need to add another check for function call results


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
assistant_end,
flush_cache,
function,
func_call,
gen,
gen_int,
gen_string,
Expand Down Expand Up @@ -59,5 +58,4 @@
"user_end",
"assistant_begin",
"assistant_end",
"func_call",
]
13 changes: 4 additions & 9 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
SglRoleEnd,
SglSelect,
SglVideo,
SglFuncCall,
)


Expand Down Expand Up @@ -133,6 +132,8 @@ def gen_string(
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
return SglGen(
name,
Expand All @@ -146,6 +147,8 @@ def gen_string(
ignore_eos,
str,
None,
tools,
tool_choice,
)


Expand Down Expand Up @@ -199,11 +202,3 @@ def assistant_begin():

def assistant_end():
return SglRoleEnd("assistant")


def func_call(
name: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_choice: Optional[str] = "auto",
):
return SglFuncCall(name, tools, tool_choice)
82 changes: 48 additions & 34 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def create_logit_bias_int(tokenizer):
"gpt-3.5-turbo-instruct",
]

PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES = [
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
]

FUNC_CALL_ENABLED_MODEL_NAMES = PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES + [
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0613",
]


@dataclasses.dataclass
class TokenUsage:
Expand Down Expand Up @@ -142,6 +161,7 @@ def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
function_call_messages: List = [],
spec_var_name: str = None,
):
if sampling_params.dtype is None:
Expand All @@ -153,11 +173,7 @@ def generate(
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
# Open AI model requires function call information to be sent to the model
# along with the prompt.
for function_call in s.function_calls:
prompt.append(function_call)
prompt = s.messages_ + function_call_messages
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
Expand Down Expand Up @@ -235,32 +251,20 @@ def spec_pattern_match(self, comp):
return False
return True

def function_calling(
def build_function_call_messages(
self,
s: StreamExecutor,
tools: List[str],
tool_choice: str,
):
assert self.is_chat_model, "function calling only supported on chat model"
# TODO: special handling for chat model vs. non chat model, stream vs non stream
if self.model_name not in [
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
# OpenAI chat models currently do not support function calling
if self.model_name not in FUNC_CALL_ENABLED_MODEL_NAMES:
raise RuntimeError(
"This model currently does not support function calling."
)
is_parallel_func_call_enabled_model = (
self.model_name in PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES
)

def convert_param_type(type):
if type == "int" or type == "integer":
Expand Down Expand Up @@ -293,52 +297,61 @@ def function_to_json_schema(func):
}
return func_schema

def build_tool_choice_param():
if tool_choice in ["auto", "required", "none"]:
return tool_choice
else:
assert (
tool_choice in tools
), "could not find a candidate function that matches the provided tool choice"
return {"type": "function", "function": {"name": tool_choice}}

tools_to_use = []
if tools:
tools_to_use = [
function_to_json_schema(tool_to_use) for tool_to_use in tools
]
cur_tool_choice = "auto"
if tool_choice:
cur_tool_choice = (
tool_choice
if tool_choice in ["auto", "required", "none"]
else {"type": "function", "function": {"name": tool_choice}}
)
tool_choice = build_tool_choice_param()

# TODO: "Never mention what tools you use." or provide a system prompt input argument
response = self.client.chat.completions.create(
model=self.model_name,
messages=s.messages_,
tools=tools_to_use,
tool_choice=cur_tool_choice,
tool_choice=tool_choice,
**self.spec_kwargs,
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
# Check if the model wanted to call a function
ret_messages = []
single_tool_call = []
if tool_calls:
# Call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {}
for tool in tools:
available_functions[tool.__name__] = tool
ret_messages.append(response_message)
single_tool_call.append(response_message)
# Send the info for each function call and function response to the model
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
ret_messages.append(
single_tool_call.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": str(function_response),
}
)
if not is_parallel_func_call_enabled_model:
ret_messages.append(single_tool_call)
single_tool_call = []
if is_parallel_func_call_enabled_model:
ret_messages.append(single_tool_call)
return ret_messages

def role_end_generate(
Expand Down Expand Up @@ -378,6 +391,7 @@ def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
function_call_messages: List = [],
):
if sampling_params.dtype is None:
if self.is_chat_model:
Expand All @@ -386,7 +400,7 @@ def generate_stream(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
prompt = s.messages_ + function_call_messages
else:
prompt = s.text_

Expand Down
Loading

0 comments on commit 075b053

Please sign in to comment.