Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX] Fix bwc failure in neural sparse search #696

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438))
- Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578))
- Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615))
- Add max_token_score field placeholder in NeuralSparseQueryBuilder to fix the rolling-upgrade from 2.x nodes bwc tests. ([#696](https://github.com/opensearch-project/neural-search/pull/696))
zhichao-aws marked this conversation as resolved.
Show resolved Hide resolved
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
*/
package org.opensearch.neuralsearch.bwc;

import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.junit.Before;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
*/
package org.opensearch.neuralsearch.bwc;

import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.junit.Before;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import lombok.extern.log4j.Log4j2;

/**
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
*/
Expand All @@ -63,6 +63,11 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
// logics change, this field is not needed any more.
@VisibleForTesting
@Deprecated
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();

private static MLCommonsClientAccessor ML_CLIENT;

Expand All @@ -73,6 +78,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private String fieldName;
private String queryText;
private String modelId;
private Float maxTokenScore;
private Supplier<Map<String, Float>> queryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

Expand All @@ -91,6 +97,7 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
} else {
this.modelId = in.readString();
}
this.maxTokenScore = in.readOptionalFloat();
if (in.readBoolean()) {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
Expand All @@ -106,6 +113,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
} else {
out.writeString(this.modelId);
}
out.writeOptionalFloat(maxTokenScore);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
Expand All @@ -122,6 +130,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -131,7 +140,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* The expected parsing form looks like:
* "SAMPLE_FIELD": {
* "query_text": "string",
* "model_id": "string"
* "model_id": "string",
* "max_token_score": float (optional)
* }
*
* @param parser XContentParser
Expand Down Expand Up @@ -189,6 +199,8 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
sparseEncodingQueryBuilder.queryText(parser.text());
} else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
sparseEncodingQueryBuilder.modelId(parser.text());
} else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -227,6 +239,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return new NeuralSparseQueryBuilder().fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
.maxTokenScore(maxTokenScore)
.queryTokensSupplier(queryTokensSetOnce::get);
}

Expand Down Expand Up @@ -280,22 +293,23 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false;
if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId);
if (!Objects.isNull(queryTokensSupplier)) {
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId);
if (!Objects.isNull(queryTokensSupplier)) {
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
Expand All @@ -22,6 +23,9 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.client.Client;
Expand All @@ -37,9 +41,11 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
Expand All @@ -54,6 +60,7 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final float BOOST = 1.8f;
private static final String QUERY_NAME = "queryName";
private static final Float MAX_TOKEN_SCORE = 123f;
private static final Supplier<Map<String, Float>> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f);

@Before
Expand Down Expand Up @@ -121,6 +128,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() {
/*
{
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
"max_token_score": 123.0
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely");
}

@SneakyThrows
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
/*
Expand Down Expand Up @@ -248,7 +281,8 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
public void testToXContent() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT);
.queryText(QUERY_TEXT)
.maxTokenScore(MAX_TOKEN_SCORE);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -273,6 +307,7 @@ public void testToXContent() {

assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
}

public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
Expand All @@ -285,6 +320,7 @@ public void testStreams() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
original.maxTokenScore(MAX_TOKEN_SCORE);
original.modelId(MODEL_ID);
original.boost(BOOST);
original.queryName(QUERY_NAME);
Expand All @@ -306,11 +342,11 @@ public void testStreams() {
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

BytesStreamOutput streamOutput2 = new BytesStreamOutput();
original.writeTo(streamOutput2);
streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);

filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput2.bytes().streamInput(),
streamOutput.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
Expand All @@ -327,6 +363,8 @@ public void testHashAndEquals() {
String queryText2 = "query text 2";
String modelId1 = "model-1";
String modelId2 = "model-2";
float maxTokenScore1 = 1.1f;
float maxTokenScore2 = 2.2f;
float boost1 = 1.8f;
float boost2 = 3.8f;
String queryName1 = "query-1";
Expand All @@ -337,60 +375,77 @@ public void testHashAndEquals() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1);
.modelId(modelId1)
.maxTokenScore(maxTokenScore1);

// Identical to sparseEncodingQueryBuilder_baseline except diff field name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query text
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff boost
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost2)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName2);

// Identical to sparseEncodingQueryBuilder_baseline except diff max token score
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore2)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens1);
Expand All @@ -399,6 +454,7 @@ public void testHashAndEquals() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);
Expand Down Expand Up @@ -427,6 +483,9 @@ public void testHashAndEquals() {
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

Expand Down Expand Up @@ -486,4 +545,23 @@ private void setUpClusterService(Version version) {
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version);
NeuralSearchClusterUtil.instance().initialize(clusterService);
}

@SneakyThrows
public void testDoToQuery_successfulDoToQuery() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class);
MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class);
doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName();
doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any());

BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder();
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD);
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD);

assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build());
}
}
Loading