|
7 | 7 | from rest_framework.decorators import api_view
|
8 | 8 | from rest_framework.response import Response
|
9 | 9 |
|
| 10 | +from llm.chains.embeddings import get_pgvector_idx |
| 11 | + |
10 | 12 | from .chains.functions import detect_languages_chain
|
11 | 13 | from .chains.chat import run_chat_chain
|
12 |
| -from .chains import embeddings |
13 | 14 | from .data import loader
|
14 | 15 | from .models import Organization
|
15 | 16 |
|
| 17 | +from django.http import JsonResponse |
| 18 | +from rest_framework.parsers import MultiPartParser |
| 19 | +from rest_framework.views import APIView |
| 20 | +from pypdf import PdfReader |
| 21 | +from llm.models import Embedding |
| 22 | + |
| 23 | +import openai |
16 | 24 |
|
17 | 25 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "llm.settings")
|
18 | 26 |
|
|
21 | 29 |
|
22 | 30 | @api_view(["POST"])
|
23 | 31 | def create_chat(request):
|
24 |
| - prompt = request.data.get("prompt").strip() |
25 |
| - session_id = (request.data.get("session_id") or generate_short_id()).strip() |
26 |
| - |
27 |
| - language_detector = detect_languages_chain() |
28 |
| - languages = language_detector.run(prompt) |
29 |
| - |
30 |
| - print(f"Language detector chain result: {languages}") |
| 32 | + try: |
| 33 | + organization = current_organization(request) |
| 34 | + if not organization: |
| 35 | + return Response( |
| 36 | + f"Invalid API key", |
| 37 | + status=status.HTTP_404_NOT_FOUND, |
| 38 | + ) |
| 39 | + |
| 40 | + prompt = request.data.get("prompt").strip() |
| 41 | + gpt_model = request.data.get("gpt_model", "gpt-3.5-turbo").strip() |
| 42 | + session_id = (request.data.get("session_id") or generate_short_id()).strip() |
| 43 | + |
| 44 | + language_detector = detect_languages_chain(gpt_model) |
| 45 | + languages = language_detector.run(prompt) |
| 46 | + |
| 47 | + print(f"Language detector chain result: {languages}") |
| 48 | + |
| 49 | + primary_language = languages["primary_detected_language"] |
| 50 | + english_translation_prompt = languages["translation_to_english"] |
| 51 | + |
| 52 | + response = run_chat_chain( |
| 53 | + prompt=prompt, |
| 54 | + session_id=session_id, |
| 55 | + primary_language=primary_language, |
| 56 | + english_translation_prompt=english_translation_prompt, |
| 57 | + organization_id=organization.id, |
| 58 | + gpt_model=gpt_model, |
| 59 | + ) |
31 | 60 |
|
32 |
| - primary_language = languages["primary_detected_language"] |
33 |
| - english_translation_prompt = languages["translation_to_english"] |
| 61 | + print(f"Chat chain result: {response}") |
34 | 62 |
|
35 |
| - response = run_chat_chain( |
36 |
| - prompt=prompt, |
37 |
| - session_id=session_id, |
38 |
| - primary_language=primary_language, |
39 |
| - english_translation_prompt=english_translation_prompt, |
40 |
| - ) |
| 63 | + del response["source_documents"] |
41 | 64 |
|
42 |
| - print(f"Chat chain result: {response}") |
| 65 | + return Response( |
| 66 | + { |
| 67 | + "answer": response["result"], |
| 68 | + "chat_history": response["chat_history"], |
| 69 | + "session_id": session_id, |
| 70 | + }, |
| 71 | + status=status.HTTP_201_CREATED, |
| 72 | + ) |
| 73 | + except Exception as error: |
| 74 | + print(f"Error: {error}") |
| 75 | + return Response( |
| 76 | + f"Something went wrong", |
| 77 | + status=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 78 | + ) |
43 | 79 |
|
44 |
| - # Remove source documents from response since we have verbose logging setup |
45 |
| - del response["source_documents"] |
46 | 80 |
|
47 |
| - return Response( |
48 |
| - { |
49 |
| - "answer": response["result"], |
50 |
| - "chat_history": response["chat_history"], |
51 |
| - "session_id": session_id, |
52 |
| - }, |
53 |
| - status=status.HTTP_201_CREATED, |
54 |
| - ) |
| 81 | +class FileUploadView(APIView): |
| 82 | + parser_classes = (MultiPartParser,) |
| 83 | + |
| 84 | + def post(self, request, format=None): |
| 85 | + try: |
| 86 | + org = current_organization(request) |
| 87 | + if not org: |
| 88 | + return Response( |
| 89 | + f"Invalid API key", |
| 90 | + status=status.HTTP_404_NOT_FOUND, |
| 91 | + ) |
| 92 | + |
| 93 | + if "file" not in request.data: |
| 94 | + raise ValueError("Empty content") |
| 95 | + |
| 96 | + file = request.data["file"] |
| 97 | + |
| 98 | + pdf_reader = PdfReader(file) |
| 99 | + for page in pdf_reader.pages: |
| 100 | + page_text = page.extract_text().replace("\n", " ") |
| 101 | + |
| 102 | + response = openai.Embedding.create( |
| 103 | + model="text-embedding-ada-002", input=page_text |
| 104 | + ) |
| 105 | + |
| 106 | + embeddings = response["data"][0]["embedding"] |
| 107 | + if len(embeddings) != 1536: |
| 108 | + raise ValueError(f"Invalid embedding length: #{len(embeddings)}") |
| 109 | + |
| 110 | + # TODO: uncomment after move away from langchain to simplify code |
| 111 | + # Embedding.objects.create( |
| 112 | + # source_name=file.name, |
| 113 | + # original_text=page_text, |
| 114 | + # text_vectors=embeddings, |
| 115 | + # organization=org, |
| 116 | + # ) |
| 117 | + |
| 118 | + pgvector_idx = get_pgvector_idx() |
| 119 | + pgvector_idx.add_texts( |
| 120 | + texts=[page_text], metadatas=[{"organization_id": org.id}] |
| 121 | + ) |
| 122 | + |
| 123 | + return JsonResponse({"status": "file upload successful"}) |
| 124 | + except ValueError as error: |
| 125 | + return Response( |
| 126 | + f"Invalid file: {error}", |
| 127 | + status=status.HTTP_400_BAD_REQUEST, |
| 128 | + ) |
| 129 | + except Exception as error: |
| 130 | + print(f"Error: {error}") |
| 131 | + return Response( |
| 132 | + f"Something went wrong", |
| 133 | + status=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 134 | + ) |
55 | 135 |
|
56 | 136 |
|
57 | 137 | @api_view(["POST"])
|
58 |
| -def create_embeddings(_): |
59 |
| - chunks = loader.load_pdfs() |
60 |
| - embeddings.create_embeddings(chunks=chunks, index_name="embeddings") |
61 |
| - |
62 |
| - return Response( |
63 |
| - f"Created embeddings index", |
64 |
| - status=status.HTTP_201_CREATED, |
65 |
| - ) |
| 138 | +def set_system_prompt(request): |
| 139 | + try: |
| 140 | + system_prompt = request.data.get("system_prompt").strip() |
| 141 | + org = current_organization(request) |
| 142 | + if not org: |
| 143 | + return Response( |
| 144 | + f"Invalid API key", |
| 145 | + status=status.HTTP_404_NOT_FOUND, |
| 146 | + ) |
66 | 147 |
|
| 148 | + Organization.objects.filter(id=org.id).update(system_prompt=system_prompt) |
67 | 149 |
|
68 |
| -@api_view(["POST"]) |
69 |
| -def set_system_prompt(request): |
70 |
| - system_prompt = request.data.get("system_prompt").strip() |
71 |
| - org = current_org(request) |
72 |
| - if not org: |
73 | 150 | return Response(
|
74 |
| - f"Invalid API key", |
75 |
| - status=status.HTTP_404_NOT_FOUND, |
| 151 | + f"Updated System Prompt", |
| 152 | + status=status.HTTP_201_CREATED, |
76 | 153 | )
|
77 |
| - |
78 |
| - try: |
79 |
| - Organization.objects.filter(id=org.id).update(system_prompt=system_prompt) |
80 | 154 | except Exception as error:
|
81 | 155 | print(f"Error: {error}")
|
82 | 156 | return Response(
|
83 | 157 | f"Something went wrong",
|
84 | 158 | status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
85 | 159 | )
|
86 | 160 |
|
87 |
| - return Response( |
88 |
| - f"Updated System Prompt", |
89 |
| - status=status.HTTP_201_CREATED, |
90 |
| - ) |
91 |
| - |
92 | 161 |
|
93 |
| -def current_org(request): |
| 162 | +def current_organization(request): |
94 | 163 | api_key = request.headers.get("Authorization")
|
95 | 164 | if not api_key:
|
96 | 165 | return None
|
|
0 commit comments