Skip to content

Commit

Permalink
support script score when doc value is disabled
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed Mar 21, 2024
1 parent e8c9ced commit a86cc0a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -39,10 +41,21 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
DocIdSetIterator values = null;
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
System.out.println(fieldInfo);
if (fieldInfo.hasVectorValues()) {
values = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32
? reader.getFloatVectorValues(fieldName)
: reader.getByteVectorValues(fieldName);
System.out.println("use vector values");
} else {
values = DocValues.getBinary(reader, fieldName);
System.out.println("use binary values");
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,29 @@

package org.opensearch.knn.index;

import java.io.IOException;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final BinaryDocValues binaryDocValues;
private final DocIdSetIterator vectorValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
}
docExists = false;
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
}

public float[] getValue() {
Expand All @@ -43,12 +42,14 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
return doGetValue();
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -58,4 +59,63 @@ public int size() {
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromDocValues(values.binaryValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void setUp() throws Exception {
createKNNVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
scriptDocValues = KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() {
createKNNFloatVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return new KNNVectorScriptDocValues(
return KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME,
VectorDataType.FLOAT
Expand All @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() {
createKNNByteVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return new KNNVectorScriptDocValues(
return KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME,
VectorDataType.BYTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx
if (scriptDocValues == null) {
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
scriptDocValues = KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(fieldName),
fieldName,
VectorDataType.FLOAT
Expand Down

0 comments on commit a86cc0a

Please sign in to comment.