diff --git a/common/build.gradle b/common/build.gradle index 0561468d1f..507ad6c0d6 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -63,4 +63,4 @@ configurations.all { resolutionStrategy.force "org.apache.httpcomponents:httpcore:4.4.13" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" -} \ No newline at end of file +} diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f714a8366b..3d9740d84c 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -306,8 +306,9 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier, SparkExecutionEngineConfig sparkExecutionEngineConfig) { + StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); EMRServerlessClient emrServerlessClient = createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); @@ -319,8 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 7cba2757cc..18ae47c2b9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -69,13 +69,14 @@ public CreateAsyncQueryResponse createAsyncQuery( createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( new AsyncQueryJobMetadata( + dispatchQueryResponse.getQueryId(), sparkExecutionEngineConfig.getApplicationId(), dispatchQueryResponse.getJobId(), dispatchQueryResponse.isDropIndexQuery(), dispatchQueryResponse.getResultIndex(), dispatchQueryResponse.getSessionId())); return new CreateAsyncQueryResponse( - dispatchQueryResponse.getJobId(), dispatchQueryResponse.getSessionId()); + dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index a95a6ffe45..6de8c35f03 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,166 +7,31 @@ package org.opensearch.sql.spark.asyncquery; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; + import java.util.Optional; -import org.apache.commons.io.IOUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ +@RequiredArgsConstructor public class OpensearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { - public static final String JOB_METADATA_INDEX = ".ql-job-metadata"; - private static final String JOB_METADATA_INDEX_MAPPING_FILE_NAME = - "job-metadata-index-mapping.yml"; - private static final String JOB_METADATA_INDEX_SETTINGS_FILE_NAME = - "job-metadata-index-settings.yml"; - private static final Logger LOG = LogManager.getLogger(); - private final Client client; - private final ClusterService clusterService; - - /** - * This class implements JobMetadataStorageService interface using OpenSearch as underlying - * storage. - * - * @param client opensearch NodeClient. - * @param clusterService ClusterService. - */ - public OpensearchAsyncQueryJobMetadataStorageService( - Client client, ClusterService clusterService) { - this.client = client; - this.clusterService = clusterService; - } + private final StateStore stateStore; @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - } - IndexRequest indexRequest = new IndexRequest(JOB_METADATA_INDEX); - indexRequest.id(asyncQueryJobMetadata.getJobId()); - indexRequest.opType(DocWriteRequest.OpType.CREATE); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - ActionFuture indexResponseActionFuture; - IndexResponse indexResponse; - try (ThreadContext.StoredContext storedContext = - client.threadPool().getThreadContext().stashContext()) { - indexRequest.source(AsyncQueryJobMetadata.convertToXContent(asyncQueryJobMetadata)); - indexResponseActionFuture = client.index(indexRequest); - indexResponse = indexResponseActionFuture.actionGet(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("JobMetadata : {} successfully created", asyncQueryJobMetadata.getJobId()); - } else { - throw new RuntimeException( - "Saving job metadata information failed with result : " - + indexResponse.getResult().getLowercase()); - } + AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); + createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); } @Override - public Optional getJobMetadata(String jobId) { - if (!this.clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) { - createJobMetadataIndex(); - return Optional.empty(); - } - return searchInJobMetadataIndex(QueryBuilders.termQuery("jobId.keyword", jobId)).stream() - .findFirst(); - } - - private void createJobMetadataIndex() { - try { - InputStream mappingFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_MAPPING_FILE_NAME); - InputStream settingsFileStream = - OpensearchAsyncQueryJobMetadataStorageService.class - .getClassLoader() - .getResourceAsStream(JOB_METADATA_INDEX_SETTINGS_FILE_NAME); - CreateIndexRequest createIndexRequest = new CreateIndexRequest(JOB_METADATA_INDEX); - createIndexRequest - .mapping(IOUtils.toString(mappingFileStream, StandardCharsets.UTF_8), XContentType.YAML) - .settings( - IOUtils.toString(settingsFileStream, StandardCharsets.UTF_8), XContentType.YAML); - ActionFuture createIndexResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - createIndexResponseActionFuture = client.admin().indices().create(createIndexRequest); - } - CreateIndexResponse createIndexResponse = createIndexResponseActionFuture.actionGet(); - if (createIndexResponse.isAcknowledged()) { - LOG.info("Index: {} creation Acknowledged", JOB_METADATA_INDEX); - } else { - throw new RuntimeException("Index creation is not acknowledged."); - } - } catch (Throwable e) { - throw new RuntimeException( - "Internal server error while creating" - + JOB_METADATA_INDEX - + " index:: " - + e.getMessage()); - } - } - - private List searchInJobMetadataIndex(QueryBuilder query) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(JOB_METADATA_INDEX); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(query); - searchSourceBuilder.size(1); - searchRequest.source(searchSourceBuilder); - // https://github.com/opensearch-project/sql/issues/1801. - searchRequest.preference("_primary_first"); - ActionFuture searchResponseActionFuture; - try (ThreadContext.StoredContext ignored = - client.threadPool().getThreadContext().stashContext()) { - searchResponseActionFuture = client.search(searchRequest); - } - SearchResponse searchResponse = searchResponseActionFuture.actionGet(); - if (searchResponse.status().getStatus() != 200) { - throw new RuntimeException( - "Fetching job metadata information failed with status : " + searchResponse.status()); - } else { - List list = new ArrayList<>(); - for (SearchHit searchHit : searchResponse.getHits().getHits()) { - String sourceAsString = searchHit.getSourceAsString(); - AsyncQueryJobMetadata asyncQueryJobMetadata; - try { - asyncQueryJobMetadata = AsyncQueryJobMetadata.toJobMetadata(sourceAsString); - } catch (IOException e) { - throw new RuntimeException(e); - } - list.add(asyncQueryJobMetadata); - } - return list; - } + public Optional getJobMetadata(String qid) { + AsyncQueryId queryId = new AsyncQueryId(qid); + return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) + .apply(queryId.docId()); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java new file mode 100644 index 0000000000..b99ebe0e8c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryId.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.asyncquery.model; + +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + +import lombok.Data; + +/** Async query id. */ +@Data +public class AsyncQueryId { + private final String id; + + public static AsyncQueryId newAsyncQueryId(String datasourceName) { + return new AsyncQueryId(encode(datasourceName)); + } + + public String getDataSourceName() { + return decode(id); + } + + /** OpenSearch DocId. */ + public String docId() { + return "qid" + id; + } + + @Override + public String toString() { + return "asyncQueryId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index b80fefa173..3c59403661 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -8,37 +8,83 @@ package org.opensearch.sql.spark.asyncquery.model; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.sql.spark.execution.statement.StatementModel.QUERY_ID; import com.google.gson.Gson; import java.io.IOException; -import lombok.AllArgsConstructor; import lombok.Data; import lombok.EqualsAndHashCode; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; +import lombok.SneakyThrows; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** This class models all the metadata required for a job. */ @Data -@AllArgsConstructor -@EqualsAndHashCode -public class AsyncQueryJobMetadata { - private String applicationId; - private String jobId; - private boolean isDropIndexQuery; - private String resultIndex; +@EqualsAndHashCode(callSuper = false) +public class AsyncQueryJobMetadata extends StateModel { + public static final String TYPE_JOBMETA = "jobmeta"; + + private final AsyncQueryId queryId; + private final String applicationId; + private final String jobId; + private final boolean isDropIndexQuery; + private final String resultIndex; // optional sessionId. - private String sessionId; + private final String sessionId; + + @EqualsAndHashCode.Exclude private final long seqNo; + @EqualsAndHashCode.Exclude private final long primaryTerm; - public AsyncQueryJobMetadata(String applicationId, String jobId, String resultIndex) { + public AsyncQueryJobMetadata( + AsyncQueryId queryId, String applicationId, String jobId, String resultIndex) { + this( + queryId, + applicationId, + jobId, + false, + resultIndex, + null, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId) { + this( + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + SequenceNumbers.UNASSIGNED_SEQ_NO, + SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + + public AsyncQueryJobMetadata( + AsyncQueryId queryId, + String applicationId, + String jobId, + boolean isDropIndexQuery, + String resultIndex, + String sessionId, + long seqNo, + long primaryTerm) { + this.queryId = queryId; this.applicationId = applicationId; this.jobId = jobId; - this.isDropIndexQuery = false; + this.isDropIndexQuery = isDropIndexQuery; this.resultIndex = resultIndex; - this.sessionId = null; + this.sessionId = sessionId; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; } @Override @@ -49,39 +95,36 @@ public String toString() { /** * Converts JobMetadata to XContentBuilder. * - * @param metadata metadata. * @return XContentBuilder {@link XContentBuilder} * @throws Exception Exception. */ - public static XContentBuilder convertToXContent(AsyncQueryJobMetadata metadata) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder.startObject(); - builder.field("jobId", metadata.getJobId()); - builder.field("applicationId", metadata.getApplicationId()); - builder.field("isDropIndexQuery", metadata.isDropIndexQuery()); - builder.field("resultIndex", metadata.getResultIndex()); - builder.field("sessionId", metadata.getSessionId()); - builder.endObject(); + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(QUERY_ID, queryId.getId()) + .field("type", TYPE_JOBMETA) + .field("jobId", jobId) + .field("applicationId", applicationId) + .field("isDropIndexQuery", isDropIndexQuery) + .field("resultIndex", resultIndex) + .field("sessionId", sessionId) + .endObject(); return builder; } - /** - * Converts json string to DataSourceMetadata. - * - * @param json jsonstring. - * @return jobmetadata {@link AsyncQueryJobMetadata} - * @throws java.io.IOException IOException. - */ - public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOException { - try (XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - json)) { - return toJobMetadata(parser); - } + /** copy builder. update seqNo and primaryTerm */ + public static AsyncQueryJobMetadata copy( + AsyncQueryJobMetadata copy, long seqNo, long primaryTerm) { + return new AsyncQueryJobMetadata( + copy.getQueryId(), + copy.getApplicationId(), + copy.getJobId(), + copy.isDropIndexQuery(), + copy.getResultIndex(), + copy.getSessionId(), + seqNo, + primaryTerm); } /** @@ -91,17 +134,23 @@ public static AsyncQueryJobMetadata toJobMetadata(String json) throws IOExceptio * @return JobMetadata {@link AsyncQueryJobMetadata} * @throws IOException IOException. */ - public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws IOException { + @SneakyThrows + public static AsyncQueryJobMetadata fromXContent( + XContentParser parser, long seqNo, long primaryTerm) { + AsyncQueryId queryId = null; String jobId = null; String applicationId = null; boolean isDropIndexQuery = false; String resultIndex = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { + case QUERY_ID: + queryId = new AsyncQueryId(parser.textOrNull()); + break; case "jobId": jobId = parser.textOrNull(); break; @@ -117,6 +166,8 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws case "sessionId": sessionId = parser.textOrNull(); break; + case "type": + break; default: throw new IllegalArgumentException("Unknown field: " + fieldName); } @@ -125,6 +176,18 @@ public static AsyncQueryJobMetadata toJobMetadata(XContentParser parser) throws throw new IllegalArgumentException("jobId and applicationId are required fields."); } return new AsyncQueryJobMetadata( - applicationId, jobId, isDropIndexQuery, resultIndex, sessionId); + queryId, + applicationId, + jobId, + isDropIndexQuery, + resultIndex, + sessionId, + seqNo, + primaryTerm); + } + + @Override + public String getId() { + return queryId.docId(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java new file mode 100644 index 0000000000..77a0e1cd09 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -0,0 +1,49 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.JobRunState; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; + +/** Process async query request. */ +public abstract class AsyncQueryHandler { + + public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { + if (asyncQueryJobMetadata.isDropIndexQuery()) { + return SparkQueryDispatcher.DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()) + .result(); + } + + JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata); + if (result.has(DATA_FIELD)) { + JSONObject items = result.getJSONObject(DATA_FIELD); + + // If items have STATUS_FIELD, use it; otherwise, mark failed + String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); + result.put(STATUS_FIELD, status); + + // If items have ERROR_FIELD, use it; otherwise, set empty string + String error = items.optString(ERROR_FIELD, ""); + result.put(ERROR_FIELD, error); + return result; + } else { + return getResponseFromExecutor(asyncQueryJobMetadata); + } + } + + protected abstract JSONObject getResponseFromResultIndex( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + protected abstract JSONObject getResponseFromExecutor( + AsyncQueryJobMetadata asyncQueryJobMetadata); + + abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java new file mode 100644 index 0000000000..8a582278e1 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class BatchQueryHandler extends AsyncQueryHandler { + private final EMRServerlessClient emrServerlessClient; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + // either empty json when the result is not available or data with status + // Fetch from Result Index + return jobExecutionResponseReader.getResultFromOpensearchIndex( + asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + // make call to EMR Serverless when related result index documents are not available + GetJobRunResult getJobRunResult = + emrServerlessClient.getJobRunResult( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + String jobState = getJobRunResult.getJobRun().getState(); + result.put(STATUS_FIELD, jobState); + result.put(ERROR_FIELD, ""); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + emrServerlessClient.cancelJobRun( + asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); + return asyncQueryJobMetadata.getQueryId().getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java new file mode 100644 index 0000000000..24ea1528c8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.json.JSONObject; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +@RequiredArgsConstructor +public class InteractiveQueryHandler extends AsyncQueryHandler { + private final SessionManager sessionManager; + private final JobExecutionResponseReader jobExecutionResponseReader; + + @Override + protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + return jobExecutionResponseReader.getResultWithQueryId( + queryId, asyncQueryJobMetadata.getResultIndex()); + } + + @Override + protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) { + JSONObject result = new JSONObject(); + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); + StatementState statementState = statement.getStatementState(); + result.put(STATUS_FIELD, statementState.getState()); + result.put(ERROR_FIELD, ""); + return result; + } + + @Override + public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { + String queryId = asyncQueryJobMetadata.getQueryId().getId(); + getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel(); + return queryId; + } + + private Statement getStatementByQueryId(String sid, String qid) { + SessionId sessionId = new SessionId(sid); + Optional session = sessionManager.getSession(sessionId); + if (session.isPresent()) { + // todo, statementId == jobId if statement running in session. + StatementId statementId = new StatementId(qid); + Optional statement = session.get().get(statementId); + if (statement.isPresent()) { + return statement.get(); + } else { + throw new IllegalArgumentException("no statement found. " + statementId); + } + } else { + throw new IllegalArgumentException("no session found. " + sessionId); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 2bd1ae67b9..882f2663d9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -10,8 +10,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; -import com.amazonaws.services.emrserverless.model.CancelJobRunResult; -import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRunState; import java.nio.charset.StandardCharsets; import java.util.Base64; @@ -33,6 +31,7 @@ import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; @@ -46,9 +45,6 @@ import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statement.QueryRequest; -import org.opensearch.sql.spark.execution.statement.Statement; -import org.opensearch.sql.spark.execution.statement.StatementId; -import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @@ -92,97 +88,22 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { - // todo. refactor query process logic in plugin. - if (asyncQueryJobMetadata.isDropIndexQuery()) { - return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result(); - } - - JSONObject result; - if (asyncQueryJobMetadata.getSessionId() == null) { - // either empty json when the result is not available or data with status - // Fetch from Result Index - result = - jobExecutionResponseReader.getResultFromOpensearchIndex( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + if (asyncQueryJobMetadata.getSessionId() != null) { + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } else { - // when session enabled, jobId in asyncQueryJobMetadata is actually queryId. - result = - jobExecutionResponseReader.getResultWithQueryId( - asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex()); + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .getQueryResponse(asyncQueryJobMetadata); } - // if result index document has a status, we are gonna use the status directly; otherwise, we - // will use emr-s job status. - // That a job is successful does not mean there is no error in execution. For example, even if - // result - // index mapping is incorrect, we still write query result and let the job finish. - // That a job is running does not mean the status is running. For example, index/streaming Query - // is a - // long-running job which runs forever. But we need to return success from the result index - // immediately. - if (result.has(DATA_FIELD)) { - JSONObject items = result.getJSONObject(DATA_FIELD); - - // If items have STATUS_FIELD, use it; otherwise, mark failed - String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString()); - result.put(STATUS_FIELD, status); - - // If items have ERROR_FIELD, use it; otherwise, set empty string - String error = items.optString(ERROR_FIELD, ""); - result.put(ERROR_FIELD, error); - } else { - if (asyncQueryJobMetadata.getSessionId() != null) { - SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); - Optional session = sessionManager.getSession(sessionId); - if (session.isPresent()) { - // todo, statementId == jobId if statement running in session. - StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); - Optional statement = session.get().get(statementId); - if (statement.isPresent()) { - StatementState statementState = statement.get().getStatementState(); - result.put(STATUS_FIELD, statementState.getState()); - result.put(ERROR_FIELD, ""); - } else { - throw new IllegalArgumentException("no statement found. " + statementId); - } - } else { - throw new IllegalArgumentException("no session found. " + sessionId); - } - } else { - // make call to EMR Serverless when related result index documents are not available - GetJobRunResult getJobRunResult = - emrServerlessClient.getJobRunResult( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - String jobState = getJobRunResult.getJobRun().getState(); - result.put(STATUS_FIELD, jobState); - result.put(ERROR_FIELD, ""); - } - } - - return result; } public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - SessionId sessionId = new SessionId(asyncQueryJobMetadata.getSessionId()); - Optional session = sessionManager.getSession(sessionId); - if (session.isPresent()) { - // todo, statementId == jobId if statement running in session. - StatementId statementId = new StatementId(asyncQueryJobMetadata.getJobId()); - Optional statement = session.get().get(statementId); - if (statement.isPresent()) { - statement.get().cancel(); - return statementId.getId(); - } else { - throw new IllegalArgumentException("no statement found. " + statementId); - } - } else { - throw new IllegalArgumentException("no session found. " + sessionId); - } + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); } else { - CancelJobRunResult cancelJobRunResult = - emrServerlessClient.cancelJobRun( - asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); - return cancelJobRunResult.getJobRunId(); + return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) + .cancelJob(asyncQueryJobMetadata); } } @@ -229,12 +150,18 @@ private DispatchQueryResponse handleIndexQuery( indexDetails.getAutoRefresh(), dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); } private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + AsyncQueryId queryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); @@ -267,12 +194,12 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ dataSourceMetadata.getResultIndex(), dataSourceMetadata.getName())); } - StatementId statementId = - session.submit( - new QueryRequest( - dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); + session.submit( + new QueryRequest( + queryId, dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); return new DispatchQueryResponse( - statementId.getId(), + queryId, + session.getSessionModel().getJobId(), false, dataSourceMetadata.getResultIndex(), session.getSessionId().getSessionId()); @@ -294,7 +221,8 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ false, dataSourceMetadata.getResultIndex()); String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null); + return new DispatchQueryResponse( + queryId, jobId, false, dataSourceMetadata.getResultIndex(), null); } } @@ -325,7 +253,11 @@ private DispatchQueryResponse handleDropIndexQuery( } } return new DispatchQueryResponse( - new DropIndexResult(status).toJobId(), true, dataSourceMetadata.getResultIndex(), null); + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + new DropIndexResult(status).toJobId(), + true, + dataSourceMetadata.getResultIndex(), + null); } private static Map getDefaultTagsForJobSubmission( diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java index 893446c617..e44379daff 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryResponse.java @@ -2,10 +2,12 @@ import lombok.AllArgsConstructor; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; @Data @AllArgsConstructor public class DispatchQueryResponse { + private AsyncQueryId queryId; private String jobId; private boolean isDropIndexQuery; private String resultIndex; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 4428c3b83d..a2e7cfe6ee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -81,7 +81,8 @@ public StatementId submit(QueryRequest request) { } else { sessionModel = model.get(); if (!END_STATE.contains(sessionModel.getSessionState())) { - StatementId statementId = newStatementId(); + String qid = request.getQueryId().getId(); + StatementId statementId = newStatementId(qid); Statement st = Statement.builder() .sessionId(sessionId) @@ -92,7 +93,7 @@ public StatementId submit(QueryRequest request) { .langType(LangType.SQL) .datasourceName(sessionModel.getDatasourceName()) .query(request.getQuery()) - .queryId(statementId.getId()) + .queryId(qid) .build(); st.open(); return statementId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java index b3bd716925..c85e4dd35c 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -5,10 +5,10 @@ package org.opensearch.sql.spark.execution.session; -import java.nio.charset.StandardCharsets; -import java.util.Base64; +import static org.opensearch.sql.spark.utils.IDUtils.decode; +import static org.opensearch.sql.spark.utils.IDUtils.encode; + import lombok.Data; -import org.apache.commons.lang3.RandomStringUtils; @Data public class SessionId { @@ -24,15 +24,6 @@ public String getDataSourceName() { return decode(sessionId); } - private static String decode(String sessionId) { - return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN); - } - - private static String encode(String datasourceName) { - String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; - return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); - } - @Override public String toString() { return "sessionId=" + sessionId; diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java index 10061404ca..c365265224 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -6,10 +6,12 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.rest.model.LangType; @Data public class QueryRequest { + private final AsyncQueryId queryId; private final LangType langType; private final String query; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java index d9381ad45f..33284c4b3d 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -6,14 +6,14 @@ package org.opensearch.sql.spark.execution.statement; import lombok.Data; -import org.apache.commons.lang3.RandomStringUtils; @Data public class StatementId { private final String id; - public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.randomAlphanumeric(16)); + // construct statementId from queryId. + public static StatementId newStatementId(String qid) { + return new StatementId(qid); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index a36ee3ef45..6546d303fb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.session.SessionModel; import org.opensearch.sql.spark.execution.session.SessionState; import org.opensearch.sql.spark.execution.statement.StatementModel; @@ -53,7 +54,6 @@ public class StateStore { public static String MAPPING_FILE_NAME = "query_execution_request_mapping.yml"; public static Function DATASOURCE_TO_REQUEST_INDEX = datasourceName -> String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - public static String ALL_REQUEST_INDEX = String.format("%s_*", SPARK_REQUEST_BUFFER_INDEX_NAME); private static final Logger LOG = LogManager.getLogger(); @@ -77,7 +77,6 @@ protected T create( try (ThreadContext.StoredContext ignored = client.threadPool().getThreadContext().stashContext()) { IndexResponse indexResponse = client.index(indexRequest).actionGet(); - ; if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { LOG.debug("Successfully created doc. id: {}", st.getId()); return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); @@ -227,10 +226,6 @@ public static Function> getSession( docId, SessionModel::fromXContent, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Function> searchSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent, ALL_REQUEST_INDEX); - } - public static BiFunction updateSessionState( StateStore stateStore, String datasourceName) { return (old, state) -> @@ -241,8 +236,21 @@ public static BiFunction updateSession DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } - public static Runnable createStateStoreIndex(StateStore stateStore, String datasourceName) { - String indexName = String.format("%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName); - return () -> stateStore.createIndex(indexName); + public static Function createJobMetaData( + StateStore stateStore, String datasourceName) { + return (jobMetadata) -> + stateStore.create( + jobMetadata, + AsyncQueryJobMetadata::copy, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); + } + + public static Function> getJobMetaData( + StateStore stateStore, String datasourceName) { + return (docId) -> + stateStore.get( + docId, + AsyncQueryJobMetadata::fromXContent, + DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName)); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java new file mode 100644 index 0000000000..438d2342b4 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/utils/IDUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.utils; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.RandomStringUtils; + +@UtilityClass +public class IDUtils { + public static final int PREFIX_LEN = 10; + + public static String decode(String id) { + return new String(Base64.getDecoder().decode(id)).substring(PREFIX_LEN); + } + + public static String encode(String datasourceName) { + String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName; + return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/spark/src/main/resources/job-metadata-index-mapping.yml b/spark/src/main/resources/job-metadata-index-mapping.yml deleted file mode 100644 index 3a39b989a2..0000000000 --- a/spark/src/main/resources/job-metadata-index-mapping.yml +++ /dev/null @@ -1,25 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Schema file for the .ql-job-metadata index -# Also "dynamic" is set to "false" so that other fields can be added. -dynamic: false -properties: - jobId: - type: text - fields: - keyword: - type: keyword - applicationId: - type: text - fields: - keyword: - type: keyword - resultIndex: - type: text - fields: - keyword: - type: keyword \ No newline at end of file diff --git a/spark/src/main/resources/job-metadata-index-settings.yml b/spark/src/main/resources/job-metadata-index-settings.yml deleted file mode 100644 index be93f4645c..0000000000 --- a/spark/src/main/resources/job-metadata-index-settings.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -## -# Copyright OpenSearch Contributors -# SPDX-License-Identifier: Apache-2.0 -## - -# Settings file for the .ql-job-metadata index -index: - number_of_shards: "1" - auto_expand_replicas: "0-2" - number_of_replicas: "0" \ No newline at end of file diff --git a/spark/src/main/resources/query_execution_request_mapping.yml b/spark/src/main/resources/query_execution_request_mapping.yml index 87bd927e6e..fbe90a1cba 100644 --- a/spark/src/main/resources/query_execution_request_mapping.yml +++ b/spark/src/main/resources/query_execution_request_mapping.yml @@ -8,6 +8,8 @@ # Also "dynamic" is set to "false" so that other fields can be added. dynamic: false properties: + version: + type: keyword type: type: keyword state: diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 3eb8958eb2..1ee119df78 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -284,8 +284,9 @@ private DataSourceServiceImpl createDataSourceService() { private AsyncQueryExecutorService createAsyncQueryExecutorService( EMRServerlessClient emrServerlessClient) { + StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(client, clusterService); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( @@ -295,8 +296,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager( - new StateStore(client, clusterService), emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 0d4e280b61..2ed316795f 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; @@ -29,6 +30,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.config.SparkExecutionEngineConfig; import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier; @@ -47,6 +49,7 @@ public class AsyncQueryExecutorServiceImplTest { private AsyncQueryExecutorService jobExecutorService; @Mock private SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); @BeforeEach void setUp() { @@ -78,11 +81,12 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME))) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata(new AsyncQueryJobMetadata("00fd775baqpu4g0p", EMR_JOB_ID, null)); + .storeJobMetadata( + new AsyncQueryJobMetadata(QUERY_ID, "00fd775baqpu4g0p", EMR_JOB_ID, null)); verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); verify(sparkQueryDispatcher, times(1)) .dispatch( @@ -93,7 +97,7 @@ void testCreateAsyncQuery() { LangType.SQL, "arn:aws:iam::270824043731:role/emr-job-execution-role", TEST_CLUSTER_NAME)); - Assertions.assertEquals(EMR_JOB_ID, createAsyncQueryResponse.getQueryId()); + Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); } @Test @@ -107,7 +111,7 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) - .thenReturn(new DispatchQueryResponse(EMR_JOB_ID, false, null, null)); + .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, false, null, null)); jobExecutorService.createAsyncQuery( new CreateAsyncQueryRequest( @@ -139,11 +143,13 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -157,11 +163,13 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = @@ -208,9 +216,11 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn(Optional.of(new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn( + Optional.of( + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); when(sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null))) + new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) .thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 7288fd3fc2..de0caf5589 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,242 +5,70 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService.JOB_METADATA_INDEX; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import java.util.Optional; -import org.apache.lucene.search.TotalHits; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; -import org.mockito.ArgumentMatchers; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.test.OpenSearchIntegTestCase; -@ExtendWith(MockitoExtension.class) -public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest { +public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest + extends OpenSearchIntegTestCase { - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private ClusterService clusterService; - - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private SearchResponse searchResponse; - - @Mock private ActionFuture searchResponseActionFuture; - @Mock private ActionFuture createIndexResponseActionFuture; - @Mock private ActionFuture indexResponseActionFuture; - @Mock private IndexResponse indexResponse; - @Mock private SearchHit searchHit; - - @InjectMocks + public static final String DS_NAME = "mys3"; + private static final String MOCK_SESSION_ID = "sessionId"; + private static final String MOCK_RESULT_INDEX = "resultIndex"; private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; - @Test - public void testStoreJobMetadata() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithOutCreatingIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.TRUE); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.CREATED); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - - this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata); - - Mockito.verify(client.admin().indices(), Mockito.times(0)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithException() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())) - .thenThrow(new RuntimeException("error while indexing")); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "java.lang.RuntimeException: error while indexing", runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testStoreJobMetadataWithIndexCreationFailed() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(false, false, JOB_METADATA_INDEX)); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Internal server error while creating.ql-job-metadata index:: " - + "Index creation is not acknowledged.", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(1)).stashContext(); - } - - @Test - public void testStoreJobMetadataFailedWithNotFoundResponse() { - - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - Mockito.when(indexResponseActionFuture.actionGet()).thenReturn(indexResponse); - Mockito.when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); - - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMR_JOB_ID, EMRS_APPLICATION_ID, null); - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> this.opensearchJobMetadataStorageService.storeJobMetadata(asyncQueryJobMetadata)); - Assertions.assertEquals( - "Saving job metadata information failed with result : not_found", - runtimeException.getMessage()); - - Mockito.verify(client.admin().indices(), Mockito.times(1)).create(ArgumentMatchers.any()); - Mockito.verify(client, Mockito.times(1)).index(ArgumentMatchers.any()); - Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); - } - - @Test - public void testGetJobMetadata() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - AsyncQueryJobMetadata asyncQueryJobMetadata = - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null); - Mockito.when(searchHit.getSourceAsString()).thenReturn(asyncQueryJobMetadata.toString()); - - Optional jobMetadataOptional = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - Assertions.assertTrue(jobMetadataOptional.isPresent()); - Assertions.assertEquals(EMR_JOB_ID, jobMetadataOptional.get().getJobId()); - Assertions.assertEquals(EMRS_APPLICATION_ID, jobMetadataOptional.get().getApplicationId()); + @Before + public void setup() { + opensearchJobMetadataStorageService = + new OpensearchAsyncQueryJobMetadataStorageService( + new StateStore(client(), clusterService())); } @Test - public void testGetJobMetadataWith404SearchResponse() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.NOT_FOUND); - - RuntimeException runtimeException = - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); - Assertions.assertEquals( - "Fetching job metadata information failed with status : NOT_FOUND", - runtimeException.getMessage()); - } - - @Test - public void testGetJobMetadataWithParsingFailed() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(true); - Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); - Mockito.when(searchResponseActionFuture.actionGet()).thenReturn(searchResponse); - Mockito.when(searchResponse.status()).thenReturn(RestStatus.OK); - Mockito.when(searchResponse.getHits()) - .thenReturn( - new SearchHits( - new SearchHit[] {searchHit}, new TotalHits(21, TotalHits.Relation.EQUAL_TO), 1.0F)); - Mockito.when(searchHit.getSourceAsString()).thenReturn("..tesJOBs"); - - Assertions.assertThrows( - RuntimeException.class, - () -> opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)); + public void testStoreJobMetadata() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + MOCK_RESULT_INDEX); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertFalse(actual.get().isDropIndexQuery()); + assertNull(actual.get().getSessionId()); } @Test - public void testGetJobMetadataWithNoIndex() { - Mockito.when(clusterService.state().routingTable().hasIndex(JOB_METADATA_INDEX)) - .thenReturn(Boolean.FALSE); - Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) - .thenReturn(createIndexResponseActionFuture); - Mockito.when(createIndexResponseActionFuture.actionGet()) - .thenReturn(new CreateIndexResponse(true, true, JOB_METADATA_INDEX)); - Mockito.when(client.index(ArgumentMatchers.any())).thenReturn(indexResponseActionFuture); - - Optional jobMetadata = - opensearchJobMetadataStorageService.getJobMetadata(EMR_JOB_ID); - - Assertions.assertFalse(jobMetadata.isPresent()); + public void testStoreJobMetadataWithResultExtraData() { + AsyncQueryJobMetadata expected = + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMR_JOB_ID, + EMRS_APPLICATION_ID, + true, + MOCK_RESULT_INDEX, + MOCK_SESSION_ID); + + opensearchJobMetadataStorageService.storeJobMetadata(expected); + Optional actual = + opensearchJobMetadataStorageService.getJobMetadata(expected.getQueryId().getId()); + + assertTrue(actual.isPresent()); + assertEquals(expected, actual.get()); + assertTrue(actual.get().isDropIndexQuery()); + assertEquals("resultIndex", actual.get().getResultIndex()); + assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 15211dec01..4acccae0e2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.spark.dispatcher; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; @@ -19,6 +20,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; @@ -47,7 +49,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -58,6 +59,7 @@ import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; @@ -86,19 +88,22 @@ public class SparkQueryDispatcherTest { @Mock private DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper; @Mock private FlintIndexMetadataReader flintIndexMetadataReader; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) + @Mock(answer = RETURNS_DEEP_STUBS) private Client openSearchClient; @Mock private FlintIndexMetadata flintIndexMetadata; @Mock private SessionManager sessionManager; - @Mock private Session session; + @Mock(answer = RETURNS_DEEP_STUBS) + private Session session; @Mock private Statement statement; private SparkQueryDispatcher sparkQueryDispatcher; + private final AsyncQueryId QUERY_ID = AsyncQueryId.newAsyncQueryId(DS_NAME); + @Captor ArgumentCaptor startJobRequestArgumentCaptor; @BeforeEach @@ -285,6 +290,7 @@ void testDispatchSelectQueryCreateNewSession() { doReturn(session).when(sessionManager).createSession(any()); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -292,7 +298,7 @@ void testDispatchSelectQueryCreateNewSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).getSession(any()); - Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -307,6 +313,7 @@ void testDispatchSelectQueryReuseSession() { .getSession(eq(new SessionId(MOCK_SESSION_ID))); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); + when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID); DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata); doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata); @@ -314,7 +321,7 @@ void testDispatchSelectQueryReuseSession() { verifyNoInteractions(emrServerlessClient); verify(sessionManager, never()).createSession(any()); - Assertions.assertEquals(MOCK_STATEMENT_ID, dispatchQueryResponse.getJobId()); + Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId()); } @@ -636,10 +643,8 @@ void testCancelJob() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = - sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test @@ -698,10 +703,8 @@ void testCancelQueryWithNoSessionId() { new CancelJobRunResult() .withJobRunId(EMR_JOB_ID) .withApplicationId(EMRS_APPLICATION_ID)); - String jobId = - sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); - Assertions.assertEquals(EMR_JOB_ID, jobId); + String queryId = sparkQueryDispatcher.cancelJob(asyncQueryJobMetadata()); + Assertions.assertEquals(QUERY_ID.getId(), queryId); } @Test @@ -712,9 +715,7 @@ void testGetQueryResponse() { // simulate result index is not created yet when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(new JSONObject()); - JSONObject result = - sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); Assertions.assertEquals("PENDING", result.get("status")); } @@ -790,9 +791,7 @@ void testGetQueryResponseWithSuccess() { queryResult.put(DATA_FIELD, resultMap); when(jobExecutionResponseReader.getResultFromOpensearchIndex(EMR_JOB_ID, null)) .thenReturn(queryResult); - JSONObject result = - sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, EMR_JOB_ID, null)); + JSONObject result = sparkQueryDispatcher.getQueryResponse(asyncQueryJobMetadata()); verify(jobExecutionResponseReader, times(1)).getResultFromOpensearchIndex(EMR_JOB_ID, null); Assertions.assertEquals( new HashSet<>(Arrays.asList(DATA_FIELD, STATUS_FIELD, ERROR_FIELD)), result.keySet()); @@ -827,7 +826,13 @@ void testGetQueryResponseOfDropIndex() { JSONObject result = sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, jobId, true, null, null)); + new AsyncQueryJobMetadata( + AsyncQueryId.newAsyncQueryId(DS_NAME), + EMRS_APPLICATION_ID, + jobId, + true, + null, + null)); verify(jobExecutionResponseReader, times(0)) .getResultFromOpensearchIndex(anyString(), anyString()); Assertions.assertEquals("SUCCESS", result.get(STATUS_FIELD)); @@ -1210,8 +1215,13 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str sessionId); } + private AsyncQueryJobMetadata asyncQueryJobMetadata() { + return new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null); + } + private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( - String queryId, String sessionId) { - return new AsyncQueryJobMetadata(EMRS_APPLICATION_ID, queryId, false, null, sessionId); + String statementId, String sessionId) { + return new AsyncQueryJobMetadata( + new AsyncQueryId(statementId), EMRS_APPLICATION_ID, EMR_JOB_ID, false, null, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index ff3ddd1bef..1e33c8a6b9 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -22,6 +22,7 @@ import org.junit.Test; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; @@ -208,7 +209,7 @@ public void submitStatementInRunningSession() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); } @@ -218,7 +219,7 @@ public void submitStatementInNotStartedState() { new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); assertFalse(statementId.getId().isEmpty()); } @@ -231,9 +232,7 @@ public void failToSubmitStatementInDeadState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.DEAD); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " dead", @@ -249,9 +248,7 @@ public void failToSubmitStatementInFailState() { updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.FAIL); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals( "can't submit statement, session should not be in end state, current session state is:" + " fail", @@ -263,7 +260,7 @@ public void newStatementFieldAssert() { Session session = new SessionManager(stateStore, emrsClient, sessionSetting(false)) .createSession(createSessionRequest()); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -288,9 +285,7 @@ public void failToSubmitStatementInDeletedSession() { .actionGet(); IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertThrows(IllegalStateException.class, () -> session.submit(queryRequest())); assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); } @@ -301,7 +296,7 @@ public void getStatementSuccess() { .createSession(createSessionRequest()); // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + StatementId statementId = session.submit(queryRequest()); Optional statement = session.get(statementId); assertTrue(statement.isPresent()); @@ -317,7 +312,7 @@ public void getStatementNotExist() { // App change state to running updateSessionState(stateStore, DS_NAME).apply(session.getSessionModel(), SessionState.RUNNING); - Optional statement = session.get(StatementId.newStatementId()); + Optional statement = session.get(StatementId.newStatementId("not-exist-id")); assertFalse(statement.isPresent()); } @@ -361,4 +356,8 @@ public TestStatement cancel() { return this; } } + + private QueryRequest queryRequest() { + return new QueryRequest(AsyncQueryId.newAsyncQueryId(DS_NAME), LangType.SQL, "select 1"); + } }