Skip to content

Commit

Permalink
Merge pull request #28 from glific/26-organization-api-keys
Browse files Browse the repository at this point in the history
some cleanup; middlewares, proper api and error response. Refactoring…
  • Loading branch information
Ishankoradia committed Mar 9, 2024
2 parents 523a8b2 + 9be58f6 commit 1889832
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 84 deletions.
160 changes: 78 additions & 82 deletions llm/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import django
import secrets
import string
import json
from logging import basicConfig, INFO, getLogger

Expand All @@ -10,12 +8,12 @@
from django.http import JsonResponse
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView
import openai

from llm.utils import context_prompt_messages, evaluate_criteria_score
from llm.utils.prompt import context_prompt_messages, evaluate_criteria_score
from llm.utils.general import generate_session_id
from llm.models import Organization, Embedding, Message


Expand All @@ -32,16 +30,12 @@
@api_view(["POST"])
def create_chat(request):
try:
organization = current_organization(request)
if not organization:
return Response(
f"Invalid API key",
status=status.HTTP_404_NOT_FOUND,
)
organization: Organization = request.org
logger.info(f"processing chat prompt request for org {organization.name}")

prompt = request.data.get("prompt").strip()
gpt_model = request.data.get("gpt_model", "gpt-3.5-turbo").strip()
session_id = (request.data.get("session_id") or generate_short_id()).strip()
session_id = (request.data.get("session_id") or generate_session_id()).strip()

# 1. Function calling to do language detection of the user's question (1st call to OpenAI)
response = openai.ChatCompletion.create(
Expand Down Expand Up @@ -166,7 +160,7 @@ def create_chat(request):
)
logger.info("Stored messages in django db")

