Skip to content

Commit 27c9f95

Browse files
feat: Add support for separate LLM and embedding model endpoints
- Added LLM_ENDPOINT and EMBEDDING_ENDPOINT configuration options - Updated GenericLLMProvider to handle custom base URLs - Enhanced embedding initialization to use separate endpoints - Improved configuration handling for both LLM and embedding providers
1 parent 9f1a708 commit 27c9f95

File tree

5 files changed

+100
-54
lines changed

5 files changed

+100
-54
lines changed

gpt_researcher/actions/agent_creator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ async def choose_agent(
3232

3333
try:
3434
response = await create_chat_completion(
35-
model=cfg.smart_llm_model,
35+
model=cfg.fast_llm_model,
3636
messages=[
3737
{"role": "system", "content": f"{prompt_family.auto_agent_instructions()}"},
3838
{"role": "user", "content": f"task: {query}"},
3939
],
4040
temperature=0.15,
41-
llm_provider=cfg.smart_llm_provider,
41+
llm_provider=cfg.fast_llm_provider,
4242
llm_kwargs=cfg.llm_kwargs,
4343
cost_callback=cost_callback,
4444
**kwargs

gpt_researcher/config/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,29 @@ def _set_llm_attributes(self) -> None:
6262
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm)
6363
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm)
6464
self.reasoning_effort = self.parse_reasoning_effort(os.getenv("REASONING_EFFORT"))
65+
66+
# Set base URLs for LLM and embedding endpoints
67+
self.llm_base_url = getattr(self, 'llm_endpoint', 'http://localhost:8080/v1')
68+
self.embedding_base_url = getattr(self, 'embedding_endpoint', 'http://localhost:8081/v1')
69+
70+
# Update LLM kwargs with the appropriate base URL
71+
if not self.llm_kwargs.get('base_url'):
72+
self.llm_kwargs['base_url'] = self.llm_base_url
73+
74+
# Update embedding kwargs with the appropriate base URL
75+
if not self.embedding_kwargs.get('base_url'):
76+
self.embedding_kwargs['base_url'] = self.embedding_base_url
6577

6678
def _handle_deprecated_attributes(self) -> None:
79+
# Handle environment variables for endpoints
80+
if os.getenv("LLM_ENDPOINT"):
81+
self.llm_base_url = os.environ["LLM_ENDPOINT"]
82+
self.llm_kwargs['base_url'] = self.llm_base_url
83+
84+
if os.getenv("EMBEDDING_ENDPOINT"):
85+
self.embedding_base_url = os.environ["EMBEDDING_ENDPOINT"]
86+
self.embedding_kwargs['base_url'] = self.embedding_base_url
87+
6788
if os.getenv("EMBEDDING_PROVIDER") is not None:
6889
warnings.warn(
6990
"EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.",

gpt_researcher/config/variables/default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@
4242
"MCP_ALLOWED_ROOT_PATHS": [], # List of allowed root paths for local file access
4343
"MCP_STRATEGY": "fast", # MCP execution strategy: "fast", "deep", "disabled"
4444
"REASONING_EFFORT": "medium",
45+
"LLM_ENDPOINT": "http://localhost:8080/v1", # LLM endpoint
46+
"EMBEDDING_ENDPOINT": "http://localhost:8081/v1", # Embedding endpoint
4547
}

gpt_researcher/llm_provider/generic/base.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,22 @@ def __init__(self, llm, chat_log: str | None = None, verbose: bool = True):
9090
self.verbose = verbose
9191
@classmethod
9292
def from_provider(cls, provider: str, chat_log: str | None = None, verbose: bool=True, **kwargs: Any):
93+
# Get the appropriate base URL from kwargs or environment
94+
base_url = kwargs.pop('base_url', None) or os.environ.get(f"{provider.upper()}_BASE_URL")
95+
96+
if base_url:
97+
kwargs['base_url'] = base_url
98+
if verbose:
99+
print(f"Using {provider} endpoint: {base_url}")
100+
93101
if provider == "openai":
94102
_check_pkg("langchain_openai")
95103
from langchain_openai import ChatOpenAI
96-
104+
105+
# Handle custom endpoints for OpenAI-compatible APIs
106+
if 'base_url' in kwargs and 'api_key' not in kwargs:
107+
kwargs['api_key'] = 'dummy' # Some APIs don't require a key
108+
97109
llm = ChatOpenAI(**kwargs)
98110
elif provider == "anthropic":
99111
_check_pkg("langchain_anthropic")
@@ -133,8 +145,13 @@ def from_provider(cls, provider: str, chat_log: str | None = None, verbose: bool
133145
_check_pkg("langchain_community")
134146
_check_pkg("langchain_ollama")
135147
from langchain_ollama import ChatOllama
136-
137-
llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs)
148+
149+
# Use provided base_url or fall back to environment variable
150+
base_url = kwargs.pop('base_url', os.environ.get("OLLAMA_BASE_URL"))
151+
if base_url:
152+
kwargs['base_url'] = base_url
153+
154+
llm = ChatOllama(**kwargs)
138155
elif provider == "together":
139156
_check_pkg("langchain_together")
140157
from langchain_together import ChatTogether

gpt_researcher/memory/embeddings.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,105 +27,111 @@
2727

2828

