diff --git a/inference_gateway/models.py b/inference_gateway/models.py index 6283df89..fd201664 100644 --- a/inference_gateway/models.py +++ b/inference_gateway/models.py @@ -48,6 +48,7 @@ class InferenceToolCallArgument(BaseModel): name: str value: Any class InferenceToolCall(BaseModel): + id: str name: str arguments: List[InferenceToolCallArgument] @@ -62,6 +63,7 @@ def openai_tool_calls_to_inference_tool_calls(openai_tool_calls: List[ChatComple arguments_dict = {} inference_tool_calls.append(InferenceToolCall( + id=openai_tool_call.id, name=openai_tool_call.function.name, arguments=[InferenceToolCallArgument(name=name, value=value) for name, value in arguments_dict.items()] )) @@ -96,6 +98,31 @@ class EmbeddingResult(BaseModel): class InferenceMessage(BaseModel): role: str content: str + tool_calls: Optional[List[InferenceToolCall]] = None + tool_call_id: Optional[str] = None + +def inference_message_to_openai_message(message: InferenceMessage) -> dict: + openai_message = { + "role": message.role, + "content": message.content + } + + if message.tool_calls: + openai_message["tool_calls"] = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.name, + "arguments": json.dumps({arg.name: arg.value for arg in tool_call.arguments}) + } + } for tool_call in message.tool_calls + ] + + if message.tool_call_id: + openai_message["tool_call_id"] = message.tool_call_id + + return openai_message diff --git a/inference_gateway/providers/chutes.py b/inference_gateway/providers/chutes.py index 8d8bc92b..3297ea79 100644 --- a/inference_gateway/providers/chutes.py +++ b/inference_gateway/providers/chutes.py @@ -7,7 +7,7 @@ from typing import List, Optional from openai import AsyncOpenAI, APIStatusError from inference_gateway.providers.provider import Provider -from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, EmbeddingModelPricingMode, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls +from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, EmbeddingModelPricingMode, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls, inference_message_to_openai_message @@ -153,7 +153,7 @@ async def _inference( chat_completion = await self.chutes_inference_client.chat.completions.create( model=model_info.external_name, temperature=temperature, - messages=messages, + messages=[inference_message_to_openai_message(message) for message in messages], tool_choice=inference_tool_mode_to_openai_tool_choice(tool_mode), tools=inference_tools_to_openai_tools(tools) if tools else None, stream=False diff --git a/inference_gateway/providers/targon.py b/inference_gateway/providers/targon.py index efd8edf6..4fc2fb1a 100644 --- a/inference_gateway/providers/targon.py +++ b/inference_gateway/providers/targon.py @@ -7,7 +7,7 @@ from typing import List, Optional from openai import AsyncOpenAI, APIStatusError from inference_gateway.providers.provider import Provider -from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls +from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls, inference_message_to_openai_message @@ -145,7 +145,7 @@ async def _inference( chat_completion = await self.targon_client.chat.completions.create( model=model_info.external_name, temperature=temperature, - messages=messages, + messages=[inference_message_to_openai_message(message) for message in messages], tool_choice=inference_tool_mode_to_openai_tool_choice(tool_mode), tools=inference_tools_to_openai_tools(tools) if tools else None, stream=False