|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class Memory:
|
30 |
| - def __init__(self, embedding_provider: str, model: str, **embdding_kwargs: Any): |
| 30 | + def __init__(self, embedding_provider: str, model: str, **embedding_kwargs: Any): |
31 | 31 | _embeddings = None
|
| 32 | + |
| 33 | + # Get base URL from kwargs or environment |
| 34 | + base_url = embedding_kwargs.pop('base_url', None) or os.environ.get('EMBEDDING_ENDPOINT') |
| 35 | + |
32 | 36 | match embedding_provider:
|
33 |
| - case "custom": |
| 37 | + case "custom" | "openai": |
34 | 38 | from langchain_openai import OpenAIEmbeddings
|
35 |
| - |
| 39 | + |
| 40 | + # For custom endpoints, use a dummy key if none provided |
| 41 | + api_key = os.getenv("OPENAI_API_KEY", "dummy") |
| 42 | + if embedding_provider == "custom" and not base_url: |
| 43 | + base_url = os.getenv("OPENAI_BASE_URL", "http://localhost:1234/v1") |
| 44 | + |
36 | 45 | _embeddings = OpenAIEmbeddings(
|
37 | 46 | model=model,
|
38 |
| - openai_api_key=os.getenv("OPENAI_API_KEY", "custom"), |
39 |
| - openai_api_base=os.getenv( |
40 |
| - "OPENAI_BASE_URL", "http://localhost:1234/v1" |
41 |
| - ), # default for lmstudio |
| 47 | + openai_api_key=api_key, |
| 48 | + openai_api_base=base_url, |
42 | 49 | check_embedding_ctx_length=False,
|
43 |
| - **embdding_kwargs, |
44 |
| - ) # quick fix for lmstudio |
45 |
| - case "openai": |
46 |
| - from langchain_openai import OpenAIEmbeddings |
47 |
| - |
48 |
| - _embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs) |
| 50 | + **embedding_kwargs, |
| 51 | + ) |
49 | 52 | case "azure_openai":
|
50 | 53 | from langchain_openai import AzureOpenAIEmbeddings
|
51 | 54 |
|
52 | 55 | _embeddings = AzureOpenAIEmbeddings(
|
53 | 56 | model=model,
|
54 |
| - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], |
55 |
| - openai_api_key=os.environ["AZURE_OPENAI_API_KEY"], |
56 |
| - openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"], |
57 |
| - **embdding_kwargs, |
| 57 | + openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15"), |
| 58 | + azure_deployment=model, |
| 59 | + **embedding_kwargs, |
58 | 60 | )
|
59 | 61 | case "cohere":
|
60 | 62 | from langchain_cohere import CohereEmbeddings
|
61 |
| - |
62 |
| - _embeddings = CohereEmbeddings(model=model, **embdding_kwargs) |
| 63 | + _embeddings = CohereEmbeddings(model=model, **embedding_kwargs) |
| 64 | + |
63 | 65 | case "google_vertexai":
|
64 | 66 | from langchain_google_vertexai import VertexAIEmbeddings
|
65 |
| - |
66 |
| - _embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs) |
| 67 | + _embeddings = VertexAIEmbeddings(model=model, **embedding_kwargs) |
| 68 | + |
67 | 69 | case "google_genai":
|
68 | 70 | from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
69 |
| - |
70 | 71 | _embeddings = GoogleGenerativeAIEmbeddings(
|
71 |
| - model=model, **embdding_kwargs |
| 72 | + model=model, |
| 73 | + **embedding_kwargs, |
72 | 74 | )
|
| 75 | + |
73 | 76 | case "fireworks":
|
74 | 77 | from langchain_fireworks import FireworksEmbeddings
|
75 |
| - |
76 |
| - _embeddings = FireworksEmbeddings(model=model, **embdding_kwargs) |
| 78 | + _embeddings = FireworksEmbeddings(model=model, **embedding_kwargs) |
| 79 | + |
77 | 80 | case "gigachat":
|
78 | 81 | from langchain_gigachat import GigaChatEmbeddings
|
79 |
| - |
80 |
| - _embeddings = GigaChatEmbeddings(model=model, **embdding_kwargs) |
| 82 | + _embeddings = GigaChatEmbeddings(model=model, **embedding_kwargs) |
| 83 | + |
81 | 84 | case "ollama":
|
82 | 85 | from langchain_ollama import OllamaEmbeddings
|
83 |
| - |
| 86 | + # Use provided base_url or fall back to environment variable |
| 87 | + ollama_base = base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") |
84 | 88 | _embeddings = OllamaEmbeddings(
|
85 | 89 | model=model,
|
86 |
| - base_url=os.environ["OLLAMA_BASE_URL"], |
87 |
| - **embdding_kwargs, |
| 90 | + base_url=ollama_base, |
| 91 | + **embedding_kwargs, |
88 | 92 | )
|
| 93 | + |
89 | 94 | case "together":
|
90 | 95 | from langchain_together import TogetherEmbeddings
|
91 |
| - |
92 |
| - _embeddings = TogetherEmbeddings(model=model, **embdding_kwargs) |
| 96 | + _embeddings = TogetherEmbeddings(model=model, **embedding_kwargs) |
| 97 | + |
93 | 98 | case "mistralai":
|
94 | 99 | from langchain_mistralai import MistralAIEmbeddings
|
95 |
| - |
96 |
| - _embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs) |
| 100 | + _embeddings = MistralAIEmbeddings(model=model, **embedding_kwargs) |
| 101 | + |
97 | 102 | case "huggingface":
|
98 | 103 | from langchain_huggingface import HuggingFaceEmbeddings
|
99 |
| - |
100 |
| - _embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs) |
| 104 | + _embeddings = HuggingFaceEmbeddings( |
| 105 | + model_name=model, **embedding_kwargs |
| 106 | + ) |
| 107 | + |
101 | 108 | case "nomic":
|
102 | 109 | from langchain_nomic import NomicEmbeddings
|
103 |
| - |
104 |
| - _embeddings = NomicEmbeddings(model=model, **embdding_kwargs) |
| 110 | + _embeddings = NomicEmbeddings(model=model, **embedding_kwargs) |
| 111 | + |
105 | 112 | case "voyageai":
|
106 | 113 | from langchain_voyageai import VoyageAIEmbeddings
|
107 |
| - |
108 | 114 | _embeddings = VoyageAIEmbeddings(
|
109 |
| - voyage_api_key=os.environ["VOYAGE_API_KEY"], |
110 | 115 | model=model,
|
111 |
| - **embdding_kwargs, |
| 116 | + voyage_api_key=os.getenv("VOYAGE_API_KEY"), |
| 117 | + **embedding_kwargs, |
112 | 118 | )
|
| 119 | + |
113 | 120 | case "dashscope":
|
114 | 121 | from langchain_community.embeddings import DashScopeEmbeddings
|
115 |
| - |
116 |
| - _embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs) |
| 122 | + _embeddings = DashScopeEmbeddings(model=model, **embedding_kwargs) |
| 123 | + |
117 | 124 | case "bedrock":
|
118 | 125 | from langchain_aws.embeddings import BedrockEmbeddings
|
119 |
| - |
120 |
| - _embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs) |
| 126 | + _embeddings = BedrockEmbeddings(model_id=model, **embedding_kwargs) |
| 127 | + |
121 | 128 | case "aimlapi":
|
122 | 129 | from langchain_openai import OpenAIEmbeddings
|
123 |
| - |
124 | 130 | _embeddings = OpenAIEmbeddings(
|
125 | 131 | model=model,
|
126 |
| - openai_api_key=os.getenv("AIMLAPI_API_KEY"), |
127 |
| - openai_api_base=os.getenv("AIMLAPI_BASE_URL", "https://api.aimlapi.com/v1"), |
128 |
| - **embdding_kwargs, |
| 132 | + openai_api_key=os.getenv("OPENAI_API_KEY", "custom"), |
| 133 | + openai_api_base=base_url or os.getenv("AIMLAPI_BASE_URL"), |
| 134 | + **embedding_kwargs, |
129 | 135 | )
|
130 | 136 | case _:
|
131 | 137 | raise Exception("Embedding not found.")
|
|
0 commit comments