2929
class Memory:
30-
def __init__(self, embedding_provider: str, model: str, **embdding_kwargs: Any):
30+
def __init__(self, embedding_provider: str, model: str, **embedding_kwargs: Any):
3131
_embeddings = None
32+
33+
# Get base URL from kwargs or environment
34+
base_url = embedding_kwargs.pop('base_url', None) or os.environ.get('EMBEDDING_ENDPOINT')
35+
3236
match embedding_provider:
33-
case "custom":
37+
case "custom" | "openai":
3438
from langchain_openai import OpenAIEmbeddings
35-
39+
40+
# For custom endpoints, use a dummy key if none provided
41+
api_key = os.getenv("OPENAI_API_KEY", "dummy")
42+
if embedding_provider == "custom" and not base_url:
43+
base_url = os.getenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
44+
3645
_embeddings = OpenAIEmbeddings(
3746
model=model,
38-
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
39-
openai_api_base=os.getenv(
40-
"OPENAI_BASE_URL", "http://localhost:1234/v1"
41-
), # default for lmstudio
47+
openai_api_key=api_key,
48+
openai_api_base=base_url,
4249
check_embedding_ctx_length=False,
43-
**embdding_kwargs,
44-
) # quick fix for lmstudio
45-
case "openai":
46-
from langchain_openai import OpenAIEmbeddings
47-
48-
_embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs)
50+
**embedding_kwargs,
51+
)
4952
case "azure_openai":
5053
from langchain_openai import AzureOpenAIEmbeddings
5154

5255
_embeddings = AzureOpenAIEmbeddings(
5356
model=model,
54-
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
55-
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
56-
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
57-
**embdding_kwargs,
57+
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15"),
58+
azure_deployment=model,
59+
**embedding_kwargs,
5860
)
5961
case "cohere":
6062
from langchain_cohere import CohereEmbeddings
61-
62-
_embeddings = CohereEmbeddings(model=model, **embdding_kwargs)
63+
_embeddings = CohereEmbeddings(model=model, **embedding_kwargs)
64+
6365
case "google_vertexai":
6466
from langchain_google_vertexai import VertexAIEmbeddings
65-
66-
_embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs)
67+
_embeddings = VertexAIEmbeddings(model=model, **embedding_kwargs)
68+
6769
case "google_genai":
6870
from langchain_google_genai import GoogleGenerativeAIEmbeddings
69-
7071
_embeddings = GoogleGenerativeAIEmbeddings(
71-
model=model, **embdding_kwargs
72+
model=model,
73+
**embedding_kwargs,
7274
)
75+
7376
case "fireworks":
7477
from langchain_fireworks import FireworksEmbeddings
75-
76-
_embeddings = FireworksEmbeddings(model=model, **embdding_kwargs)
78+
_embeddings = FireworksEmbeddings(model=model, **embedding_kwargs)
79+
7780
case "gigachat":
7881
from langchain_gigachat import GigaChatEmbeddings
79-
80-
_embeddings = GigaChatEmbeddings(model=model, **embdding_kwargs)
82+
_embeddings = GigaChatEmbeddings(model=model, **embedding_kwargs)
83+
8184
case "ollama":
8285
from langchain_ollama import OllamaEmbeddings
83-
86+
# Use provided base_url or fall back to environment variable
87+
ollama_base = base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
8488
_embeddings = OllamaEmbeddings(
8589
model=model,
86-
base_url=os.environ["OLLAMA_BASE_URL"],
87-
**embdding_kwargs,
90+
base_url=ollama_base,
91+
**embedding_kwargs,
8892
)
93+
8994
case "together":
9095
from langchain_together import TogetherEmbeddings
91-
92-
_embeddings = TogetherEmbeddings(model=model, **embdding_kwargs)
96+
_embeddings = TogetherEmbeddings(model=model, **embedding_kwargs)
97+
9398
case "mistralai":
9499
from langchain_mistralai import MistralAIEmbeddings
95-
96-
_embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs)
100+
_embeddings = MistralAIEmbeddings(model=model, **embedding_kwargs)
101+
97102
case "huggingface":
98103
from langchain_huggingface import HuggingFaceEmbeddings
99-
100-
_embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs)
104+
_embeddings = HuggingFaceEmbeddings(
105+
model_name=model, **embedding_kwargs
106+
)
107+
101108
case "nomic":
102109
from langchain_nomic import NomicEmbeddings
103-
104-
_embeddings = NomicEmbeddings(model=model, **embdding_kwargs)
110+
_embeddings = NomicEmbeddings(model=model, **embedding_kwargs)
111+
105112
case "voyageai":
106113
from langchain_voyageai import VoyageAIEmbeddings
107-
108114
_embeddings = VoyageAIEmbeddings(
109-
voyage_api_key=os.environ["VOYAGE_API_KEY"],
110115
model=model,
111-
**embdding_kwargs,
116+
voyage_api_key=os.getenv("VOYAGE_API_KEY"),
117+
**embedding_kwargs,
112118
)
119+
113120
case "dashscope":
114121
from langchain_community.embeddings import DashScopeEmbeddings
115-
116-
_embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs)
122+
_embeddings = DashScopeEmbeddings(model=model, **embedding_kwargs)
123+
117124
case "bedrock":
118125
from langchain_aws.embeddings import BedrockEmbeddings
119-
120-
_embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs)
126+
_embeddings = BedrockEmbeddings(model_id=model, **embedding_kwargs)
127+
121128
case "aimlapi":
122129
from langchain_openai import OpenAIEmbeddings
123-
124130
_embeddings = OpenAIEmbeddings(
125131
model=model,
126-
openai_api_key=os.getenv("AIMLAPI_API_KEY"),
127-
openai_api_base=os.getenv("AIMLAPI_BASE_URL", "https://api.aimlapi.com/v1"),
128-
**embdding_kwargs,
132+
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
133+
openai_api_base=base_url or os.getenv("AIMLAPI_BASE_URL"),
134+
**embedding_kwargs,
129135
)
130136
case _:
131137
raise Exception("Embedding not found.")

0 commit comments

Comments
 (0)