Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/api/model_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from typing import Any

from python.helpers import settings
from python.helpers.settings import ModelGroupConfig


# Fields needed in model group config
MODEL_GROUP_FIELDS = [
"chat_model_provider", "chat_model_name", "chat_model_api_base",
"util_model_provider", "util_model_name", "util_model_api_base",
"browser_model_provider", "browser_model_name", "browser_model_api_base",
"embed_model_provider", "embed_model_name", "embed_model_api_base",
]


def create_model_group_from_current() -> ModelGroupConfig:
"""Create model group config from current settings"""
current = settings.get_settings()
return ModelGroupConfig(
id=str(uuid.uuid4()),
name="",
description="",
created_at=datetime.now().isoformat(),
chat_model_provider=current["chat_model_provider"],
chat_model_name=current["chat_model_name"],
chat_model_api_base=current["chat_model_api_base"],
util_model_provider=current["util_model_provider"],
util_model_name=current["util_model_name"],
util_model_api_base=current["util_model_api_base"],
browser_model_provider=current["browser_model_provider"],
browser_model_name=current["browser_model_name"],
browser_model_api_base=current["browser_model_api_base"],
embed_model_provider=current["embed_model_provider"],
embed_model_name=current["embed_model_name"],
embed_model_api_base=current["embed_model_api_base"],
)


def apply_model_group_to_settings(group: ModelGroupConfig) -> None:
"""Apply model group config to current settings"""
current = settings.get_settings()
for field in MODEL_GROUP_FIELDS:
if field in group:
current[field] = group[field]
current["active_model_group_id"] = group["id"]
settings.set_settings(current)
42 changes: 42 additions & 0 deletions python/api/model_groups_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import uuid
from datetime import datetime
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings
from python.helpers.settings import ModelGroupConfig


