Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 159 additions & 50 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,61 @@ impl SpannIndexWriter {
Ok(false)
}

async fn try_delete_posting_list(&self, head_id: u32) -> Result<(), SpannIndexWriterError> {
let _write_guard = self.posting_list_partitioned_mutex.lock(&head_id).await;
if self.is_head_deleted(head_id as usize).await? {
return Ok(());
}
let result = self
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", head_id)
.await;
// If the error is posting list not found, then return ok.
match result {
Ok(Some((doc_offset_ids, doc_versions, _))) => {
let mut outdated_count = 0;
for (doc_offset_id, doc_version) in doc_offset_ids.iter().zip(doc_versions.iter()) {
if self.is_outdated(*doc_offset_id, *doc_version).await? {
outdated_count += 1;
}
}
Comment on lines +782 to +786
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommended

[Performance] The loop starting on this line calls self.is_outdated for each item, which acquires a read lock on self.versions_map in every iteration. For large posting lists, this can be inefficient due to repeated lock acquisition overhead.

To optimize this, consider acquiring the read lock once before the loop and performing the checks within that single locked scope. This would reduce lock contention and improve performance.

This would involve replacing lines 781-786 with something like:

let version_map_guard = self.versions_map.read().await;
let mut outdated_count = 0;
for (doc_offset_id, doc_version) in doc_offset_ids.iter().zip(doc_versions.iter()) {
    let current_version = version_map_guard
        .versions_map
        .get(doc_offset_id)
        .ok_or(SpannIndexWriterError::VersionNotFound)?;
    if Self::is_deleted(*current_version) || *doc_version < *current_version {
        outdated_count += 1;
    }
}
Context for Agents
The loop starting on this line calls `self.is_outdated` for each item, which acquires a read lock on `self.versions_map` in every iteration. For large posting lists, this can be inefficient due to repeated lock acquisition overhead.

To optimize this, consider acquiring the read lock once before the loop and performing the checks within that single locked scope. This would reduce lock contention and improve performance.

This would involve replacing lines 781-786 with something like:
```rust
let version_map_guard = self.versions_map.read().await;
let mut outdated_count = 0;
for (doc_offset_id, doc_version) in doc_offset_ids.iter().zip(doc_versions.iter()) {
    let current_version = version_map_guard
        .versions_map
        .get(doc_offset_id)
        .ok_or(SpannIndexWriterError::VersionNotFound)?;
    if Self::is_deleted(*current_version) || *doc_version < *current_version {
        outdated_count += 1;
    }
}
```

File: rust/index/src/spann/types.rs
Line: 786

if outdated_count == doc_offset_ids.len() {
{
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard
.hnsw_index
.delete(head_id as usize)
.map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
})?;
}
self.posting_list_writer
.delete::<u32, &SpannPostingList<'_>>("", head_id)
.await
.map_err(|e| {
tracing::error!(
"Error deleting posting list for head {}: {}",
head_id,
e
);
SpannIndexWriterError::PostingListSetError(e)
})?;
}
}
Ok(None) => {}
Err(e) => {
tracing::error!("Error getting posting list for head {}: {}", head_id, e);
return Err(SpannIndexWriterError::PostingListGetError(e));
}
}
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn collect_and_reassign_split_points(
&self,
Expand Down Expand Up @@ -814,6 +869,8 @@ impl SpannIndexWriter {
.await?;
}
}
// Delete head if all points were moved out.
self.try_delete_posting_list(new_head_ids[k] as u32).await?;
}
Ok(assigned_ids)
}
Expand Down Expand Up @@ -946,17 +1003,20 @@ impl SpannIndexWriter {
let doc_versions;
let doc_embeddings;
{
// TODO(Sanket): Check if head is deleted, can happen if another concurrent thread
// deletes it.
(doc_offset_ids, doc_versions, doc_embeddings) = self
let result = self
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", head_id as u32)
.await
.map_err(|e| {
tracing::error!("Error getting posting list for head {}: {}", head_id, e);
SpannIndexWriterError::PostingListGetError(e)
})?
.ok_or(SpannIndexWriterError::PostingListNotFound)?;
.await;
match result {
Ok(Some((offset_ids, versions, embeddings))) => {
doc_offset_ids = offset_ids;
doc_versions = versions;
doc_embeddings = embeddings;
}
// Posting list can be concurrent deleted so bail out early if not found.
Ok(None) => return Ok(()),
Err(e) => return Err(SpannIndexWriterError::PostingListGetError(e)),
}
}
for (index, doc_offset_id) in doc_offset_ids.iter().enumerate() {
if assigned_ids.contains(doc_offset_id)
Expand Down Expand Up @@ -1004,6 +1064,8 @@ impl SpannIndexWriter {
)
.await?;
}
// Delete head if all points were moved out.
self.try_delete_posting_list(head_id as u32).await?;
Ok(())
}

