Skip to content

Commit

Permalink
Refactor async executor service dependencies using guice framework
Browse files Browse the repository at this point in the history
Signed-off-by: Vamsi Manohar <[email protected]>
  • Loading branch information
vamsi-amazon committed Feb 1, 2024
1 parent e59bf75 commit a64b3f3
Show file tree
Hide file tree
Showing 20 changed files with 533 additions and 185 deletions.
101 changes: 3 additions & 98 deletions plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@

package org.opensearch.sql.plugin;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static java.util.Collections.singletonList;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.time.Clock;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -68,7 +61,6 @@
import org.opensearch.sql.datasources.transport.*;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.executor.AsyncRestExecutor;
import org.opensearch.sql.legacy.metrics.GaugeMetric;
import org.opensearch.sql.legacy.metrics.Metrics;
import org.opensearch.sql.legacy.plugin.RestSqlAction;
import org.opensearch.sql.legacy.plugin.RestSqlStatsAction;
Expand All @@ -86,22 +78,7 @@
import org.opensearch.sql.plugin.transport.TransportPPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;
import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorService;
import org.opensearch.sql.spark.asyncquery.AsyncQueryExecutorServiceImpl;
import org.opensearch.sql.spark.asyncquery.AsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryJobMetadataStorageService;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EmrServerlessClientImpl;
import org.opensearch.sql.spark.cluster.ClusterManagerEventListener;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplierImpl;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl;
import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;
import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction;
import org.opensearch.sql.spark.storage.SparkStorageFactory;
import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction;
Expand All @@ -127,7 +104,6 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin {

private NodeClient client;
private DataSourceServiceImpl dataSourceService;
private AsyncQueryExecutorService asyncQueryExecutorService;
private Injector injector;

public String name() {
Expand Down Expand Up @@ -223,23 +199,6 @@ public Collection<Object> createComponents(
dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata());
LocalClusterState.state().setClusterService(clusterService);
LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings);
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier =
new SparkExecutionEngineConfigSupplierImpl(pluginSettings);
SparkExecutionEngineConfig sparkExecutionEngineConfig =
sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
if (StringUtils.isEmpty(sparkExecutionEngineConfig.getRegion())) {
LOGGER.warn(
String.format(
"Async Query APIs are disabled as %s is not configured properly in cluster settings. "
+ "Please configure and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
this.asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl();
} else {
this.asyncQueryExecutorService =
createAsyncQueryExecutorService(
sparkExecutionEngineConfigSupplier, sparkExecutionEngineConfig);
}

ModulesBuilder modules = new ModulesBuilder();
modules.add(new OpenSearchPluginModule());
modules.add(
Expand All @@ -260,13 +219,12 @@ public Collection<Object> createComponents(
OpenSearchSettings.RESULT_INDEX_TTL_SETTING,
OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING,
environment.settings());
return ImmutableList.of(
dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings);
return ImmutableList.of(dataSourceService, clusterManagerEventListener, pluginSettings);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return Collections.singletonList(
return singletonList(
new FixedExecutorBuilder(
settings,
AsyncRestExecutor.SQL_WORKER_THREAD_POOL_NAME,
Expand Down Expand Up @@ -318,57 +276,4 @@ private DataSourceServiceImpl createDataSourceService() {
dataSourceMetadataStorage,
dataSourceUserAuthorizationHelper);
}

private AsyncQueryExecutorService createAsyncQueryExecutorService(
SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier,
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
StateStore stateStore = new StateStore(client, clusterService);
registerStateStoreMetrics(stateStore);
AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService =
new OpensearchAsyncQueryJobMetadataStorageService(stateStore);
EMRServerlessClient emrServerlessClient =
createEMRServerlessClient(sparkExecutionEngineConfig.getRegion());
JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client);
SparkQueryDispatcher sparkQueryDispatcher =
new SparkQueryDispatcher(
emrServerlessClient,
this.dataSourceService,
new DataSourceUserAuthorizationHelperImpl(client),
jobExecutionResponseReader,
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(stateStore, emrServerlessClient, pluginSettings),
new DefaultLeaseManager(pluginSettings, stateStore),
stateStore);
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
sparkExecutionEngineConfigSupplier);
}

private void registerStateStoreMetrics(StateStore stateStore) {
GaugeMetric<Long> activeSessionMetric =
new GaugeMetric<>(
"active_async_query_sessions_count",
StateStore.activeSessionsCount(stateStore, ALL_DATASOURCE));
GaugeMetric<Long> activeStatementMetric =
new GaugeMetric<>(
"active_async_query_statements_count",
StateStore.activeStatementsCount(stateStore, ALL_DATASOURCE));
Metrics.getInstance().registerMetric(activeSessionMetric);
Metrics.getInstance().registerMetric(activeStatementMetric);
}

private EMRServerlessClient createEMRServerlessClient(String region) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(region)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.opensearch.sql.spark.client;

/** Interface for EMRServerlessClientFactory. */
public interface EMRServerlessClientFactory {

/**
* @return {@link EMRServerlessClient}
*/
EMRServerlessClient getClient();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.opensearch.sql.spark.client;

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfigSupplier;

@RequiredArgsConstructor
public class EMRServerlessClientFactoryImpl implements EMRServerlessClientFactory {

private final SparkExecutionEngineConfigSupplier sparkExecutionEngineConfigSupplier;
private EMRServerlessClient emrServerlessClient;
private String region;

@Override
public EMRServerlessClient getClient() {
SparkExecutionEngineConfig sparkExecutionEngineConfig =
this.sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig();
validateSparkExecutionEngineConfig(sparkExecutionEngineConfig);
if (isNewClientCreationRequired(sparkExecutionEngineConfig.getRegion())) {
region = sparkExecutionEngineConfig.getRegion();
this.emrServerlessClient = createEMRServerlessClient(this.region);
}
return this.emrServerlessClient;
}

private boolean isNewClientCreationRequired(String region) {
return !region.equals(this.region) || emrServerlessClient == null;
}

private void validateSparkExecutionEngineConfig(
SparkExecutionEngineConfig sparkExecutionEngineConfig) {
if (sparkExecutionEngineConfig == null || sparkExecutionEngineConfig.getRegion() == null) {
throw new IllegalArgumentException(
String.format(
"Async Query APIs are disabled as %s is not configured in cluster settings. Please"
+ " configure the setting and restart the domain to enable Async Query APIs",
SPARK_EXECUTION_ENGINE_CONFIG.getKeyValue()));
}
}

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import java.util.HashMap;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
Expand All @@ -18,6 +16,8 @@
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
Expand All @@ -35,13 +35,12 @@
@AllArgsConstructor
public class SparkQueryDispatcher {

private static final Logger LOG = LogManager.getLogger();
public static final String INDEX_TAG_KEY = "index";
public static final String DATASOURCE_TAG_KEY = "datasource";
public static final String CLUSTER_NAME_TAG_KEY = "domain_ident";
public static final String JOB_TYPE_TAG_KEY = "type";

private EMRServerlessClient emrServerlessClient;
private EMRServerlessClientFactory emrServerlessClientFactory;

private DataSourceService dataSourceService;

Expand All @@ -60,10 +59,10 @@ public class SparkQueryDispatcher {
private StateStore stateStore;

public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource());
dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata);

AsyncQueryHandler asyncQueryHandler =
sessionManager.isEnabled()
? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
Expand All @@ -83,7 +82,7 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
contextBuilder.indexQueryDetails(indexQueryDetails);

if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) {
asyncQueryHandler = createIndexDMLHandler();
asyncQueryHandler = createIndexDMLHandler(emrServerlessClient);
} else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())
&& indexQueryDetails.isAutoRefresh()) {
asyncQueryHandler =
Expand All @@ -98,33 +97,37 @@ public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest)
return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build());
}

private void createEmrServerlessClient(SparkExecutionEngineConfig sparkExecutionEngineConfig) {}

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
if (asyncQueryJobMetadata.getSessionId() != null) {
return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
return createIndexDMLHandler().getQueryResponse(asyncQueryJobMetadata);
return createIndexDMLHandler(emrServerlessClient).getQueryResponse(asyncQueryJobMetadata);
} else {
return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager)
.getQueryResponse(asyncQueryJobMetadata);
}
}

