Skip to content

Commit

Permalink
Adding protos for search classes
Browse files Browse the repository at this point in the history
Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Apr 16, 2024
1 parent 07d447b commit f0a62c5
Show file tree
Hide file tree
Showing 27 changed files with 1,349 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ private boolean inspectable(ExecutableElement executable) {
*/
private boolean inspectable(Element element) {
final PackageElement pckg = processingEnv.getElementUtils().getPackageOf(element);
return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE);
return pckg.getQualifiedName().toString().startsWith(OPENSEARCH_PACKAGE)
&& !element.getEnclosingElement()
.getAnnotationMirrors()
.stream()
.anyMatch(
m -> m.getAnnotationType()
.toString() /* ClassSymbol.toString() returns class name */
.equalsIgnoreCase("javax.annotation.Generated")
);
}

/**
Expand Down
27 changes: 27 additions & 0 deletions server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,32 @@ tasks.named("dependencyLicenses").configure {
}
}

tasks.named("missingJavadoc").configure {
/*
* annotate_code in L210 does not add the Generated annotation to nested code generated using protobuf.
* TODO: Add support to missingJavadoc task to ignore all such nested classes.
* https://github.com/opensearch-project/OpenSearch/issues/11913
*/
dependsOn("generateProto")
javadocMissingIgnore = [
"org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.RescoreDocIds.setIntegerOrBuilder",
"org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.RescoreDocIdsOrBuilder",
"org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocs.ScoreDocOrBuilder",
"org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocsOrBuilder",
"org.opensearch.server.proto.QuerySearchResultProto.QuerySearchResult.TopDocsAndMaxScoreOrBuilder",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.SearchSortValuesOrBuilder",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.HighlightFieldOrBuilder",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.DocumentFieldOrBuilder",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.NestedIdentityOrBuilder",
"org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.MessageCase",
"org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.ResponseHandlersListOrBuilder",
"org.opensearch.server.proto.NodeToNodeMessageProto.NodeToNodeMessage.HeaderOrBuilder",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.Explanation.ExplanationValueCase",
"org.opensearch.server.proto.FetchSearchResultProto.SearchHit.ExplanationOrBuilder",
"org.opensearch.server.proto.ShardSearchRequestProto.OriginalIndices.IndicesOptionsOrBuilder",
]
}

tasks.named("filepermissions").configure {
mustRunAfter("generateProto")
}
Expand All @@ -364,6 +390,7 @@ tasks.named("licenseHeaders").configure {
excludes << 'org/opensearch/client/documentation/placeholder.txt'
// Ignore for protobuf generated code
excludes << 'org/opensearch/extensions/proto/*'
excludes << 'org/opensearch/server/proto/*'
}

