Skip to content

Commit 0dbeeb6

Browse files
committed
conversation history
1 parent 61ba579 commit 0dbeeb6

11 files changed

+149
-29
lines changed

.env.example

+7
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1+
# Copy this file to .env and replace the values with your own
12
OPENAI_API_KEY=XXXXXXXXXX
3+
DATABASE_URL=postgres://llm_agent:llm_password@localhost:5432/llm_db
4+
DB_NAME=llm_db
5+
DB_USER=llm_agent
6+
DB_PASSWORD=llm_password
7+
DB_HOST=localhost
8+
DB_PORT=5432

Pipfile

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ tiktoken = "*"
1111
djangorestframework = "*"
1212
pypdf = "*"
1313
openai = "*"
14+
python-dotenv = "*"
15+
psycopg = "*"
1416

1517
[dev-packages]
1618

Pipfile.lock

+40-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

llm/api.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
import django
3+
import secrets
4+
import string
5+
36
from rest_framework import status
47
from rest_framework.decorators import api_view
58
from rest_framework.response import Response
@@ -16,12 +19,16 @@
1619

1720
@api_view(["POST"])
1821
def create_chat(request):
19-
question = request.data.get("question").strip()
22+
prompt = request.data.get("prompt").strip()
23+
session_id = (request.data.get("session_id") or generate_short_id()).strip()
2024

21-
answer = run_chat_chain(question)
25+
answer = run_chat_chain(prompt=prompt, session_id=session_id)
2226

2327
return Response(
24-
{"answer": answer},
28+
{
29+
"answer": answer,
30+
"session_id": session_id,
31+
},
2532
status=status.HTTP_201_CREATED,
2633
)
2734

@@ -36,4 +43,9 @@ def create_embeddings(request):
3643
return Response(
3744
f"Created {document_file_name} index",
3845
status=status.HTTP_201_CREATED,
39-
)
46+
)
47+
48+
49+
def generate_short_id(length=6):
50+
alphanumeric = string.ascii_letters + string.digits
51+
return ''.join(secrets.choice(alphanumeric) for _ in range(length))

llm/apps.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from django.apps import AppConfig
2+
3+
4+
class LlmConfig(AppConfig):
5+
default_auto_field = "django.db.models.BigAutoField"
6+
name = "llm"

llm/chains/chat.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
from llm.chains import config, embeddings
1+
from llm.chains import config, embeddings, memory
22

33
from langchain.chat_models import ChatOpenAI
44
from langchain.chains import RetrievalQA
55

66

7-
def run_chat_chain(prompt: str):
7+
def run_chat_chain(prompt: str, session_id: str):
88
llm = get_llm()
99
retriever = embeddings.get_retriever()
10-
qa_chain = RetrievalQA.from_chain_type(
11-
llm=llm, retriever=retriever, chain_type="stuff", return_source_documents=True
10+
chain_memory = memory.conversation_history(session_id=session_id)
11+
chain = RetrievalQA.from_chain_type(
12+
llm=llm,
13+
retriever=retriever,
14+
chain_type="stuff",
15+
return_source_documents=True,
16+
memory=chain_memory,
1217
)
13-
result = qa_chain({"query": prompt})
14-
15-
print(result)
18+
result = chain({"query": prompt})
1619

1720
return result
1821

1922

2023
def get_llm():
2124
llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k", temperature=config.BOT_TEMPATURE)
22-
return llm
25+
return llm

llm/chains/memory.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import os
2+
3+
from langchain.memory import PostgresChatMessageHistory, ConversationBufferMemory
4+
5+
6+
def conversation_history(session_id: str):
7+
history = PostgresChatMessageHistory(
8+
connection_string=os.getenv("DATABASE_URL"),
9+
session_id=session_id,
10+
)
11+
memory = ConversationBufferMemory(
12+
memory_key="chat_history",
13+
return_messages=True,
14+
input_key="query",
15+
output_key="result",
16+
chat_memory=history
17+
)
18+
return memory
19+

llm/migrations/0001_initial.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Generated by Django 4.2.4 on 2023-08-30 20:46
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
initial = True
9+
10+
dependencies = [
11+
]
12+
13+
operations = [
14+
migrations.CreateModel(
15+
name='MessageStore',
16+
fields=[
17+
('id', models.AutoField(primary_key=True, serialize=False)),
18+
('session_id', models.TextField()),
19+
('message', models.JSONField()),
20+
],
21+
options={
22+
'db_table': 'message_store',
23+
},
24+
),
25+
]

llm/migrations/__init__.py

Whitespace-only changes.

llm/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from django.db import models
2+
3+
class MessageStore(models.Model):
4+
"""
5+
See Langchain source code for the table schema we match here: https://github.com/langchain-ai/langchain/blob/7fa82900cb15d9c41099ad7dbb8aaa66941f6905/libs/langchain/langchain/memory/chat_message_histories/postgres.py#L39-L42
6+
"""
7+
id = models.AutoField(primary_key=True)
8+
session_id = models.TextField()
9+
message = models.JSONField()
10+
11+
class Meta:
12+
db_table = "message_store"

llm/settings.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
For the full list of settings and their values, see
1010
https://docs.djangoproject.com/en/4.2/ref/settings/
1111
"""
12-
12+
import os
1313
from pathlib import Path
1414

15+
from dotenv import load_dotenv
16+
load_dotenv()
17+
1518
# Build paths inside the project like this: BASE_DIR / 'subdir'.
1619
BASE_DIR = Path(__file__).resolve().parent.parent
1720

@@ -31,6 +34,7 @@
3134
# Application definition
3235

3336
INSTALLED_APPS = [
37+
'llm.apps.LlmConfig',
3438
'django.contrib.admin',
3539
'django.contrib.auth',
3640
'django.contrib.contenttypes',
@@ -75,8 +79,12 @@
7579

7680
DATABASES = {
7781
'default': {
78-
'ENGINE': 'django.db.backends.sqlite3',
79-
'NAME': BASE_DIR / 'db.sqlite3',
82+
'ENGINE': 'django.db.backends.postgresql_psycopg2',
83+
'NAME': os.getenv("DB_NAME"),
84+
'USER': os.getenv("DB_USER"),
85+
'PASSWORD': os.getenv("DB_PASSWORD"),
86+
'HOST': os.getenv("DB_HOST"),
87+
'PORT': os.getenv("DB_PORT"),
8088
}
8189
}
8290

0 commit comments

Comments
 (0)