diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java index db78abb2a8..9a73b0f364 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/SparkSubmitParameters.java @@ -31,6 +31,7 @@ public class SparkSubmitParameters { public static final String SPACE = " "; public static final String EQUALS = "="; + public static final String FLINT_BASIC_AUTH = "basic"; private final String className; private final Map config; @@ -114,7 +115,7 @@ private void setFlintIndexStoreAuthProperties( Supplier password, Supplier region) { if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) { - config.put(FLINT_INDEX_STORE_AUTH_KEY, authType); + config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH); config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get()); config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get()); } else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 24ea1528c8..52cc2efbe2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -39,7 +39,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId); StatementState statementState = statement.getStatementState(); result.put(STATUS_FIELD, statementState.getState()); - result.put(ERROR_FIELD, ""); + result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse("")); return result; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java index ca2b2b4867..b2201fbd01 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -21,14 +21,40 @@ public class CreateSessionRequest { private final String datasourceName; public StartJobRequest getStartJobRequest() { - return new StartJobRequest( + return new InteractiveSessionStartJobRequest( "select 1", jobName, applicationId, executionRoleArn, sparkSubmitParametersBuilder.build().toString(), tags, - false, resultIndex); } + + static class InteractiveSessionStartJobRequest extends StartJobRequest { + public InteractiveSessionStartJobRequest( + String query, + String jobName, + String applicationId, + String executionRoleArn, + String sparkSubmitParams, + Map tags, + String resultIndex) { + super( + query, + jobName, + applicationId, + executionRoleArn, + sparkSubmitParams, + tags, + false, + resultIndex); + } + + /** Interactive query keep running. */ + @Override + public Long executionTimeout() { + return 0L; + } + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 1ee119df78..19edd53eae 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -17,6 +17,7 @@ import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE; import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX; import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -26,7 +27,9 @@ import com.google.common.collect.ImmutableSet; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import lombok.Getter; import org.junit.After; @@ -109,7 +112,7 @@ public void setup() { "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", "glue.indexstore.opensearch.uri", - "http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200", + "http://localhost:9200", "glue.indexstore.opensearch.auth", "noauth"), null)); @@ -269,8 +272,114 @@ public void reuseSessionWhenCreateAsyncQuery() { assertEquals(second.getQueryId(), secondModel.get().getQueryId()); } + @Test + public void batchQueryHasTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + enableSession(false); + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + + assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void interactiveQueryNoTimeout() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout()); + } + + @Test + public void datasourceWithBasicAuth() { + Map properties = new HashMap<>(); + properties.put("glue.auth.type", "iam_role"); + properties.put( + "glue.auth.role_arn", "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole"); + properties.put("glue.indexstore.opensearch.uri", "http://localhost:9200"); + properties.put("glue.indexstore.opensearch.auth", "basicauth"); + properties.put("glue.indexstore.opensearch.auth.username", "username"); + properties.put("glue.indexstore.opensearch.auth.password", "password"); + + dataSourceService.createDataSource( + new DataSourceMetadata( + "mybasicauth", DataSourceType.S3GLUE, ImmutableList.of(), properties, null)); + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null)); + String params = emrsClient.getJobRequest().getSparkSubmitParams(); + assertTrue(params.contains(String.format("--conf spark.datasource.flint.auth=basic"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.username=username"))); + assertTrue( + params.contains(String.format("--conf spark.datasource.flint.auth.password=password"))); + } + + @Test + public void withSessionCreateAsyncQueryFailed() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("myselect 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(response.getSessionId()); + Optional statementModel = + getStatement(stateStore, DATASOURCE).apply(response.getQueryId()); + assertTrue(statementModel.isPresent()); + assertEquals(StatementState.WAITING, statementModel.get().getStatementState()); + + // 2. fetch async query result. not result write to SPARK_RESPONSE_BUFFER_INDEX_NAME yet. + // mock failed statement. + StatementModel submitted = statementModel.get(); + StatementModel mocked = + StatementModel.builder() + .version("1.0") + .statementState(submitted.getStatementState()) + .statementId(submitted.getStatementId()) + .sessionId(submitted.getSessionId()) + .applicationId(submitted.getApplicationId()) + .jobId(submitted.getJobId()) + .langType(submitted.getLangType()) + .datasourceName(submitted.getDatasourceName()) + .query(submitted.getQuery()) + .queryId(submitted.getQueryId()) + .submitTime(submitted.getSubmitTime()) + .error("mock error") + .seqNo(submitted.getSeqNo()) + .primaryTerm(submitted.getPrimaryTerm()) + .build(); + updateStatementState(stateStore, DATASOURCE).apply(mocked, StatementState.FAILED); + + AsyncQueryExecutionResponse asyncQueryResults = + asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus()); + assertEquals("mock error", asyncQueryResults.getError()); + } + private DataSourceServiceImpl createDataSourceService() { - String masterKey = "1234567890"; + String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = new OpenSearchDataSourceMetadataStorage( client, clusterService, new EncryptorImpl(masterKey)); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 700acb973e..a69c6e2b1a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -99,7 +99,8 @@ public class SparkQueryDispatcherTest { @Mock(answer = RETURNS_DEEP_STUBS) private Session session; - @Mock private Statement statement; + @Mock(answer = RETURNS_DEEP_STUBS) + private Statement statement; private SparkQueryDispatcher sparkQueryDispatcher; @@ -184,7 +185,7 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { String query = "select * from my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( - "basicauth", + "basic", new HashMap<>() { { put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); @@ -783,6 +784,7 @@ void testGetQueryResponse() { void testGetQueryResponseWithSession() { doReturn(Optional.of(session)).when(sessionManager).getSession(new SessionId(MOCK_SESSION_ID)); doReturn(Optional.of(statement)).when(session).get(any()); + when(statement.getStatementModel().getError()).thenReturn("mock error"); doReturn(StatementState.WAITING).when(statement).getStatementState(); doReturn(new JSONObject())