Skip to content

Commit

Permalink
Fix: Require contiguous rows
Browse files Browse the repository at this point in the history
Closes #543

Co-authored-by: Michelangiolo Mazzeschi <[email protected]>
Co-authored-by: Michelangiolo Mazzeschi
 <[email protected]>
  • Loading branch information
3 people committed Dec 29, 2024
1 parent 1330170 commit 4973e37
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,13 @@
"xtr1common": "cpp",
"xtree": "cpp",
"xutility": "cpp",
"execution": "cpp"
"execution": "cpp",
"text_encoding": "cpp"
},
"cSpell.words": [
"allclose",
"arange",
"ascontiguousarray",
"ashvardanian",
"astype",
"autovec",
Expand Down Expand Up @@ -182,6 +184,7 @@
"pytest",
"Quickstart",
"relock",
"relwithdebinfo",
"repr",
"rtype",
"SIMD",
Expand Down
13 changes: 12 additions & 1 deletion python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,15 @@ static void add_many_to_index( //

if (keys_info.itemsize != sizeof(dense_key_t))
throw std::invalid_argument("Incompatible key type!");
if (keys_info.strides[0] != static_cast<Py_ssize_t>(keys_info.itemsize))
throw std::invalid_argument("Keys array must be C-contiguous.");

if (keys_info.ndim != 1)
throw std::invalid_argument("Keys must be placed in a single-dimensional array!");

if (vectors_info.ndim != 2)
throw std::invalid_argument("Expects a matrix of vectors to add!");
if (vectors_info.strides[1] != static_cast<Py_ssize_t>(vectors_info.itemsize))
throw std::invalid_argument("Matrix rows must be contiguous, try `ascontiguousarray`.");

Py_ssize_t keys_count = keys_info.shape[0];
Py_ssize_t vectors_count = vectors_info.shape[0];
Expand Down Expand Up @@ -428,6 +431,8 @@ static py::tuple search_many_in_index( //
Py_ssize_t vectors_dimensions = vectors_info.shape[1];
if (vectors_dimensions != static_cast<Py_ssize_t>(index.scalar_words()))
throw std::invalid_argument("The number of vector dimensions doesn't match!");
if (vectors_info.strides[1] != static_cast<Py_ssize_t>(vectors_info.itemsize))
throw std::invalid_argument("Matrix rows must be contiguous, try `ascontiguousarray`.");

py::array_t<dense_key_t> keys_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({vectors_count, static_cast<Py_ssize_t>(wanted)});
Expand Down Expand Up @@ -474,6 +479,10 @@ static py::tuple search_many_brute_force( //
py::buffer_info queries_info = queries.request();
if (dataset_info.ndim != 2 || queries_info.ndim != 2)
throw std::invalid_argument("Expects a matrix of dataset to add!");
if (dataset_info.strides[1] != static_cast<Py_ssize_t>(dataset_info.itemsize))
throw std::invalid_argument("Dataset rows must be contiguous, try `ascontiguousarray`.");
if (queries_info.strides[1] != static_cast<Py_ssize_t>(queries_info.itemsize))
throw std::invalid_argument("Queries rows must be contiguous, try `ascontiguousarray`.");

std::size_t dataset_count = static_cast<std::size_t>(dataset_info.shape[0]);
std::size_t dataset_dimensions = static_cast<std::size_t>(dataset_info.shape[1]);
Expand Down Expand Up @@ -570,6 +579,8 @@ static py::tuple cluster_many_brute_force( //
py::buffer_info dataset_info = dataset.request();
if (dataset_info.ndim != 2)
throw std::invalid_argument("Expects a matrix (rank-2 tensor) of dataset to cluster!");
if (dataset_info.strides[1] != static_cast<Py_ssize_t>(dataset_info.itemsize))
throw std::invalid_argument("Dataset rows must be contiguous, try `ascontiguousarray`.");

std::size_t dataset_count = static_cast<std::size_t>(dataset_info.shape[0]);
std::size_t dataset_dimensions = static_cast<std::size_t>(dataset_info.shape[1]);
Expand Down
26 changes: 26 additions & 0 deletions python/scripts/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ def test_index_retrieval(ndim, metric, quantization, dtype, batch_size):
vectors_batch_retrieved = vectors_batch_retrieved[vectors_reordering]
assert np.allclose(vectors_batch_retrieved, vectors, atol=0.1)

if quantization != ScalarKind.I8 and batch_size > 1:
# When dealing with non-continuous data, it's important to check that
# the native bindings access them with correct strides or normalize
# similar to `np.ascontiguousarray`:
index = Index(ndim=ndim, metric=metric, dtype=quantization, multi=False)
vectors = random_vectors(count=batch_size, ndim=ndim + 1, dtype=dtype)
# Let's skip the first dimension of each vector:
vectors = vectors[:, 1:]
index.add(keys, vectors, threads=threads)
vectors_retrieved = np.vstack(index.get(keys, dtype))
assert np.allclose(vectors_retrieved, vectors, atol=0.1)

# Try a transposed version of the same vectors, that is not C-contiguous
# and should raise an exception!
index = Index(ndim=ndim, metric=metric, dtype=quantization, multi=False)
vectors = random_vectors(count=ndim, ndim=batch_size, dtype=dtype) #! reversed dims
assert vectors.strides == (batch_size * dtype().itemsize, dtype().itemsize)
assert vectors.T.strides == (dtype().itemsize, batch_size * dtype().itemsize)
with pytest.raises(Exception):
index.add(keys, vectors.T, threads=threads)


@pytest.mark.parametrize("ndim", [3, 97, 256])
@pytest.mark.parametrize("metric", [MetricKind.Cos, MetricKind.L2sq])
Expand All @@ -104,13 +125,18 @@ def test_index_search(ndim, metric, quantization, dtype, batch_size):

if batch_size == 1:
matches: Matches = index.search(vectors, 10, threads=threads)
assert isinstance(matches, Matches)
assert isinstance(matches[0], Match)
assert matches.keys.ndim == 1
assert matches.keys.shape[0] == matches.distances.shape[0]
assert len(matches) == batch_size
assert np.all(np.sort(index.keys) == np.sort(keys))

else:
matches: BatchMatches = index.search(vectors, 10, threads=threads)
assert isinstance(matches, BatchMatches)
assert isinstance(matches[0], Matches)
assert isinstance(matches[0][0], Match)
assert matches.keys.ndim == 2
assert matches.keys.shape[0] == matches.distances.shape[0]
assert len(matches) == batch_size
Expand Down

0 comments on commit 4973e37

Please sign in to comment.