tasks.test {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.document.serializer;

import com.google.protobuf.ByteString;
import org.opensearch.OpenSearchException;
import org.opensearch.common.document.DocumentField;
import org.opensearch.core.common.text.Text;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue;
import org.opensearch.server.proto.FetchSearchResultProto.DocumentFieldValue.Builder;

import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
* Serializer for {@link DocumentField} to/from protobuf.
*/
public class DocumentFieldProtobufSerializer implements DocumentFieldSerializer<InputStream> {

private FetchSearchResultProto.SearchHit.DocumentField documentField;

@Override
public DocumentField createDocumentField(InputStream inputStream) throws IOException {
documentField = FetchSearchResultProto.SearchHit.DocumentField.parseFrom(inputStream);
String name = documentField.getName();
List<Object> values = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentField.getValuesList()) {
values.add(readDocumentFieldValueFromProtobuf(value));
}
return new DocumentField(name, values);
}

private Object readDocumentFieldValueFromProtobuf(FetchSearchResultProto.DocumentFieldValue documentFieldValue) throws IOException {
if (documentFieldValue.hasValueString()) {
return documentFieldValue.getValueString();
} else if (documentFieldValue.hasValueInt()) {
return documentFieldValue.getValueInt();
} else if (documentFieldValue.hasValueLong()) {
return documentFieldValue.getValueLong();
} else if (documentFieldValue.hasValueFloat()) {
return documentFieldValue.getValueFloat();
} else if (documentFieldValue.hasValueDouble()) {
return documentFieldValue.getValueDouble();
} else if (documentFieldValue.hasValueBool()) {
return documentFieldValue.getValueBool();
} else if (documentFieldValue.getValueByteArrayList().size() > 0) {
return documentFieldValue.getValueByteArrayList().toArray();
} else if (documentFieldValue.getValueArrayListList().size() > 0) {
List<Object> list = new ArrayList<>();
for (FetchSearchResultProto.DocumentFieldValue value : documentFieldValue.getValueArrayListList()) {
list.add(readDocumentFieldValueFromProtobuf(value));
}
return list;
} else if (documentFieldValue.getValueMapMap().size() > 0) {
Map<String, Object> map = Map.of();
for (Map.Entry<String, FetchSearchResultProto.DocumentFieldValue> entrySet : documentFieldValue.getValueMapMap().entrySet()) {
map.put(entrySet.getKey(), readDocumentFieldValueFromProtobuf(entrySet.getValue()));
}
return map;
} else if (documentFieldValue.hasValueDate()) {
return new Date(documentFieldValue.getValueDate());
} else if (documentFieldValue.hasValueZonedDate() && documentFieldValue.hasValueZonedTime()) {
return ZonedDateTime.ofInstant(
Instant.ofEpochMilli(documentFieldValue.getValueZonedTime()),
ZoneId.of(documentFieldValue.getValueZonedDate())
);
} else if (documentFieldValue.hasValueText()) {
return new Text(documentFieldValue.getValueText());
} else {
throw new IOException("Can't read generic value of type [" + documentFieldValue + "]");
}
}

public static DocumentFieldValue.Builder convertDocumentFieldValueToProto(Object value, Builder valueBuilder) {
if (value == null) {
// null is not allowed in protobuf, so we use a special string to represent null
return valueBuilder.setValueString("null");
}
Class type = value.getClass();
if (type == String.class) {
valueBuilder.setValueString((String) value);
} else if (type == Integer.class) {
valueBuilder.setValueInt((Integer) value);
} else if (type == Long.class) {
valueBuilder.setValueLong((Long) value);
} else if (type == Float.class) {
valueBuilder.setValueFloat((Float) value);
} else if (type == Double.class) {
valueBuilder.setValueDouble((Double) value);
} else if (type == Boolean.class) {
valueBuilder.setValueBool((Boolean) value);
} else if (type == byte[].class) {
valueBuilder.addValueByteArray(ByteString.copyFrom((byte[]) value));
} else if (type == List.class) {
List<Object> list = (List<Object>) value;
for (Object listValue : list) {
valueBuilder.addValueArrayList(convertDocumentFieldValueToProto(listValue, valueBuilder));
}
} else if (type == Map.class || type == HashMap.class || type == LinkedHashMap.class) {
Map<String, Object> map = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : map.entrySet()) {
valueBuilder.putValueMap(entry.getKey(), convertDocumentFieldValueToProto(entry.getValue(), valueBuilder).build());
}
} else if (type == Date.class) {
valueBuilder.setValueDate(((Date) value).getTime());
} else if (type == ZonedDateTime.class) {
valueBuilder.setValueZonedDate(((ZonedDateTime) value).getZone().getId());
valueBuilder.setValueZonedTime(((ZonedDateTime) value).toInstant().toEpochMilli());
} else if (type == Text.class) {
valueBuilder.setValueText(((Text) value).string());
} else {
throw new OpenSearchException("Can't convert generic value of type [" + type + "] to protobuf");
}
return valueBuilder;
}

