Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/user rate limits #1342

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
20 changes: 13 additions & 7 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from websockets import ConnectionClosedError

from redbox import Redbox
from redbox.models.settings import get_settings
from redbox_app.redbox_core import error_messages
from redbox_app.redbox_core.models import (
Chat,
ChatLLMBackend,
ChatMessage,
File,
)
from redbox_app.redbox_core.ratelimit import UserRateLimiter
from redbox_app.redbox_core.utils import sanitise_string

User = get_user_model()
Expand All @@ -43,6 +45,7 @@ class ChatConsumer(AsyncWebsocketConsumer):
full_reply: ClassVar = []
route = None
redbox = Redbox(debug=settings.DEBUG)
user_ratelimiter: UserRateLimiter = UserRateLimiter(token_ratelimit=get_settings().user_token_ratelimit)

async def receive(self, text_data=None, bytes_data=None):
"""Receive & respond to message from browser websocket."""
Expand Down Expand Up @@ -108,16 +111,19 @@ async def llm_conversation(self, session: Chat) -> None:
state = await sync_to_async(session.to_langchain)()

try:
await self.redbox.run(
state,
response_tokens_callback=self.handle_text,
)

message = await self.save_message(session, "".join(self.full_reply), ChatMessage.Role.ai)
if await self.user_ratelimiter.is_allowed(state):
await self.redbox.run(
state,
response_tokens_callback=self.handle_text,
)
message = await self.save_message(session, "".join(self.full_reply), ChatMessage.Role.ai)
else:
await self.send_to_client(
"text", "You have exceeded your rate limit for the last minute, please wait before trying again"
)
await self.send_to_client(
"end", {"message_id": message.id, "title": session.name, "session_id": session.id}
)

except RateLimitError as e:
logger.exception("Rate limit error", exc_info=e)
await self.send_to_client("error", error_messages.RATE_LIMITED)
Expand Down
6 changes: 6 additions & 0 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def to_langchain(self) -> redbox.models.chain.RedboxState:
)

