diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b52a37702..8c7011e1c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,10 +61,10 @@ jobs: if lscpu | grep -i avx2 then echo "avx2 available on system" - su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=true" + su `id -un 1000` -c "whoami && java -version && ./gradlew build" else echo "avx2 not available on system" - su `id -un 1000` -c "whoami && java -version && ./gradlew build" + su `id -un 1000` -c "whoami && java -version && ./gradlew build -Dsimd.enabled=false" fi @@ -101,10 +101,10 @@ jobs: if sysctl -n machdep.cpu.features machdep.cpu.leaf7_features | grep -i AVX2 then echo "avx2 available on system" - ./gradlew build -Dsimd.enabled=true + ./gradlew build else echo "avx2 not available on system" - ./gradlew build + ./gradlew build -Dsimd.enabled=false fi Build-k-NN-Windows: @@ -158,5 +158,5 @@ jobs: - name: Run build run: | - ./gradlew.bat build + ./gradlew.bat build -D'simd.enabled=false' diff --git a/CHANGELOG.md b/CHANGELOG.md index 02962c7b2..ef95bcbca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402) +* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502) +* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501) * Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527) ### Bug Fixes * Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518) +* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532) ### Infrastructure * Manually install zlib for win CI [#1513](https://github.com/opensearch-project/k-NN/pull/1513) ### Documentation diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index c232e75a9..ca5e7cc52 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -238,9 +238,9 @@ If you want to make a custom patch on JNI library 4. Make a change in `jni/CmakeLists.txt`, `.github/workflows/CI.yml` to apply the patch during build ### Enable SIMD Optimization -SIMD(Single Instruction/Multiple Data) Optimization can be enabled by setting this optional parameter `simd.enabled` to `true` which boosts the performance +SIMD(Single Instruction/Multiple Data) Optimization is enabled by default on Linux and Mac which boosts the performance by enabling `AVX2` on `x86 architecture` and `NEON` on `ARM64 architecture` while building the Faiss library. But to enable SIMD, the underlying processor -should support this (AVX2 or NEON). So, by default it is set to `false`. +should support this (AVX2 or NEON). It can be disabled by setting the parameter `simd.enabled` to `false`. As of now, it is not supported on Windows OS. ``` # While building OpenSearch k-NN diff --git a/benchmarks/osb/requirements.txt b/benchmarks/osb/requirements.txt index 2da38cfaa..3d41012f2 100644 --- a/benchmarks/osb/requirements.txt +++ b/benchmarks/osb/requirements.txt @@ -38,7 +38,7 @@ ijson==2.6.1 # via opensearch-benchmark importlib-metadata==4.11.3 # via jsonschema -jinja2==2.11.3 +jinja2==3.1.3 # via opensearch-benchmark jsonschema==3.1.1 # via opensearch-benchmark diff --git a/build.gradle b/build.gradle index f34c0fa3a..1c6ae5efd 100644 --- a/build.gradle +++ b/build.gradle @@ -17,7 +17,7 @@ buildscript { version_qualifier = System.getProperty("build.version_qualifier", "") opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - simd_enabled = System.getProperty("simd.enabled", "false") + simd_enabled = System.getProperty("simd.enabled", "true") version_tokens = opensearch_version.tokenize('-') opensearch_build = version_tokens[0] + '.0' @@ -287,6 +287,10 @@ dependencies { testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2' testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.7' testFixturesImplementation "org.opensearch:common-utils:${version}" + implementation 'com.github.oshi:oshi-core:6.4.13' + api "net.java.dev.jna:jna:5.13.0" + api "net.java.dev.jna:jna-platform:5.13.0" + implementation 'org.slf4j:slf4j-api:1.7.36' zipArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" } diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index c06337338..30a77d095 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -112,8 +112,8 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S set(BUILD_TESTING OFF) # Avoid building faiss tests set(BLA_STATIC ON) # Statically link BLAS - if(NOT SIMD_ENABLED) - set(SIMD_ENABLED false) # set default value as false if the argument is not set + if(NOT DEFINED SIMD_ENABLED) + set(SIMD_ENABLED true) # set default value as true if the argument is not set endif() if(${CMAKE_SYSTEM_NAME} STREQUAL Windows OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" OR NOT ${SIMD_ENABLED}) @@ -122,6 +122,7 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S else() set(FAISS_OPT_LEVEL avx2) # Keep optimization level as avx2 to improve performance on Linux and Mac. set(TARGET_LINK_FAISS_LIB faiss_avx2) + string(PREPEND LIB_EXT "_avx2") # Prepend "_avx2" to lib extension to create the library as "libopensearchknn_faiss_avx2.so" on linux and "libopensearchknn_faiss_avx2.jnilib" on mac endif() if (${CMAKE_SYSTEM_NAME} STREQUAL Darwin) @@ -160,6 +161,7 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S if (EXISTS ${PATCH_FILE}) message(STATUS "Applying custom patches.") execute_process(COMMAND git apply --ignore-space-change --ignore-whitespace --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + if(RESULT_CODE) message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") endif() diff --git a/scripts/build.sh b/scripts/build.sh index f17d794a0..b2cdef687 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -119,8 +119,14 @@ fi # Build k-NN lib and plugin through gradle tasks cd $work_dir -# Gradle build is used here to replace gradle assemble due to build will also call cmake and make before generating jars -./gradlew build --no-daemon --refresh-dependencies -x integTest -DskipTests=true -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER +./gradlew build --no-daemon --refresh-dependencies -x integTest -x test -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER +./gradlew :buildJniLib -Dsimd.enabled=false + +if [ "$PLATFORM" != "windows" ] && [ "$ARCHITECTURE" = "x64" ]; then + echo "Building k-NN library after enabling AVX2" + ./gradlew :buildJniLib -Dsimd.enabled=true +fi + ./gradlew publishPluginZipPublicationToZipStagingRepository -Dopensearch.version=$VERSION -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER ./gradlew publishPluginZipPublicationToMavenLocal -Dbuild.snapshot=$SNAPSHOT -Dbuild.version_qualifier=$QUALIFIER -Dopensearch.version=$VERSION @@ -150,20 +156,6 @@ cd $distributions zip -ur $zipPath lib cd $work_dir -if [ "$PLATFORM" != "windows" ]; then - echo "Building k-NN libraries after enabling SIMD" - ./gradlew :buildJniLib -Dsimd.enabled=true - mkdir $distributions/lib_simd - cp -v $ompPath $distributions/lib_simd - cp -v ./jni/release/${libPrefix}* $distributions/lib_simd - ls -l $distributions/lib_simd - - # Add lib_simd directory to the k-NN plugin zip - cd $distributions - zip -ur $zipPath lib_simd - cd $work_dir -fi - echo "COPY ${distributions}/*.zip" mkdir -p $OUTPUT/plugins cp -v ${distributions}/*.zip $OUTPUT/plugins diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 457bc2df3..269f774b5 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -114,6 +114,7 @@ public class KNNConstants { // Lib names private static final String JNI_LIBRARY_PREFIX = "opensearchknn_"; public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME; + public static final String FAISS_AVX2_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME + "_avx2"; public static final String NMSLIB_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + NMSLIB_NAME; // API Constants diff --git a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java new file mode 100644 index 000000000..ca8e1459a --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.common; + +import java.util.Locale; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.VectorDataType; + +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class KNNValidationUtil { + /** + * Validate the float vector value and throw exception if it is not a number or not in the finite range. + * + * @param value float vector value + */ + public static void validateFloatVectorValue(float value) { + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + } + + /** + * Validate the float vector value in the byte range if it is a finite number, + * with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. + * + * @param value float value in byte range + */ + public static void validateByteVectorValue(float value) { + validateFloatVectorValue(value); + if (value % 1 != 0) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + + ); + } + if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ); + } + } + + /** + * Validate if the given vector size matches with the dimension provided in mapping. + * + * @param dimension dimension of vector + * @param vectorSize size of the vector + */ + public static void validateVectorDimension(int dimension, int vectorSize) { + if (dimension != vectorSize) { + String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); + throw new IllegalArgumentException(errorMessage); + } + } +} diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java new file mode 100644 index 000000000..fd9e5b6c2 --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import java.util.Objects; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class KNNVectorUtil { + /** + * Check if all the elements of a given vector are zero + * + * @param vector the vector + * @return true if yes; otherwise false + */ + public static boolean isZeroVector(byte[] vector) { + Objects.requireNonNull(vector, "vector must not be null"); + for (byte e : vector) { + if (e != 0) { + return false; + } + } + return true; + } + + /** + * Check if all the elements of a given vector are zero + * + * @param vector the vector + * @return true if yes; otherwise false + */ + public static boolean isZeroVector(float[] vector) { + Objects.requireNonNull(vector, "vector must not be null"); + for (float e : vector) { + if (e != 0f) { + return false; + } + } + return true; + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 53ee17150..572c9220e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -76,6 +76,7 @@ public class KNNSettings { public static final String MODEL_INDEX_NUMBER_OF_REPLICAS = "knn.model.index.number_of_replicas"; public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; + public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled"; /** * Default setting values @@ -230,6 +231,9 @@ public class KNNSettings { NodeScope, Dynamic ); + + public static final Setting KNN_FAISS_AVX2_DISABLED_SETTING = Setting.boolSetting(KNN_FAISS_AVX2_DISABLED, false, NodeScope); + /** * Dynamic settings */ @@ -339,6 +343,10 @@ private Setting getSetting(String key) { return ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING; } + if (KNN_FAISS_AVX2_DISABLED.equals(key)) { + return KNN_FAISS_AVX2_DISABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -355,7 +363,8 @@ public List> getSettings() { MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, MODEL_CACHE_SIZE_LIMIT_SETTING, - ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, + KNN_FAISS_AVX2_DISABLED_SETTING ); return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); } @@ -376,6 +385,10 @@ public static double getCircuitBreakerUnsetPercentage() { return KNNSettings.state().getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); } + public static boolean isFaissAVX2Disabled() { + return KNNSettings.state().getSettingValue(KNNSettings.KNN_FAISS_AVX2_DISABLED); + } + public static Integer getFilteredExactSearchThreshold(final String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index efa0f1be3..50d8d352c 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -11,11 +11,14 @@ package org.opensearch.knn.index; +import java.util.Locale; import org.apache.lucene.index.VectorSimilarityFunction; import java.util.HashSet; import java.util.Set; +import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector; + /** * Enum contains spaces supported for approximate nearest neighbor search in the k-NN plugin. Each engine's methods are * expected to support a subset of these spaces. Validation should be done in the jni layer and an exception should be @@ -44,6 +47,24 @@ public float scoreTranslation(float rawScore) { public VectorSimilarityFunction getVectorSimilarityFunction() { return VectorSimilarityFunction.COSINE; } + + @Override + public void validateVector(byte[] vector) { + if (isZeroVector(vector)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue()) + ); + } + } + + @Override + public void validateVector(float[] vector) { + if (isZeroVector(vector)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue()) + ); + } + } }, L1("l1") { @Override @@ -76,7 +97,7 @@ public float scoreTranslation(float rawScore) { @Override public VectorSimilarityFunction getVectorSimilarityFunction() { - return VectorSimilarityFunction.DOT_PRODUCT; + return VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; } }, HAMMING_BIT("hammingbit") { @@ -105,6 +126,24 @@ public VectorSimilarityFunction getVectorSimilarityFunction() { throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue())); } + /** + * Validate if the given byte vector is supported by this space type + * + * @param vector the given vector + */ + public void validateVector(byte[] vector) { + // do nothing + } + + /** + * Validate if the given float vector is supported by this space type + * + * @param vector the given vector + */ + public void validateVector(float[] vector) { + // do nothing + } + /** * Get space type name in engine * diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 5b427517b..2369a6937 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -5,22 +5,29 @@ package org.opensearch.knn.index.mapper; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; import lombok.Getter; import lombok.extern.log4j.Log4j2; -import org.opensearch.Version; -import org.opensearch.common.ValidationException; -import org.opensearch.knn.common.KNNConstants; - import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.common.Explicit; +import org.opensearch.common.Nullable; +import org.opensearch.common.ValidationException; +import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -32,9 +39,11 @@ import org.opensearch.index.mapper.ValueFetcher; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; @@ -42,25 +51,16 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.function.Supplier; - import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension; /** * Field Mapper for KNN vector type. @@ -313,7 +313,13 @@ public KNNVectorFieldMapper build(BuilderContext context) { return new LegacyFieldMapper( name, - new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()), + new KNNVectorFieldType( + buildFullName(context), + metaValue, + dimension.getValue(), + vectorDataType.getValue(), + SpaceType.getSpace(spaceType) + ), multiFieldsBuilder, copyToBuilder, ignoreMalformed, @@ -384,17 +390,24 @@ public static class KNNVectorFieldType extends MappedFieldType { String modelId; KNNMethodContext knnMethodContext; VectorDataType vectorDataType; + SpaceType spaceType; - public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType) { - this(name, meta, dimension, null, null, vectorDataType); + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + VectorDataType vectorDataType, + SpaceType spaceType + ) { + this(name, meta, dimension, null, null, vectorDataType, spaceType); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD); + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); } public KNNVectorFieldType( @@ -404,22 +417,24 @@ public KNNVectorFieldType( KNNMethodContext knnMethodContext, VectorDataType vectorDataType ) { - this(name, meta, dimension, knnMethodContext, null, vectorDataType); + this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); } public KNNVectorFieldType( String name, Map meta, int dimension, - KNNMethodContext knnMethodContext, - String modelId, - VectorDataType vectorDataType + @Nullable KNNMethodContext knnMethodContext, + @Nullable String modelId, + VectorDataType vectorDataType, + @Nullable SpaceType spaceType ) { super(name, false, false, true, TextSearchInfo.NONE, meta); this.dimension = dimension; this.modelId = modelId; this.knnMethodContext = knnMethodContext; this.vectorDataType = vectorDataType; + this.spaceType = spaceType; } @Override @@ -496,10 +511,10 @@ protected String contentType() { @Override protected void parseCreateField(ParseContext context) throws IOException { - parseCreateField(context, fieldType().getDimension()); + parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType()); } - protected void parseCreateField(ParseContext context, int dimension) throws IOException { + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); @@ -507,10 +522,11 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension); - if (!bytesArrayOptional.isPresent()) { + if (bytesArrayOptional.isEmpty()) { return; } final byte[] array = bytesArrayOptional.get(); + spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); @@ -518,12 +534,12 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); - if (!floatsArrayOptional.isPresent()) { + if (floatsArrayOptional.isEmpty()) { return; } final float[] array = floatsArrayOptional.get(); + spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); - context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); } else { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index bf331eeb3..b525b9dc6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -11,6 +11,8 @@ package org.opensearch.knn.index.mapper; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; @@ -25,69 +27,8 @@ import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +@NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { - /** - * Validate the float vector value and throw exception if it is not a number or not in the finite range. - * - * @param value float vector value - */ - public static void validateFloatVectorValue(float value) { - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } - - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); - } - } - - /** - * Validate the float vector value in the byte range if it is a finite number, - * with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. - * - * @param value float value in byte range - */ - public static void validateByteVectorValue(float value) { - validateFloatVectorValue(value); - if (value % 1 != 0) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", - VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue() - ) - - ); - } - if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", - VECTOR_DATA_TYPE_FIELD, - VectorDataType.BYTE.getValue(), - Byte.MIN_VALUE, - Byte.MAX_VALUE - ) - ); - } - } - - /** - * Validate if the given vector size matches with the dimension provided in mapping. - * - * @param dimension dimension of vector - * @param vectorSize size of the vector - */ - public static void validateVectorDimension(int dimension, int vectorSize) { - if (dimension != vectorSize) { - String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); - throw new IllegalArgumentException(errorMessage); - } - - } - /** * Validates and throws exception if data_type field is set in the index mapping * using any VectorDataType (other than float, which is default) because other diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 831a23f4b..81c7216bf 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -5,6 +5,9 @@ package org.opensearch.knn.index.mapper; +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; @@ -15,14 +18,11 @@ import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; - import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; @@ -75,7 +75,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { } @Override - protected void parseCreateField(ParseContext context, int dimension) throws IOException { + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); @@ -86,6 +86,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx return; } final byte[] array = bytesArrayOptional.get(); + spaceType.validateVector(array); KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); context.doc().add(point); @@ -101,7 +102,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx return; } final float[] array = floatsArrayOptional.get(); - + spaceType.validateVector(array); KnnVectorField point = new KnnVectorField(name(), array, fieldType); context.doc().add(point); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 2f138dba6..2367d7422 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension()); + parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType()); } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index eb43f67f7..2140487c5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,36 +5,36 @@ package org.opensearch.knn.index.query; +import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.Objects; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; -import org.opensearch.core.common.Strings; -import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.lucene.search.Query; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.plugin.stats.KNNCounter; -import java.io.IOException; -import java.util.List; -import java.util.Objects; - +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; /** * Helper class to build the KNN query @@ -284,12 +284,17 @@ protected Query doToQuery(QueryShardContext context) { KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); KNNEngine knnEngine = KNNEngine.DEFAULT; VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); + SpaceType spaceType = knnVectorFieldType.getSpaceType(); if (fieldDimension == -1) { + if (spaceType != null) { + throw new IllegalStateException("Space type should be null when the field uses a model"); + } // If dimension is not set, the field uses a model and the information needs to be retrieved from there ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); fieldDimension = modelMetadata.getDimension(); knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping knnEngine = knnMethodContext.getKnnEngine(); @@ -308,6 +313,9 @@ protected Query doToQuery(QueryShardContext context) { validateByteVectorValue(vector[i]); byteVector[i] = (byte) vector[i]; } + spaceType.validateVector(byteVector); + } else { + spaceType.validateVector(vector); } if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index f330352ec..a1a07547b 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -19,6 +19,9 @@ import java.security.PrivilegedAction; import java.util.Map; +import static org.opensearch.knn.index.KNNSettings.isFaissAVX2Disabled; +import static org.opensearch.knn.jni.PlatformUtils.isAVX2SupportedBySystem; + /** * Service to interact with faiss jni layer. Class dependencies should be minimal * @@ -31,7 +34,15 @@ class FaissService { static { AccessController.doPrivileged((PrivilegedAction) () -> { - System.loadLibrary(KNNConstants.FAISS_JNI_LIBRARY_NAME); + + // Even if the underlying system supports AVX2, users can override and disable it by using the + // 'knn.faiss.avx2.disabled' setting by setting it to true in the opensearch.yml configuration + if (!isFaissAVX2Disabled() && isAVX2SupportedBySystem()) { + System.loadLibrary(KNNConstants.FAISS_AVX2_JNI_LIBRARY_NAME); + } else { + System.loadLibrary(KNNConstants.FAISS_JNI_LIBRARY_NAME); + } + initLibrary(); KNNEngine.FAISS.setInitialized(true); return null; diff --git a/src/main/java/org/opensearch/knn/jni/PlatformUtils.java b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java new file mode 100644 index 000000000..8a5549dec --- /dev/null +++ b/src/main/java/org/opensearch/knn/jni/PlatformUtils.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import com.sun.jna.Platform; +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import oshi.util.platform.mac.SysctlUtil; + +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Locale; + +public class PlatformUtils { + + private static final Logger logger = LogManager.getLogger(PlatformUtils.class); + + /** + * Verify if the underlying system supports AVX2 SIMD Optimization or not + * 1. If the architecture is not x86 return false. + * 2. If the operating system is not Mac or Linux return false(for example Windows). + * 3. If the operating system is macOS, use oshi library to verify if the cpu flags + * contains 'avx2' and return true if it exists else false. + * 4. If the operating system is linux, read the '/proc/cpuinfo' file path and verify if + * the flags contains 'avx2' and return true if it exists else false. + */ + public static boolean isAVX2SupportedBySystem() { + if (!Platform.isIntel()) { + return false; + } + + if (Platform.isMac()) { + + // sysctl or system control retrieves system info and allows processes with appropriate privileges + // to set system info. This system info contains the machine dependent cpu features that are supported by it. + // On MacOS, if the underlying processor supports AVX2 instruction set, it will be listed under the "leaf7" + // subset of instructions ("sysctl -a | grep machdep.cpu.leaf7_features"). + // https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/sysctl.3.html + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + String flags = SysctlUtil.sysctl("machdep.cpu.leaf7_features", "empty"); + return (flags.toLowerCase(Locale.ROOT)).contains("avx2"); + }); + } catch (Exception e) { + logger.error("[KNN] Error fetching cpu flags info. [{}]", e.getMessage(), e); + } + + } else if (Platform.isLinux()) { + + // The "/proc/cpuinfo" is a virtual file which identifies and provides the processor details used + // by system. This info contains "flags" for each processor which determines the qualities of that processor + // and it's ability to process different instruction sets like mmx, avx, avx2 and so on. + // https://access.redhat.com/documentation/en-us/red_hat_enterprise_linux/6/html/deployment_guide/s2-proc-cpuinfo + // Here, we are trying to read the details of all processors used by system and find if any of the processor + // supports AVX2 instructions. Pentium and Celeron are a couple of examples which doesn't support AVX2 + // https://ark.intel.com/content/www/us/en/ark/products/199285/intel-pentium-gold-g6600-processor-4m-cache-4-20-ghz.html + String fileName = "/proc/cpuinfo"; + try { + return AccessController.doPrivileged( + (PrivilegedExceptionAction) () -> (Boolean) Files.lines(Paths.get(fileName)) + .filter(s -> s.startsWith("flags")) + .anyMatch(s -> StringUtils.containsIgnoreCase(s, "avx2")) + ); + + } catch (Exception e) { + logger.error("[KNN] Error reading file [{}]. [{}]", fileName, e.getMessage(), e); + } + } + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 0e4c9f815..5a8cdb036 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -6,6 +6,7 @@ package org.opensearch.knn.plugin.script; import org.apache.lucene.search.IndexSearcher; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; import org.apache.lucene.index.LeafReaderContext; @@ -90,7 +91,7 @@ class CosineSimilarity implements KNNScoringSpace { */ public CosineSimilarity(Object query, MappedFieldType fieldType) { if (!isKNNVectorFieldType(fieldType)) { - throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must " + "be knn_vector."); + throw new IllegalArgumentException("Incompatible field_type for cosine space. The field type must be knn_vector."); } this.processedQuery = parseToFloatArray( @@ -98,6 +99,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); + SpaceType.COSINESIMIL.validateVector(processedQuery); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index 3ec1a9941..c482413fb 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import java.util.List; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.plugin.stats.KNNCounter; @@ -13,11 +14,10 @@ import org.opensearch.index.mapper.NumberFieldMapper; import java.math.BigInteger; -import java.util.ArrayList; import java.util.Base64; import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; public class KNNScoringSpaceUtil { @@ -108,7 +108,7 @@ public static float[] parseToFloatArray(Object object, int expectedDimensions, V public static float[] convertVectorToPrimitive(Object vector, VectorDataType vectorDataType) { float[] primitiveVector = null; if (vector != null) { - final ArrayList tmp = (ArrayList) vector; + final List tmp = (List) vector; primitiveVector = new float[tmp.size()]; for (int i = 0; i < primitiveVector.length; i++) { float value = tmp.get(i).floatValue(); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 130c4d8e0..114499100 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -5,16 +5,16 @@ package org.opensearch.knn.plugin.script; -import org.opensearch.knn.index.KNNVectorScriptDocValues; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.knn.index.VectorDataType; - import java.math.BigInteger; import java.util.List; import java.util.Objects; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.KNNVectorScriptDocValues; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; public class KNNScoringUtil { private static Logger logger = LogManager.getLogger(KNNScoringUtil.class); @@ -134,11 +134,9 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { - return cosinesimilOptimized( - toFloat(queryVector, docValues.getVectorDataType()), - docValues.getValue(), - queryVectorMagnitude.floatValue() - ); + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); } /** @@ -183,7 +181,9 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { * @return cosine score */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { - return cosinesimil(toFloat(queryVector, docValues.getVectorDataType()), docValues.getValue()); + float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); + SpaceType.COSINESIMIL.validateVector(inputVector); + return cosinesimil(inputVector, docValues.getValue()); } /** diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index c7345edcd..91624613c 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -1,5 +1,8 @@ grant { permission java.lang.RuntimePermission "loadLibrary.opensearchknn_nmslib"; permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss"; + permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss_avx2"; permission java.net.SocketPermission "*", "connect,resolve"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.io.FilePermission "/proc/cpuinfo", "read"; }; diff --git a/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java new file mode 100644 index 000000000..457ea8c5b --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.common; + +import org.opensearch.knn.KNNTestCase; + +public class KNNVectorUtilTests extends KNNTestCase { + public void testByteZeroVector() { + assertTrue(KNNVectorUtil.isZeroVector(new byte[] { 0, 0, 0 })); + assertFalse(KNNVectorUtil.isZeroVector(new byte[] { 1, 1, 1 })); + } + + public void testFloatZeroVector() { + assertTrue(KNNVectorUtil.isZeroVector(new float[] { 0.0f, 0.0f, 0.0f })); + assertFalse(KNNVectorUtil.isZeroVector(new float[] { 1.0f, 1.0f, 1.0f })); + } +} diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 9da067fde..ffbab463c 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -19,6 +19,7 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.junit.BeforeClass; import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.KNNRestTestCase; @@ -714,6 +715,57 @@ public void testQueryWithFilter_withDifferentCombination_thenSuccess() { assertEquals(0, emptyKNNFilteredResultsFromResponse.size()); } + @SneakyThrows + public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", 2) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW).getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .endObject() + .endObject(); + final String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, mapping); + + final List dataVectors = Arrays.asList(new Float[] { -2.0f, 2.0f }, new Float[] { 2.0f, -2.0f }); + final List ids = Arrays.asList(DOC_ID_1, DOC_ID_2); + + // Ingest all of the documents + for (int i = 0; i < dataVectors.size(); i++) { + addKnnDoc(INDEX_NAME, ids.get(i), FIELD_NAME, dataVectors.get(i)); + } + refreshIndex(INDEX_NAME); + + // Execute the search request with a match all query to ensure exact logic gets called + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, 1000)); + float[] queryVector = new float[] { -2.0f, 2.0f }; + int k = 2; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder(FIELD_NAME, queryVector, k, QueryBuilders.matchAllQuery()), + k + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); + + // Check that the expected scores are returned + final List expectedScores = Arrays.asList( + KNNEngine.FAISS.score(8.0f, SpaceType.INNER_PRODUCT), + KNNEngine.FAISS.score(-8.0f, SpaceType.INNER_PRODUCT) + ); + assertEquals(expectedScores.size(), knnResults.size()); + for (int i = 0; i < expectedScores.size(); i++) { + assertEquals(expectedScores.get(i), knnResults.get(i), 0.0000001); + } + } + protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index dfb6923aa..75eb14713 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -158,6 +158,18 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { assertEquals(userProvidedEfSearch, efSearchValue); } + @SneakyThrows + public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenValidateAndSucceed() { + boolean expectedKNNFaissAVX2Disabled = true; + Node mockNode = createMockNode(Map.of(KNNSettings.KNN_FAISS_AVX2_DISABLED, expectedKNNFaissAVX2Disabled)); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + KNNSettings.state().setClusterService(clusterService); + boolean actualKNNFaissAVX2Disabled = KNNSettings.state().getSettingValue(KNNSettings.KNN_FAISS_AVX2_DISABLED); + mockNode.close(); + assertEquals(expectedKNNFaissAVX2Disabled, actualKNNFaissAVX2Disabled); + } + private Node createMockNode(Map configSettings) throws IOException { Path configDir = createTempDir(); File configFile = configDir.resolve("opensearch.yml").toFile(); diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 4130e2df5..ff2d4518e 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -13,6 +13,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; +import java.util.Locale; +import lombok.SneakyThrows; import org.junit.BeforeClass; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; @@ -263,6 +265,35 @@ public void testIndexingVectorValidation_differentSizes() throws Exception { assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("Vector dimension mismatch. Expected: 4, Given: 5")); } + @SneakyThrows + public void testIndexingVectorValidation_zeroVector() { + Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); + final boolean valid = randomBoolean(); + final String method = KNNConstants.METHOD_HNSW; + String engine; + String spaceType; + if (valid) { + engine = randomFrom(KNNEngine.values()).getName(); + spaceType = SpaceType.L2.getValue(); + } else { + engine = randomFrom(KNNConstants.LUCENE_NAME, KNNConstants.NMSLIB_NAME); + spaceType = SpaceType.COSINESIMIL.getValue(); + } + createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 4, method, engine, spaceType)); + Float[] zeroVector = { 0.0f, 0.0f, 0.0f, 0.0f }; + if (valid) { + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, zeroVector); + } else { + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, "1", FIELD_NAME, zeroVector)); + assertTrue( + EntityUtils.toString(ex.getResponse().getEntity()) + .contains( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()) + ) + ); + } + } + public void testVectorMappingValidation_noDimension() throws Exception { Settings settings = Settings.builder().put(getKNNDefaultIndexSettings()).build(); diff --git a/src/test/java/org/opensearch/knn/index/SpaceTypeTests.java b/src/test/java/org/opensearch/knn/index/SpaceTypeTests.java index 5df9ec6f5..9d4fbfc2b 100644 --- a/src/test/java/org/opensearch/knn/index/SpaceTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/SpaceTypeTests.java @@ -13,6 +13,10 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNEngine; + +import java.util.Arrays; +import java.util.List; public class SpaceTypeTests extends KNNTestCase { @@ -23,4 +27,38 @@ public void testGetVectorSimilarityFunction_l2() { public void testGetVectorSimilarityFunction_invalid() { expectThrows(UnsupportedOperationException.class, SpaceType.L1::getVectorSimilarityFunction); } + + public void testGetVectorSimilarityFunction_whenInnerproduct_thenConsistentWithScoreTranslation() { + /* + For the innerproduct space type, we expect that negative dot product scores will be transformed as follows: + if (negativeDotProduct >= 0) { + return 1 / (1 + negativeDotProduct); + } + return -negativeDotProduct + 1; + + Internally, Lucene uses scaleMaxInnerProductScore to scale the raw dot product into a proper lucene score. + See: + 1. https://github.com/apache/lucene/blob/releases/lucene/9.10.0/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java#L195-L200 + 2. https://github.com/apache/lucene/blob/releases/lucene/9.10.0/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java#L90 + */ + final List dataVectors = Arrays.asList( + new float[] { 0.0f, 0.0f }, + new float[] { 0.25f, -0.25f }, + new float[] { 0.125f, -0.125f }, + new float[] { 25.0f, -25.0f }, + new float[] { -0.125f, 0.125f }, + new float[] { -0.25f, 0.25f }, + new float[] { -25.0f, 25.0f } + ); + float[] queryVector = new float[] { -2.0f, 2.0f }; + List dotProducts = List.of(0.0f, -1.0f, -0.5f, -100.0f, 0.5f, 1.0f, 100.0f); + + for (int i = 0; i < dataVectors.size(); i++) { + assertEquals( + KNNEngine.FAISS.score(dotProducts.get(i), SpaceType.INNER_PRODUCT), + SpaceType.INNER_PRODUCT.getVectorSimilarityFunction().compare(queryVector, dataVectors.get(i)), + 0.0000001 + ); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 95e90aa6a..992f24308 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -71,21 +71,21 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { - private final static String TEST_FIELD_NAME = "test-field-name"; + private static final String TEST_FIELD_NAME = "test-field-name"; - private final static int TEST_DIMENSION = 17; + private static final int TEST_DIMENSION = 17; - private final static float TEST_VECTOR_VALUE = 1.5f; + private static final float TEST_VECTOR_VALUE = 1.5f; - private final static float[] TEST_VECTOR = createInitializedFloatArray(TEST_DIMENSION, TEST_VECTOR_VALUE); + private static final float[] TEST_VECTOR = createInitializedFloatArray(TEST_DIMENSION, TEST_VECTOR_VALUE); - private final static byte TEST_BYTE_VECTOR_VALUE = 10; - private final static byte[] TEST_BYTE_VECTOR = createInitializedByteArray(TEST_DIMENSION, TEST_BYTE_VECTOR_VALUE); + private static final byte TEST_BYTE_VECTOR_VALUE = 10; + private static final byte[] TEST_BYTE_VECTOR = createInitializedByteArray(TEST_DIMENSION, TEST_BYTE_VECTOR_VALUE); - private final static BytesRef TEST_VECTOR_BYTES_REF = new BytesRef( + private static final BytesRef TEST_VECTOR_BYTES_REF = new BytesRef( KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(TEST_VECTOR) ); - private final static BytesRef TEST_BYTE_VECTOR_BYTES_REF = new BytesRef(TEST_BYTE_VECTOR); + private static final BytesRef TEST_BYTE_VECTOR_BYTES_REF = new BytesRef(TEST_BYTE_VECTOR); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final String TYPE_FIELD_NAME = "type"; @@ -759,8 +759,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField List fields = document.getFields(); @@ -798,7 +797,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); // Document should have 1 field: one for KnnVectorField fields = document.getFields(); @@ -827,7 +826,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField List fields = document.getFields(); @@ -864,7 +863,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); // Document should have 1 field: one for KnnByteVectorField fields = document.getFields(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3ea469ada..bcd784e23 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import java.util.Locale; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -222,6 +223,7 @@ public void testDoToQuery_Normal() throws Exception { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -238,6 +240,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -255,6 +258,7 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -272,6 +276,7 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); @@ -298,6 +303,7 @@ public void testDoToQuery_FromModel() { ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -334,6 +340,48 @@ public void testDoToQuery_InvalidFieldType() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + public void testDoToQuery_InvalidZeroFloatVector() { + float[] queryVector = { 0.0f, 0.0f, 0.0f, 0.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> knnQueryBuilder.doToQuery(mockQueryShardContext) + ); + assertEquals( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), + exception.getMessage() + ); + } + + public void testDoToQuery_InvalidZeroByteVector() { + float[] queryVector = { 0.0f, 0.0f, 0.0f, 0.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); + when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> knnQueryBuilder.doToQuery(mockQueryShardContext) + ); + assertEquals( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), + exception.getMessage() + ); + } + public void testSerialization() throws Exception { assertSerialization(Version.CURRENT, Optional.empty()); diff --git a/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java new file mode 100644 index 000000000..7816505de --- /dev/null +++ b/src/test/java/org/opensearch/knn/jni/PlatformUtilTests.java @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import com.sun.jna.Platform; +import org.mockito.MockedStatic; +import org.opensearch.knn.KNNTestCase; +import oshi.util.platform.mac.SysctlUtil; + +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.stream.Stream; + +import static org.mockito.Mockito.mockStatic; +import static org.opensearch.knn.jni.PlatformUtils.isAVX2SupportedBySystem; + +public class PlatformUtilTests extends KNNTestCase { + public static final String MAC_CPU_FEATURES = "machdep.cpu.leaf7_features"; + public static final String LINUX_PROC_CPU_INFO = "/proc/cpuinfo"; + + public void testIsAVX2SupportedBySystem_platformIsNotIntel_returnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(false); + assertFalse(isAVX2SupportedBySystem()); + } + } + + public void testIsAVX2SupportedBySystem_platformIsIntelWithOSAsWindows_returnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isWindows).thenReturn(true); + assertFalse(isAVX2SupportedBySystem()); + } + } + + public void testIsAVX2SupportedBySystem_platformIsMac_returnsTrue() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(true); + + try (MockedStatic mockedSysctlUtil = mockStatic(SysctlUtil.class)) { + mockedSysctlUtil.when(() -> SysctlUtil.sysctl(MAC_CPU_FEATURES, "empty")) + .thenReturn( + "RDWRFSGS TSC_THREAD_OFFSET SGX BMI1 AVX2 SMEP BMI2 ERMS INVPCID FPU_CSDS MPX RDSEED ADX SMAP CLFSOPT IPT SGXLC MDCLEAR TSXFA IBRS STIBP L1DF ACAPMSR SSBD" + ); + assertTrue(isAVX2SupportedBySystem()); + } + } + } + + public void testIsAVX2SupportedBySystem_platformIsMac_returnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(true); + + try (MockedStatic mockedSysctlUtil = mockStatic(SysctlUtil.class)) { + mockedSysctlUtil.when(() -> SysctlUtil.sysctl(MAC_CPU_FEATURES, "empty")).thenReturn("NO Flags"); + assertFalse(isAVX2SupportedBySystem()); + } + } + + } + + public void testIsAVX2SupportedBySystem_platformIsMac_throwsExceptionReturnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(true); + + try (MockedStatic mockedSysctlUtil = mockStatic(SysctlUtil.class)) { + mockedSysctlUtil.when(() -> SysctlUtil.sysctl(MAC_CPU_FEATURES, "empty")).thenThrow(RuntimeException.class); + assertFalse(isAVX2SupportedBySystem()); + } + } + + } + + public void testIsAVX2SupportedBySystem_platformIsLinux_returnsTrue() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(false); + mockedPlatform.when(Platform::isLinux).thenReturn(true); + + try (MockedStatic mockedFiles = mockStatic(Files.class)) { + mockedFiles.when(() -> Files.lines(Paths.get(LINUX_PROC_CPU_INFO))).thenReturn(Stream.of("flags: AVX2", "dummy string")); + assertTrue(isAVX2SupportedBySystem()); + } + } + } + + public void testIsAVX2SupportedBySystem_platformIsLinux_returnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(false); + mockedPlatform.when(Platform::isLinux).thenReturn(true); + + try (MockedStatic mockedFiles = mockStatic(Files.class)) { + mockedFiles.when(() -> Files.lines(Paths.get(LINUX_PROC_CPU_INFO))).thenReturn(Stream.of("flags: ", "dummy string")); + assertFalse(isAVX2SupportedBySystem()); + } + } + + } + + public void testIsAVX2SupportedBySystem_platformIsLinux_throwsExceptionReturnsFalse() { + try (MockedStatic mockedPlatform = mockStatic(Platform.class)) { + mockedPlatform.when(Platform::isIntel).thenReturn(true); + mockedPlatform.when(Platform::isMac).thenReturn(false); + mockedPlatform.when(Platform::isLinux).thenReturn(true); + + try (MockedStatic mockedPaths = mockStatic(Paths.class)) { + mockedPaths.when(() -> Paths.get(LINUX_PROC_CPU_INFO)).thenThrow(RuntimeException.class); + assertFalse(isAVX2SupportedBySystem()); + } + } + + } + +} diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index 24bc74ff4..52bc22eff 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -10,20 +10,21 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.index.mapper.NumberFieldMapper; -import java.util.ArrayList; import java.util.List; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldType.getDimension()).thenReturn(3); NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( "field", NumberFieldMapper.NumberType.LONG ); - List floatQueryObject = new ArrayList<>(); + List floatQueryObject = List.of(1.0f, 1.0f, 1.0f); Long longQueryObject = 0L; assertTrue( diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index c80949b43..6b40f375c 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -5,8 +5,10 @@ package org.opensearch.knn.plugin.script; +import java.util.Locale; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; @@ -59,11 +61,26 @@ public void testCosineSimilarity() { assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); + // invalid zero vector + final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> new KNNScoringSpace.CosineSimilarity(queryZeroVector, fieldType) + ); + assertEquals( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), + exception1.getMessage() + ); + NumberFieldMapper.NumberFieldType invalidFieldType = new NumberFieldMapper.NumberFieldType( "field", NumberFieldMapper.NumberType.INTEGER ); - expectThrows(IllegalArgumentException.class, () -> new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, invalidFieldType)); + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, invalidFieldType) + ); + assertEquals("Incompatible field_type for cosine space. The field type must be knn_vector.", exception2.getMessage()); } public void testInnerProdSimilarity() { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 4a2bb7254..8c43a4acf 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -5,8 +5,10 @@ package org.opensearch.knn.plugin.script; +import java.util.Locale; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.apache.lucene.tests.analysis.MockAnalyzer; @@ -22,17 +24,16 @@ import java.io.IOException; import java.math.BigInteger; -import java.util.ArrayList; import java.util.List; public class KNNScoringUtilTests extends KNNTestCase { private List getTestQueryVector() { - List queryVector = new ArrayList<>(); - queryVector.add(1.0f); - queryVector.add(1.0f); - queryVector.add(1.0f); - return queryVector; + return List.of(1.0f, 1.0f, 1.0f); + } + + private List getTestZeroVector() { + return List.of(0.0f, 0.0f, 0.0f); } public void testL2SquaredScoringFunction() { @@ -211,6 +212,24 @@ public void testScriptDocValuesFailsCosineSimilarity() throws IOException { dataset.close(); } + public void testZeroVectorFailsCosineSimilarity() throws IOException { + List queryVector = getTestZeroVector(); + TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + scriptDocValues.setNextDocId(0); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues) + ); + assertEquals( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), + exception.getMessage() + ); + dataset.close(); + } + public void testCosineSimilarityOptimizedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); @@ -227,7 +246,24 @@ public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOExcepti TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); - expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)); + dataset.close(); + } + + public void testZeroVectorFailsCosineSimilarityOptimized() throws IOException { + List queryVector = getTestZeroVector(); + TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); + dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + scriptDocValues.setNextDocId(0); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f) + ); + assertEquals( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", SpaceType.COSINESIMIL.getValue()), + exception.getMessage() + ); dataset.close(); } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 1d984927f..f0a6b1c3e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -17,6 +17,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -338,6 +339,28 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions, Str .toString(); } + /** + * Utility to create a Knn Index Mapping with specific algorithm, engine and spaceType + */ + protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine, String spaceType) + throws IOException { + return XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field(KNNConstants.TYPE, KNNConstants.TYPE_KNN_VECTOR) + .field(KNNConstants.DIMENSION, dimensions.toString()) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, algoName) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType) + .field(KNNConstants.KNN_ENGINE, knnEngine) + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + } + /** * Utility to create a Knn Index Mapping with multiple k-NN fields */