Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions inference_gateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class InferenceToolCallArgument(BaseModel):
name: str
value: Any
class InferenceToolCall(BaseModel):
id: str
name: str
arguments: List[InferenceToolCallArgument]

Expand All @@ -62,6 +63,7 @@ def openai_tool_calls_to_inference_tool_calls(openai_tool_calls: List[ChatComple
arguments_dict = {}

inference_tool_calls.append(InferenceToolCall(
id=openai_tool_call.id,
name=openai_tool_call.function.name,
arguments=[InferenceToolCallArgument(name=name, value=value) for name, value in arguments_dict.items()]
))
Expand Down Expand Up @@ -96,6 +98,31 @@ class EmbeddingResult(BaseModel):
class InferenceMessage(BaseModel):
role: str
content: str
tool_calls: Optional[List[InferenceToolCall]] = None
tool_call_id: Optional[str] = None

def inference_message_to_openai_message(message: InferenceMessage) -> dict:
openai_message = {
"role": message.role,
"content": message.content
}

if message.tool_calls:
openai_message["tool_calls"] = [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": json.dumps({arg.name: arg.value for arg in tool_call.arguments})
}
} for tool_call in message.tool_calls
]

if message.tool_call_id:
openai_message["tool_call_id"] = message.tool_call_id

return openai_message



Expand Down
4 changes: 2 additions & 2 deletions inference_gateway/providers/chutes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional
from openai import AsyncOpenAI, APIStatusError
from inference_gateway.providers.provider import Provider
from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, EmbeddingModelPricingMode, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls
from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, EmbeddingModelPricingMode, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls, inference_message_to_openai_message



Expand Down Expand Up @@ -153,7 +153,7 @@ async def _inference(
chat_completion = await self.chutes_inference_client.chat.completions.create(
model=model_info.external_name,
temperature=temperature,
messages=messages,
messages=[inference_message_to_openai_message(message) for message in messages],
tool_choice=inference_tool_mode_to_openai_tool_choice(tool_mode),
tools=inference_tools_to_openai_tools(tools) if tools else None,
stream=False
Expand Down
4 changes: 2 additions & 2 deletions inference_gateway/providers/targon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional
from openai import AsyncOpenAI, APIStatusError
from inference_gateway.providers.provider import Provider
from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls
from inference_gateway.models import InferenceTool, EmbeddingResult, InferenceResult, InferenceMessage, InferenceToolMode, EmbeddingModelInfo, InferenceModelInfo, inference_tools_to_openai_tools, inference_tool_mode_to_openai_tool_choice, openai_tool_calls_to_inference_tool_calls, inference_message_to_openai_message



Expand Down Expand Up @@ -145,7 +145,7 @@ async def _inference(
chat_completion = await self.targon_client.chat.completions.create(
model=model_info.external_name,
temperature=temperature,
messages=messages,
messages=[inference_message_to_openai_message(message) for message in messages],
tool_choice=inference_tool_mode_to_openai_tool_choice(tool_mode),
tools=inference_tools_to_openai_tools(tools) if tools else None,
stream=False
Expand Down