Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor #2172

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Optimize reduceToTopK in ResultUtil by removing pre-filling and reducing peek calls [#2146](https://github.com/opensearch-project/k-NN/pull/2146)
* Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149)
* KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155)
* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172)
### Bug Fixes
* KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147)
### Infrastructure
Expand Down
36 changes: 35 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit";
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: A disabled suffix might cause confusion here if we ever expose this to customers. consider index.knn.shard_level_rescoring and the default can be false


/**
* Default setting values
Expand All @@ -112,11 +113,31 @@ public class KNNSettings {
public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed
// 10% of the JVM heap
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;
public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true;

/**
* Settings Definition
*/

/**
* This setting controls whether shard-level re-scoring for KNN disk-based vectors is turned off.
* The setting uses:
* <ul>
* <li><b>KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED:</b> The name of the setting.</li>
* <li><b>KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE:</b> The default value (true or false).</li>
* <li><b>IndexScope:</b> The setting works at the index level.</li>
* <li><b>Dynamic:</b> This setting can be changed without restarting the cluster.</li>
* </ul>
*
* @see Setting#boolSetting(String, boolean, Setting.Property...)
*/
public static final Setting<Boolean> KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting(
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE,
IndexScope,
Dynamic
);

// This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default
// 1% of the JVM heap
public static final Setting<ByteSizeValue> KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting(
Expand Down Expand Up @@ -454,6 +475,10 @@ private Setting<?> getSetting(String key) {
return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING;
}

if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -475,7 +500,8 @@ public List<Setting<?>> getSettings() {
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING,
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down Expand Up @@ -528,6 +554,14 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
.getSettings()
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,35 @@ public static boolean isConfigured(CompressionLevel compressionLevel) {
/**
* Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}.
*
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of
* {@code dimension}:
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the
* {@code dimension} value:
* <ul>
* <li>If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an
* oversample factor of 5.0f.</li>
* <li>If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with
* the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.</li>
* <li>If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).</li>
* <li>If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).</li>
* <li>If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).</li>
* </ul>
* If the {@code mode} is not valid, the method returns {@code null}.
* If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}.
*
* @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}.
* @param dimension The dimensional value that determines the {@link RescoreContext} behavior.
* @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than
* or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode
* is invalid.
* @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode
* is not valid.
*/
public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) {
if (modesForRescore.contains(mode)) {
// Adjust RescoreContext based on dimension
if (dimension <= RescoreContext.DIMENSION_THRESHOLD) {
// For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD).build();
if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) {
// No oversampling for dimensions >= 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build();
} else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) {
// 2x oversampling for dimensions >= 768 but < 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build();
} else {
return defaultRescoreContext;
// 3x oversampling for dimensions < 768
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build();
Vikasht34 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
Expand Down Expand Up @@ -54,7 +55,6 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();

List<Map<Integer, Float>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
int finalK = knnQuery.getK();
Expand All @@ -63,7 +63,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
} else {
int firstPassK = rescoreContext.getFirstPassK(finalK);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}
Vikasht34 marked this conversation as resolved.
Show resolved Hide resolved

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ public final class RescoreContext {
public static final int DIMENSION_THRESHOLD = 1000;
public static final float OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD = 5.0f;

// Dimension thresholds for adjusting oversample factor
public static final int DIMENSION_THRESHOLD_1000 = 1000;
public static final int DIMENSION_THRESHOLD_768 = 768;

// Oversample factors based on dimension thresholds
public static final float OVERSAMPLE_FACTOR_1000 = 1.0f; // No oversampling for dimensions >= 1000
public static final float OVERSAMPLE_FACTOR_768 = 2.0f; // 2x oversampling for dimensions >= 768 and < 1000
public static final float OVERSAMPLE_FACTOR_BELOW_768 = 3.0f; // 3x oversampling for dimensions < 768

// Todo:- We will improve this in upcoming releases
public static final int MIN_FIRST_PASS_RESULTS = 100;

Expand Down
35 changes: 35 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,41 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
assertEquals(userProvidedEfSearch, efSearchValue);
}

@SneakyThrows
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertTrue(shardLevelRescoringDisabled);
}

@SneakyThrows
public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() {
boolean userDefinedRescoringDisabled = false;
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

final Settings rescoringDisabledSetting = Settings.builder()
.put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled)
.build();

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}

@SneakyThrows
public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenValidateAndSucceed() {
boolean expectedKNNFaissAVX2Disabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,65 +44,84 @@ public void testIsConfigured() {
public void testGetDefaultRescoreContext() {
// Test rescore context for ON_DISK mode
Mode mode = Mode.ON_DISK;
int belowThresholdDimension = 500; // A dimension below the threshold
int aboveThresholdDimension = 1500; // A dimension above the threshold

// x32 with dimension <= 1000 should have an oversample factor of 5.0f
// Test various dimensions based on the updated oversampling logic
int belowThresholdDimension = 500; // A dimension below 768
int between768and1000Dimension = 800; // A dimension between 768 and 1000
int above1000Dimension = 1500; // A dimension above 1000

// Compression level x32 with dimension < 768 should have an oversample factor of 3.0f
RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x32 with dimension > 1000 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x16 with dimension <= 1000 should have an oversample factor of 5.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
// Compression level x32 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x16 with dimension > 1000 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x16 with dimension < 768 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x8 with dimension <= 1000 should have an oversample factor of 5.0f
// Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x16 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x8 with dimension < 768 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x8 with dimension > 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext)
// Compression level x8 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x4 with dimension < 768 should return null (no RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNull(rescoreContext);

// Other compression levels should behave similarly with respect to dimension
// Compression level x4 with dimension > 1000 should return null (no RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// Compression level x2 with dimension < 768 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

// x2 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x2 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// Compression level x1 with dimension < 768 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

// x1 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x1 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f
// NOT_CONFIGURED mode should return null for any dimension
rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

}

}
Loading
Loading