Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ To create and register a custom AI persona:
```python
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
from jupyterlab_chat.models import Message
import os

# Path to avatar file in your package
AVATAR_PATH = os.path.join(os.path.dirname(__file__), "assets", "avatar.svg")


class MyCustomPersona(BasePersona):
@property
def defaults(self):
return PersonaDefaults(
name="MyPersona",
description="A helpful custom assistant",
avatar_path="/api/ai/static/custom-avatar.svg",
avatar_path=AVATAR_PATH, # Absolute path to avatar file
system_prompt="You are a helpful assistant specialized in...",
)

Expand All @@ -39,6 +44,8 @@ class MyCustomPersona(BasePersona):
self.send_message(response)
```

**Avatar Path**: The `avatar_path` should be an absolute path to an image file (SVG, PNG, or JPG) within your package. The avatar will be automatically served at `/api/ai/avatars/{filename}`. If multiple personas use the same filename, the first one found will be served.

### 2. Register via Entry Points

Add to your package's `pyproject.toml`:
Expand Down Expand Up @@ -85,21 +92,28 @@ For development and local customization, personas can be loaded from the `.jupyt
```python
from jupyter_ai_persona_manager import BasePersona, PersonaDefaults
from jupyterlab_chat.models import Message
import os

# Path to avatar file (in same directory as persona file)
AVATAR_PATH = os.path.join(os.path.dirname(__file__), "avatar.svg")


class MyLocalPersona(BasePersona):
@property
def defaults(self):
return PersonaDefaults(
name="Local Dev Assistant",
description="A persona for local development",
avatar_path="/api/ai/static/jupyternaut.svg",
avatar_path=AVATAR_PATH,
system_prompt="You help with local development tasks.",
)

