Skip to content

Commit 953e3cf

Browse files
committed
Added litellm, models handler
1 parent 036c734 commit 953e3cf

File tree

5 files changed

+87
-3
lines changed

5 files changed

+87
-3
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
2+
from pydantic import BaseModel
3+
from tornado import web
4+
5+
from .model_list import CHAT_MODELS
6+
7+
8+
class ChatModelsRestAPI(BaseAPIHandler):
9+
"""
10+
A Tornado handler that defines the REST API served on the
11+
`/api/ai/models/chat` endpoint.
12+
13+
- `GET /api/ai/models/chat`: returns list of all chat models.
14+
15+
- `GET /api/ai/models/chat?id=<model_id>`: returns info on that model (TODO)
16+
"""
17+
18+
@web.authenticated
19+
def get(self):
20+
response = ListChatModelsResponse(chat_models=CHAT_MODELS)
21+
self.finish(response.model_dump_json())
22+
23+
24+
class ListChatModelsResponse(BaseModel):
25+
chat_models: list[str]
26+
27+
28+
class ListEmbeddingModelsResponse(BaseModel):
29+
embedding_models: list[str]

jupyter_ai_litellm/handlers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from jupyter_server.utils import url_path_join
55
import tornado
66

7+
from .chat_models_rest_api import ChatModelsRestAPI
8+
79
class RouteHandler(APIHandler):
810
# The following decorator should be present on all verb methods (head, get, post,
911
# patch, put, delete, options) to ensure only authorized user can request the
@@ -19,6 +21,10 @@ def setup_handlers(web_app):
1921
host_pattern = ".*$"
2022

2123
base_url = web_app.settings["base_url"]
24+
print(f"Base url is {base_url}")
2225
route_pattern = url_path_join(base_url, "jupyter-ai-litellm", "get-example")
23-
handlers = [(route_pattern, RouteHandler)]
26+
handlers = [
27+
(route_pattern, RouteHandler),
28+
(url_path_join(base_url, "api/models/chat") + r"(?:\?.*)?", ChatModelsRestAPI)
29+
]
2430
web_app.add_handlers(host_pattern, handlers)

jupyter_ai_litellm/model_list.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from litellm import all_embedding_models, models_by_provider
2+
3+
chat_model_ids = []
4+
embedding_model_ids = []
5+
embedding_model_set = set(all_embedding_models)
6+
7+
for provider_name in models_by_provider:
8+
for model_name in models_by_provider[provider_name]:
9+
model_name: str = model_name
10+
11+
if model_name.startswith(f"{provider_name}/"):
12+
model_id = model_name
13+
else:
14+
model_id = f"{provider_name}/{model_name}"
15+
16+
is_embedding = (
17+
model_name in embedding_model_set
18+
or model_id in embedding_model_set
19+
or "embed" in model_id
20+
)
21+
22+
if is_embedding:
23+
embedding_model_ids.append(model_id)
24+
else:
25+
chat_model_ids.append(model_id)
26+
27+
28+
CHAT_MODELS = sorted(chat_model_ids)
29+
"""
30+
List of chat model IDs, following the `litellm` syntax.
31+
"""
32+
33+
EMBEDDING_MODELS = sorted(embedding_model_ids)
34+
"""
35+
List of embedding model IDs, following the `litellm` syntax.
36+
"""

jupyter_ai_litellm/tests/test_handlers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,16 @@ async def test_get_example(jp_fetch):
1010
payload = json.loads(response.body)
1111
assert payload == {
1212
"data": "This is /jupyter-ai-litellm/get-example endpoint!"
13-
}
13+
}
14+
15+
async def test_get_chat_models(jp_fetch):
16+
# When
17+
response = await jp_fetch("api", "models", "chat")
18+
19+
# Then
20+
assert response.code == 200
21+
payload = json.loads(response.body)
22+
chat_models = payload.get("chat_models")
23+
24+
assert chat_models
25+
assert len(chat_models) > 0

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ authors = [
2727
{ name = "Project Jupyter", email = "[email protected]" },
2828
]
2929
dependencies = [
30-
"jupyter_server>=2.4.0,<3"
30+
"jupyter_server>=2.4.0,<3",
31+
"litellm>=1.73,<2",
3132
]
3233
dynamic = ["version"]
3334

0 commit comments

Comments
 (0)