Skip to content

Commit c0661c2

Browse files
committed
Merge remote-tracking branch 'origin/main' into 98-shift-to-litellm
2 parents a9e5be2 + 3ea31c1 commit c0661c2

File tree

5 files changed

+42
-15
lines changed

5 files changed

+42
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ spring_ai/drop.sql
6767
src/client/spring_ai/target/classes/*
6868
api_server_key
6969
.env
70+

src/client/content/config/tabs/settings.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def spring_ai_conf_check(ll_model: dict, embed_model: dict) -> str:
159159

160160
ll_provider = ll_model.get("provider", "")
161161
embed_provider = embed_model.get("provider", "")
162-
162+
logger.info(f"llm chat:{ll_provider} - embeddings:{embed_provider}")
163+
if all("openai_compatible" in p for p in (ll_provider, embed_provider)):
164+
return "openai_compatible"
163165
if all("openai" in p for p in (ll_provider, embed_provider)):
164166
return "openai"
165167
if all("ollama" in p for p in (ll_provider, embed_provider)):
@@ -342,6 +344,8 @@ def display_settings():
342344
embed_config = {}
343345
spring_ai_conf = spring_ai_conf_check(ll_config, embed_config)
344346

347+
logger.info(f"config found:{spring_ai_conf}")
348+
345349
if spring_ai_conf == "hybrid":
346350
st.markdown(f"""
347351
The current configuration combination of embedding and language models
@@ -352,21 +356,23 @@ def display_settings():
352356
else:
353357
col_left, col_centre, _ = st.columns([3, 4, 3])
354358
with col_left:
355-
st.download_button(
356-
label="Download SpringAI",
357-
data=spring_ai_zip(spring_ai_conf, ll_config, embed_config), # Generate zip on the fly
358-
file_name="spring_ai.zip", # Zip file name
359-
mime="application/zip", # Mime type for zip file
360-
disabled=spring_ai_conf == "hybrid",
361-
)
362-
with col_centre:
363359
st.download_button(
364360
label="Download LangchainMCP",
365361
data=langchain_mcp_zip(settings), # Generate zip on the fly
366362
file_name="langchain_mcp.zip", # Zip file name
367363
mime="application/zip", # Mime type for zip file
368364
disabled=spring_ai_conf == "hybrid",
369365
)
366+
with col_centre:
367+
if (spring_ai_conf != "openai_compatible"):
368+
st.download_button(
369+
label="Download SpringAI",
370+
data=spring_ai_zip(spring_ai_conf, ll_config, embed_config), # Generate zip on the fly
371+
file_name="spring_ai.zip", # Zip file name
372+
mime="application/zip", # Mime type for zip file
373+
disabled=spring_ai_conf == "hybrid",
374+
)
375+
370376

371377

372378
if __name__ == "__main__":

src/client/mcp/rag/optimizer_utils/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def get_llm(data):
3939
llm = OllamaLLM(model=model, base_url=url)
4040
logger.info("Ollama LLM created")
4141
elif provider == "openai":
42-
llm = llm = ChatOpenAI(model=model, api_key=api_key)
42+
llm = ChatOpenAI(model=model, api_key=api_key)
4343
logger.info("OpenAI LLM created")
44+
elif provider =="openai_compatible":
45+
llm = ChatOpenAI(model=model, api_key=api_key,base_url=url)
46+
logger.info("OpenAI compatible LLM created")
4447
return llm
4548

4649

@@ -60,9 +63,13 @@ def get_embeddings(data):
6063
if provider == "ollama":
6164
embeddings = OllamaEmbeddings(model=model, base_url=url)
6265
logger.info("Ollama Embeddings connection successful")
63-
elif (provider == "openai") or (provider == "openai_compatible"):
66+
elif (provider == "openai"):
6467
embeddings = OpenAIEmbeddings(model=model, api_key=api_key)
6568
logger.info("OpenAI embeddings connection successful")
69+
elif (provider == "openai_compatible"):
70+
embeddings = OpenAIEmbeddings(model=model, api_key=api_key,base_url=url,check_embedding_ctx_length=False)
71+
logger.info("OpenAI compatible embeddings connection successful")
72+
6673
return embeddings
6774

6875

@@ -80,7 +87,7 @@ def get_vectorstore(data, embeddings):
8087
distance_metric=data["client_settings"]["vector_search"]["distance_metric"]
8188
index_type=data["client_settings"]["vector_search"]["index_type"]
8289

83-
db_table=(table_alias+"_"+model+"_"+chunk_size+"_"+chunk_overlap+"_"+distance_metric+"_"+index_type).upper().replace("-", "_")
90+
db_table=(table_alias+"_"+model+"_"+chunk_size+"_"+chunk_overlap+"_"+distance_metric+"_"+index_type).upper().replace("-", "_").replace("/", "_")
8491
logger.info(f"db_table:{db_table}")
8592

8693

src/client/mcp/rag/optimizer_utils/rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def rag_tool_base(question: str) -> str:
5454

5555
logger.info("rag_prompt:")
5656
logger.info(rag_prompt)
57-
template = """DOCUMENTS: {context} \n"""+rag_prompt+"""\nQuestion: {question} """
57+
template = rag_prompt+"""\n# DOCUMENTS :\n {context} \n"""+"""\n # Question: {question} """
5858
logger.info(template)
5959
logger.info(f"user_question: {user_question}")
6060
prompt = PromptTemplate.from_template(template)

src/server/bootstrap/models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
added via the APIs
77
"""
88
# spell-checker:ignore configfile genai ollama pplx docos mxbai nomic thenlper
9-
# spell-checker:ignore huggingface
9+
# spell-checker:ignore huggingface vllm
1010

1111
import os
1212

@@ -97,6 +97,18 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str):
9797
"max_completion_tokens": 2048,
9898
"frequency_penalty": 0.0,
9999
},
100+
{
101+
"id": "Llama-3.2-1B-Instruct",
102+
"enabled": os.getenv("ON_PREM_VLLM_URL") is not None,
103+
"type": "ll",
104+
"provider": "meta-llama",
105+
"api_key": "",
106+
"url": os.environ.get("ON_PREM_VLLM_URL", default="http://gpu:8000/v1"),
107+
"context_length": 131072,
108+
"temperature": 1.0,
109+
"max_completion_tokens": 2048,
110+
"frequency_penalty": 0.0,
111+
},
100112
{
101113
# This is intentionally last to line up with docos
102114
"id": "llama3.1",
@@ -138,7 +150,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str):
138150
"max_chunk_size": 512,
139151
},
140152
{
141-
"id": "text-embedding-nomic-embed-text-v1.5",
153+
"id": "nomic-ai/nomic-embed-text-v1",
142154
"enabled": False,
143155
"type": "embed",
144156
"provider": "huggingface",
@@ -212,6 +224,7 @@ def values_differ(a, b):
212224
update_env_var(model, "oci", "api_base", "OCI_GENAI_SERVICE_ENDPOINT")
213225
update_env_var(model, "ollama", "api_base", "ON_PREM_OLLAMA_URL")
214226
update_env_var(model, "huggingface", "api_base", "ON_PREM_HF_URL")
227+
update_env_var(model, "meta-llama", "api_base", "ON_PREM_VLLM_URL")
215228

216229
# Check URL accessible for enabled models and disable if not:
217230
url_access_cache = {}

0 commit comments

Comments
 (0)