public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
EMRServerlessClient emrServerlessClient = emrServerlessClientFactory.getClient();
AsyncQueryHandler queryHandler;
if (asyncQueryJobMetadata.getSessionId() != null) {
queryHandler =
new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager);
} else if (IndexDMLHandler.isIndexDMLQuery(asyncQueryJobMetadata.getJobId())) {
queryHandler = createIndexDMLHandler();
queryHandler = createIndexDMLHandler(emrServerlessClient);
} else {
queryHandler =
new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader, leaseManager);
}
return queryHandler.cancelJob(asyncQueryJobMetadata);
}

private IndexDMLHandler createIndexDMLHandler() {
private IndexDMLHandler createIndexDMLHandler(EMRServerlessClient emrServerlessClient) {
return new IndexDMLHandler(
emrServerlessClient,
dataSourceService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import java.util.Optional;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.EMRServerlessClientFactory;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.utils.RealTimeProvider;

Expand All @@ -21,13 +21,15 @@
*/
public class SessionManager {
private final StateStore stateStore;
private final EMRServerlessClient emrServerlessClient;
private final EMRServerlessClientFactory emrServerlessClientFactory;
private Settings settings;

public SessionManager(
StateStore stateStore, EMRServerlessClient emrServerlessClient, Settings settings) {
StateStore stateStore,
EMRServerlessClientFactory emrServerlessClientFactory,
Settings settings) {
this.stateStore = stateStore;
this.emrServerlessClient = emrServerlessClient;
this.emrServerlessClientFactory = emrServerlessClientFactory;
this.settings = settings;
}

Expand All @@ -36,7 +38,7 @@ public Session createSession(CreateSessionRequest request) {
InteractiveSession.builder()
.sessionId(newSessionId(request.getDatasourceName()))
.stateStore(stateStore)
.serverlessClient(emrServerlessClient)
.serverlessClient(emrServerlessClientFactory.getClient())
.build();
session.open(request);
return session;
Expand Down Expand Up @@ -68,7 +70,7 @@ public Optional<Session> getSession(SessionId sid, String dataSourceName) {
InteractiveSession.builder()
.sessionId(sid)
.stateStore(stateStore)
.serverlessClient(emrServerlessClient)
.serverlessClient(emrServerlessClientFactory.getClient())
.sessionModel(model.get())
.sessionInactivityTimeoutMilli(
settings.getSettingValue(SESSION_INACTIVITY_TIMEOUT_MILLIS))
Expand Down
Loading

0 comments on commit a64b3f3

Please sign in to comment.