return redbox.models.chain.RedboxState(
user_uuid=self.user.id,
documents=[Document(str(f.text), metadata={"uri": f.original_file.name}) for f in self.file_set.all()],
messages=[message.to_langchain() for message in self.chatmessage_set.all()],
chat_backend=chat_backend,
Expand Down Expand Up @@ -661,6 +662,11 @@ def to_langchain(self) -> AnyMessage:
return AIMessage(content=escape_curly_brackets(self.text))
return HumanMessage(content=escape_curly_brackets(self.text))

@classmethod
def get_since(cls, user: User, since: datetime) -> Sequence["ChatMessage"]:
"""Returns all chat messages for a given user, with the most recent message after 'since'"""
return cls.objects.filter(chat__user=user, created_at__gt=since)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return cls.objects.filter(chat__user=user, created_at__gt=since)
cls.objects.filter(user=user).filter(created_at_gt=since).annotate(Sum("token_count"))["token_count__sum"] or 0

you can get the database to do the sum here?

Copy link
Collaborator

@gecBurton gecBurton Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this could now be moved to be a method belonging to the User model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but i think that there is a bigger problem here... the ChatMessage.tocken_count only includes what the user has typed, and not the inlcuded files, I think we need to make this change first?


def log(self):
elastic_log_msg = {
"@timestamp": self.created_at.isoformat(),
Expand Down
29 changes: 29 additions & 0 deletions django_app/redbox_app/redbox_core/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from datetime import UTC, datetime, timedelta
from uuid import UUID

from asgiref.sync import sync_to_async

from redbox.models.chain import RedboxState
from redbox.models.settings import get_settings
from redbox_app.redbox_core.models import ChatMessage, User

_settings = get_settings()


class UserRateLimiter:
def __init__(self, token_ratelimit=_settings.user_token_ratelimit) -> None:
self.token_ratelimit = token_ratelimit

async def is_allowed(self, state: RedboxState):
consumed_ratelimit = await sync_to_async(self.get_tokens_for_user_in_last_minute)(state.user_uuid)
request_tokens = self.calculate_request_credits(state)
return request_tokens < self.token_ratelimit - consumed_ratelimit

def get_tokens_for_user_in_last_minute(self, user_uuid: UUID):
user = User.objects.get(pk=user_uuid)
since = datetime.now(UTC) - timedelta(minutes=1)
recent_messages = ChatMessage.get_since(user, since)
return sum(m.token_count for m in recent_messages)

def calculate_request_credits(self, state: RedboxState) -> int:
return int(sum(d.metadata.get("token_count", 0) for d in state.documents))
1 change: 1 addition & 0 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ async def test_chat_consumer_redbox_state(

# Then
expected_request = RedboxState(
user_uuid=alice.id,
documents=documents,
messages=[
HumanMessage(content="A question?"),
Expand Down
55 changes: 55 additions & 0 deletions django_app/tests/test_ratelimiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass
from uuid import UUID, uuid4

import pytest
from langchain_core.documents import Document

from redbox.models.chain import RedboxState
from redbox_app.redbox_core.ratelimit import UserRateLimiter

test_user_uuid = uuid4()


@dataclass
class UserActivity:
total_document_tokens: int
number_documents: int


def request_state(number_of_documents: int, total_document_tokens: int):
return RedboxState(
user_uuid=test_user_uuid,
documents=[
Document("", metadata={"token_count": int(total_document_tokens / number_of_documents)})
for _ in range(number_of_documents)
],
)


@pytest.mark.parametrize(
("token_ratelimit", "users_consumed_tokens", "request_state", "expect_allowed"),
[
(1000, {test_user_uuid: 100}, request_state(2, 200), True),
(1000, {test_user_uuid: 100, uuid4(): 1000, uuid4(): 900}, request_state(1, 800), True),
(1000, {test_user_uuid: 800}, request_state(2, 240), False),
(1000, {test_user_uuid: 800}, request_state(8, 400), False),
],
)
@pytest.mark.asyncio()
async def test_ratelimiter(
token_ratelimit: int, users_consumed_tokens: dict[UUID, int], request_state: RedboxState, expect_allowed: bool
):
ratelimiter = UserRateLimiter(token_ratelimit=token_ratelimit)

def create_user_tokens_consumed_mock(tokens_per_user: dict[UUID, int]):
def _impl(user_uuid: UUID):
return tokens_per_user.get(user_uuid, 0)

return _impl

ratelimiter.get_tokens_for_user_in_last_minute = create_user_tokens_consumed_mock(users_consumed_tokens)
request_allowed = await ratelimiter.is_allowed(request_state)
if request_allowed != expect_allowed:
pytest.fail(
reason=f"Request allow status did not match: Expected [{expect_allowed}]. Received [{request_allowed}]"
)
3 changes: 3 additions & 0 deletions redbox-core/redbox/models/chain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from langchain_core.documents import Document
from langchain_core.messages import AnyMessage
from pydantic import BaseModel, Field
Expand All @@ -6,6 +8,7 @@


class RedboxState(BaseModel):
user_uuid: UUID = Field(description="UUID of the user making the request")
documents: list[Document] = Field(description="List of files to process", default_factory=list)
messages: list[AnyMessage] = Field(description="All previous messages in chat", default_factory=list)
chat_backend: ChatLLMBackend = Field(description="User request AI settings", default_factory=ChatLLMBackend)
3 changes: 3 additions & 0 deletions redbox-core/redbox/models/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class Settings(BaseSettings):

model_config = SettingsConfigDict(env_file=".env", env_nested_delimiter="__", extra="allow", frozen=True)

# Rate Limiting
user_token_ratelimit: int = 360_000

@property
def elastic_chat_mesage_index(self):
return self.elastic_root_index + "-chat-mesage-log"
Expand Down
1 change: 1 addition & 0 deletions tests/.env.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DEV_MODE=true

# === Database ===

ELASTIC__HOST=localhost
KIBANA_SYSTEM_PASSWORD=redboxpass
METRICBEAT_INTERNAL_PASSWORD=redboxpass
FILEBEAT_INTERNAL_PASSWORD=redboxpass
Expand Down