Skip to content

Commit aa4d83d

Browse files
authored
Merge pull request #1 from 3coins/add-litellm
Add LiteLLM dependency and handlers
2 parents 036c734 + cdfe2fb commit aa4d83d

File tree

7 files changed

+110
-5
lines changed

7 files changed

+110
-5
lines changed

README.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,29 @@
22

33
[![Github Actions Status](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/workflows/Build/badge.svg)](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/actions/workflows/build.yml)
44

5-
A JupyterLab extension that provides LiteLLM model abstraction
5+
A JupyterLab extension that provides LiteLLM model abstraction for Jupyter AI
66

7-
This extension is composed of a Python package named `jupyter_ai_litellm`.
7+
This extension is composed of a Python package named `jupyter_ai_litellm` that exposes LiteLLM's extensive catalog of language models through a standardized API.
8+
9+
## Features
10+
11+
- **Comprehensive Model Support**: Access to hundreds of chat and embedding models from various providers (OpenAI, Anthropic, Google, Cohere, Azure, AWS, and more) through LiteLLM's unified interface
12+
- **Standardized API**: Consistent REST API endpoints for model discovery and interaction
13+
- **Easy Integration**: Seamlessly integrates with Jupyter AI to expand available model options
14+
15+
## API Endpoints
16+
17+
### Chat Models
18+
19+
- `GET /api/ai/models/chat` - Returns a list of all available chat models
20+
21+
The response includes model IDs in LiteLLM format (e.g., `openai/gpt-4`, `anthropic/claude-3-sonnet`, etc.)
22+
23+
### Model Lists
24+
25+
The extension automatically discovers and categorizes models from LiteLLM's supported providers:
26+
- Chat models for conversational AI
27+
- Embedding models for vector representations
828

929
## Requirements
1030

jupyter_ai_litellm/_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = '0.0.0'
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
2+
from pydantic import BaseModel
3+
from tornado import web
4+
5+
from .model_list import CHAT_MODELS
6+
7+
8+
class ChatModelsRestAPI(BaseAPIHandler):
9+
"""
10+
A Tornado handler that defines the REST API served on the
11+
`/api/ai/models/chat` endpoint.
12+
13+
- `GET /api/ai/models/chat`: returns list of all chat models.
14+
15+
- `GET /api/ai/models/chat?id=<model_id>`: returns info on that model (TODO)
16+
"""
17+
18+
@web.authenticated
19+
def get(self):
20+
response = ListChatModelsResponse(chat_models=CHAT_MODELS)
21+
self.finish(response.model_dump_json())
22+
23+
24+
class ListChatModelsResponse(BaseModel):
25+
chat_models: list[str]
26+
27+
28+
class ListEmbeddingModelsResponse(BaseModel):
29+
embedding_models: list[str]

jupyter_ai_litellm/handlers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from jupyter_server.utils import url_path_join
55
import tornado
66

7+
from .chat_models_rest_api import ChatModelsRestAPI
8+
79
class RouteHandler(APIHandler):
810
# The following decorator should be present on all verb methods (head, get, post,
911
# patch, put, delete, options) to ensure only authorized user can request the
@@ -19,6 +21,10 @@ def setup_handlers(web_app):
1921
host_pattern = ".*$"
2022

2123
base_url = web_app.settings["base_url"]
24+
print(f"Base url is {base_url}")
2225
route_pattern = url_path_join(base_url, "jupyter-ai-litellm", "get-example")
23-
handlers = [(route_pattern, RouteHandler)]
26+
handlers = [
27+
(route_pattern, RouteHandler),
28+
(url_path_join(base_url, "api/ai/models/chat") + r"(?:\?.*)?", ChatModelsRestAPI)
29+
]
2430
web_app.add_handlers(host_pattern, handlers)

jupyter_ai_litellm/model_list.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from litellm import all_embedding_models, models_by_provider
2+
3+
chat_model_ids = []
4+
embedding_model_ids = []
5+
embedding_model_set = set(all_embedding_models)
6+
7+
for provider_name in models_by_provider:
8+
for model_name in models_by_provider[provider_name]:
9+
model_name: str = model_name
10+
11+
if model_name.startswith(f"{provider_name}/"):
12+
model_id = model_name
13+
else:
14+
model_id = f"{provider_name}/{model_name}"
15+
16+
is_embedding = (
17+
model_name in embedding_model_set
18+
or model_id in embedding_model_set
19+
or "embed" in model_id
20+
)
21+
22+
if is_embedding:
23+
embedding_model_ids.append(model_id)
24+
else:
25+
chat_model_ids.append(model_id)
26+
27+
28+
CHAT_MODELS = sorted(chat_model_ids)
29+
"""
30+
List of chat model IDs, following the `litellm` syntax.
31+
"""
32+
33+
EMBEDDING_MODELS = sorted(embedding_model_ids)
34+
"""
35+
List of embedding model IDs, following the `litellm` syntax.
36+
"""

jupyter_ai_litellm/tests/test_handlers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,16 @@ async def test_get_example(jp_fetch):
1010
payload = json.loads(response.body)
1111
assert payload == {
1212
"data": "This is /jupyter-ai-litellm/get-example endpoint!"
13-
}
13+
}
14+
15+
async def test_get_chat_models(jp_fetch):
16+
# When
17+
response = await jp_fetch("api", "ai", "models", "chat")
18+
19+
# Then
20+
assert response.code == 200
21+
payload = json.loads(response.body)
22+
chat_models = payload.get("chat_models")
23+
24+
assert chat_models
25+
assert len(chat_models) > 0

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ authors = [
2727
{ name = "Project Jupyter", email = "[email protected]" },
2828
]
2929
dependencies = [
30-
"jupyter_server>=2.4.0,<3"
30+
"jupyter_server>=2.4.0,<3",
31+
"litellm>=1.73,<2",
3132
]
3233
dynamic = ["version"]
3334

0 commit comments

Comments
 (0)