From 965dd93c933abdc8f01c438dad33eaefc33b34b3 Mon Sep 17 00:00:00 2001 From: crvernon Date: Wed, 11 Sep 2024 11:36:14 -0400 Subject: [PATCH] additional utils tests --- tests/test_utils.py | 52 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index be913a2..3af5c80 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,9 @@ import unittest +from unittest.mock import patch +from io import BytesIO import tiktoken +from pypdf import PdfWriter, PdfReader import highlight as hlt @@ -23,5 +26,50 @@ def test_get_token_count_empty_text(self): self.assertEqual(hlt.get_token_count(text), expected_token_count) -if __name__ == "__main__": - unittest.main() +class TestReadPdf(unittest.TestCase): + def test_read_pdf_without_reference_indicator(self): + # Create a sample PDF file using pypdf + buffer = BytesIO() + + # Create a PDF writer + writer = PdfWriter() + + # Add a blank page + writer.add_blank_page(width=612, height=792) + + # Write the PDF to the buffer + writer.write(buffer) + + # Reset buffer position to the beginning + buffer.seek(0) + + # Mock the PdfReader's extract_text method to return "Hello World" + with patch.object(PdfReader, 'pages', new_callable=unittest.mock.PropertyMock) as mock_pages: + mock_page = unittest.mock.Mock() + mock_page.extract_text.return_value = "Hello World" + mock_pages.return_value = [mock_page] + + # Test with a PDF without the reference indicator + result = hlt.read_pdf(buffer, reference_indicator="References\n") + + # Validate the content and structure + self.assertEqual(result["n_pages"], 1) + self.assertIn("Hello World", result["content"]) + self.assertGreater(result["n_characters"], 0) + self.assertGreater(result["n_words"], 0) + self.assertGreater(result["n_tokens"], 0) + + +class TestReadText(unittest.TestCase): + def test_read_text(self): + # Simulate a text file using BytesIO + sample_text = "Hello World!\nThis is a test file." + text_file = BytesIO(sample_text.encode('utf-8')) + + result = hlt.read_text(text_file) + + # Assertions to validate the output + self.assertEqual(result["content"], sample_text) + self.assertEqual(result["n_pages"], 1) + self.assertEqual(result["n_characters"], len(sample_text)) + self.assertEqual(result["n_words"], 7)