diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index dfecf4f462c4d..31be93546ad73 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -38,6 +38,7 @@ import org.opensearch.search.fetch.ShardFetchSearchRequest; import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.ResourceLimitGroupTask; import org.opensearch.tasks.SearchBackpressureTask; import java.util.Map; @@ -50,9 +51,10 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class SearchShardTask extends CancellableTask implements SearchBackpressureTask { +public class SearchShardTask extends CancellableTask implements SearchBackpressureTask, ResourceLimitGroupTask { // generating metadata in a lazy way since source can be quite big private final MemoizedSupplier metadataSupplier; + private String resourceLimitGroupId; public SearchShardTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { this(id, type, action, description, parentTaskId, headers, () -> ""); @@ -84,4 +86,12 @@ public boolean supportsResourceTracking() { public boolean shouldCancelChildrenOnCancellation() { return false; } + + public String getResourceLimitGroupName() { + return resourceLimitGroupId; + } + + public void setResourceLimitGroupName(String resourceLimitGroupId) { + this.resourceLimitGroupId = resourceLimitGroupId; + } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index d3c1043c50cce..e0e358a23e95b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -36,6 +36,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.tasks.TaskId; import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.ResourceLimitGroupTask; import org.opensearch.tasks.SearchBackpressureTask; import java.util.Map; @@ -49,10 +50,11 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class SearchTask extends CancellableTask implements SearchBackpressureTask { +public class SearchTask extends CancellableTask implements SearchBackpressureTask, ResourceLimitGroupTask { // generating description in a lazy way since source can be quite big private final Supplier descriptionSupplier; private SearchProgressListener progressListener = SearchProgressListener.NOOP; + private String resourceLimitGroupId; public SearchTask( long id, @@ -106,4 +108,12 @@ public final SearchProgressListener getProgressListener() { public boolean shouldCancelChildrenOnCancellation() { return true; } + + public String getResourceLimitGroupName() { + return resourceLimitGroupId; + } + + public void setResourceLimitGroupName(String resourceLimitGroupId) { + this.resourceLimitGroupId = resourceLimitGroupId; + } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 143b01af3f62f..e621b01467859 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -72,6 +72,7 @@ import org.opensearch.core.indices.breaker.CircuitBreakerService; import org.opensearch.core.tasks.TaskId; import org.opensearch.index.query.Rewriteable; +import org.opensearch.search.MultiTenantLabel; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.SearchService; import org.opensearch.search.SearchShardTarget; @@ -166,6 +167,7 @@ public class TransportSearchAction extends HandledTransportAction multiTenantLabels = searchRequest.source().multiTenantLabels(); + tenant = (String) multiTenantLabels.get(MultiTenantLabel.TENANT.name()); + } + task.setResourceLimitGroupName(tenant); + searchAsyncActionProvider.asyncSearchAction( task, searchRequest, diff --git a/server/src/main/java/org/opensearch/search/MultiTenantLabel.java b/server/src/main/java/org/opensearch/search/MultiTenantLabel.java new file mode 100644 index 0000000000000..9b69f3f00985e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/MultiTenantLabel.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +import java.io.IOException; + +/** + * Enum to hold all multitenant labels in workloads + */ +public enum MultiTenantLabel implements Writeable { + // This label is basically used to define tenancy for multiple features e,g; Query Sandboxing, Query Insights + TENANT("tenant"); + + private final String value; + + MultiTenantLabel(String name) { + this.value = name; + } + + public String getValue() { + return value; + } + + public static MultiTenantLabel fromName(String name) { + for (MultiTenantLabel label : values()) { + if (label.getValue().equalsIgnoreCase(name)) { + return label; + } + } + throw new IllegalArgumentException("Illegal name + " + name); + } + + public static MultiTenantLabel fromName(StreamInput in) throws IOException { + return fromName(in.readString()); + } + + /** + * Write this into the {@linkplain StreamOutput}. + * + * @param out + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } +} diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 744d3a19f1593..8b3dc068a4d05 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -158,6 +158,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.LongSupplier; +import static org.opensearch.action.search.TransportSearchAction.NOT_PROVIDED; import static org.opensearch.common.unit.TimeValue.timeValueHours; import static org.opensearch.common.unit.TimeValue.timeValueMillis; import static org.opensearch.common.unit.TimeValue.timeValueMinutes; @@ -568,6 +569,7 @@ public void executeQueryPhase( assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; final IndexShard shard = getShard(request); + setTenantInTask(task, request); rewriteAndFetchShardRequest(shard, request, new ActionListener() { @Override public void onResponse(ShardSearchRequest orig) { @@ -598,6 +600,14 @@ public void onFailure(Exception exc) { }); } + private void setTenantInTask(SearchShardTask task, ShardSearchRequest request) { + String tenant = NOT_PROVIDED; + if (request.source() != null && request.source().multiTenantLabels() != null) { + tenant = (String) request.source().multiTenantLabels().get(MultiTenantLabel.TENANT.name()); + } + task.setResourceLimitGroupName(tenant); + } + private IndexShard getShard(ShardSearchRequest request) { if (request.readerId() != null) { return findReaderContext(request.readerId(), request).indexShard(); @@ -676,6 +686,7 @@ public void executeQueryPhase( } runAsync(getExecutor(readerContext.indexShard()), () -> { final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null); + setTenantInTask(task, shardSearchRequest); try ( SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false); SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext) @@ -778,6 +789,7 @@ public void executeFetchPhase( public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener listener) { final ReaderContext readerContext = findReaderContext(request.contextId(), request); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); + setTenantInTask(task, shardSearchRequest); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); runAsync(getExecutor(readerContext.indexShard()), () -> { try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) { diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 07248a0719c3a..e43fc9590824a 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -79,6 +79,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -136,6 +137,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R public static final ParseField SLICE = new ParseField("slice"); public static final ParseField POINT_IN_TIME = new ParseField("pit"); public static final ParseField SEARCH_PIPELINE = new ParseField("search_pipeline"); + public static final ParseField MULTI_TENANT_LABELS = new ParseField("multitenant_attrs"); public static SearchSourceBuilder fromXContent(XContentParser parser) throws IOException { return fromXContent(parser, true); @@ -223,6 +225,7 @@ public static HighlightBuilder highlight() { private PointInTimeBuilder pointInTimeBuilder = null; private Map searchPipelineSource = null; + private Map multiTenantLabels = new HashMap<>(); /** * Constructs a new search source builder. @@ -297,6 +300,10 @@ public SearchSourceBuilder(StreamInput in) throws IOException { derivedFields = in.readList(DerivedField::new); } } + + if (in.getVersion().onOrAfter(Version.V_2_14_0)) { + multiTenantLabels = in.readMap(); + } } @Override @@ -377,6 +384,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeList(derivedFields); } } + + if (out.getVersion().onOrAfter(Version.V_2_14_0)) { + out.writeMap(multiTenantLabels); + } } /** @@ -1088,6 +1099,14 @@ public SearchSourceBuilder searchPipelineSource(Map searchPipeli return this; } + /** + * + * @return {@code } pairs + */ + public Map multiTenantLabels() { + return multiTenantLabels; + } + /** * Rewrites this search source builder into its primitive form. e.g. by * rewriting the QueryBuilder. If the builder did not change the identity @@ -1334,6 +1353,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th searchPipelineSource = parser.mapOrdered(); } else if (DERIVED_FIELDS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { derivedFieldsObject = parser.map(); + } else if (MULTI_TENANT_LABELS.match(currentFieldName, parser.getDeprecationHandler())) { + multiTenantLabels = parser.map(); } else { throw new ParsingException( parser.getTokenLocation(), diff --git a/server/src/main/java/org/opensearch/tasks/ResourceLimitGroupTask.java b/server/src/main/java/org/opensearch/tasks/ResourceLimitGroupTask.java new file mode 100644 index 0000000000000..4f16c24dfee64 --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/ResourceLimitGroupTask.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +/** + * Tasks which should be grouped + */ +public interface ResourceLimitGroupTask { + void setResourceLimitGroupName(String name); + + String getResourceLimitGroupName(); +} diff --git a/server/src/test/java/org/opensearch/search/MultiTenantLabelTests.java b/server/src/test/java/org/opensearch/search/MultiTenantLabelTests.java new file mode 100644 index 0000000000000..6d215a8c775fe --- /dev/null +++ b/server/src/test/java/org/opensearch/search/MultiTenantLabelTests.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +public class MultiTenantLabelTests extends AbstractSearchTestCase { + + public void testValidMultiTenantLabel() { + MultiTenantLabel label = MultiTenantLabel.fromName("tenant"); + assertEquals(label.getValue(), "tenant"); + } + + public void testInvalidMultiTenantLabel() { + assertThrows(IllegalArgumentException.class, () -> MultiTenantLabel.fromName("foo")); + } + + public void testValidMultiTenantLabelWithStreamInput() throws IOException { + StreamInput streamInput = mock(StreamInput.class); + doReturn("tenant").when(streamInput).readString(); + + MultiTenantLabel label = MultiTenantLabel.fromName(streamInput); + assertEquals(label.getValue(), "tenant"); + } + + public void testInvalidMultiTenantLabelWithStreamInput() throws IOException { + StreamInput streamInput = mock(StreamInput.class); + doReturn("foo").when(streamInput).readString(); + + assertThrows(IllegalArgumentException.class, () -> MultiTenantLabel.fromName(streamInput)); + } + +}