diff --git a/CHANGELOG.md b/CHANGELOG.md index b858ede4b78cf..288515c1dffa4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - Add useCompoundFile index setting ([#13478](https://github.com/opensearch-project/OpenSearch/pull/13478)) - Make outbound side of transport protocol dependent ([#13293](https://github.com/opensearch-project/OpenSearch/pull/13293)) +- Add proto structures for Query and Fetch search result. ([#13178](https://github.com/opensearch-project/OpenSearch/pull/13178)) ### Dependencies - Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559)) diff --git a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java index 569f48a8465f3..6264d00f01887 100644 --- a/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java +++ b/libs/common/src/main/java/org/opensearch/common/annotation/processor/ApiAnnotationProcessor.java @@ -238,7 +238,15 @@ private boolean inspectable(ExecutableElement executable) { */ private boolean inspectable(Element element) { final PackageElement pckg = processingEnv.getElementUtils().getPackageOf(element); - return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE); + return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE) + && !element.getEnclosingElement() + .getAnnotationMirrors() + .stream() + .anyMatch( + m -> m.getAnnotationType() + .toString() /* ClassSymbol.toString() returns class name */ + .equalsIgnoreCase("javax.annotation.Generated") + ); } /** diff --git a/server/build.gradle b/server/build.gradle index 9714f13ec67d6..c325cbd746783 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -380,6 +380,7 @@ tasks.named("licenseHeaders").configure { excludes << 'org/opensearch/client/documentation/placeholder.txt' // Ignore for protobuf generated code excludes << 'org/opensearch/extensions/proto/*' + excludes << 'org/opensearch/server/proto/*' } tasks.test { diff --git a/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldDeserializer.java b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldDeserializer.java new file mode 100644 index 0000000000000..fbd00eded893f --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/DocumentFieldDeserializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.document.serializer; + +import org.opensearch.common.document.DocumentField; + +import java.io.IOException; + +/** + * Deserializer for {@link DocumentField} which can be implemented for different types of serde mechanisms. + */ +public interface DocumentFieldDeserializer { + + DocumentField createDocumentField(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/common/document/serializer/package-info.java b/server/src/main/java/org/opensearch/common/document/serializer/package-info.java new file mode 100644 index 0000000000000..e8419ac59bb03 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for documents. */ +package org.opensearch.common.document.serializer; diff --git a/server/src/main/java/org/opensearch/common/document/serializer/protobuf/DocumentFieldProtobufDeserializer.java b/server/src/main/java/org/opensearch/common/document/serializer/protobuf/DocumentFieldProtobufDeserializer.java new file mode 100644 index 0000000000000..dd5cc9450bf99 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/protobuf/DocumentFieldProtobufDeserializer.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.document.serializer.protobuf; + +import com.google.protobuf.ByteString; +import org.opensearch.OpenSearchException; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.document.serializer.DocumentFieldDeserializer; +import org.opensearch.core.common.text.Text; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue; +import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder; + +import java.io.IOException; +import java.io.InputStream; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Deserializer for {@link DocumentField} to/from protobuf. + */ +public class DocumentFieldProtobufDeserializer implements DocumentFieldDeserializer { + + private FetchSearchResultProto.SearchHit.DocumentField documentField; + + @Override + public DocumentField createDocumentField(InputStream inputStream) throws IOException { + documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(inputStream); + String name = documentField.getName(); + List values = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) { + values.add(readDocumentFieldValueFromProtobuf(value)); + } + return new DocumentField(name, values); + } + + private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException { + if (documentFieldValue.hasValueString()) { + return documentFieldValue.getValueString(); + } else if (documentFieldValue.hasValueInt()) { + return documentFieldValue.getValueInt(); + } else if (documentFieldValue.hasValueLong()) { + return documentFieldValue.getValueLong(); + } else if (documentFieldValue.hasValueFloat()) { + return documentFieldValue.getValueFloat(); + } else if (documentFieldValue.hasValueDouble()) { + return documentFieldValue.getValueDouble(); + } else if (documentFieldValue.hasValueBool()) { + return documentFieldValue.getValueBool(); + } else if (documentFieldValue.getValueByteArrayList().size() > 0) { + return documentFieldValue.getValueByteArrayList().toArray(); + } else if (documentFieldValue.getValueArrayListList().size() > 0) { + List list = new ArrayList<>(); + for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) { + list.add(readDocumentFieldValueFromProtobuf(value)); + } + return list; + } else if (documentFieldValue.getValueMapMap().size() > 0) { + Map map = Map.of(); + for (Map.Entry entrySet : documentFieldValue.getValueMapMap().entrySet()) { + map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue())); + } + return map; + } else if (documentFieldValue.hasValueDate()) { + return new Date(documentFieldValue.getValueDate()); + } else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) { + return ZonedDateTime.ofInstant( + Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()), + ZoneId.of(documentFieldValue.getValueZonedDate()) + ); + } else if (documentFieldValue.hasValueText()) { + return new Text(documentFieldValue.getValueText()); + } else { + throw new IOException("Can't read generic value of type [" + documentFieldValue + "]"); + } + } + + public static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) { + if (value == null) { + // null is not allowed in protobuf, so we use a special string to represent null + return valueBuilder.setValueString("null"); + } + Class type = value.getClass(); + if (type == String.class) { + valueBuilder.setValueString((String) value); + } else if (type == Integer.class) { + valueBuilder.setValueInt((Integer) value); + } else if (type == Long.class) { + valueBuilder.setValueLong((Long) value); + } else if (type == Float.class) { + valueBuilder.setValueFloat((Float) value); + } else if (type == Double.class) { + valueBuilder.setValueDouble((Double) value); + } else if (type == Boolean.class) { + valueBuilder.setValueBool((Boolean) value); + } else if (type == byte[].class) { + valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value)); + } else if (type == List.class) { + List list = (List) value; + for (Object listValue : list) { + valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder)); + } + } else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) { + Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build()); + } + } else if (type == Date.class) { + valueBuilder.setValueDate(((Date) value).getTime()); + } else if (type == ZonedDateTime.class) { + valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId()); + valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli()); + } else if (type == Text.class) { + valueBuilder.setValueText(((Text) value).string()); + } else { + throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf"); + } + return valueBuilder; + } + + public static FetchSearchResultProto.SearchHit.DocumentField convertDocumentFieldToProto(DocumentField documentField) { + FetchSearchResultProto.SearchHit.DocumentField.Builder builder = FetchSearchResultProto.SearchHit.DocumentField.newBuilder(); + builder.setName(documentField.getName()); + for (Object value : documentField.getValues()) { + FetchSearchResultProto.DocumentFieldValue.Builder valueBuilder = FetchSearchResultProto.DocumentFieldValue.newBuilder(); + builder.addValues(convertDocumentFieldValueToProto(value, valueBuilder)); + } + return builder.build(); + } + +} diff --git a/server/src/main/java/org/opensearch/common/document/serializer/protobuf/package-info.java b/server/src/main/java/org/opensearch/common/document/serializer/protobuf/package-info.java new file mode 100644 index 0000000000000..690bc7642390d --- /dev/null +++ b/server/src/main/java/org/opensearch/common/document/serializer/protobuf/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Protobuf Serializer package for documents. */ +package org.opensearch.common.document.serializer.protobuf; diff --git a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java index a72583607ede0..154ec822d5a19 100644 --- a/server/src/main/java/org/opensearch/common/util/FeatureFlags.java +++ b/server/src/main/java/org/opensearch/common/util/FeatureFlags.java @@ -67,6 +67,11 @@ public class FeatureFlags { */ public static final String PLUGGABLE_CACHE = "opensearch.experimental.feature.pluggable.caching.enabled"; + /** + * Gates the functionality of integrating protobuf within search API and node-to-node communication. + */ + public static final String PROTOBUF = "opensearch.experimental.feature.search_with_protobuf.enabled"; + public static final Setting REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING = Setting.boolSetting( REMOTE_STORE_MIGRATION_EXPERIMENTAL, false, @@ -93,6 +98,8 @@ public class FeatureFlags { public static final Setting PLUGGABLE_CACHE_SETTING = Setting.boolSetting(PLUGGABLE_CACHE, false, Property.NodeScope); + public static final Setting PROTOBUF_SETTING = Setting.boolSetting(PROTOBUF, false, Property.NodeScope, Property.Dynamic); + private static final List> ALL_FEATURE_FLAG_SETTINGS = List.of( REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING, EXTENSIONS_SETTING, @@ -100,7 +107,8 @@ public class FeatureFlags { TELEMETRY_SETTING, DATETIME_FORMATTER_CACHING_SETTING, WRITEABLE_REMOTE_INDEX_SETTING, - PLUGGABLE_CACHE_SETTING + PLUGGABLE_CACHE_SETTING, + PROTOBUF_SETTING ); /** * Should store the settings from opensearch.yml. diff --git a/server/src/main/java/org/opensearch/search/SearchSortValues.java b/server/src/main/java/org/opensearch/search/SearchSortValues.java index cbc3900f72f79..1957c79c98649 100644 --- a/server/src/main/java/org/opensearch/search/SearchSortValues.java +++ b/server/src/main/java/org/opensearch/search/SearchSortValues.java @@ -67,6 +67,16 @@ public class SearchSortValues implements ToXContentFragment, Writeable { this.rawSortValues = EMPTY_ARRAY; } + public SearchSortValues(Object[] sortValues, Object[] rawSortValues) { + Objects.requireNonNull(rawSortValues); + Objects.requireNonNull(sortValues); + if (rawSortValues.length != sortValues.length) { + throw new IllegalArgumentException("formattedSortValues and sortValues must hold the same number of items"); + } + this.formattedSortValues = sortValues; + this.rawSortValues = rawSortValues; + } + public SearchSortValues(Object[] rawSortValues, DocValueFormat[] sortValueFormats) { Objects.requireNonNull(rawSortValues); Objects.requireNonNull(sortValueFormats); diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java index 26fa90141c2a9..fd5e2b9cdd250 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchSearchResult.java @@ -33,6 +33,7 @@ package org.opensearch.search.fetch; import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.SearchHit; @@ -41,8 +42,13 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.serializer.protobuf.SearchHitsProtobufDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.ShardSearchRequestProto; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; /** * Result from a fetch @@ -56,6 +62,8 @@ public final class FetchSearchResult extends SearchPhaseResult { // client side counter private transient int counter; + private FetchSearchResultProto.FetchSearchResult fetchSearchResultProto; + public FetchSearchResult() {} public FetchSearchResult(StreamInput in) throws IOException { @@ -64,9 +72,24 @@ public FetchSearchResult(StreamInput in) throws IOException { hits = new SearchHits(in); } + public FetchSearchResult(InputStream in) throws IOException { + this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in); + contextId = new ShardSearchContextId( + this.fetchSearchResultProto.getContextId().getSessionId(), + this.fetchSearchResultProto.getContextId().getId() + ); + SearchHitsProtobufDeserializer protobufSerializer = new SearchHitsProtobufDeserializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); + } + public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) { this.contextId = id; setSearchShardTarget(shardTarget); + this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.newBuilder() + .setContextId( + ShardSearchRequestProto.ShardSearchContextId.newBuilder().setSessionId(id.getSessionId()).setId(id.getId()).build() + ) + .build(); } @Override @@ -82,6 +105,11 @@ public FetchSearchResult fetchResult() { public void hits(SearchHits hits) { assert assertNoSearchTarget(hits); this.hits = hits; + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { + this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder() + .setHits(SearchHitsProtobufDeserializer.convertHitsToProto(hits)) + .build(); + } } private boolean assertNoSearchTarget(SearchHits hits) { @@ -92,6 +120,16 @@ private boolean assertNoSearchTarget(SearchHits hits) { } public SearchHits hits() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) { + SearchHits hits; + try { + SearchHitsProtobufDeserializer protobufSerializer = new SearchHitsProtobufDeserializer(); + hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray())); + return hits; + } catch (IOException e) { + throw new RuntimeException(e); + } + } return hits; } diff --git a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java index ce4c59fc77489..cd1b713fc6b56 100644 --- a/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java +++ b/server/src/main/java/org/opensearch/search/fetch/QueryFetchSearchResult.java @@ -38,8 +38,10 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.server.proto.QueryFetchSearchResultProto; import java.io.IOException; +import java.io.InputStream; /** * Query fetch result @@ -51,12 +53,20 @@ public final class QueryFetchSearchResult extends SearchPhaseResult { private final QuerySearchResult queryResult; private final FetchSearchResult fetchResult; + private QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResultProto; + public QueryFetchSearchResult(StreamInput in) throws IOException { super(in); queryResult = new QuerySearchResult(in); fetchResult = new FetchSearchResult(in); } + public QueryFetchSearchResult(InputStream in) throws IOException { + this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in); + queryResult = new QuerySearchResult(in); + fetchResult = new FetchSearchResult(in); + } + public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) { this.queryResult = queryResult; this.fetchResult = fetchResult; diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldDeserializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldDeserializer.java new file mode 100644 index 0000000000000..8ede74986ff9f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/HighlightFieldDeserializer.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer; + +import org.opensearch.search.fetch.subphase.highlight.HighlightField; + +import java.io.IOException; + +/** + * Deserializer for {@link HighlightField} which can be implemented for different types of serde mechanisms. + */ +public interface HighlightFieldDeserializer { + + HighlightField createHighLightField(T inputStream) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java new file mode 100644 index 0000000000000..dc08282a8954f --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for highlights. */ +package org.opensearch.search.fetch.subphase.highlight.serializer; diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/HighlightFieldProtobufDeserializer.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/HighlightFieldProtobufDeserializer.java new file mode 100644 index 0000000000000..d3023a7342f7d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/HighlightFieldProtobufDeserializer.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.fetch.subphase.highlight.serializer.protobuf; + +import org.opensearch.core.common.text.Text; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.search.fetch.subphase.highlight.serializer.HighlightFieldDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Deserializer for {@link HighlightField} to/from protobuf. + */ +public class HighlightFieldProtobufDeserializer implements HighlightFieldDeserializer { + + @Override + public HighlightField createHighLightField(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.HighlightField highlightField = FetchSearchResultProto.SearchHit.HighlightField.parseFrom( + inputStream + ); + String name = highlightField.getName(); + Text[] fragments = Text.EMPTY_ARRAY; + if (highlightField.getFragmentsCount() > 0) { + List values = new ArrayList<>(); + for (String fragment : highlightField.getFragmentsList()) { + values.add(new Text(fragment)); + } + fragments = values.toArray(new Text[0]); + } + return new HighlightField(name, fragments); + } + +} diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/package-info.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/package-info.java new file mode 100644 index 0000000000000..796383b48afda --- /dev/null +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/serializer/protobuf/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Protobuf Serializer package for highlights. */ +package org.opensearch.search.fetch.subphase.highlight.serializer.protobuf; diff --git a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java index f3ac953ab9d1d..453d4aa9ad3ef 100644 --- a/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/opensearch/search/query/QuerySearchResult.java @@ -33,10 +33,14 @@ package org.opensearch.search.query; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.io.stream.DelayableWriteable; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.search.DocValueFormat; @@ -50,8 +54,14 @@ import org.opensearch.search.profile.NetworkTime; import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.suggest.Suggest; +import org.opensearch.server.proto.QuerySearchResultProto; +import org.opensearch.server.proto.ShardSearchRequestProto; +import org.opensearch.server.proto.ShardSearchRequestProto.AliasFilter; +import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest.SearchType; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import static org.opensearch.common.lucene.Lucene.readTopDocs; import static org.opensearch.common.lucene.Lucene.writeTopDocs; @@ -90,6 +100,8 @@ public final class QuerySearchResult extends SearchPhaseResult { private final boolean isNull; + private QuerySearchResultProto.QuerySearchResult querySearchResultProto; + public QuerySearchResult() { this(false); } @@ -103,11 +115,89 @@ public QuerySearchResult(StreamInput in) throws IOException { } } + public QuerySearchResult(InputStream in) throws IOException { + this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.parseFrom(in); + isNull = this.querySearchResultProto.getIsNull(); + if (!isNull) { + this.contextId = new ShardSearchContextId( + this.querySearchResultProto.getContextId().getSessionId(), + this.querySearchResultProto.getContextId().getId() + ); + } + } + public QuerySearchResult(ShardSearchContextId contextId, SearchShardTarget shardTarget, ShardSearchRequest shardSearchRequest) { this.contextId = contextId; setSearchShardTarget(shardTarget); isNull = false; setShardSearchRequest(shardSearchRequest); + + ShardSearchRequestProto.ShardId shardIdProto = ShardSearchRequestProto.ShardId.newBuilder() + .setShardId(shardTarget.getShardId().getId()) + .setHashCode(shardTarget.getShardId().hashCode()) + .setIndexName(shardTarget.getShardId().getIndexName()) + .setIndexUUID(shardTarget.getShardId().getIndex().getUUID()) + .build(); + QuerySearchResultProto.SearchShardTarget.Builder searchShardTarget = QuerySearchResultProto.SearchShardTarget.newBuilder() + .setNodeId(shardTarget.getNodeId()) + .setShardId(shardIdProto); + ShardSearchRequestProto.ShardSearchContextId shardSearchContextId = ShardSearchRequestProto.ShardSearchContextId.newBuilder() + .setSessionId(contextId.getSessionId()) + .setId(contextId.getId()) + .build(); + ShardSearchRequestProto.ShardSearchRequest.Builder shardSearchRequestProto = ShardSearchRequestProto.ShardSearchRequest + .newBuilder(); + if (shardSearchRequest != null) { + ShardSearchRequestProto.OriginalIndices.Builder originalIndices = ShardSearchRequestProto.OriginalIndices.newBuilder(); + if (shardSearchRequest.indices() != null) { + for (String index : shardSearchRequest.indices()) { + originalIndices.addIndices(index); + } + originalIndices.setIndicesOptions( + ShardSearchRequestProto.OriginalIndices.IndicesOptions.newBuilder() + .setIgnoreUnavailable(shardSearchRequest.indicesOptions().ignoreUnavailable()) + .setAllowNoIndices(shardSearchRequest.indicesOptions().allowNoIndices()) + .setExpandWildcardsOpen(shardSearchRequest.indicesOptions().expandWildcardsOpen()) + .setExpandWildcardsClosed(shardSearchRequest.indicesOptions().expandWildcardsClosed()) + .setExpandWildcardsHidden(shardSearchRequest.indicesOptions().expandWildcardsHidden()) + .setAllowAliasesToMultipleIndices(shardSearchRequest.indicesOptions().allowAliasesToMultipleIndices()) + .setForbidClosedIndices(shardSearchRequest.indicesOptions().forbidClosedIndices()) + .setIgnoreAliases(shardSearchRequest.indicesOptions().ignoreAliases()) + .setIgnoreThrottled(shardSearchRequest.indicesOptions().ignoreThrottled()) + .build() + ); + } + AliasFilter.Builder aliasFilter = AliasFilter.newBuilder(); + if (shardSearchRequest.getAliasFilter() != null) { + for (int i = 0; i < shardSearchRequest.getAliasFilter().getAliases().length; i++) { + aliasFilter.addAliases(shardSearchRequest.getAliasFilter().getAliases()[i]); + } + } + shardSearchRequestProto.setInboundNetworkTime(shardSearchRequest.getInboundNetworkTime()) + .setOutboundNetworkTime(shardSearchRequest.getOutboundNetworkTime()) + .setShardId(shardIdProto) + .setAllowPartialSearchResults(shardSearchRequest.allowPartialSearchResults()) + .setNumberOfShards(shardSearchRequest.numberOfShards()) + .setReaderId(shardSearchContextId) + .setOriginalIndices(originalIndices) + .setSearchType(SearchType.QUERY_THEN_FETCH) + .setAliasFilter(aliasFilter); + if (shardSearchRequest.keepAlive() != null) { + shardSearchRequestProto.setTimeValue(shardSearchRequest.keepAlive().getStringRep()); + } + } + + if (shardTarget.getClusterAlias() != null) { + searchShardTarget.setClusterAlias(shardTarget.getClusterAlias()); + } + + this.querySearchResultProto = QuerySearchResultProto.QuerySearchResult.newBuilder() + .setContextId(shardSearchContextId) + .setSearchShardTarget(searchShardTarget.build()) + .setSearchShardRequest(shardSearchRequestProto.build()) + .setHasAggs(false) + .setIsNull(isNull) + .build(); } private QuerySearchResult(boolean isNull) { @@ -157,9 +247,33 @@ public Boolean terminatedEarly() { } public TopDocsAndMaxScore topDocs() { - if (topDocsAndMaxScore == null) { + if (topDocsAndMaxScore == null && this.querySearchResultProto.getTopDocsAndMaxScore() == null) { throw new IllegalStateException("topDocs already consumed"); } + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + ScoreDoc[] scoreDocs = new ScoreDoc[this.querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getScoreDocsCount()]; + for (int i = 0; i < scoreDocs.length; i++) { + org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDoc scoreDoc = this.querySearchResultProto + .getTopDocsAndMaxScore() + .getTopDocs() + .getScoreDocsList() + .get(i); + scoreDocs[i] = new ScoreDoc(scoreDoc.getDoc(), scoreDoc.getScore(), scoreDoc.getShardIndex()); + } + TopDocs topDocsFromProtobuf = new TopDocs( + new TotalHits( + this.querySearchResultProto.getTotalHits().getValue(), + Relation.valueOf(this.querySearchResultProto.getTotalHits().getRelation().toString()) + ), + scoreDocs + ); + + TopDocsAndMaxScore topDocsFromProtobufAndMaxScore = new TopDocsAndMaxScore( + topDocsFromProtobuf, + this.querySearchResultProto.getMaxScore() + ); + return topDocsFromProtobufAndMaxScore; + } return topDocsAndMaxScore; } @@ -289,6 +403,9 @@ public void suggest(Suggest suggest) { } public int from() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return this.querySearchResultProto.getFrom(); + } return from; } @@ -301,6 +418,9 @@ public QuerySearchResult from(int from) { * Returns the maximum size of this results top docs. */ public int size() { + if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING)) { + return this.querySearchResultProto.getSize(); + } return size; } @@ -377,6 +497,12 @@ public void writeTo(StreamOutput out) throws IOException { } } + public void writeTo(OutputStream out) throws IOException { + if (!isNull) { + out.write(this.querySearchResultProto.toByteArray()); + } + } + public void writeToNoId(StreamOutput out) throws IOException { out.writeVInt(from); out.writeVInt(size); @@ -417,4 +543,8 @@ public TotalHits getTotalHits() { public float getMaxScore() { return maxScore; } + + public QuerySearchResultProto.QuerySearchResult response() { + return this.querySearchResultProto; + } } diff --git a/server/src/main/java/org/opensearch/search/serializer/NestedIdentityDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityDeserializer.java new file mode 100644 index 0000000000000..f707921b81ca8 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/NestedIdentityDeserializer.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit.NestedIdentity; + +import java.io.IOException; + +/** + * Deserializer for {@link NestedIdentity} which can be implemented for different types of serde mechanisms. + */ +public interface NestedIdentityDeserializer { + + public NestedIdentity createNestedIdentity(T inputStream) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitDeserializer.java new file mode 100644 index 0000000000000..6dbf3a8fd4e05 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitDeserializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHit; + +import java.io.IOException; + +/** + * Deserializer for {@link SearchHit} which can be implemented for different types of serde mechanisms. + */ +public interface SearchHitDeserializer { + + SearchHit createSearchHit(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchHitsDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchHitsDeserializer.java new file mode 100644 index 0000000000000..6457eb1d29ad2 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchHitsDeserializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchHits; + +import java.io.IOException; + +/** + * Deserializer for {@link SearchHits} which can be implemented for different types of serde mechanisms. + */ +public interface SearchHitsDeserializer { + + SearchHits createSearchHits(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesDeserializer.java new file mode 100644 index 0000000000000..2aafe485b9d59 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/SearchSortValuesDeserializer.java @@ -0,0 +1,22 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer; + +import org.opensearch.search.SearchSortValues; + +import java.io.IOException; + +/** + * Deserializer for {@link SearchSortValues} which can be implemented for different types of serde mechanisms. + */ +public interface SearchSortValuesDeserializer { + + SearchSortValues createSearchSortValues(T inputStream) throws IOException; + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/package-info.java b/server/src/main/java/org/opensearch/search/serializer/package-info.java new file mode 100644 index 0000000000000..25a4d1935016e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serializer package for search. */ +package org.opensearch.search.serializer; diff --git a/server/src/main/java/org/opensearch/search/serializer/protobuf/NestedIdentityProtobufDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/protobuf/NestedIdentityProtobufDeserializer.java new file mode 100644 index 0000000000000..a7e26e8c0d1ae --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/protobuf/NestedIdentityProtobufDeserializer.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer.protobuf; + +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.search.serializer.NestedIdentityDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * Deserializer for {@link NestedIdentity} to/from protobuf. + */ +public class NestedIdentityProtobufDeserializer implements NestedIdentityDeserializer { + + @Override + public NestedIdentity createNestedIdentity(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.NestedIdentity proto = FetchSearchResultProto.SearchHit.NestedIdentity.parseFrom(inputStream); + String field; + int offset; + NestedIdentity child; + if (proto.hasField()) { + field = proto.getField(); + } else { + field = null; + } + if (proto.hasOffset()) { + offset = proto.getOffset(); + } else { + offset = -1; + } + if (proto.hasChild()) { + child = createNestedIdentity(new ByteArrayInputStream(proto.getChild().toByteArray())); + } else { + child = null; + } + return new NestedIdentity(field, offset, child); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitProtobufDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitProtobufDeserializer.java new file mode 100644 index 0000000000000..29bc523473c5d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitProtobufDeserializer.java @@ -0,0 +1,202 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer.protobuf; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.Explanation; +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.document.serializer.protobuf.DocumentFieldProtobufDeserializer; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHit.NestedIdentity; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.SearchSortValues; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.search.fetch.subphase.highlight.serializer.protobuf.HighlightFieldProtobufDeserializer; +import org.opensearch.search.serializer.SearchHitDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Deserializer for {@link SearchHit} to/from protobuf. + */ +public class SearchHitProtobufDeserializer implements SearchHitDeserializer { + + private FetchSearchResultProto.SearchHit searchHitProto; + + @Override + public SearchHit createSearchHit(InputStream inputStream) throws IOException { + this.searchHitProto = FetchSearchResultProto.SearchHit.parseFrom(inputStream); + int docId = -1; + float score = this.searchHitProto.getScore(); + String id = this.searchHitProto.getId(); + NestedIdentity nestedIdentity; + if (!this.searchHitProto.hasNestedIdentity() && this.searchHitProto.getNestedIdentity().toByteArray().length > 0) { + NestedIdentityProtobufDeserializer protobufSerializer = new NestedIdentityProtobufDeserializer(); + nestedIdentity = protobufSerializer.createNestedIdentity( + new ByteArrayInputStream(this.searchHitProto.getNestedIdentity().toByteArray()) + ); + } else { + nestedIdentity = null; + } + long version = this.searchHitProto.getVersion(); + long seqNo = this.searchHitProto.getSeqNo(); + long primaryTerm = this.searchHitProto.getPrimaryTerm(); + BytesReference source = BytesReference.fromByteBuffer(ByteBuffer.wrap(this.searchHitProto.getSource().toByteArray())); + if (source.length() == 0) { + source = null; + } + Map documentFields = new HashMap<>(); + DocumentFieldProtobufDeserializer protobufSerializer = new DocumentFieldProtobufDeserializer(); + this.searchHitProto.getDocumentFieldsMap().forEach((k, v) -> { + try { + documentFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map metaFields = new HashMap<>(); + this.searchHitProto.getMetaFieldsMap().forEach((k, v) -> { + try { + metaFields.put(k, protobufSerializer.createDocumentField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse document field", e); + } + }); + Map highlightFields = new HashMap<>(); + HighlightFieldProtobufDeserializer highlightFieldProtobufSerializer = new HighlightFieldProtobufDeserializer(); + this.searchHitProto.getHighlightFieldsMap().forEach((k, v) -> { + try { + highlightFields.put(k, highlightFieldProtobufSerializer.createHighLightField(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse highlight field", e); + } + }); + SearchSortValuesProtobufDeserializer sortValueProtobufSerializer = new SearchSortValuesProtobufDeserializer(); + SearchSortValues sortValues = sortValueProtobufSerializer.createSearchSortValues( + new ByteArrayInputStream(this.searchHitProto.getSortValues().toByteArray()) + ); + Map matchedQueries = new HashMap<>(); + if (this.searchHitProto.getMatchedQueriesCount() > 0) { + matchedQueries = new LinkedHashMap<>(this.searchHitProto.getMatchedQueriesCount()); + for (String query : this.searchHitProto.getMatchedQueriesList()) { + matchedQueries.put(query, Float.NaN); + } + } + if (this.searchHitProto.getMatchedQueriesWithScoresCount() > 0) { + Map tempMap = this.searchHitProto.getMatchedQueriesWithScoresMap() + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue())); + matchedQueries = tempMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByKey()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new)); + } + Explanation explanation = null; + if (this.searchHitProto.hasExplanation()) { + explanation = readExplanation(this.searchHitProto.getExplanation().toByteArray()); + } + SearchShardTarget searchShardTarget = new SearchShardTarget( + this.searchHitProto.getShard().getNodeId(), + new ShardId( + this.searchHitProto.getShard().getShardId().getIndexName(), + this.searchHitProto.getShard().getShardId().getIndexUUID(), + this.searchHitProto.getShard().getShardId().getShardId() + ), + this.searchHitProto.getShard().getClusterAlias(), + OriginalIndices.NONE + ); + Map innerHits; + if (this.searchHitProto.getInnerHitsCount() > 0) { + innerHits = new HashMap<>(); + this.searchHitProto.getInnerHitsMap().forEach((k, v) -> { + try { + SearchHitsProtobufDeserializer protobufHitsFactory = new SearchHitsProtobufDeserializer(); + innerHits.put(k, protobufHitsFactory.createSearchHits(new ByteArrayInputStream(v.toByteArray()))); + } catch (IOException e) { + throw new OpenSearchParseException("failed to parse inner hits", e); + } + }); + } else { + innerHits = null; + } + SearchHit searchHit = new SearchHit(docId, id, nestedIdentity, documentFields, metaFields); + searchHit.score(score); + searchHit.version(version); + searchHit.setSeqNo(seqNo); + searchHit.setPrimaryTerm(primaryTerm); + searchHit.sourceRef(source); + searchHit.highlightFields(highlightFields); + searchHit.sortValues(sortValues); + searchHit.matchedQueriesWithScores(matchedQueries); + searchHit.explanation(explanation); + searchHit.shard(searchShardTarget); + searchHit.setInnerHits(innerHits); + return searchHit; + } + + public static FetchSearchResultProto.SearchHit convertHitToProto(SearchHit hit) { + FetchSearchResultProto.SearchHit.Builder searchHitBuilder = FetchSearchResultProto.SearchHit.newBuilder(); + if (hit.getIndex() != null) { + searchHitBuilder.setIndex(hit.getIndex()); + } + searchHitBuilder.setId(hit.getId()); + searchHitBuilder.setScore(hit.getScore()); + searchHitBuilder.setSeqNo(hit.getSeqNo()); + searchHitBuilder.setPrimaryTerm(hit.getPrimaryTerm()); + searchHitBuilder.setVersion(hit.getVersion()); + searchHitBuilder.setDocId(hit.docId()); + if (hit.getSourceRef() != null) { + searchHitBuilder.setSource(ByteString.copyFrom(hit.getSourceRef().toBytesRef().bytes)); + } + for (Map.Entry entry : hit.getFields().entrySet()) { + searchHitBuilder.putDocumentFields( + entry.getKey(), + DocumentFieldProtobufDeserializer.convertDocumentFieldToProto(entry.getValue()) + ); + } + return searchHitBuilder.build(); + } + + public Explanation readExplanation(byte[] in) throws IOException { + FetchSearchResultProto.SearchHit.Explanation explanationProto = FetchSearchResultProto.SearchHit.Explanation.parseFrom(in); + boolean match = explanationProto.getMatch(); + String description = explanationProto.getDescription(); + final Explanation[] subExplanations = new Explanation[explanationProto.getSubExplanationsCount()]; + for (int i = 0; i < subExplanations.length; ++i) { + subExplanations[i] = readExplanation(explanationProto.getSubExplanations(i).toByteArray()); + } + Number explanationValue = null; + if (explanationProto.hasValue1()) { + explanationValue = explanationProto.getValue1(); + } else if (explanationProto.hasValue2()) { + explanationValue = explanationProto.getValue2(); + } else if (explanationProto.hasValue3()) { + explanationValue = explanationProto.getValue3(); + } + if (match) { + return Explanation.match(explanationValue, description, subExplanations); + } else { + return Explanation.noMatch(description, subExplanations); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitsProtobufDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitsProtobufDeserializer.java new file mode 100644 index 0000000000000..94ee61aa58621 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchHitsProtobufDeserializer.java @@ -0,0 +1,156 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer.protobuf; + +import com.google.protobuf.ByteString; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.TotalHits.Relation; +import org.apache.lucene.util.BytesRef; +import org.opensearch.OpenSearchException; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.serializer.SearchHitsDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; +import org.opensearch.server.proto.QuerySearchResultProto; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +/** + * Deserializer for {@link SearchHits} to/from protobuf. + */ +public class SearchHitsProtobufDeserializer implements SearchHitsDeserializer { + + private FetchSearchResultProto.SearchHits searchHitsProto; + + @Override + public SearchHits createSearchHits(InputStream inputStream) throws IOException { + this.searchHitsProto = FetchSearchResultProto.SearchHits.parseFrom(inputStream); + SearchHit[] hits = new SearchHit[this.searchHitsProto.getHitsCount()]; + SearchHitProtobufDeserializer protobufSerializer = new SearchHitProtobufDeserializer(); + for (int i = 0; i < this.searchHitsProto.getHitsCount(); i++) { + hits[i] = protobufSerializer.createSearchHit(new ByteArrayInputStream(this.searchHitsProto.getHits(i).toByteArray())); + } + TotalHits totalHits = new TotalHits( + this.searchHitsProto.getTotalHits().getValue(), + Relation.valueOf(this.searchHitsProto.getTotalHits().getRelation().toString()) + ); + float maxScore = this.searchHitsProto.getMaxScore(); + SortField[] sortFields = this.searchHitsProto.getSortFieldsList() + .stream() + .map(sortField -> new SortField(sortField.getField(), SortField.Type.valueOf(sortField.getType().toString()))) + .toArray(SortField[]::new); + String collapseField = this.searchHitsProto.getCollapseField(); + Object[] collapseValues = new Object[this.searchHitsProto.getCollapseValuesCount()]; + for (int i = 0; i < this.searchHitsProto.getCollapseValuesCount(); i++) { + collapseValues[i] = readSortValueFromProtobuf(this.searchHitsProto.getCollapseValues(i)); + } + return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues); + } + + public static Object readSortValueFromProtobuf(FetchSearchResultProto.SortValue collapseValue) throws IOException { + if (collapseValue.hasCollapseString()) { + return collapseValue.getCollapseString(); + } else if (collapseValue.hasCollapseInt()) { + return collapseValue.getCollapseInt(); + } else if (collapseValue.hasCollapseLong()) { + return collapseValue.getCollapseLong(); + } else if (collapseValue.hasCollapseFloat()) { + return collapseValue.getCollapseFloat(); + } else if (collapseValue.hasCollapseDouble()) { + return collapseValue.getCollapseDouble(); + } else if (collapseValue.hasCollapseBytes()) { + return new BytesRef(collapseValue.getCollapseBytes().toByteArray()); + } else if (collapseValue.hasCollapseBool()) { + return collapseValue.getCollapseBool(); + } else { + throw new IOException("Can't handle sort field value of type [" + collapseValue + "]"); + } + } + + public static FetchSearchResultProto.SearchHits convertHitsToProto(SearchHits hits) { + List searchHitList = new ArrayList<>(); + for (SearchHit hit : hits) { + searchHitList.add(SearchHitProtobufDeserializer.convertHitToProto(hit)); + } + QuerySearchResultProto.TotalHits.Builder totalHitsBuilder = QuerySearchResultProto.TotalHits.newBuilder(); + if (hits.getTotalHits() != null) { + totalHitsBuilder.setValue(hits.getTotalHits().value); + totalHitsBuilder.setRelation( + hits.getTotalHits().relation == Relation.EQUAL_TO + ? QuerySearchResultProto.TotalHits.Relation.EQUAL_TO + : QuerySearchResultProto.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + ); + } + FetchSearchResultProto.SearchHits.Builder searchHitsBuilder = FetchSearchResultProto.SearchHits.newBuilder(); + searchHitsBuilder.setMaxScore(hits.getMaxScore()); + searchHitsBuilder.addAllHits(searchHitList); + searchHitsBuilder.setTotalHits(totalHitsBuilder.build()); + if (hits.getSortFields() != null && hits.getSortFields().length > 0) { + for (SortField sortField : hits.getSortFields()) { + FetchSearchResultProto.SortField.Builder sortFieldBuilder = FetchSearchResultProto.SortField.newBuilder(); + if (sortField.getField() != null) { + sortFieldBuilder.setField(sortField.getField()); + } + sortFieldBuilder.setType(FetchSearchResultProto.SortField.Type.valueOf(sortField.getType().name())); + searchHitsBuilder.addSortFields(sortFieldBuilder.build()); + } + } + if (hits.getCollapseField() != null) { + searchHitsBuilder.setCollapseField(hits.getCollapseField()); + for (Object value : hits.getCollapseValues()) { + FetchSearchResultProto.SortValue.Builder collapseValueBuilder = FetchSearchResultProto.SortValue.newBuilder(); + try { + collapseValueBuilder = readSortValueForProtobuf(value, collapseValueBuilder); + } catch (IOException e) { + throw new OpenSearchException(e); + } + searchHitsBuilder.addCollapseValues(collapseValueBuilder.build()); + } + } + return searchHitsBuilder.build(); + } + + public static FetchSearchResultProto.SortValue.Builder readSortValueForProtobuf( + Object collapseValue, + FetchSearchResultProto.SortValue.Builder collapseValueBuilder + ) throws IOException { + Class type = collapseValue.getClass(); + if (type == String.class) { + collapseValueBuilder.setCollapseString((String) collapseValue); + } else if (type == Integer.class || type == Short.class) { + collapseValueBuilder.setCollapseInt((Integer) collapseValue); + } else if (type == Long.class) { + collapseValueBuilder.setCollapseLong((Long) collapseValue); + } else if (type == Float.class) { + collapseValueBuilder.setCollapseFloat((Float) collapseValue); + } else if (type == Double.class) { + collapseValueBuilder.setCollapseDouble((Double) collapseValue); + } else if (type == Byte.class) { + byte b = (Byte) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(new byte[] { b })); + } else if (type == Boolean.class) { + collapseValueBuilder.setCollapseBool((Boolean) collapseValue); + } else if (type == BytesRef.class) { + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(((BytesRef) collapseValue).bytes)); + } else if (type == BigInteger.class) { + BigInteger bigInt = (BigInteger) collapseValue; + collapseValueBuilder.setCollapseBytes(ByteString.copyFrom(bigInt.toByteArray())); + } else { + throw new IOException("Can't handle sort field value of type [" + type + "]"); + } + return collapseValueBuilder; + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchSortValuesProtobufDeserializer.java b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchSortValuesProtobufDeserializer.java new file mode 100644 index 0000000000000..63ac21a0e2fb9 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/protobuf/SearchSortValuesProtobufDeserializer.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.serializer.protobuf; + +import org.opensearch.search.SearchSortValues; +import org.opensearch.search.serializer.SearchSortValuesDeserializer; +import org.opensearch.server.proto.FetchSearchResultProto; + +import java.io.IOException; +import java.io.InputStream; + +/** + * Deserializer for {@link SearchSortValues} to/from protobuf. + */ +public class SearchSortValuesProtobufDeserializer implements SearchSortValuesDeserializer { + + @Override + public SearchSortValues createSearchSortValues(InputStream inputStream) throws IOException { + FetchSearchResultProto.SearchHit.SearchSortValues searchSortValues = FetchSearchResultProto.SearchHit.SearchSortValues.parseFrom( + inputStream + ); + Object[] formattedSortValues = new Object[searchSortValues.getFormattedSortValuesCount()]; + for (int i = 0; i < searchSortValues.getFormattedSortValuesCount(); i++) { + formattedSortValues[i] = SearchHitsProtobufDeserializer.readSortValueFromProtobuf(searchSortValues.getFormattedSortValues(i)); + } + Object[] rawSortValues = new Object[searchSortValues.getRawSortValuesCount()]; + for (int i = 0; i < searchSortValues.getRawSortValuesCount(); i++) { + rawSortValues[i] = SearchHitsProtobufDeserializer.readSortValueFromProtobuf(searchSortValues.getRawSortValues(i)); + } + return new SearchSortValues(formattedSortValues, rawSortValues); + } + +} diff --git a/server/src/main/java/org/opensearch/search/serializer/protobuf/package-info.java b/server/src/main/java/org/opensearch/search/serializer/protobuf/package-info.java new file mode 100644 index 0000000000000..6a1a564bf2664 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/serializer/protobuf/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Protobuf serializer package for search. */ +package org.opensearch.search.serializer.protobuf; diff --git a/server/src/main/proto/server/search/FetchSearchResultProto.proto b/server/src/main/proto/server/search/FetchSearchResultProto.proto new file mode 100644 index 0000000000000..4983cdbd4ef4a --- /dev/null +++ b/server/src/main/proto/server/search/FetchSearchResultProto.proto @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "google/protobuf/any.proto"; +import "server/search/QuerySearchResultProto.proto"; +import "server/search/ShardSearchRequestProto.proto"; + +option java_outer_classname = "FetchSearchResultProto"; + +message FetchSearchResult { + ShardSearchContextId contextId = 1; + optional SearchHits hits = 2; +} + +message SearchHits { + TotalHits totalHits = 1; + float maxScore = 2; + int32 size = 3; + repeated SearchHit hits = 4; + repeated SortField sortFields = 5; + optional string collapseField = 6; + repeated SortValue collapseValues = 7; +} + +message SearchHit { + int32 docId = 1; + float score = 2; + string id = 3; + optional NestedIdentity nestedIdentity = 4; + int64 version = 5; + int64 seqNo = 6; + int64 primaryTerm = 7; + bytes source = 8; + map documentFields = 9; + map metaFields = 10; + map highlightFields = 11; + SearchSortValues sortValues = 12; + repeated string matchedQueries = 13; + optional Explanation explanation = 14; + SearchShardTarget shard = 15; + optional string index = 16; + optional string clusterAlias = 17; + map innerHits = 18; + map matchedQueriesWithScores = 19; + + message NestedIdentity { + optional string field = 1; + optional int32 offset = 2; + optional NestedIdentity child = 3; + } + + message DocumentField { + string name = 1; + repeated DocumentFieldValue values = 2; + } + + message HighlightField { + string name = 1; + repeated string fragments = 2; + } + + message SearchSortValues { + repeated SortValue formattedSortValues = 1; + repeated SortValue rawSortValues = 2; + } + + message Explanation { + bool match = 1; + string description = 2; + repeated Explanation subExplanations = 3; + oneof explanationValue { + float value1 = 4; + double value2 = 5; + int64 value3 = 6; + } + } +} + +message SortField { + Type type = 1; + string field = 2; + + enum Type { + SCORE = 0; + DOC = 1; + STRING = 2; + INT = 3; + FLOAT = 4; + LONG = 5; + DOUBLE = 6; + CUSTOM = 7; + STRING_VAL = 8; + REWRITEABLE = 9; + } +} + +message SortValue { + optional string collapseString = 1; + optional int32 collapseInt = 2; + optional int64 collapseLong = 3; + optional float collapseFloat = 4; + optional double collapseDouble = 5; + optional bytes collapseBytes = 6; + optional bool collapseBool = 7; +} + +message DocumentFieldValue { + optional string valueString = 1; + optional int32 valueInt = 2; + optional int64 valueLong = 3; + optional float valueFloat = 4; + optional double valueDouble = 5; + optional bool valueBool = 6; + repeated bytes valueByteArray = 7; + repeated DocumentFieldValue valueArrayList = 8; + map valueMap = 9; + optional int64 valueDate = 10; + optional string valueZonedDate = 11; + optional int64 valueZonedTime = 12; + optional string valueText = 13; +} diff --git a/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto b/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto new file mode 100644 index 0000000000000..deac135c0e3d0 --- /dev/null +++ b/server/src/main/proto/server/search/QueryFetchSearchResultProto.proto @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/QuerySearchResultProto.proto"; +import "server/search/FetchSearchResultProto.proto"; + +option java_outer_classname = "QueryFetchSearchResultProto"; + +message QueryFetchSearchResult { + QuerySearchResult queryResult = 1; + FetchSearchResult fetchResult = 2; +} diff --git a/server/src/main/proto/server/search/QuerySearchResultProto.proto b/server/src/main/proto/server/search/QuerySearchResultProto.proto new file mode 100644 index 0000000000000..c09b2e02498e5 --- /dev/null +++ b/server/src/main/proto/server/search/QuerySearchResultProto.proto @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +import "server/search/ShardSearchRequestProto.proto"; + +option java_outer_classname = "QuerySearchResultProto"; + +message QuerySearchResult { + ShardSearchContextId contextId = 1; + optional int32 from = 2; + optional int32 size = 3; + optional TopDocsAndMaxScore topDocsAndMaxScore = 4; + optional bool hasScoreDocs = 5; + optional TotalHits totalHits = 6; + optional float maxScore = 7; + optional TopDocs topDocs = 8; + optional bool hasAggs = 9; + optional bool hasSuggest = 10; + optional bool searchTimedOut = 11; + optional bool terminatedEarly = 12; + optional bytes profileShardResults = 13; + optional int64 serviceTimeEWMA = 14; + optional int32 nodeQueueSize = 15; + SearchShardTarget searchShardTarget = 17; + ShardSearchRequest searchShardRequest = 18; + bool isNull = 19; + + message TopDocsAndMaxScore { + TopDocs topDocs = 1; + float maxScore = 2; + } + + message TopDocs { + TotalHits totalHits = 1; + repeated ScoreDoc scoreDocs = 2; + + message ScoreDoc { + int32 doc = 1; + float score = 2; + int32 shardIndex = 3; + } + } + + message RescoreDocIds { + map docIds = 1; + + message setInteger { + repeated int32 values = 1; + } + } + +} + +message SearchShardTarget { + string nodeId = 1; + ShardId shardId = 2; + optional string clusterAlias = 3; +} + +message TotalHits { + int64 value = 1; + Relation relation = 2; + + enum Relation { + EQUAL_TO = 0; + GREATER_THAN_OR_EQUAL_TO = 1; + } +} diff --git a/server/src/main/proto/server/search/ShardSearchRequestProto.proto b/server/src/main/proto/server/search/ShardSearchRequestProto.proto new file mode 100644 index 0000000000000..0705283761ecd --- /dev/null +++ b/server/src/main/proto/server/search/ShardSearchRequestProto.proto @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +syntax = "proto3"; +package org.opensearch.server.proto; + +option java_outer_classname = "ShardSearchRequestProto"; + +message ShardSearchRequest { + OriginalIndices originalIndices = 1; + ShardId shardId = 2; + int32 numberOfShards = 3; + SearchType searchType = 4; + bytes source = 5; + bool requestCache = 6; + AliasFilter aliasFilter = 7; + float indexBoost = 8; + bool allowPartialSearchResults = 9; + repeated string indexRoutings = 10; + string preference = 11; + Scroll scroll = 12; + int64 nowInMillis = 13; + optional string clusterAlias = 14; + optional ShardSearchContextId readerId = 15; + optional string timeValue = 16; + int64 inboundNetworkTime = 17; + int64 outboundNetworkTime = 18; + bool canReturnNullResponseIfMatchNoDocs = 19; + + enum SearchType { + QUERY_THEN_FETCH = 0; + DFS_QUERY_THEN_FETCH = 1; + } +} + +message ShardSearchContextId { + string sessionId = 1; + int64 id = 2; +} + +message ShardId { + int32 shardId = 1; + int32 hashCode = 2; + string indexName = 3; + string indexUUID = 4; +} + +message Scroll { + string keepAlive = 1; +} + +message OriginalIndices { + repeated string indices = 1; + IndicesOptions indicesOptions = 2; + + message IndicesOptions { + bool ignoreUnavailable = 1; + bool allowNoIndices = 2; + bool expandWildcardsOpen = 3; + bool expandWildcardsClosed = 4; + bool expandWildcardsHidden = 5; + bool allowAliasesToMultipleIndices = 6; + bool forbidClosedIndices = 7; + bool ignoreAliases = 8; + bool ignoreThrottled = 9; + } +} + +message AliasFilter { + repeated string aliases = 1; +} diff --git a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java index 41e4e1ae45a73..44b23e8bec2b3 100644 --- a/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java +++ b/server/src/test/java/org/opensearch/search/query/QuerySearchResultTests.java @@ -39,9 +39,11 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.action.OriginalIndicesTests; import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.SuppressForbidden; import org.opensearch.common.UUIDs; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.index.shard.ShardId; @@ -54,8 +56,13 @@ import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.suggest.SuggestTests; +import org.opensearch.server.proto.QuerySearchResultProto; import org.opensearch.test.OpenSearchTestCase; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; + import static java.util.Collections.emptyList; public class QuerySearchResultTests extends OpenSearchTestCase { @@ -120,6 +127,32 @@ public void testSerialization() throws Exception { assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly()); } + @SuppressForbidden(reason = "manipulates system properties for testing") + public void testBytesSerialization() throws Exception { + System.setProperty(FeatureFlags.PROTOBUF, "true"); + QuerySearchResult querySearchResult = createTestInstance(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + querySearchResult.writeTo(stream); + + InputStream inputStream = new ByteArrayInputStream(stream.toByteArray()); + QuerySearchResult deserialized = new QuerySearchResult(inputStream); + QuerySearchResultProto.QuerySearchResult querySearchResultProto = deserialized.response(); + assertNotNull(querySearchResultProto); + assertEquals(querySearchResult.getContextId().getId(), querySearchResultProto.getContextId().getId()); + assertEquals( + querySearchResult.getSearchShardTarget().getShardId().getIndex().getUUID(), + querySearchResultProto.getSearchShardTarget().getShardId().getIndexUUID() + ); + assertEquals(querySearchResult.topDocs().maxScore, querySearchResultProto.getTopDocsAndMaxScore().getMaxScore(), 0f); + assertEquals( + querySearchResult.topDocs().topDocs.totalHits.value, + querySearchResultProto.getTopDocsAndMaxScore().getTopDocs().getTotalHits().getValue() + ); + assertEquals(querySearchResult.from(), querySearchResultProto.getFrom()); + assertEquals(querySearchResult.size(), querySearchResultProto.getSize()); + System.setProperty(FeatureFlags.PROTOBUF, "false"); + } + public void testNullResponse() throws Exception { QuerySearchResult querySearchResult = QuerySearchResult.nullInstance(); QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry, QuerySearchResult::new, Version.CURRENT);