class ModelGroupsCreate(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
name = input.get("name", "").strip()
description = input.get("description", "").strip()

if not name:
return {"ok": False, "error": "Name is required"}

# Create new model group
group = ModelGroupConfig(
id=str(uuid.uuid4()),
name=name,
description=description,
created_at=datetime.now().isoformat(),
chat_model_provider=input.get("chat_model_provider", ""),
chat_model_name=input.get("chat_model_name", ""),
chat_model_api_base=input.get("chat_model_api_base", ""),
util_model_provider=input.get("util_model_provider", ""),
util_model_name=input.get("util_model_name", ""),
util_model_api_base=input.get("util_model_api_base", ""),
browser_model_provider=input.get("browser_model_provider", ""),
browser_model_name=input.get("browser_model_name", ""),
browser_model_api_base=input.get("browser_model_api_base", ""),
embed_model_provider=input.get("embed_model_provider", ""),
embed_model_name=input.get("embed_model_name", ""),
embed_model_api_base=input.get("embed_model_api_base", ""),
)

current = settings.get_settings()
groups = current.get("model_groups", [])
groups.append(group)
current["model_groups"] = groups
settings.set_settings(current, apply=False)

return {"ok": True, "group": group}
27 changes: 27 additions & 0 deletions python/api/model_groups_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings


class ModelGroupsDelete(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
group_id = input.get("group_id")
if not group_id:
return {"ok": False, "error": "Missing group_id"}

current = settings.get_settings()
groups = current.get("model_groups", [])

new_groups = [g for g in groups if g["id"] != group_id]

if len(new_groups) == len(groups):
return {"ok": False, "error": "Model group not found"}

current["model_groups"] = new_groups

# If deleted group was active, clear active state
if current.get("active_model_group_id") == group_id:
current["active_model_group_id"] = ""

settings.set_settings(current, apply=False)

return {"ok": True, "message": "Model group deleted"}
36 changes: 36 additions & 0 deletions python/api/model_groups_duplicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import uuid
import copy
from datetime import datetime
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings


class ModelGroupsDuplicate(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
group_id = input.get("group_id")
if not group_id:
return {"ok": False, "error": "Missing group_id"}

current = settings.get_settings()
groups = current.get("model_groups", [])

source_group = None
for group in groups:
if group["id"] == group_id:
source_group = group
break

if not source_group:
return {"ok": False, "error": "Model group not found"}

# Create copy
new_group = copy.deepcopy(source_group)
new_group["id"] = str(uuid.uuid4())
new_group["name"] = f"{source_group['name']} (Copy)"
new_group["created_at"] = datetime.now().isoformat()

groups.append(new_group)
current["model_groups"] = groups
settings.set_settings(current, apply=False)

return {"ok": True, "group": new_group}
15 changes: 15 additions & 0 deletions python/api/model_groups_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings


class ModelGroupsList(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
current = settings.get_settings()
groups = current.get("model_groups", [])
active_group_id = current.get("active_model_group_id", "")

return {
"ok": True,
"groups": groups,
"active_group_id": active_group_id
}
18 changes: 18 additions & 0 deletions python/api/model_groups_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers.providers import get_providers


class ModelGroupsProviders(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
chat_providers = get_providers("chat")
embed_providers = get_providers("embedding")

return {
"ok": True,
"chat_providers": chat_providers,
"embed_providers": embed_providers,
}

@classmethod
def get_methods(cls) -> list[str]:
return ["GET", "POST"]
26 changes: 26 additions & 0 deletions python/api/model_groups_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings
from python.api.model_groups import create_model_group_from_current


class ModelGroupsSave(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
name = input.get("name", "").strip()
description = input.get("description", "").strip()

if not name:
return {"ok": False, "error": "Name is required"}

# Create model group from current settings
group = create_model_group_from_current()
group["name"] = name
group["description"] = description

current = settings.get_settings()
groups = current.get("model_groups", [])
groups.append(group)
current["model_groups"] = groups
current["active_model_group_id"] = group["id"]
settings.set_settings(current, apply=False)

return {"ok": True, "group": group}
37 changes: 37 additions & 0 deletions python/api/model_groups_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings
from python.api.model_groups import MODEL_GROUP_FIELDS


class ModelGroupsSwitch(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
group_id = input.get("group_id") # Can be empty or null, meaning switch to default/manual config

current = settings.get_settings()

if not group_id:
# Switch to default config (clear model group activation)
current["active_model_group_id"] = ""
settings.set_settings(current, apply=False)
return {"ok": True, "message": "Switched to default configuration"}

# Find model group
groups = current.get("model_groups", [])
target_group = None
for group in groups:
if group["id"] == group_id:
target_group = group
break

if not target_group:
return {"ok": False, "error": "Model group not found"}

# Apply model group config
for field in MODEL_GROUP_FIELDS:
if field in target_group:
current[field] = target_group[field]

current["active_model_group_id"] = group_id
settings.set_settings(current) # apply=True triggers model reinitialization

return {"ok": True, "message": f"Switched to model group: {target_group['name']}"}
42 changes: 42 additions & 0 deletions python/api/model_groups_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import settings
from python.api.model_groups import MODEL_GROUP_FIELDS


class ModelGroupsUpdate(ApiHandler):
async def process(self, input: dict, request: Request) -> dict | Response:
group_id = input.get("group_id")
if not group_id:
return {"ok": False, "error": "Missing group_id"}

current = settings.get_settings()
groups = current.get("model_groups", [])
active_group_id = current.get("active_model_group_id", "")

for i, group in enumerate(groups):
if group["id"] == group_id:
# Update fields
if "name" in input:
group["name"] = input["name"].strip()
if "description" in input:
group["description"] = input["description"].strip()

for field in MODEL_GROUP_FIELDS:
if field in input:
group[field] = input[field]

groups[i] = group
current["model_groups"] = groups

# If this is the active group, also update current settings
if group_id == active_group_id:
for field in MODEL_GROUP_FIELDS:
if field in group:
current[field] = group[field]
settings.set_settings(current, apply=True)
else:
settings.set_settings(current, apply=False)

return {"ok": True, "group": group}

return {"ok": False, "error": "Model group not found"}
33 changes: 33 additions & 0 deletions python/helpers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def get_default_value(name: str, value: T) -> T:
)
return value

class ModelGroupConfig(TypedDict):
"""Model group configuration, including provider, name, and api_base for four types of models"""
id: str # Unique ID (UUID)
name: str # Display name
description: str # Description (optional)
created_at: str # Creation time

# Chat Model
chat_model_provider: str
chat_model_name: str
chat_model_api_base: str

# Utility Model
util_model_provider: str
util_model_name: str
util_model_api_base: str

# Browser Model
browser_model_provider: str
browser_model_name: str
browser_model_api_base: str

# Embedding Model
embed_model_provider: str
embed_model_name: str
embed_model_api_base: str


class Settings(TypedDict):
version: str
Expand Down Expand Up @@ -148,6 +175,10 @@ class Settings(TypedDict):

update_check_enabled: bool

# Model groups
model_groups: list[ModelGroupConfig]
active_model_group_id: str

class PartialSettings(Settings, total=False):
pass

Expand Down Expand Up @@ -1568,6 +1599,8 @@ def get_default_settings() -> Settings:
secrets="",
litellm_global_kwargs=get_default_value("litellm_global_kwargs", {}),
update_check_enabled=get_default_value("update_check_enabled", True),
model_groups=get_default_value("model_groups", []),
active_model_group_id=get_default_value("active_model_group_id", ""),
)


Expand Down
6 changes: 6 additions & 0 deletions webui/components/chat/top-section/chat-top.html
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<!-- Import the alpine store -->
<script type="module">
import { store } from "/components/chat/top-section/chat-top-store.js";
import { store as modelGroupsStore } from "/components/model-groups/model-group-store.js";
</script>
</head>

Expand Down Expand Up @@ -33,11 +34,16 @@
<x-component path="notifications/notification-icons.html"></x-component>
<!-- Project Selector -->
<x-component path="projects/project-selector.html"></x-component>
<!-- Model Groups Switcher -->
<x-component path="model-groups/model-group-switcher.html"></x-component>
</div>

</template>
</div>

<!-- Model Groups Modal -->
<x-component path="model-groups/model-group-modal.html"></x-component>

</body>

</html>
Loading