Skip to content

Commit

Permalink
fix[chat_bubble]: always place bubble at the end of the sentence (#38)
Browse files Browse the repository at this point in the history
* fix[chat_bubble]: always place bubble at the end of the sentence

* fix[chat_bubble]: remove extra print statements

* fix[chat_bubble]: refactor code to optimum aglo for finding index

* fix(chat_bubble): adding test cases for chat method

* fix(chat_bubble): adding test cases for chat method

* fix(chat_bubble): adding test cases for chat method
  • Loading branch information
ArslanSaleem authored and gventuri committed Oct 24, 2024
1 parent 44e1e1f commit 00e255d
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 5 deletions.
9 changes: 7 additions & 2 deletions backend/app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
user_repository,
)
from app.requests import chat_query
from app.utils import clean_text
from app.utils import clean_text, find_following_sentence_ending, find_sentence_endings
from app.vectorstore.chroma import ChromaDB
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
Expand Down Expand Up @@ -94,6 +94,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
content = response["response"]
content_length = len(content)
clean_content = clean_text(content)
context_sentence_endings = find_sentence_endings(content)
text_references = []
not_exact_matched_refs = []

Expand Down Expand Up @@ -121,8 +122,12 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
metadata = doc_metadata[best_match_index]
sent = doc_sent[best_match_index]

# find sentence start index of reference in the context
index = clean_content.find(clean_text(sentence))

# Find the following sentence end from the end index
reference_ending_index = find_following_sentence_ending(context_sentence_endings, index + len(sentence))

if index != -1:
text_reference = {
"asset_id": metadata["asset_id"],
Expand All @@ -131,7 +136,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d
"filename": original_filename,
"source": [sent],
"start": index,
"end": index + len(sentence),
"end": reference_ending_index,
}
text_references.append(text_reference)
else:
Expand Down
27 changes: 27 additions & 0 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import hashlib
from typing import List
from urllib.parse import urlparse
import uuid
import requests
import re
import string

from bisect import bisect_right


def generate_unique_filename(url, extension=".html"):
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()
Expand Down Expand Up @@ -61,3 +64,27 @@ def fetch_html_and_save(url, file_path):
# Save the content to a file
with open(file_path, "wb") as file:
file.write(response.content)


def find_sentence_endings(text: str) -> List[int]:
# Regex to find periods, exclamation marks, and question marks followed by a space or the end of the text
sentence_endings = [match.end() for match in re.finditer(r'[.!?](?:\s|$)', text)]

# Add the last index of the text as an additional sentence ending
sentence_endings.append(len(text))

return sentence_endings

def find_following_sentence_ending(sentence_endings: List[int], index: int) -> int:
"""
Find the closest sentence ending that follows the given index.
Args:
sentence_endings: Sorted list of sentence ending positions
index: Current position in text
Returns:
Next sentence ending position or original index if none found
"""
pos = bisect_right(sentence_endings, index)
return sentence_endings[pos] if pos < len(sentence_endings) else index
148 changes: 145 additions & 3 deletions backend/tests/api/v1/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def mock_db():

@pytest.fixture
def mock_vectorstore():
with patch("app.vectorstore.chroma.ChromaDB") as mock:
with patch("app.api.v1.chat.ChromaDB") as mock:
yield mock


Expand All @@ -24,15 +24,38 @@ def mock_chat_query():
with patch("app.api.v1.chat.chat_query") as mock:
yield mock


@pytest.fixture
def mock_user_repository():
with patch("app.api.v1.chat.user_repository.get_user_api_key") as mock:
yield mock


@pytest.fixture
def mock_get_users_repository():
with patch("app.api.v1.chat.user_repository.get_users") as mock:
yield mock

@pytest.fixture
def mock_conversation_repository():
with patch("app.api.v1.chat.conversation_repository") as mock:
yield mock

@pytest.fixture
def mock_get_assets_filename():
with patch("app.api.v1.chat.project_repository.get_assets_filename") as mock:
yield mock


def test_chat_status_endpoint(mock_db):
# Arrange
project_id = 1

with patch("app.repositories.project_repository.get_assets_without_content", return_value=[]):
with patch("app.repositories.project_repository.get_assets_content_pending", return_value=[]):
# Act
response = client.get(f"/v1/chat/project/{project_id}/status")

# Assert
assert response.status_code == 200
assert response.json()["status"] == "success"
Expand Down Expand Up @@ -281,3 +304,122 @@ def test_group_by_start_end_large_values():
assert len(result) == 1
assert result[0]["start"] == 1000000 and result[0]["end"] == 1000010
assert len(result[0]["references"]) == 2


def test_chat_endpoint_success(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Tell me about sustainability.",
"conversation_id": None
}

# Mocking dependencies
mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1", "Quote 2"], [1, 2], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Here's a response", "references": []}
mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=123)
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf", "file2.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["status"] == "success"
assert response.json()["data"]["conversation_id"] == "123"
assert response.json()["data"]["response"] == "Here's a response"

