diff --git a/examples/quick_start/openai_example_func_call.py b/examples/quick_start/openai_example_func_call.py index 56a08078102..6f411e945dc 100644 --- a/examples/quick_start/openai_example_func_call.py +++ b/examples/quick_start/openai_example_func_call.py @@ -23,23 +23,25 @@ def get_current_weather(location: str, unit: str = "fahrenheit"): @sgl.function -def multi_turn_question(s, question_1, functions=[]): +def question(s, question, tools=[]): s += sgl.system("You are a helpful assistant.") - s += sgl.user(question_1) - s += sgl.func_call("func_call_1", tools=functions, tool_choice="auto") - s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question) + s += sgl.assistant( + sgl.gen("answer_1", max_tokens=256, tools=tools, tool_choice="auto") + ) def single(): - state = multi_turn_question.run( - question_1="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?", - functions=[get_current_weather], + state = question.run( + question="What's the weather like in San Francisco, Tokyo, Paris, and Beijing?", + tools=[get_current_weather], ) for m in state.messages(): print(m["role"], ":", m["content"]) print("\n-- answer_1 --\n", state["answer_1"]) + # TODO: do we need to add another check for function call results if __name__ == "__main__": diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index ece67b727b2..556b9eb335b 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -8,7 +8,6 @@ assistant_end, flush_cache, function, - func_call, gen, gen_int, gen_string, @@ -59,5 +58,4 @@ "user_end", "assistant_begin", "assistant_end", - "func_call", ] diff --git a/python/sglang/api.py b/python/sglang/api.py index b7da35392af..6a93e67f5ab 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -16,7 +16,6 @@ SglRoleEnd, SglSelect, SglVideo, - SglFuncCall, ) @@ -133,6 +132,8 @@ def gen_string( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, + tools: Optional[List[str]] = None, + tool_choice: Optional[str] = "auto", ): return SglGen( name, @@ -146,6 +147,8 @@ def gen_string( ignore_eos, str, None, + tools, + tool_choice, ) @@ -199,11 +202,3 @@ def assistant_begin(): def assistant_end(): return SglRoleEnd("assistant") - - -def func_call( - name: Optional[str] = None, - tools: Optional[List[str]] = None, - tool_choice: Optional[str] = "auto", -): - return SglFuncCall(name, tools, tool_choice) diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 167cb10b898..9bce156942a 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -44,6 +44,25 @@ def create_logit_bias_int(tokenizer): "gpt-3.5-turbo-instruct", ] +PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES = [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", +] + +FUNC_CALL_ENABLED_MODEL_NAMES = PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES + [ + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", +] + @dataclasses.dataclass class TokenUsage: @@ -142,6 +161,7 @@ def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, + function_call_messages: List = [], spec_var_name: str = None, ): if sampling_params.dtype is None: @@ -153,11 +173,7 @@ def generate( "For OpenAI chat models, sgl.gen must be right after sgl.assistant. " "Example of adding api speculative execution: @function(num_api_spec_tokens=128)." ) - prompt = s.messages_ - # Open AI model requires function call information to be sent to the model - # along with the prompt. - for function_call in s.function_calls: - prompt.append(function_call) + prompt = s.messages_ + function_call_messages else: return self._prepare_spec_execution( sampling_params, s.num_api_spec_tokens, spec_var_name @@ -235,32 +251,20 @@ def spec_pattern_match(self, comp): return False return True - def function_calling( + def build_function_call_messages( self, s: StreamExecutor, tools: List[str], tool_choice: str, ): - assert self.is_chat_model, "function calling only supported on chat model" - # TODO: special handling for chat model vs. non chat model, stream vs non stream - if self.model_name not in [ - "gpt-4o", - "gpt-4o-2024-05-13", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4", - "gpt-4-0613", - "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0613", - ]: + # OpenAI chat models currently do not support function calling + if self.model_name not in FUNC_CALL_ENABLED_MODEL_NAMES: raise RuntimeError( "This model currently does not support function calling." ) + is_parallel_func_call_enabled_model = ( + self.model_name in PARALLEL_FUNC_CALL_ENABLED_MODEL_NAMES + ) def convert_param_type(type): if type == "int" or type == "integer": @@ -293,45 +297,49 @@ def function_to_json_schema(func): } return func_schema + def build_tool_choice_param(): + if tool_choice in ["auto", "required", "none"]: + return tool_choice + else: + assert ( + tool_choice in tools + ), "could not find a candidate function that matches the provided tool choice" + return {"type": "function", "function": {"name": tool_choice}} + tools_to_use = [] if tools: tools_to_use = [ function_to_json_schema(tool_to_use) for tool_to_use in tools ] - cur_tool_choice = "auto" if tool_choice: - cur_tool_choice = ( - tool_choice - if tool_choice in ["auto", "required", "none"] - else {"type": "function", "function": {"name": tool_choice}} - ) + tool_choice = build_tool_choice_param() - # TODO: "Never mention what tools you use." or provide a system prompt input argument response = self.client.chat.completions.create( model=self.model_name, messages=s.messages_, tools=tools_to_use, - tool_choice=cur_tool_choice, + tool_choice=tool_choice, **self.spec_kwargs, ) response_message = response.choices[0].message tool_calls = response_message.tool_calls # Check if the model wanted to call a function ret_messages = [] + single_tool_call = [] if tool_calls: # Call the function # Note: the JSON response may not always be valid; be sure to handle errors available_functions = {} for tool in tools: available_functions[tool.__name__] = tool - ret_messages.append(response_message) + single_tool_call.append(response_message) # Send the info for each function call and function response to the model for tool_call in tool_calls: function_name = tool_call.function.name function_to_call = available_functions[function_name] function_args = json.loads(tool_call.function.arguments) function_response = function_to_call(**function_args) - ret_messages.append( + single_tool_call.append( { "tool_call_id": tool_call.id, "role": "tool", @@ -339,6 +347,11 @@ def function_to_json_schema(func): "content": str(function_response), } ) + if not is_parallel_func_call_enabled_model: + ret_messages.append(single_tool_call) + single_tool_call = [] + if is_parallel_func_call_enabled_model: + ret_messages.append(single_tool_call) return ret_messages def role_end_generate( @@ -378,6 +391,7 @@ def generate_stream( self, s: StreamExecutor, sampling_params: SglSamplingParams, + function_call_messages: List = [], ): if sampling_params.dtype is None: if self.is_chat_model: @@ -386,7 +400,7 @@ def generate_stream( "This use case is not supported. " "For OpenAI chat models, sgl.gen must be right after sgl.assistant" ) - prompt = s.messages_ + prompt = s.messages_ + function_call_messages else: prompt = s.text_ diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index faf63725703..98c1e2d2679 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -23,7 +23,6 @@ SglFunction, SglGen, SglImage, - SglFuncCall, SglRoleBegin, SglRoleEnd, SglSelect, @@ -204,7 +203,7 @@ def __init__( self.cur_role_begin_pos = None # For function calling - self.function_calls = [] # The messages in the OpenAI API format + self.function_calls_messages = [] # The messages in the OpenAI API format # For vision self.images_ = [] @@ -371,8 +370,6 @@ def _execute(self, other): elif isinstance(other, SglExprList): for x in other.expr_list: self._execute(x) - elif isinstance(other, SglFuncCall): - self._execute_func_call(other) elif isinstance(other, SglRoleBegin): self._execute_role_begin(other) elif isinstance(other, SglRoleEnd): @@ -492,33 +489,71 @@ def find_stop(): return comp, meta_info def _execute_gen(self, expr: SglGen): + if expr.tools: + if self.backend.is_chat_model: + # Previous function calls are not remembered, users are expected to + # provide all candidate functions in the current generate call + self.function_calls_messages = ( + self.backend.build_function_call_messages( + self, expr.tools, expr.tool_choice + ) + ) + self._execute_gen_helper(expr) + else: + self._execute_gen_helper(expr) + + def _execute_gen_helper(self, expr: SglGen): sampling_params = self._resolve_sampling_params(expr.sampling_params) name = expr.name if not self.stream: if self.num_api_spec_tokens is None: - comp, meta_info = self.backend.generate( - self, - sampling_params=sampling_params, - ) + if self.function_calls_messages: + comp_list = [] + for function_call_messages in self.function_calls_messages: + function_call_response, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + function_call_messages=function_call_messages, + ) + comp_list.append(function_call_response) + else: + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + function_call_messages=function_call_messages, + ) else: if self.backend.is_chat_model: # Speculative execution on models with only chat interface. # Store the calls into a temporary list. # They will be lazily executed later. - comp, meta_info = self.backend.generate( - self, - sampling_params=sampling_params, - spec_var_name=name, - ) + if self.function_calls_messages: + # Handles models that support/don't support parallel function calling + comp_list = [] + for function_call_messages in self.function_calls_messages: + function_call_response, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + function_call_messages=function_call_messages, + spec_var_name=name, + ) + comp_list.append(function_call_response) + else: + comp, meta_info = self.backend.generate( + self, + sampling_params=sampling_params, + function_call_messages=function_call_messages, + spec_var_name=name, + ) return else: # Speculative execution on models with completion interface comp, meta_info = self._spec_gen(sampling_params) - self.text_ += comp + self.text_ += comp_list if self.function_calls_messages else comp - self.variables[name] = comp + self.variables[name] = comp_list if self.function_calls_messages else comp self.meta_info[name] = meta_info self.variable_event[name].set() else: @@ -526,7 +561,8 @@ def _execute_gen(self, expr: SglGen): self.num_api_spec_tokens is None ), "stream is not supported with api speculative execution" generator = self.backend.generate_stream( - self, sampling_params=sampling_params + self, + sampling_params=sampling_params, ) self.stream_var_event[name].set() @@ -560,12 +596,6 @@ def _execute_select(self, expr: SglSelect): self.variable_event[name].set() self.text_ += decision - def _execute_func_call(self, expr: SglFuncCall): - # TODO: Should we clear the previous function call states for the next function call - self.function_calls = self.backend.function_calling( - self, expr.tools, expr.tool_choice - ) - def _execute_variable(self, expr: SglVariable): src_executor = expr.source_stream_executor value = src_executor.get_var(expr.name) @@ -762,14 +792,7 @@ def text(self): return self.stream_executor.text() def messages(self): - # We do not want to expose tool use information to users in the final response, - # so removing the auxillary information from final messages. - filtered_list = [ - item - for item in self.stream_executor.messages() - if item not in self.stream_executor.function_calls - ] - return filtered_list + return self.stream_executor.messages() def sync(self): return self.stream_executor.sync() diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index a14a4623ed5..968050a0e8d 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -366,6 +366,8 @@ def __init__( ignore_eos, dtype, regex, + tools, + tool_choice, ): super().__init__() self.name = name @@ -381,6 +383,8 @@ def __init__( dtype=dtype, regex=regex, ) + self.tools = tools + self.tool_choice = tool_choice def __repr__(self): return f"Gen('{self.name}')" @@ -424,19 +428,6 @@ def __repr__(self): return f"Select({self.name}, choices={self.choices})" -class SglFuncCall(SglExpr): - def __init__(self, name, tools, tool_choice): - super().__init__() - self.name = name - self.tools = tools - self.tool_choice = tool_choice - - def __repr__(self): - return ( - f"FuncCall({self.name}, tools={self.tools}, tool_choice={self.tool_choice})" - ) - - class SglFork(SglExpr): def __init__(self, number, position_ids_offset=None): super().__init__()