Skip to content

Commit add322c

Browse files
authored
feat: Add Gemini support for language models (#3632)
Introduce support for the Gemini model in the language model configuration, allowing users to utilize Gemini alongside existing models. Update the configuration to include Gemini-specific settings and ensure compatibility with the overall architecture.
1 parent 4f0fb6f commit add322c

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

core/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ dependencies = [
2323
"markupsafe>=2.1.5",
2424
"megaparse-sdk>=0.1.11",
2525
"langchain-mistralai>=0.2.3",
26+
"langchain-google-genai>=2.1.3",
27+
"langchain-xai>=0.2.3",
2628
"fasttext-langdetect>=1.0.5",
2729
"langfuse>=2.57.0",
2830
]

core/quivr_core/llm/llm_endpoint.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import tiktoken
88
from langchain_anthropic import ChatAnthropic
99
from langchain_core.language_models.chat_models import BaseChatModel
10+
from langchain_google_genai import ChatGoogleGenerativeAI
1011
from langchain_mistralai import ChatMistralAI
1112
from langchain_openai import AzureChatOpenAI, ChatOpenAI
13+
from langchain_xai import ChatXAI
1214
from pydantic import SecretStr
1315

1416
from quivr_core.brain.info import LLMInfo
@@ -206,7 +208,14 @@ def get_config(self):
206208

207209
@classmethod
208210
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
209-
_llm: Union[AzureChatOpenAI, ChatOpenAI, ChatAnthropic, ChatMistralAI]
211+
_llm: Union[
212+
AzureChatOpenAI,
213+
ChatOpenAI,
214+
ChatAnthropic,
215+
ChatMistralAI,
216+
ChatGoogleGenerativeAI,
217+
ChatXAI,
218+
]
210219
try:
211220
if config.supplier == DefaultModelSuppliers.AZURE:
212221
# Parse the URL
@@ -255,6 +264,27 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
255264
base_url=config.llm_base_url,
256265
temperature=config.temperature,
257266
)
267+
elif config.supplier == DefaultModelSuppliers.GEMINI:
268+
_llm = ChatGoogleGenerativeAI(
269+
model=config.model,
270+
api_key=SecretStr(config.llm_api_key)
271+
if config.llm_api_key
272+
else None,
273+
base_url=config.llm_base_url,
274+
max_tokens=config.max_output_tokens,
275+
temperature=config.temperature,
276+
)
277+
elif config.supplier == DefaultModelSuppliers.GROQ:
278+
_llm = ChatXAI(
279+
model=config.model,
280+
api_key=SecretStr(config.llm_api_key)
281+
if config.llm_api_key
282+
else None,
283+
base_url=config.llm_base_url,
284+
max_tokens=config.max_output_tokens,
285+
temperature=config.temperature,
286+
)
287+
258288
else:
259289
_llm = ChatOpenAI(
260290
model=config.model,

core/quivr_core/rag/entities/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class DefaultModelSuppliers(str, Enum):
7272
META = "meta"
7373
MISTRAL = "mistral"
7474
GROQ = "groq"
75+
GEMINI = "gemini"
7576

7677

7778
class LLMConfig(QuivrBaseConfig):
@@ -98,6 +99,11 @@ class LLMModelConfig:
9899
max_output_tokens=100000,
99100
tokenizer_hub="Quivr/gpt-4o",
100101
),
102+
"o4-mini": LLMConfig(
103+
max_context_tokens=200000,
104+
max_output_tokens=100000,
105+
tokenizer_hub="Quivr/gpt-4o",
106+
),
101107
"o1-mini": LLMConfig(
102108
max_context_tokens=128000,
103109
max_output_tokens=65536,
@@ -139,6 +145,11 @@ class LLMModelConfig:
139145
),
140146
},
141147
DefaultModelSuppliers.ANTHROPIC: {
148+
"claude-3-7-sonnet": LLMConfig(
149+
max_context_tokens=200000,
150+
max_output_tokens=8192,
151+
tokenizer_hub="Quivr/claude-tokenizer",
152+
),
142153
"claude-3-5-sonnet": LLMConfig(
143154
max_context_tokens=200000,
144155
max_output_tokens=8192,
@@ -209,6 +220,16 @@ class LLMModelConfig:
209220
"code-llama": LLMConfig(
210221
max_context_tokens=16384, tokenizer_hub="Quivr/llama-code-tokenizer"
211222
),
223+
"deepseek-r1-distill-llama-70b": LLMConfig(
224+
max_context_tokens=128000,
225+
max_output_tokens=32768,
226+
tokenizer_hub="Quivr/Meta-Llama-3.1-Tokenizer",
227+
),
228+
"meta-llama/llama-4-maverick-17b-128e-instruct": LLMConfig(
229+
max_context_tokens=128000,
230+
max_output_tokens=32768,
231+
tokenizer_hub="Quivr/Meta-Llama-3.1-Tokenizer",
232+
),
212233
},
213234
DefaultModelSuppliers.MISTRAL: {
214235
"mistral-large": LLMConfig(
@@ -230,6 +251,13 @@ class LLMModelConfig:
230251
max_context_tokens=32000, tokenizer_hub="Quivr/mistral-tokenizer-v3"
231252
),
232253
},
254+
DefaultModelSuppliers.GEMINI: {
255+
"gemini-2.5": LLMConfig(
256+
max_context_tokens=128000,
257+
max_output_tokens=4096,
258+
tokenizer_hub="Quivr/gemini-tokenizer",
259+
),
260+
},
233261
}
234262

235263
@classmethod

core/quivr_core/rag/quivr_rag_langgraph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ def generate_zendesk_rag(self, state: AgentState) -> AgentState:
945945

946946
msg = prompt_template.format_prompt(**inputs)
947947
llm = self.bind_tools_to_llm(self.generate_zendesk_rag.__name__)
948+
948949
response = llm.invoke(msg)
949950

950951
return {**state, "messages": [response]}

0 commit comments

Comments
 (0)