From 00e255d213d56af968291934bca1016bf682a600 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Thu, 24 Oct 2024 16:06:32 +0200 Subject: [PATCH] fix[chat_bubble]: always place bubble at the end of the sentence (#38) * fix[chat_bubble]: always place bubble at the end of the sentence * fix[chat_bubble]: remove extra print statements * fix[chat_bubble]: refactor code to optimum aglo for finding index * fix(chat_bubble): adding test cases for chat method * fix(chat_bubble): adding test cases for chat method * fix(chat_bubble): adding test cases for chat method --- backend/app/api/v1/chat.py | 9 +- backend/app/utils.py | 27 ++++ backend/tests/api/v1/test_chat.py | 148 +++++++++++++++++- .../utils/test_following_sentence_ending.py | 45 ++++++ backend/tests/utils/test_sentence_endings.py | 39 +++++ 5 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 backend/tests/utils/test_following_sentence_ending.py create mode 100644 backend/tests/utils/test_sentence_endings.py diff --git a/backend/app/api/v1/chat.py b/backend/app/api/v1/chat.py index f6854e0..357ee2a 100644 --- a/backend/app/api/v1/chat.py +++ b/backend/app/api/v1/chat.py @@ -11,7 +11,7 @@ user_repository, ) from app.requests import chat_query -from app.utils import clean_text +from app.utils import clean_text, find_following_sentence_ending, find_sentence_endings from app.vectorstore.chroma import ChromaDB from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -94,6 +94,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d content = response["response"] content_length = len(content) clean_content = clean_text(content) + context_sentence_endings = find_sentence_endings(content) text_references = [] not_exact_matched_refs = [] @@ -121,8 +122,12 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d metadata = doc_metadata[best_match_index] sent = doc_sent[best_match_index] + # find sentence start index of reference in the context index = clean_content.find(clean_text(sentence)) + # Find the following sentence end from the end index + reference_ending_index = find_following_sentence_ending(context_sentence_endings, index + len(sentence)) + if index != -1: text_reference = { "asset_id": metadata["asset_id"], @@ -131,7 +136,7 @@ def chat(project_id: int, chat_request: ChatRequest, db: Session = Depends(get_d "filename": original_filename, "source": [sent], "start": index, - "end": index + len(sentence), + "end": reference_ending_index, } text_references.append(text_reference) else: diff --git a/backend/app/utils.py b/backend/app/utils.py index 32c0454..786d140 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,10 +1,13 @@ import hashlib +from typing import List from urllib.parse import urlparse import uuid import requests import re import string +from bisect import bisect_right + def generate_unique_filename(url, extension=".html"): url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest() @@ -61,3 +64,27 @@ def fetch_html_and_save(url, file_path): # Save the content to a file with open(file_path, "wb") as file: file.write(response.content) + + +def find_sentence_endings(text: str) -> List[int]: + # Regex to find periods, exclamation marks, and question marks followed by a space or the end of the text + sentence_endings = [match.end() for match in re.finditer(r'[.!?](?:\s|$)', text)] + + # Add the last index of the text as an additional sentence ending + sentence_endings.append(len(text)) + + return sentence_endings + +def find_following_sentence_ending(sentence_endings: List[int], index: int) -> int: + """ + Find the closest sentence ending that follows the given index. + + Args: + sentence_endings: Sorted list of sentence ending positions + index: Current position in text + + Returns: + Next sentence ending position or original index if none found + """ + pos = bisect_right(sentence_endings, index) + return sentence_endings[pos] if pos < len(sentence_endings) else index diff --git a/backend/tests/api/v1/test_chat.py b/backend/tests/api/v1/test_chat.py index 6a524b8..6437115 100644 --- a/backend/tests/api/v1/test_chat.py +++ b/backend/tests/api/v1/test_chat.py @@ -15,7 +15,7 @@ def mock_db(): @pytest.fixture def mock_vectorstore(): - with patch("app.vectorstore.chroma.ChromaDB") as mock: + with patch("app.api.v1.chat.ChromaDB") as mock: yield mock @@ -24,15 +24,38 @@ def mock_chat_query(): with patch("app.api.v1.chat.chat_query") as mock: yield mock + +@pytest.fixture +def mock_user_repository(): + with patch("app.api.v1.chat.user_repository.get_user_api_key") as mock: + yield mock + + +@pytest.fixture +def mock_get_users_repository(): + with patch("app.api.v1.chat.user_repository.get_users") as mock: + yield mock + +@pytest.fixture +def mock_conversation_repository(): + with patch("app.api.v1.chat.conversation_repository") as mock: + yield mock + +@pytest.fixture +def mock_get_assets_filename(): + with patch("app.api.v1.chat.project_repository.get_assets_filename") as mock: + yield mock + + def test_chat_status_endpoint(mock_db): # Arrange project_id = 1 - + with patch("app.repositories.project_repository.get_assets_without_content", return_value=[]): with patch("app.repositories.project_repository.get_assets_content_pending", return_value=[]): # Act response = client.get(f"/v1/chat/project/{project_id}/status") - + # Assert assert response.status_code == 200 assert response.json()["status"] == "success" @@ -281,3 +304,122 @@ def test_group_by_start_end_large_values(): assert len(result) == 1 assert result[0]["start"] == 1000000 and result[0]["end"] == 1000010 assert len(result[0]["references"]) == 2 + + +def test_chat_endpoint_success(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename): + # Arrange + project_id = 1 + chat_request = { + "query": "Tell me about sustainability.", + "conversation_id": None + } + + # Mocking dependencies + mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1", "Quote 2"], [1, 2], {}) + mock_user_repository.return_value.key = "test_api_key" + mock_chat_query.return_value = {"response": "Here's a response", "references": []} + mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=123) + mock_get_users_repository.return_value = MagicMock(id=1) + mock_get_assets_filename.return_value = ["file1.pdf", "file2.pdf"] + + # Act + response = client.post(f"/v1/chat/project/{project_id}", json=chat_request) + + # Assert + assert response.status_code == 200 + assert response.json()["status"] == "success" + assert response.json()["data"]["conversation_id"] == "123" + assert response.json()["data"]["response"] == "Here's a response" + +def test_chat_endpoint_creates_conversation(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename): + # Arrange + project_id = 1 + chat_request = { + "query": "What's the latest on climate change?", + "conversation_id": None + } + + # Set up mock responses + mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote 1"], [1], {}) + mock_user_repository.return_value.key = "test_api_key" + mock_chat_query.return_value = {"response": "Latest news on climate change", "references": []} + + # Explicitly set the mock to return 456 as the conversation ID + mock_conversation_repository.create_new_conversation.return_value = MagicMock(id=456) + mock_get_users_repository.return_value = MagicMock(id=1) + mock_get_assets_filename.return_value = ["file1.pdf"] + + # Act + response = client.post(f"/v1/chat/project/{project_id}", json=chat_request) + + # Assert + assert response.status_code == 200 + assert response.json()["data"]["conversation_id"] == "456" + assert mock_conversation_repository.create_new_conversation.called + +def test_chat_endpoint_error_handling(mock_db, mock_vectorstore, mock_chat_query): + # Arrange + project_id = 1 + chat_request = { + "query": "An error should occur.", + "conversation_id": None + } + + mock_vectorstore.return_value.get_relevant_segments.side_effect = Exception("Database error") + + # Act + response = client.post(f"/v1/chat/project/{project_id}", json=chat_request) + + # Assert + assert response.status_code == 400 + assert "Unable to process the chat query" in response.json()["detail"] + +def test_chat_endpoint_reference_processing(mock_db, mock_get_users_repository, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename): + # Arrange + project_id = 1 + chat_request = { + "query": "Reference query.", + "conversation_id": None + } + + mock_vectorstore.return_value.get_relevant_segments.return_value = (["Reference Quote"], [1], [{"asset_id":1, "project_id": project_id,"filename": "test.pdf","page_number": 1}]) + mock_user_repository.return_value.key = "test_api_key" + mock_chat_query.return_value = { + "response": "Response with references", + "references": [ + { + "sentence": "Reference Quote", + "references": [{"file": "file1.pdf", "sentence": "Original sentence"}] + } + ] + } + mock_conversation_repository.create_new_conversation.return_value.id = 789 + mock_get_users_repository.return_value = MagicMock(id=1) + mock_get_assets_filename.return_value = ["file1.pdf"] + + # Act + response = client.post(f"/v1/chat/project/{project_id}", json=chat_request) + + # Assert + assert response.status_code == 200 + assert len(response.json()["data"]["response_references"]) > 0 + +def test_chat_endpoint_with_conversation_id(mock_db, mock_vectorstore, mock_chat_query, mock_user_repository, mock_conversation_repository, mock_get_assets_filename): + # Arrange + project_id = 1 + chat_request = { + "query": "Chat with conversation.", + "conversation_id": "existing_convo_id" + } + + mock_vectorstore.return_value.get_relevant_segments.return_value = (["Quote"], [1], {}) + mock_user_repository.return_value.key = "test_api_key" + mock_chat_query.return_value = {"response": "Response with existing conversation", "references": []} + mock_get_assets_filename.return_value = ["file1.pdf"] + + # Act + response = client.post(f"/v1/chat/project/{project_id}", json=chat_request) + + # Assert + assert response.status_code == 200 + assert response.json()["data"]["conversation_id"] == "existing_convo_id" diff --git a/backend/tests/utils/test_following_sentence_ending.py b/backend/tests/utils/test_following_sentence_ending.py new file mode 100644 index 0000000..928ba20 --- /dev/null +++ b/backend/tests/utils/test_following_sentence_ending.py @@ -0,0 +1,45 @@ +import unittest + +from app.utils import find_following_sentence_ending + + +class TestFindFollowingSentenceEnding(unittest.TestCase): + def test_basic_case(self): + sentence_endings = [10, 20, 30, 40] + index = 15 + expected = 20 # Closest ending greater than 15 is 20 + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + + def test_no_greater_ending(self): + sentence_endings = [10, 20, 30] + index = 35 + expected = 35 # No greater ending than 35, so it returns the index itself + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + + def test_at_ending_boundary(self): + sentence_endings = [10, 20, 30, 40] + index = 30 + expected = 40 # The next greater ending after 30 is 40 + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + + def test_first_sentence(self): + sentence_endings = [10, 20, 30, 40] + index = 5 + expected = 10 # The closest ending greater than 5 is 10 + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + + def test_empty_sentence_endings(self): + sentence_endings = [] + index = 5 + expected = 5 # No sentence endings, so return the index itself + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + + def test_same_index_as_last_ending(self): + sentence_endings = [10, 20, 30] + index = 30 + expected = 30 # At the last sentence ending, return the index itself + self.assertEqual(find_following_sentence_ending(sentence_endings, index), expected) + +# Run the tests +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/utils/test_sentence_endings.py b/backend/tests/utils/test_sentence_endings.py new file mode 100644 index 0000000..a793d39 --- /dev/null +++ b/backend/tests/utils/test_sentence_endings.py @@ -0,0 +1,39 @@ +import unittest + +from app.utils import find_sentence_endings + + +class TestFindSentenceEndings(unittest.TestCase): + def test_basic_sentences(self): + text = "This is a sentence. This is another one!" + expected = [20, 40, len(text)] # Sentence endings at ".", "!", and the last index + self.assertEqual(find_sentence_endings(text), expected) + + def test_text_without_punctuation(self): + text = "This is a sentence without punctuation" + expected = [len(text)] # Only the last index is expected + self.assertEqual(find_sentence_endings(text), expected) + + def test_multiple_punctuation(self): + text = "Is this working? Yes! It seems so." + expected = [17, 22, 34, len(text)] # Endings after "?", "!", ".", and the last index + self.assertEqual(find_sentence_endings(text), expected) + + def test_trailing_whitespace(self): + text = "Trailing whitespace should be ignored. " + expected = [39, len(text)] # End at the period and the final index + self.assertEqual(find_sentence_endings(text), expected) + + def test_punctuation_in_middle_of_text(self): + text = "Sentence. Followed by an abbreviation e.g. and another sentence." + expected = [10, 43, 64, len(text)] # Endings after ".", abbreviation ".", and sentence "." + self.assertEqual(find_sentence_endings(text), expected) + + def test_empty_string(self): + text = "" + expected = [0] # Empty string should only have the 0th index as an "ending" + self.assertEqual(find_sentence_endings(text), expected) + +# Run the tests +if __name__ == "__main__": + unittest.main()