public static FetchSearchResultProto.SearchHit.DocumentField convertDocumentFieldToProto(DocumentField documentField) {
FetchSearchResultProto.SearchHit.DocumentField.Builder builder = FetchSearchResultProto.SearchHit.DocumentField.newBuilder();
builder.setName(documentField.getName());
for (Object value : documentField.getValues()) {
FetchSearchResultProto.DocumentFieldValue.Builder valueBuilder = FetchSearchResultProto.DocumentFieldValue.newBuilder();
builder.addValues(DocumentFieldProtobufSerializer.convertDocumentFieldValueToProto(value, valueBuilder));
}
return builder.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.document.serializer;

import org.opensearch.common.document.DocumentField;

import java.io.IOException;

/**
* Serializer for {@link DocumentField} which can be implemented for different types of serialization.
*/
public interface DocumentFieldSerializer<T> {

DocumentField createDocumentField(T inputStream) throws IOException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/** Serializer package for documents. */
package org.opensearch.common.document.serializer;
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public class FeatureFlags {
*/
public static final String PLUGGABLE_CACHE = "opensearch.experimental.feature.pluggable.caching.enabled";

/**
* Gates the functionality of integrating protobuf within search API and node-to-node communication.
*/
public static final String PROTOBUF = "opensearch.experimental.feature.search_with_protobuf.enabled";

public static final Setting<Boolean> REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING = Setting.boolSetting(
REMOTE_STORE_MIGRATION_EXPERIMENTAL,
false,
Expand All @@ -93,14 +98,17 @@ public class FeatureFlags {

public static final Setting<Boolean> PLUGGABLE_CACHE_SETTING = Setting.boolSetting(PLUGGABLE_CACHE, false, Property.NodeScope);

public static final Setting<Boolean> PROTOBUF_SETTING = Setting.boolSetting(PROTOBUF, false, Property.NodeScope, Property.Dynamic);

private static final List<Setting<Boolean>> ALL_FEATURE_FLAG_SETTINGS = List.of(
REMOTE_STORE_MIGRATION_EXPERIMENTAL_SETTING,
EXTENSIONS_SETTING,
IDENTITY_SETTING,
TELEMETRY_SETTING,
DATETIME_FORMATTER_CACHING_SETTING,
WRITEABLE_REMOTE_INDEX_SETTING,
PLUGGABLE_CACHE_SETTING
PLUGGABLE_CACHE_SETTING,
PROTOBUF_SETTING
);
/**
* Should store the settings from opensearch.yml.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public class SearchSortValues implements ToXContentFragment, Writeable {
this.rawSortValues = EMPTY_ARRAY;
}

public SearchSortValues(Object[] sortValues, Object[] rawSortValues) {
this.formattedSortValues = Objects.requireNonNull(sortValues, "sort values must not be empty");
this.rawSortValues = rawSortValues;
}

public SearchSortValues(Object[] rawSortValues, DocValueFormat[] sortValueFormats) {
Objects.requireNonNull(rawSortValues);
Objects.requireNonNull(sortValueFormats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package org.opensearch.search.fetch;

import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.SearchHit;
Expand All @@ -41,8 +42,13 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.serializer.SearchHitsProtobufSerializer;
import org.opensearch.server.proto.FetchSearchResultProto;
import org.opensearch.server.proto.ShardSearchRequestProto;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;

/**
* Result from a fetch
Expand All @@ -56,6 +62,8 @@ public final class FetchSearchResult extends SearchPhaseResult {
// client side counter
private transient int counter;

private FetchSearchResultProto.FetchSearchResult fetchSearchResultProto;

public FetchSearchResult() {}

public FetchSearchResult(StreamInput in) throws IOException {
Expand All @@ -64,9 +72,24 @@ public FetchSearchResult(StreamInput in) throws IOException {
hits = new SearchHits(in);
}

public FetchSearchResult(InputStream in) throws IOException {
this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.parseFrom(in);
contextId = new ShardSearchContextId(
this.fetchSearchResultProto.getContextId().getSessionId(),
this.fetchSearchResultProto.getContextId().getId()
);
SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer();
hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray()));
}

public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) {
this.contextId = id;
setSearchShardTarget(shardTarget);
this.fetchSearchResultProto = FetchSearchResultProto.FetchSearchResult.newBuilder()
.setContextId(
ShardSearchRequestProto.ShardSearchContextId.newBuilder().setSessionId(id.getSessionId()).setId(id.getId()).build()
)
.build();
}

@Override
Expand All @@ -82,6 +105,11 @@ public FetchSearchResult fetchResult() {
public void hits(SearchHits hits) {
assert assertNoSearchTarget(hits);
this.hits = hits;
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) {
this.fetchSearchResultProto = this.fetchSearchResultProto.toBuilder()
.setHits(SearchHitsProtobufSerializer.convertHitsToProto(hits))
.build();
}
}

private boolean assertNoSearchTarget(SearchHits hits) {
Expand All @@ -92,6 +120,16 @@ private boolean assertNoSearchTarget(SearchHits hits) {
}

public SearchHits hits() {
if (FeatureFlags.isEnabled(FeatureFlags.PROTOBUF_SETTING) && this.fetchSearchResultProto != null) {
SearchHits hits;
try {
SearchHitsProtobufSerializer protobufSerializer = new SearchHitsProtobufSerializer();
hits = protobufSerializer.createSearchHits(new ByteArrayInputStream(this.fetchSearchResultProto.getHits().toByteArray()));
return hits;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return hits;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.server.proto.QueryFetchSearchResultProto;

import java.io.IOException;
import java.io.InputStream;

/**
* Query fetch result
Expand All @@ -51,12 +53,20 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
private final QuerySearchResult queryResult;
private final FetchSearchResult fetchResult;

private QueryFetchSearchResultProto.QueryFetchSearchResult queryFetchSearchResultProto;

public QueryFetchSearchResult(StreamInput in) throws IOException {
super(in);
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
}

public QueryFetchSearchResult(InputStream in) throws IOException {
this.queryFetchSearchResultProto = QueryFetchSearchResultProto.QueryFetchSearchResult.parseFrom(in);
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
}

public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
this.queryResult = queryResult;
this.fetchResult = fetchResult;
Expand Down
Loading

0 comments on commit f0a62c5

Please sign in to comment.