return Response(
return JsonResponse(
{
"question": prompt,
"answer": prompt_response.content,
Expand All @@ -182,9 +176,9 @@ def create_chat(request):
status=status.HTTP_201_CREATED,
)
except Exception as error:
print(f"Error: {error}")
return Response(
f"Something went wrong",
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

Expand All @@ -194,12 +188,8 @@ class FileUploadView(APIView):

def post(self, request, format=None):
try:
org = current_organization(request)
if not org:
return Response(
f"Invalid API key",
status=status.HTTP_404_NOT_FOUND,
)
org: Organization = request.org
logger.info(f"processing file upload for org {org.name}")

if "file" not in request.data:
raise ValueError("Empty content")
Expand All @@ -225,87 +215,67 @@ def post(self, request, format=None):
organization=org,
)

return JsonResponse({"status": "file upload successful"})
return JsonResponse({"msg": "file upload successful"})
except ValueError as error:
return Response(
f"Invalid file: {error}",
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Invalid file"},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as error:
print(f"Error: {error}")
return Response(
f"Something went wrong",
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)


@api_view(["POST"])
def set_system_prompt(request):
try:
org: Organization = request.org
logger.info(f"processing set system prompt for org {org.name}")

system_prompt = request.data.get("system_prompt").strip()
org = current_organization(request)
if not org:
return Response(
f"Invalid API key",
status=status.HTTP_404_NOT_FOUND,
)

Organization.objects.filter(id=org.id).update(system_prompt=system_prompt)

return Response(
f"Updated System Prompt",
return JsonResponse(
{"msg": f"Updated System Prompt"},
status=status.HTTP_201_CREATED,
)
except Exception as error:
print(f"Error: {error}")
return Response(
f"Something went wrong",
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)


def current_organization(request):
api_key = request.headers.get("Authorization")
if not api_key:
return None

try:
return Organization.objects.get(api_key=api_key)
except Organization.DoesNotExist:
return None


def generate_short_id(length=6):
alphanumeric = string.ascii_letters + string.digits
return "".join(secrets.choice(alphanumeric) for _ in range(length))


@api_view(["POST"])
def set_evaluator_prompt(request):
try:
org: Organization = request.org
logger.info(f"processing set evaluator prompt request for org {org.name}")

evaluator_prompts = request.data.get("evaluator_prompts")
org = current_organization(request)
if not org:
return Response(
f"Invalid API key",
status=status.HTTP_404_NOT_FOUND,
)

Organization.objects.filter(id=org.id).update(
evaluator_prompts=evaluator_prompts
)

return Response(
f"Updated Evaluator Prompt",
return JsonResponse(
{"msg": f"Updated Evaluator Prompt"},
status=status.HTTP_201_CREATED,
)
except Exception as error:
print(f"Error: {error}")
return Response(
f"Something went wrong",
logger.info(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)



@api_view(["POST"])
def set_examples_text(request):
"""
Expand All @@ -317,26 +287,52 @@ def set_examples_text(request):
'
"""
try:
org: Organization = request.org
logger.info(f"processing set examples text request for org {org.name}")

examples_text = request.data.get("examples_text")
org = current_organization(request)
if not org:
return Response(
f"Invalid API key",
status=status.HTTP_404_NOT_FOUND,
)

Organization.objects.filter(id=org.id).update(
examples_text=examples_text
)
Organization.objects.filter(id=org.id).update(examples_text=examples_text)

return Response(
f"Updated Examples Text",
return JsonResponse(
{"msg": f"Updated Examples Text"},
status=status.HTTP_201_CREATED,
)


except Exception as error:
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)


@api_view(["POST"])
def set_openai_key(request):
"""
Example request body:
'
Question: Peshab ki jagah se kharash ho rahi hai
Chatbot Answer in Hindi: aapakee samasya ke lie dhanyavaad. yah peshaab ke samay kharaash kee samasya ho sakatee hai. ise yoorinaree traikt inphekshan (uti) kaha jaata hai. yoorinaree traikt imphekshan utpann hone ka mukhy kaaran aantarik inphekshan ho sakata hai.
'
"""
try:
org: Organization = request.org
logger.info(f"processing set openai key request for org {org.name}")

openai_key = request.data.get("key")

Organization.objects.filter(id=org.id).update(openai_key=openai_key)

return JsonResponse(
{"msg": f"Updated openai key"},
status=status.HTTP_200_OK,
)

except Exception as error:
print(f"Error: {error}")
return Response(
f"Something went wrong",
logger.error(f"Error: {error}")
return JsonResponse(
{"error": f"Something went wrong"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
)
18 changes: 18 additions & 0 deletions llm/migrations/0011_organization_openai_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.6 on 2024-03-09 05:44

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('llm', '0010_organization_examples_text'),
]

operations = [
migrations.AddField(
model_name='organization',
name='openai_key',
field=models.CharField(max_length=255, null=True, unique=True),
),
]
1 change: 1 addition & 0 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Organization(models.Model):
null=True
) # { "confidence": "Your task is to...", "friendliness": "Your task is to..." }
examples_text = models.TextField(null=True)
openai_key = models.CharField(max_length=255, unique=True, null=True)

class Meta:
db_table = "organization"
Expand Down
2 changes: 2 additions & 0 deletions llm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
For the full list of settings and their values, see
https://docs.djangoproject.com/en/4.2/ref/settings/
"""

import os
from pathlib import Path

Expand Down Expand Up @@ -56,6 +57,7 @@
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
"llm.utils.middleware.CustomMiddleware",
]

ROOT_URLCONF = "llm.urls"
Expand Down
2 changes: 2 additions & 0 deletions llm/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FileUploadView,
set_evaluator_prompt,
set_examples_text,
set_openai_key,
)

urlpatterns = [
Expand All @@ -33,4 +34,5 @@
path("api/system_prompt", set_system_prompt, name="set_system_prompt"),
path("api/evaluator_prompt", set_evaluator_prompt, name="set_evaluator_prompt"),
path("api/examples_text", set_examples_text, name="set_examples_text"),
path("api/openai_key", set_openai_key, name="set_openai_key"),
]
Empty file added llm/utils/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions llm/utils/general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import string, secrets


def generate_session_id(length=6):
alphanumeric = string.ascii_letters + string.digits
return "".join(secrets.choice(alphanumeric) for _ in range(length))
46 changes: 46 additions & 0 deletions llm/utils/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from django.http import Http404
from django.http import JsonResponse
from rest_framework.response import Response
from rest_framework import status
from llm.models import Organization
from logging import basicConfig, INFO, getLogger

basicConfig(level=INFO)
logger = getLogger()


class CustomMiddleware:
@staticmethod
def current_organization(request):
api_key = request.headers.get("Authorization")
if not api_key:
return None

try:
return Organization.objects.get(api_key=api_key)
except Organization.DoesNotExist:
return None

def __init__(self, get_response):
self.get_response = get_response

def __call__(self, request):
# Code to be executed for each request before
# the view (and later middleware) are called.
logger.info("routing request via the middleware")

org = CustomMiddleware.current_organization(request)

if not org:
return JsonResponse(
{"error": "Invalid API key"},
status=status.HTTP_404_NOT_FOUND,
)

request.org = org
response = self.get_response(request)

# Code to be executed for each request/response after
# the view is called.

return response
4 changes: 2 additions & 2 deletions llm/utils.py → llm/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def context_prompt_messages(
historical_chats: list[Message],
) -> list[dict]:
org = Organization.objects.filter(id=organization_id).first()

org_system_prompt = org.system_prompt
examples_text = org.examples_text

system_message_prompt = {"role": "system", "content": org_system_prompt}
human_message_prompt = {
"role": "user",
Expand Down

0 comments on commit 1889832

Please sign in to comment.