def test_chat_endpoint_creates_conversation(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "What's the latest on climate change?",
"conversation_id": None
}

# Set up mock responses
mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1"], [1], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Latest news on climate change", "references": []}

# Explicitly set the mock to return 456 as the conversation ID
mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=456)
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["data"]["conversation_id"] == "456"
assert mock_conversation_repository.create_new_conversation.called

def test_chat_endpoint_error_handling(mock_db, mock_vectorstore, mock_chat_query):
# Arrange
project_id = 1
chat_request = {
"query": "An error should occur.",
"conversation_id": None
}

mock_vectorstore.return_value.get_relevant_segments.side_effect = Exception("Database error")

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 400
assert "Unable to process the chat query" in response.json()["detail"]

def test_chat_endpoint_reference_processing(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Reference query.",
"conversation_id": None
}

mock_vectorstore.return_value.get_relevant_segments.return_value = (["Reference Quote"], [1], [{"asset_id":1, "project_id": project_id,"filename": "test.pdf","page_number": 1}])
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {
"response": "Response with references",
"references": [
{
"sentence": "Reference Quote",
"references": [{"file": "file1.pdf", "sentence": "Original sentence"}]
}
]
}
mock_conversation_repository.create_new_conversation.return_value.id = 789
mock_get_users_repository.return_value = MagicMock(id=1)
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert len(response.json()["data"]["response_references"]) > 0

def test_chat_endpoint_with_conversation_id(mock_db, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename):
# Arrange
project_id = 1
chat_request = {
"query": "Chat with conversation.",
"conversation_id": "existing_convo_id"
}

mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote"], [1], {})
mock_user_repository.return_value.key = "test_api_key"
mock_chat_query.return_value = {"response": "Response with existing conversation", "references": []}
mock_get_assets_filename.return_value = ["file1.pdf"]

# Act
response = client.post(f"/v1/chat/project/{project_id}", json=chat_request)

# Assert
assert response.status_code == 200
assert response.json()["data"]["conversation_id"] == "existing_convo_id"
45 changes: 45 additions & 0 deletions backend/tests/utils/test_following_sentence_ending.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest

from app.utils import find_following_sentence_ending


class TestFindFollowingSentenceEnding(unittest.TestCase):
def test_basic_case(self):
sentence_endings = [10, 20, 30, 40]
index = 15
expected = 20 # Closest ending greater than 15 is 20
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_no_greater_ending(self):
sentence_endings = [10, 20, 30]
index = 35
expected = 35 # No greater ending than 35, so it returns the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_at_ending_boundary(self):
sentence_endings = [10, 20, 30, 40]
index = 30
expected = 40 # The next greater ending after 30 is 40
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_first_sentence(self):
sentence_endings = [10, 20, 30, 40]
index = 5
expected = 10 # The closest ending greater than 5 is 10
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_empty_sentence_endings(self):
sentence_endings = []
index = 5
expected = 5 # No sentence endings, so return the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

def test_same_index_as_last_ending(self):
sentence_endings = [10, 20, 30]
index = 30
expected = 30 # At the last sentence ending, return the index itself
self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected)

# Run the tests
if __name__ == "__main__":
unittest.main()
39 changes: 39 additions & 0 deletions backend/tests/utils/test_sentence_endings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest

from app.utils import find_sentence_endings


class TestFindSentenceEndings(unittest.TestCase):
def test_basic_sentences(self):
text = "This is a sentence. This is another one!"
expected = [20, 40, len(text)] # Sentence endings at ".", "!", and the last index
self.assertEqual(find_sentence_endings(text), expected)

def test_text_without_punctuation(self):
text = "This is a sentence without punctuation"
expected = [len(text)] # Only the last index is expected
self.assertEqual(find_sentence_endings(text), expected)

def test_multiple_punctuation(self):
text = "Is this working? Yes! It seems so."
expected = [17, 22, 34, len(text)] # Endings after "?", "!", ".", and the last index
self.assertEqual(find_sentence_endings(text), expected)

def test_trailing_whitespace(self):
text = "Trailing whitespace should be ignored. "
expected = [39, len(text)] # End at the period and the final index
self.assertEqual(find_sentence_endings(text), expected)

def test_punctuation_in_middle_of_text(self):
text = "Sentence. Followed by an abbreviation e.g. and another sentence."
expected = [10, 43, 64, len(text)] # Endings after ".", abbreviation ".", and sentence "."
self.assertEqual(find_sentence_endings(text), expected)

def test_empty_string(self):
text = ""
expected = [0] # Empty string should only have the 0th index as an "ending"
self.assertEqual(find_sentence_endings(text), expected)

# Run the tests
if __name__ == "__main__":
unittest.main()

0 comments on commit 00e255d

Please sign in to comment.