async def process_message(self, message: Message):
self.send_message(f"Local persona received: {message.body}")
```

**Note**: Place your avatar file (e.g., `avatar.svg`) in the same directory as your persona file.

### Refreshing Personas

Use the `/refresh-personas` slash command in any chat to reload personas without restarting JupyterLab:
Expand Down
5 changes: 3 additions & 2 deletions jupyter_ai_persona_manager/base_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PersonaDefaults(BaseModel):
################################################
name: str # e.g. "Jupyternaut"
description: str # e.g. "..."
avatar_path: str # e.g. /avatars/jupyternaut.svg
avatar_path: str # e.g. "/path/to/package/avatars/jupyternaut.svg" - absolute path to avatar file
system_prompt: str # e.g. "You are a language model named..."

################################################
Expand Down Expand Up @@ -179,7 +179,8 @@ def avatar_path(self) -> str:
This is set here because we may require this field to be configurable
for all personas in the future.
"""
return self.defaults.avatar_path
filename = os.path.basename(self.defaults.avatar_path)
return f"/api/ai/avatars/{filename}"

@property
def system_prompt(self) -> str:
Expand Down
21 changes: 11 additions & 10 deletions jupyter_ai_persona_manager/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import time
from asyncio import get_event_loop_policy
from typing import TYPE_CHECKING

from jupyter_server.extension.application import ExtensionApp
Expand All @@ -10,7 +11,7 @@
from traitlets import Type
from traitlets.config import Config

from jupyter_ai_persona_manager.handlers import RouteHandler
from jupyter_ai_persona_manager.handlers import RouteHandler, AvatarHandler

from .persona_manager import PersonaManager

Expand All @@ -30,8 +31,9 @@ class PersonaManagerExtension(ExtensionApp):

name = "jupyter_ai_persona_manager"
handlers = [
(r"jupyter-ai-persona-manager/health/?", RouteHandler)
] # No direct HTTP handlers, works through router integration
(r"jupyter-ai-persona-manager/health/?", RouteHandler),
(r"/api/ai/avatars/(.*)", AvatarHandler),
]

persona_manager_class = Type(
klass=PersonaManager,
Expand All @@ -48,28 +50,27 @@ def event_loop(self) -> AbstractEventLoop:
"""
Returns a reference to the asyncio event loop.
"""
from asyncio import get_event_loop_policy
return get_event_loop_policy().get_event_loop()

def initialize_settings(self):
"""Initialize persona manager settings and router integration."""
start = time.time()

# Ensure 'jupyter-ai.persona-manager' is in `self.settings`, which gets
# copied to `self.serverapp.web_app.settings` after this method returns
if 'jupyter-ai' not in self.settings:
self.settings['jupyter-ai'] = {}
if 'persona-manager' not in self.settings['jupyter-ai']:
self.settings['jupyter-ai']['persona-managers'] = {}

# Set up router integration task
self.event_loop.create_task(self._setup_router_integration())

# Log server extension startup time
self.log.info(f"Registered {self.name} server extension")
startup_time = round((time.time() - start) * 1000)
self.log.info(f"Initialized Persona Manager server extension in {startup_time} ms.")

async def _setup_router_integration(self) -> None:
"""
Set up integration with jupyter-ai-router.
Expand Down Expand Up @@ -110,7 +111,7 @@ def _on_router_chat_init(self, room_id: str, ychat: "YChat") -> None:
This initializes persona manager for the new chat room.
"""
self.log.info(f"Router detected new chat room, initializing persona manager: {room_id}")

# Initialize persona manager for this chat
persona_manager = self._init_persona_manager(room_id, ychat)
if not persona_manager:
Expand All @@ -119,7 +120,7 @@ def _on_router_chat_init(self, room_id: str, ychat: "YChat") -> None:
+ "Please verify your configuration and open a new issue on GitHub if this error persists."
)
return

# Cache the persona manager in server settings dictionary.
#
# NOTE: This must be added to `self.serverapp.web_app.settings`, not
Expand All @@ -128,7 +129,7 @@ def _on_router_chat_init(self, room_id: str, ychat: "YChat") -> None:
# `self.initialize_settings` returns.
persona_managers_by_room = self.serverapp.web_app.settings['jupyter-ai']['persona-managers']
persona_managers_by_room[room_id] = persona_manager

# Register persona manager callbacks with router
self.router.observe_chat_msg(room_id, persona_manager.on_chat_message)

Expand Down
72 changes: 70 additions & 2 deletions jupyter_ai_persona_manager/handlers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,82 @@
import json
import mimetypes
import os
from functools import lru_cache
from typing import Optional

from jupyter_server.base.handlers import APIHandler
from jupyter_server.base.handlers import JupyterHandler
import tornado

class RouteHandler(APIHandler):

class RouteHandler(JupyterHandler):
# The following decorator should be present on all verb methods (head, get, post,
# patch, put, delete, options) to ensure only authorized user can request the
# Jupyter server
@tornado.web.authenticated
def get(self):
self.set_header("Content-Type", "application/json")
self.finish(json.dumps({
"data": "This is /jupyter-ai-persona-manager/get-example endpoint!"
}))


class AvatarHandler(JupyterHandler):
"""
Handler for serving persona avatar files.
Looks up avatar files through the PersonaManager to find the correct file path,
then serves the file with appropriate content-type headers.
"""

@tornado.web.authenticated
async def get(self, filename: str):
"""Serve an avatar file by filename."""
# Get the avatar file path from persona managers
avatar_path = self._find_avatar_file(filename)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is on the handler, it will walk all personas on each request. Any way to remove this from the per-request logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe update all of this on persona loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Implemented a module-level cache that's built once at initialization and rebuilt on persona refresh. Avatar requests now do O(1) dictionary lookup instead of iterating all personas. See the latest commit for details.


if avatar_path is None:
raise tornado.web.HTTPError(404, f"Avatar file not found: {filename}")

# Serve the file
try:
# Set content type based on file extension
content_type, _ = mimetypes.guess_type(avatar_path)
if content_type:
self.set_header("Content-Type", content_type)

# Read and serve the file
with open(avatar_path, 'rb') as f:
content = f.read()
self.write(content)

await self.finish()
except Exception as e:
self.log.error(f"Error serving avatar file {filename}: {e}")
raise tornado.web.HTTPError(500, f"Error serving avatar file: {str(e)}")

@lru_cache(maxsize=128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cache will be keyed on both self and filename. Because this is on the handler you will get many of these caches rather than a small number that will do the caching you want. Can you extract into a function that does the caching you want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe invalidate the cache on reloading personas?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent catch! You're absolutely right - the lru_cache on an instance method would create a new cache for each request handler instance, making it useless.

I've implemented a module-level cache instead:

  • _avatar_cache: Dict[str, str] dictionary at module level
  • build_avatar_cache() function called when personas are initialized or refreshed
  • _find_avatar_file() now does O(1) dictionary lookup instead of iterating
  • Cache is automatically rebuilt when /refresh-personas is called

This addresses your performance concern about walking all personas on each request.

def _find_avatar_file(self, filename: str) -> Optional[str]:
"""
Find the avatar file path by searching through all persona managers.
Uses LRU cache to avoid repeated lookups for the same filename.
"""
# Get all persona managers from settings
persona_managers = self.settings.get('jupyter-ai', {}).get('persona-managers', {})

for room_id, persona_manager in persona_managers.items():
# Check each persona's avatar path
for persona in persona_manager.personas.values():
try:
avatar_path = persona.defaults.avatar_path
if avatar_path and os.path.basename(avatar_path) == filename:
# Found a match, return the absolute path
if os.path.exists(avatar_path):
return avatar_path
else:
self.log.warning(f"Avatar file not found at path: {avatar_path}")
except Exception as e:
self.log.warning(f"Error checking avatar for persona {persona.name}: {e}")
continue

return None
1 change: 1 addition & 0 deletions jupyter_ai_persona_manager/persona_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _init_personas(self) -> dict[str, BasePersona]:
self.log.info(
f"SUCCESS: Initialized {len(personas)} AI personas for chat room '{self.ychat.get_id()}'. Time elapsed: {elapsed_time_ms}ms."
)

return personas

def _display_persona_error_message(self, persona_item: dict) -> None:
Expand Down
106 changes: 105 additions & 1 deletion jupyter_ai_persona_manager/tests/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import json
import os
import tempfile
from pathlib import Path
from unittest.mock import Mock

import pytest


async def test_health(jp_fetch):
Expand All @@ -10,4 +16,102 @@ async def test_health(jp_fetch):
payload = json.loads(response.body)
assert payload == {
"data": "This is /jupyter-ai-persona-manager/get-example endpoint!"
}
}


@pytest.fixture
def mock_persona_with_avatar(tmp_path):
"""Create a mock persona with an avatar file."""
# Create avatar file
avatar_file = tmp_path / "test_avatar.svg"
avatar_file.write_text('<svg><circle r="10"/></svg>')

# Create mock persona
mock_persona = Mock()
mock_persona.defaults.avatar_path = str(avatar_file)
mock_persona.name = "TestPersona"

return mock_persona, str(avatar_file)


async def test_avatar_handler_serves_file(jp_fetch, jp_serverapp, tmp_path):
"""Test that the avatar handler can serve avatar files."""
# Create avatar file
avatar_file = tmp_path / "test.svg"
avatar_file.write_text('<svg><circle r="10"/></svg>')

# Create mock persona with avatar
mock_persona = Mock()
mock_persona.defaults.avatar_path = str(avatar_file)
mock_persona.name = "TestPersona"

# Create mock persona manager
mock_pm = Mock()
mock_pm.personas = {"test-persona": mock_persona}

# Add to settings
if 'jupyter-ai' not in jp_serverapp.web_app.settings:
jp_serverapp.web_app.settings['jupyter-ai'] = {}
jp_serverapp.web_app.settings['jupyter-ai']['persona-managers'] = {
'room1': mock_pm
}

# Fetch the avatar
response = await jp_fetch("api", "ai", "avatars", "test.svg")

# Verify response
assert response.code == 200
assert b'<svg><circle r="10"/></svg>' in response.body
assert 'image/svg+xml' in response.headers.get('Content-Type', '')


async def test_avatar_handler_404_for_missing_file(jp_fetch, jp_serverapp):
"""Test that the avatar handler returns 404 for missing files."""
# Create mock persona manager with no matching avatar
mock_pm = Mock()
mock_pm.personas = {}

# Add to settings
if 'jupyter-ai' not in jp_serverapp.web_app.settings:
jp_serverapp.web_app.settings['jupyter-ai'] = {}
jp_serverapp.web_app.settings['jupyter-ai']['persona-managers'] = {
'room1': mock_pm
}

# Try to fetch a non-existent avatar
with pytest.raises(Exception) as exc_info:
await jp_fetch("api", "ai", "avatars", "nonexistent.svg")

# Verify 404 response
assert '404' in str(exc_info.value) or 'Not Found' in str(exc_info.value)


async def test_avatar_handler_serves_png(jp_fetch, jp_serverapp, tmp_path):
"""Test that the avatar handler can serve PNG files."""
# Create PNG file
avatar_file = tmp_path / "test.png"
avatar_file.write_bytes(b'\x89PNG\r\n\x1a\n')

# Create mock persona with avatar
mock_persona = Mock()
mock_persona.defaults.avatar_path = str(avatar_file)
mock_persona.name = "TestPersona"

# Create mock persona manager
mock_pm = Mock()
mock_pm.personas = {"test-persona": mock_persona}

# Add to settings
if 'jupyter-ai' not in jp_serverapp.web_app.settings:
jp_serverapp.web_app.settings['jupyter-ai'] = {}
jp_serverapp.web_app.settings['jupyter-ai']['persona-managers'] = {
'room1': mock_pm
}

# Fetch the avatar
response = await jp_fetch("api", "ai", "avatars", "test.png")

# Verify response
assert response.code == 200
assert response.body.startswith(b'\x89PNG')
assert 'image/png' in response.headers.get('Content-Type', '')
Loading