Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,58 @@ static double distance_l1_f32(const void *a, const void *b, const void *d) {
return l1_f32(a, b, d);
}

// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34
static u8 hamdist_table[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};

static f32 distance_cosine_bit_u64(u64 *a, u64 *b, size_t n) {
f32 dot = 0;
f32 aMag = 0;
f32 bMag = 0;

for (size_t i = 0; i < n; i++) {
dot += __builtin_popcountl(a[i] & b[i]);
aMag += __builtin_popcountl(a[i]);
bMag += __builtin_popcountl(b[i]);
}

return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
}

static f32 distance_cosine_bit_u8(u8 *a, u8 *b, size_t n) {
f32 dot = 0;
f32 aMag = 0;
f32 bMag = 0;

for (size_t i = 0; i < n; i++) {
dot += hamdist_table[a[i] & b[i]];
aMag += hamdist_table[a[i]];
bMag += hamdist_table[b[i]];
}

return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
}

static f32 distance_cosine_bit(const void *pA, const void *pB,
const void *pD) {
size_t dim = *((size_t *)pD);

if ((dim % 64) == 0) {
return distance_cosine_bit_u64((u64 *)pA, (u64 *)pB, dim / 8 / CHAR_BIT);
}
return distance_cosine_bit_u8((u8 *)pA, (u8 *)pB, dim / CHAR_BIT);
}

static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v,
const void *qty_ptr) {
f32 *pVect1 = (f32 *)pVect1v;
Expand Down Expand Up @@ -497,20 +549,6 @@ static f32 distance_cosine_int8(const void *pA, const void *pB,
return 1 - (dot / (sqrt(aMag) * sqrt(bMag)));
}

// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34
static u8 hamdist_table[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4,
2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5,
3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};

static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) {
int same = 0;
for (unsigned long i = 0; i < n; i++) {
Expand Down Expand Up @@ -1167,9 +1205,8 @@ static void vec_distance_cosine(sqlite3_context *context, int argc,

switch (elementType) {
case SQLITE_VEC_ELEMENT_TYPE_BIT: {
sqlite3_result_error(
context, "Cannot calculate cosine distance between two bitvectors.",
-1);
f32 result = distance_cosine_bit(a, b, &dimensions);
sqlite3_result_double(context, result);
goto finish;
}
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: {
Expand Down
19 changes: 19 additions & 0 deletions tests/test-loadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,25 @@ def check(a, b, dtype=np.float32):
check([1, 2, 3], [-9, -8, -7], dtype=np.int8)
assert vec_distance_cosine("[1.1, 1.0]", "[1.2, 1.2]") == 0.001131898257881403

vec_distance_cosine_bit = lambda *args: db.execute(
"select vec_distance_cosine(vec_bit(?), vec_bit(?))", args
).fetchone()[0]
assert isclose(
vec_distance_cosine_bit(b"\xff", b"\x01"),
npy_cosine([1,1,1,1,1,1,1,1], [0,0,0,0,0,0,0,1]),
abs_tol=1e-6
)
assert isclose(
vec_distance_cosine_bit(b"\xab", b"\xab"),
npy_cosine([1,0,1,0,1,0,1,1], [1,0,1,0,1,0,1,1]),
abs_tol=1e-6
)
# test 64-bit
assert isclose(
vec_distance_cosine_bit(b"\xaa" * 8, b"\xff" * 8),
npy_cosine([1,0] * 32, [1] * 64),
abs_tol=1e-6
)

def test_vec_distance_hamming():
vec_distance_hamming = lambda *args: db.execute(
Expand Down