Expand Down Expand Up @@ -1264,6 +1326,7 @@ impl SpannIndexWriter {
if !same_head
&& distance_function
.distance(&clustering_output.cluster_centers[k], &head_embedding)
.abs()
< 1e-6
{
same_head = true;
Expand Down Expand Up @@ -1350,17 +1413,32 @@ impl SpannIndexWriter {
}
if !same_head {
// Delete the old head
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard
.hnsw_index
.delete(head_id as usize)
// First delete from hnsw then from postings list. This order
// ensures that the head is never dangling.
{
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard
.hnsw_index
.delete(head_id as usize)
.map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
})?;
}
self.posting_list_writer
.delete::<u32, &SpannPostingList<'_>>("", head_id)
.await
.map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
"Error deleting posting list for head {}: {}",
head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
SpannIndexWriterError::PostingListSetError(e)
})?;
self.stats
.num_heads_deleted
Expand Down Expand Up @@ -1755,12 +1833,29 @@ impl SpannIndexWriter {
self.stats
.num_pl_modified
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Delete from hnsw.
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard.hnsw_index.delete(head_id).map_err(|e| {
tracing::error!("Error deleting head {} from hnsw index: {}", head_id, e);
SpannIndexWriterError::HnswIndexMutateError(e)
})?;
{
// Delete from hnsw.
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard.hnsw_index.delete(head_id).map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
})?;
}
self.posting_list_writer
.delete::<u32, &SpannPostingList<'_>>("", head_id as u32)
.await
.map_err(|e| {
tracing::error!(
"Error deleting posting list for head {}: {}",
head_id,
e
);
SpannIndexWriterError::PostingListSetError(e)
})?;
self.stats
.num_heads_deleted
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Expand All @@ -1779,18 +1874,31 @@ impl SpannIndexWriter {
self.stats
.num_pl_modified
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
// Delete from hnsw.
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard
.hnsw_index
.delete(nearest_head_id)
{
// Delete from hnsw.
let hnsw_write_guard = self.hnsw_index.inner.write();
hnsw_write_guard
.hnsw_index
.delete(nearest_head_id)
.map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
nearest_head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
})?;
}
self.posting_list_writer
.delete::<u32, &SpannPostingList<'_>>("", nearest_head_id as u32)
.await
.map_err(|e| {
tracing::error!(
"Error deleting head {} from hnsw index: {}",
"Error deleting posting list for head {}: {}",
nearest_head_id,
e
);
SpannIndexWriterError::HnswIndexMutateError(e)
SpannIndexWriterError::PostingListSetError(e)
})?;
self.stats
.num_heads_deleted
Expand Down Expand Up @@ -3583,7 +3691,7 @@ mod tests {
}

#[tokio::test]
async fn test_reassign() {
async fn test_reassign_and_delete_center() {
let tmp_dir = tempfile::tempdir().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = new_cache_for_test();
Expand Down Expand Up @@ -3775,16 +3883,19 @@ mod tests {
.expect("Expected reassign to succeed");
// See the reassigned points.
{
// Center 1 should remain unchanged.
// Center 1 should get 100 points: original 50 + 50 reassigned from center 3.
// Points 51-100 from center 3 (near 1000,1000) get reassigned because center 2
// was deleted, and center 1 is the only remaining nearby center.
let pl = writer
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", 1)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 50);
assert_eq!(pl.1.len(), 50);
assert_eq!(pl.2.len(), 100);
assert_eq!(pl.0.len(), 100);
assert_eq!(pl.1.len(), 100);
assert_eq!(pl.2.len(), 200);
// First 50 are original points 1-50 at version 1
for i in 1..=50 {
assert_eq!(pl.0[i - 1], i as u32);
assert_eq!(pl.1[i - 1], 1);
Expand All @@ -3794,28 +3905,26 @@ mod tests {
split_doc_embeddings1[(i - 1) * 2 + 1]
);
}
// Center 2 should get 50 points, all with version 2 migrating from center 3.
let pl = writer
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
.await
.expect("Error getting posting list")
.unwrap();
assert_eq!(pl.0.len(), 50);
assert_eq!(pl.1.len(), 50);
assert_eq!(pl.2.len(), 100);
for i in 1..=50 {
assert_eq!(pl.0[i - 1], 50 + i as u32);
// Next 50 are reassigned points 51-100 at version 2 (from center 3)
for i in 51..=100 {
assert_eq!(pl.0[i - 1], i as u32);
assert_eq!(pl.1[i - 1], 2);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]);
assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 51) * 2]);
assert_eq!(
pl.2[(i - 1) * 2 + 1],
split_doc_embeddings3[(i - 1) * 2 + 1]
split_doc_embeddings3[(i - 51) * 2 + 1]
);
}
// Center 3 should get 100 points. 50 points with version 1 which weere
// originally in center 3 and 50 points with version 2 which were originally
// in center 2.
// Center 2 should be deleted (all its original points were reassigned out).
let pl = writer
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", 2)
.await
.expect("Error getting posting list");
assert!(pl.is_none());
// Center 3 should get 100 points. 50 points with version 1 which were
// originally in center 3 (now outdated since reassigned to center 1) and
// 50 points with version 2 which were originally in center 2.
let pl = writer
.posting_list_writer
.get_owned::<u32, &SpannPostingList<'_>>("", 3)
Expand Down
Loading