Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

making sure the token limit is under 7000 for all api chat calls #29

Merged
merged 1 commit into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 53 additions & 24 deletions llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@
from pypdf import PdfReader
from pgvector.django import L2Distance
from django.http import JsonResponse
from django.forms.models import model_to_dict
from django.db.models import Sum
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView
import openai

from llm.utils.prompt import context_prompt_messages, evaluate_criteria_score
from llm.utils.prompt import (
context_prompt_messages,
evaluate_criteria_score,
count_tokens_for_text,
)
from llm.utils.general import generate_session_id
from llm.models import Organization, Embedding, Message


basicConfig(level=INFO)
logger = getLogger()

TOKEN_LIMIT = 7000

os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm.settings")

django.setup()
Expand Down Expand Up @@ -106,17 +114,35 @@ def create_chat(request):
model="text-embedding-ada-002", input=prompt
)["data"][0]["embedding"]

embedding_results = Embedding.objects.alias(
distance=L2Distance("text_vectors", prompt_embeddings)
).filter(distance__gt=0.9)

relevant_english_context = "".join(
result.original_text for result in embedding_results
embedding_results = (
Embedding.objects.alias(
distance=L2Distance("text_vectors", prompt_embeddings),
)
.filter(distance__gt=0.7)
.order_by("-distance")
)
logger.info(
f"retrieved {len(embedding_results)} relevant document context from db"
)

# Filter embedding to make sure token limit is under 7000
final_embeddings: list[Embedding] = []
token_count = 0
for embedding in embedding_results:
token_count += embedding.num_tokens
if token_count < TOKEN_LIMIT:
final_embeddings.append(embedding)
else:
break

logger.info(
f"Using {len(final_embeddings)}/{len(embedding_results)} relevant docs to make sure token limit is under {TOKEN_LIMIT}. Token count: {token_count}"
)

relevant_english_context = "".join(
result.original_text for result in final_embeddings
)

# 3. Fetch the chat history from our message store to send to openai and back in the response
historical_chats = Message.objects.filter(session_id=session_id).all()

Expand All @@ -130,7 +156,6 @@ def create_chat(request):
language_results["english_translation"],
historical_chats,
),
# max_tokens=150,
)
logger.info("received response from the ai bot for the current prompt")

Expand Down Expand Up @@ -208,22 +233,26 @@ def post(self, request, format=None):
for page in pdf_reader.pages:
page_text = page.extract_text().replace("\n", " ")

response = openai.Embedding.create(
model="text-embedding-ada-002", input=page_text
)

embeddings = response["data"][0]["embedding"]
if len(embeddings) != 1536:
raise ValueError(f"Invalid embedding length: #{len(embeddings)}")

Embedding.objects.create(
source_name=file.name,
original_text=page_text,
text_vectors=embeddings,
organization=org,
)

return JsonResponse({"msg": "file upload successful"})
if len(page_text) > 0:
response = openai.Embedding.create(
model="text-embedding-ada-002", input=page_text
)

embeddings = response["data"][0]["embedding"]
if len(embeddings) != 1536:
raise ValueError(
f"Invalid embedding length: #{len(embeddings)}"
)

Embedding.objects.create(
source_name=file.name,
original_text=page_text,
text_vectors=embeddings,
organization=org,
num_tokens=count_tokens_for_text(page_text),
)

return JsonResponse({"msg": f"Uploaded file {file.name} successfully"})
except ValueError as error:
logger.error(f"Error: {error}")
return JsonResponse(
Expand Down
18 changes: 18 additions & 0 deletions llm/migrations/0012_embedding_num_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.6 on 2024-03-09 08:28

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('llm', '0011_organization_openai_key'),
]

operations = [
migrations.AddField(
model_name='embedding',
name='num_tokens',
field=models.IntegerField(default=0),
),
]
1 change: 1 addition & 0 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Embedding(models.Model):
original_text = models.TextField()
text_vectors = VectorField(dimensions=1536, null=True)
organization = models.ForeignKey(Organization, on_delete=models.CASCADE)
num_tokens = models.IntegerField(default=0)

class Meta:
db_table = "embedding"
20 changes: 18 additions & 2 deletions llm/utils/prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from llm.models import Message, Organization
from llm.models import Message, Organization, Embedding
import openai
from typing import Union
import tiktoken
from logging import basicConfig, INFO, getLogger

basicConfig(level=INFO)
logger = getLogger()


def context_prompt_messages(
Expand Down Expand Up @@ -31,6 +36,7 @@ def context_prompt_messages(
Question: {question}
Chatbot Answer in {language}: """,
}

chat_prompt_messages = (
[system_message_prompt]
+ [{"role": chat.role, "content": chat.message} for chat in historical_chats]
Expand All @@ -55,7 +61,17 @@ def evaluate_criteria_score(
)
response_text = response.choices[0].message.content

print(f"response_text: {response_text}")
logger.info(f"response_text: {response_text}")
evaluation_score = int(response_text)

return evaluation_score


def count_tokens_for_text(
prompt_text: str, model: str = "text-embedding-ada-002"
) -> int:
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

tokens_arr = encoding.encode(prompt_text)

return len(tokens_arr)