From 4e48439899f5537dec3f1d70865321f691753a2c Mon Sep 17 00:00:00 2001 From: Ishankoradia Date: Sat, 9 Mar 2024 14:38:25 +0530 Subject: [PATCH] making sure the token limit is under 7000 for all api chat calls --- llm/api.py | 77 ++++++++++++++------- llm/migrations/0012_embedding_num_tokens.py | 18 +++++ llm/models.py | 1 + llm/utils/prompt.py | 20 +++++- 4 files changed, 90 insertions(+), 26 deletions(-) create mode 100644 llm/migrations/0012_embedding_num_tokens.py diff --git a/llm/api.py b/llm/api.py index a7fa468..2d8e977 100644 --- a/llm/api.py +++ b/llm/api.py @@ -6,13 +6,19 @@ from pypdf import PdfReader from pgvector.django import L2Distance from django.http import JsonResponse +from django.forms.models import model_to_dict +from django.db.models import Sum from rest_framework import status from rest_framework.decorators import api_view from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView import openai -from llm.utils.prompt import context_prompt_messages, evaluate_criteria_score +from llm.utils.prompt import ( + context_prompt_messages, + evaluate_criteria_score, + count_tokens_for_text, +) from llm.utils.general import generate_session_id from llm.models import Organization, Embedding, Message @@ -20,6 +26,8 @@ basicConfig(level=INFO) logger = getLogger() +TOKEN_LIMIT = 7000 + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm.settings") django.setup() @@ -106,17 +114,35 @@ def create_chat(request): model="text-embedding-ada-002", input=prompt )["data"][0]["embedding"] - embedding_results = Embedding.objects.alias( - distance=L2Distance("text_vectors", prompt_embeddings) - ).filter(distance__gt=0.9) - - relevant_english_context = "".join( - result.original_text for result in embedding_results + embedding_results = ( + Embedding.objects.alias( + distance=L2Distance("text_vectors", prompt_embeddings), + ) + .filter(distance__gt=0.7) + .order_by("-distance") ) logger.info( f"retrieved {len(embedding_results)} relevant document context from db" ) + # Filter embedding to make sure token limit is under 7000 + final_embeddings: list[Embedding] = [] + token_count = 0 + for embedding in embedding_results: + token_count += embedding.num_tokens + if token_count < TOKEN_LIMIT: + final_embeddings.append(embedding) + else: + break + + logger.info( + f"Using {len(final_embeddings)}/{len(embedding_results)} relevant docs to make sure token limit is under {TOKEN_LIMIT}. Token count: {token_count}" + ) + + relevant_english_context = "".join( + result.original_text for result in final_embeddings + ) + # 3. Fetch the chat history from our message store to send to openai and back in the response historical_chats = Message.objects.filter(session_id=session_id).all() @@ -130,7 +156,6 @@ def create_chat(request): language_results["english_translation"], historical_chats, ), - # max_tokens=150, ) logger.info("received response from the ai bot for the current prompt") @@ -208,22 +233,26 @@ def post(self, request, format=None): for page in pdf_reader.pages: page_text = page.extract_text().replace("\n", " ") - response = openai.Embedding.create( - model="text-embedding-ada-002", input=page_text - ) - - embeddings = response["data"][0]["embedding"] - if len(embeddings) != 1536: - raise ValueError(f"Invalid embedding length: #{len(embeddings)}") - - Embedding.objects.create( - source_name=file.name, - original_text=page_text, - text_vectors=embeddings, - organization=org, - ) - - return JsonResponse({"msg": "file upload successful"}) + if len(page_text) > 0: + response = openai.Embedding.create( + model="text-embedding-ada-002", input=page_text + ) + + embeddings = response["data"][0]["embedding"] + if len(embeddings) != 1536: + raise ValueError( + f"Invalid embedding length: #{len(embeddings)}" + ) + + Embedding.objects.create( + source_name=file.name, + original_text=page_text, + text_vectors=embeddings, + organization=org, + num_tokens=count_tokens_for_text(page_text), + ) + + return JsonResponse({"msg": f"Uploaded file {file.name} successfully"}) except ValueError as error: logger.error(f"Error: {error}") return JsonResponse( diff --git a/llm/migrations/0012_embedding_num_tokens.py b/llm/migrations/0012_embedding_num_tokens.py new file mode 100644 index 0000000..1c7e2ed --- /dev/null +++ b/llm/migrations/0012_embedding_num_tokens.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.6 on 2024-03-09 08:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('llm', '0011_organization_openai_key'), + ] + + operations = [ + migrations.AddField( + model_name='embedding', + name='num_tokens', + field=models.IntegerField(default=0), + ), + ] diff --git a/llm/models.py b/llm/models.py index 5db79b0..c64b0e8 100644 --- a/llm/models.py +++ b/llm/models.py @@ -43,6 +43,7 @@ class Embedding(models.Model): original_text = models.TextField() text_vectors = VectorField(dimensions=1536, null=True) organization = models.ForeignKey(Organization, on_delete=models.CASCADE) + num_tokens = models.IntegerField(default=0) class Meta: db_table = "embedding" diff --git a/llm/utils/prompt.py b/llm/utils/prompt.py index 42a5566..eab53af 100644 --- a/llm/utils/prompt.py +++ b/llm/utils/prompt.py @@ -1,6 +1,11 @@ -from llm.models import Message, Organization +from llm.models import Message, Organization, Embedding import openai from typing import Union +import tiktoken +from logging import basicConfig, INFO, getLogger + +basicConfig(level=INFO) +logger = getLogger() def context_prompt_messages( @@ -31,6 +36,7 @@ def context_prompt_messages( Question: {question} Chatbot Answer in {language}: """, } + chat_prompt_messages = ( [system_message_prompt] + [{"role": chat.role, "content": chat.message} for chat in historical_chats] @@ -55,7 +61,17 @@ def evaluate_criteria_score( ) response_text = response.choices[0].message.content - print(f"response_text: {response_text}") + logger.info(f"response_text: {response_text}") evaluation_score = int(response_text) return evaluation_score + + +def count_tokens_for_text( + prompt_text: str, model: str = "text-embedding-ada-002" +) -> int: + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + tokens_arr = encoding.encode(prompt_text) + + return len(tokens_arr)