|
| 1 | +import pytest |
| 2 | +from openai import OpenAI |
| 3 | +import sqlite3 |
| 4 | +import sqlite_vec |
| 5 | +import struct |
| 6 | +from typing import List |
| 7 | +from unittest.mock import MagicMock, patch |
| 8 | + |
| 9 | + |
| 10 | +def serialize(vector: List[float]) -> bytes: |
| 11 | + """Helper function to serialize a list of floats into bytes""" |
| 12 | + return struct.pack("%sf" % len(vector), *vector) |
| 13 | + |
| 14 | + |
| 15 | +@pytest.fixture |
| 16 | +def mock_db(): |
| 17 | + """Fixture that sets up an in-memory SQLite database with the vector extension""" |
| 18 | + db = sqlite3.connect(":memory:") |
| 19 | + db.enable_load_extension(True) |
| 20 | + sqlite_vec.load(db) |
| 21 | + db.enable_load_extension(False) |
| 22 | + |
| 23 | + # Create tables |
| 24 | + db.execute(""" |
| 25 | + CREATE TABLE sentences( |
| 26 | + id INTEGER PRIMARY KEY, |
| 27 | + sentence TEXT |
| 28 | + ) |
| 29 | + """) |
| 30 | + |
| 31 | + db.execute(""" |
| 32 | + CREATE VIRTUAL TABLE vec_sentences USING vec0( |
| 33 | + id INTEGER PRIMARY KEY, |
| 34 | + sentence_embedding FLOAT[1536] |
| 35 | + ) |
| 36 | + """) |
| 37 | + |
| 38 | + yield db |
| 39 | + db.close() |
| 40 | + |
| 41 | + |
| 42 | +def test_database_setup(mock_db): |
| 43 | + """Test that the database tables are created correctly""" |
| 44 | + # Verify tables exist |
| 45 | + tables = mock_db.execute( |
| 46 | + "SELECT name FROM sqlite_master WHERE type='table' AND name IN ('sentences', 'vec_sentences')" |
| 47 | + ).fetchall() |
| 48 | + |
| 49 | + assert len(tables) == 2 |
| 50 | + assert ('sentences',) in tables |
| 51 | + assert ('vec_sentences',) in tables |
| 52 | + |
| 53 | + |
| 54 | +def test_embedding_storage(mock_db, monkeypatch): |
| 55 | + """Test that embeddings can be stored and retrieved""" |
| 56 | + # Mock the OpenAI client and its response |
| 57 | + mock_embedding = [0.1] * 1536 # Mock embedding vector |
| 58 | + |
| 59 | + # Insert test data |
| 60 | + test_sentence = "This is a test sentence" |
| 61 | + with mock_db: |
| 62 | + mock_db.execute("INSERT INTO sentences(id, sentence) VALUES(?, ?)", [1, test_sentence]) |
| 63 | + |
| 64 | + # Store the embedding directly without making API calls |
| 65 | + mock_db.execute( |
| 66 | + "INSERT INTO vec_sentences(id, sentence_embedding) VALUES(?, ?)", |
| 67 | + [1, serialize(mock_embedding)] |
| 68 | + ) |
| 69 | + |
| 70 | + # Verify data was inserted |
| 71 | + result = mock_db.execute("SELECT id, sentence FROM sentences WHERE id = 1").fetchone() |
| 72 | + assert result is not None |
| 73 | + assert result[1] == test_sentence |
| 74 | + |
| 75 | + # Verify embedding was stored |
| 76 | + vec_result = mock_db.execute("SELECT id FROM vec_sentences WHERE id = 1").fetchone() |
| 77 | + assert vec_result is not None |
| 78 | + assert vec_result[0] == 1 |
| 79 | + |
| 80 | + |
| 81 | +def test_similarity_search(mock_db): |
| 82 | + """Test that similarity search works with mock embeddings""" |
| 83 | + # Insert test data with known embeddings |
| 84 | + test_sentences = [ |
| 85 | + (1, "I love programming"), |
| 86 | + (2, "Programming is fun"), |
| 87 | + (3, "The weather is nice today") |
| 88 | + ] |
| 89 | + |
| 90 | + # Create 1536-dimensional mock embeddings |
| 91 | + def create_mock_embedding(values): |
| 92 | + # Create a 1536-dim vector with the first few values set |
| 93 | + embedding = [0.0] * 1536 |
| 94 | + for i, val in enumerate(values): |
| 95 | + if i < len(embedding): |
| 96 | + embedding[i] = val |
| 97 | + return embedding |
| 98 | + |
| 99 | + # Mock embeddings (first few dimensions set, rest are 0) |
| 100 | + test_embeddings = { |
| 101 | + 1: create_mock_embedding([0.9, 0.1, 0.1]), |
| 102 | + 2: create_mock_embedding([0.8, 0.2, 0.1]), |
| 103 | + 3: create_mock_embedding([0.1, 0.1, 0.9]) |
| 104 | + } |
| 105 | + |
| 106 | + with mock_db: |
| 107 | + # Insert test sentences |
| 108 | + for id, sentence in test_sentences: |
| 109 | + mock_db.execute( |
| 110 | + "INSERT INTO sentences(id, sentence) VALUES(?, ?)", |
| 111 | + [id, sentence] |
| 112 | + ) |
| 113 | + |
| 114 | + # Insert mock embeddings |
| 115 | + for id, embedding in test_embeddings.items(): |
| 116 | + mock_db.execute( |
| 117 | + "INSERT INTO vec_sentences(id, sentence_embedding) VALUES(?, ?)", |
| 118 | + [id, serialize(embedding)] |
| 119 | + ) |
| 120 | + |
| 121 | + # Test similarity search with a query similar to the first two sentences |
| 122 | + # Create a 1536-dim query embedding |
| 123 | + query_embedding = [0.0] * 1536 |
| 124 | + query_embedding[0] = 0.85 |
| 125 | + query_embedding[1] = 0.15 |
| 126 | + query_embedding[2] = 0.1 |
| 127 | + |
| 128 | + results = mock_db.execute( |
| 129 | + """ |
| 130 | + SELECT vec_sentences.id, distance, sentence |
| 131 | + FROM vec_sentences |
| 132 | + LEFT JOIN sentences ON sentences.id = vec_sentences.id |
| 133 | + WHERE sentence_embedding MATCH ? |
| 134 | + AND k = 2 |
| 135 | + ORDER BY distance |
| 136 | + """, |
| 137 | + [serialize(query_embedding)] |
| 138 | + ).fetchall() |
| 139 | + |
| 140 | + # Verify we get the two most similar sentences |
| 141 | + assert len(results) == 2 |
| 142 | + # The first result should be the most similar (smallest distance) |
| 143 | + assert results[0][0] == 1 # ID of first sentence |
| 144 | + assert results[0][2] == "I love programming" |
| 145 | + assert results[1][0] == 2 # ID of second sentence |
| 146 | + assert results[1][2] == "Programming is fun" |
0 commit comments