diff --git a/build.gradle b/build.gradle index ffd0153d04..6d3950d483 100644 --- a/build.gradle +++ b/build.gradle @@ -89,6 +89,8 @@ spotless { 'prometheus/**/*.java', 'sql/**/*.java', 'common/**/*.java', + 'spark/**/*.java', + 'plugin/**/*.java', 'ppl/**/*.java', 'integ-test/**/*java' exclude '**/build/**', '**/build-*/**' diff --git a/plugin/build.gradle b/plugin/build.gradle index 11f97ea857..8ec6844bfd 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -85,6 +85,9 @@ publishing { } } +checkstyleTest.ignoreFailures = true +checkstyleMain.ignoreFailures = true + javadoc.enabled = false loggerUsageCheck.enabled = false dependencyLicenses.enabled = false diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 5e156c2f5d..f20de87d61 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -94,10 +94,10 @@ public class SQLPlugin extends Plugin implements ActionPlugin, ScriptPlugin { private static final Logger LOGGER = LogManager.getLogger(SQLPlugin.class); private ClusterService clusterService; - /** - * Settings should be inited when bootstrap the plugin. - */ + + /** Settings should be inited when bootstrap the plugin. */ private org.opensearch.sql.common.setting.Settings pluginSettings; + private NodeClient client; private DataSourceServiceImpl dataSourceService; private Injector injector; @@ -134,23 +134,28 @@ public List getRestHandlers( new RestDataSourceQueryAction()); } - /** - * Register action and handler so that transportClient can find proxy for action. - */ + /** Register action and handler so that transportClient can find proxy for action. */ @Override public List> getActions() { return Arrays.asList( new ActionHandler<>( new ActionType<>(PPLQueryAction.NAME, TransportPPLQueryResponse::new), TransportPPLQueryAction.class), - new ActionHandler<>(new ActionType<>(TransportCreateDataSourceAction.NAME, - CreateDataSourceActionResponse::new), TransportCreateDataSourceAction.class), - new ActionHandler<>(new ActionType<>(TransportGetDataSourceAction.NAME, - GetDataSourceActionResponse::new), TransportGetDataSourceAction.class), - new ActionHandler<>(new ActionType<>(TransportUpdateDataSourceAction.NAME, - UpdateDataSourceActionResponse::new), TransportUpdateDataSourceAction.class), - new ActionHandler<>(new ActionType<>(TransportDeleteDataSourceAction.NAME, - DeleteDataSourceActionResponse::new), TransportDeleteDataSourceAction.class)); + new ActionHandler<>( + new ActionType<>( + TransportCreateDataSourceAction.NAME, CreateDataSourceActionResponse::new), + TransportCreateDataSourceAction.class), + new ActionHandler<>( + new ActionType<>(TransportGetDataSourceAction.NAME, GetDataSourceActionResponse::new), + TransportGetDataSourceAction.class), + new ActionHandler<>( + new ActionType<>( + TransportUpdateDataSourceAction.NAME, UpdateDataSourceActionResponse::new), + TransportUpdateDataSourceAction.class), + new ActionHandler<>( + new ActionType<>( + TransportDeleteDataSourceAction.NAME, DeleteDataSourceActionResponse::new), + TransportDeleteDataSourceAction.class)); } @Override @@ -176,11 +181,12 @@ public Collection createComponents( ModulesBuilder modules = new ModulesBuilder(); modules.add(new OpenSearchPluginModule()); - modules.add(b -> { - b.bind(NodeClient.class).toInstance((NodeClient) client); - b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings); - b.bind(DataSourceService.class).toInstance(dataSourceService); - }); + modules.add( + b -> { + b.bind(NodeClient.class).toInstance((NodeClient) client); + b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings); + b.bind(DataSourceService.class).toInstance(dataSourceService); + }); injector = modules.createInjector(); return ImmutableList.of(dataSourceService); @@ -212,30 +218,31 @@ public ScriptEngine getScriptEngine(Settings settings, Collection() - .add(new OpenSearchDataSourceFactory( - new OpenSearchNodeClient(this.client), pluginSettings)) + .add( + new OpenSearchDataSourceFactory( + new OpenSearchNodeClient(this.client), pluginSettings)) .add(new PrometheusStorageFactory(pluginSettings)) .add(new SparkStorageFactory(this.client, pluginSettings)) .build(), dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); } - } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java index f301a242fb..33a785c498 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java @@ -45,8 +45,7 @@ public class OpenSearchPluginModule extends AbstractModule { BuiltinFunctionRepository.getInstance(); @Override - protected void configure() { - } + protected void configure() {} @Provides public OpenSearchClient openSearchClient(NodeClient nodeClient) { @@ -59,8 +58,8 @@ public StorageEngine storageEngine(OpenSearchClient client, Settings settings) { } @Provides - public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, - PlanSerializer planSerializer) { + public ExecutionEngine executionEngine( + OpenSearchClient client, ExecutionProtector protector, PlanSerializer planSerializer) { return new OpenSearchExecutionEngine(client, protector, planSerializer); } @@ -95,18 +94,15 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl return new SQLService(new SQLSyntaxParser(), queryManager, queryPlanFactory); } - /** - * {@link QueryPlanFactory}. - */ + /** {@link QueryPlanFactory}. */ @Provides - public QueryPlanFactory queryPlanFactory(DataSourceService dataSourceService, - ExecutionEngine executionEngine) { + public QueryPlanFactory queryPlanFactory( + DataSourceService dataSourceService, ExecutionEngine executionEngine) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); - QueryService queryService = new QueryService( - analyzer, executionEngine, planner); + QueryService queryService = new QueryService(analyzer, executionEngine, planner); return new QueryPlanFactory(queryService); } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java index 730da0e923..ad734bf150 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/request/PPLQueryRequestFactory.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.plugin.request; import java.util.Map; @@ -15,9 +14,7 @@ import org.opensearch.sql.protocol.response.format.Format; import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; -/** - * Factory of {@link PPLQueryRequest}. - */ +/** Factory of {@link PPLQueryRequest}. */ public class PPLQueryRequestFactory { private static final String PPL_URL_PARAM_KEY = "ppl"; private static final String PPL_FIELD_NAME = "query"; @@ -28,6 +25,7 @@ public class PPLQueryRequestFactory { /** * Build {@link PPLQueryRequest} from {@link RestRequest}. + * * @param request {@link PPLQueryRequest} * @return {@link RestRequest} */ @@ -63,8 +61,12 @@ private static PPLQueryRequest parsePPLRequestFromPayload(RestRequest restReques } catch (JSONException e) { throw new IllegalArgumentException("Failed to parse request payload", e); } - PPLQueryRequest pplRequest = new PPLQueryRequest(jsonContent.getString(PPL_FIELD_NAME), - jsonContent, restRequest.path(), format.getFormatName()); + PPLQueryRequest pplRequest = + new PPLQueryRequest( + jsonContent.getString(PPL_FIELD_NAME), + jsonContent, + restRequest.path(), + format.getFormatName()); // set sanitize option if csv format if (format.equals(Format.CSV)) { pplRequest.sanitize(getSanitizeOption(restRequest.params())); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index 55f8dfdfef..996ae8c700 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -102,14 +102,17 @@ protected Set responseParams() { protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) { // TODO: need move to transport Action if (!pplEnabled.get()) { - return channel -> reportError(channel, new IllegalAccessException( - "Either plugins.ppl.enabled or rest.action.multi.allow_explicit_index setting is false"), - BAD_REQUEST); + return channel -> + reportError( + channel, + new IllegalAccessException( + "Either plugins.ppl.enabled or rest.action.multi.allow_explicit_index setting is" + + " false"), + BAD_REQUEST); } - TransportPPLQueryRequest transportPPLQueryRequest = new TransportPPLQueryRequest( - PPLQueryRequestFactory.getPPLRequest(request) - ); + TransportPPLQueryRequest transportPPLQueryRequest = + new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request)); return channel -> nodeClient.execute( diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java index ef9f68a2a7..7a51fc282b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.plugin.rest; import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; @@ -26,17 +25,14 @@ import org.opensearch.sql.legacy.executor.format.ErrorMessageFactory; import org.opensearch.sql.legacy.metrics.Metrics; -/** - * PPL Node level status. - */ +/** PPL Node level status. */ public class RestPPLStatsAction extends BaseRestHandler { private static final Logger LOG = LogManager.getLogger(RestPPLStatsAction.class); - /** - * API endpoint path. - */ + /** API endpoint path. */ public static final String PPL_STATS_API_ENDPOINT = "/_plugins/_ppl/stats"; + public static final String PPL_LEGACY_STATS_API_ENDPOINT = "/_opendistro/_ppl/stats"; public RestPPLStatsAction(Settings settings, RestController restController) { @@ -70,13 +66,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli QueryContext.addRequestId(); try { - return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.OK, - Metrics.getInstance().collectToJSON())); + return channel -> + channel.sendResponse( + new BytesRestResponse(RestStatus.OK, Metrics.getInstance().collectToJSON())); } catch (Exception e) { LOG.error("Failed during Query PPL STATS Action.", e); - return channel -> channel.sendResponse(new BytesRestResponse(SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()).toString())); + return channel -> + channel.sendResponse( + new BytesRestResponse( + SERVICE_UNAVAILABLE, + ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + .toString())); } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestQuerySettingsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestQuerySettingsAction.java index b15b4dddd6..885c953c17 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestQuerySettingsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestQuerySettingsAction.java @@ -39,9 +39,14 @@ public class RestQuerySettingsAction extends BaseRestHandler { private static final String LEGACY_SQL_SETTINGS_PREFIX = "opendistro.sql."; private static final String LEGACY_PPL_SETTINGS_PREFIX = "opendistro.ppl."; private static final String LEGACY_COMMON_SETTINGS_PREFIX = "opendistro.query."; - private static final List SETTINGS_PREFIX = ImmutableList.of( - SQL_SETTINGS_PREFIX, PPL_SETTINGS_PREFIX, COMMON_SETTINGS_PREFIX, - LEGACY_SQL_SETTINGS_PREFIX, LEGACY_PPL_SETTINGS_PREFIX, LEGACY_COMMON_SETTINGS_PREFIX); + private static final List SETTINGS_PREFIX = + ImmutableList.of( + SQL_SETTINGS_PREFIX, + PPL_SETTINGS_PREFIX, + COMMON_SETTINGS_PREFIX, + LEGACY_SQL_SETTINGS_PREFIX, + LEGACY_PPL_SETTINGS_PREFIX, + LEGACY_COMMON_SETTINGS_PREFIX); public static final String SETTINGS_API_ENDPOINT = "/_plugins/_query/settings"; public static final String LEGACY_SQL_SETTINGS_API_ENDPOINT = "/_opendistro/_sql/settings"; @@ -75,10 +80,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli QueryContext.addRequestId(); final ClusterUpdateSettingsRequest clusterUpdateSettingsRequest = Requests.clusterUpdateSettingsRequest(); - clusterUpdateSettingsRequest.timeout(request.paramAsTime( - "timeout", clusterUpdateSettingsRequest.timeout())); - clusterUpdateSettingsRequest.clusterManagerNodeTimeout(request.paramAsTime( - "cluster_manager_timeout", clusterUpdateSettingsRequest.clusterManagerNodeTimeout())); + clusterUpdateSettingsRequest.timeout( + request.paramAsTime("timeout", clusterUpdateSettingsRequest.timeout())); + clusterUpdateSettingsRequest.clusterManagerNodeTimeout( + request.paramAsTime( + "cluster_manager_timeout", clusterUpdateSettingsRequest.clusterManagerNodeTimeout())); Map source; try (XContentParser parser = request.contentParser()) { source = parser.map(); @@ -86,20 +92,27 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli try { if (source.containsKey(TRANSIENT)) { - clusterUpdateSettingsRequest.transientSettings(getAndFilterSettings( - (Map) source.get(TRANSIENT))); + clusterUpdateSettingsRequest.transientSettings( + getAndFilterSettings((Map) source.get(TRANSIENT))); } if (source.containsKey(PERSISTENT)) { - clusterUpdateSettingsRequest.persistentSettings(getAndFilterSettings( - (Map) source.get(PERSISTENT))); + clusterUpdateSettingsRequest.persistentSettings( + getAndFilterSettings((Map) source.get(PERSISTENT))); } - return channel -> client.admin().cluster().updateSettings( - clusterUpdateSettingsRequest, new RestToXContentListener<>(channel)); + return channel -> + client + .admin() + .cluster() + .updateSettings(clusterUpdateSettingsRequest, new RestToXContentListener<>(channel)); } catch (Exception e) { LOG.error("Error changing OpenSearch SQL plugin cluster settings", e); - return channel -> channel.sendResponse(new BytesRestResponse(INTERNAL_SERVER_ERROR, - ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()).toString())); + return channel -> + channel.sendResponse( + new BytesRestResponse( + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) + .toString())); } } @@ -107,16 +120,19 @@ private Settings getAndFilterSettings(Map source) { try { XContentBuilder builder = XContentFactory.jsonBuilder(); builder.map(source); - Settings.Builder settingsBuilder = Settings.builder() - .loadFromSource(builder.toString(), builder.contentType()); - settingsBuilder.keys().removeIf(key -> { - for (String prefix : SETTINGS_PREFIX) { - if (key.startsWith(prefix)) { - return false; - } - } - return true; - }); + Settings.Builder settingsBuilder = + Settings.builder().loadFromSource(builder.toString(), builder.contentType()); + settingsBuilder + .keys() + .removeIf( + key -> { + for (String prefix : SETTINGS_PREFIX) { + if (key.startsWith(prefix)) { + return false; + } + } + return true; + }); return settingsBuilder.build(); } catch (IOException e) { throw new OpenSearchGenerationException("Failed to generate [" + source + "]", e); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 8a9d276673..fde9e24f75 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -139,8 +139,8 @@ private ResponseListener createListener( @Override public void onResponse(ExecutionEngine.QueryResponse response) { String responseContent = - formatter.format(new QueryResult(response.getSchema(), response.getResults(), - response.getCursor())); + formatter.format( + new QueryResult(response.getSchema(), response.getResults(), response.getCursor())); listener.onResponse(new TransportPPLQueryResponse(responseContent)); } diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequestTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequestTest.java index 0e5d99ae35..286ac20fed 100644 --- a/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequestTest.java +++ b/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryRequestTest.java @@ -59,9 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void testCustomizedNullJSONContentActionRequestFromActionRequest() { - TransportPPLQueryRequest request = new TransportPPLQueryRequest( - "source=t a=1", null, null - ); + TransportPPLQueryRequest request = new TransportPPLQueryRequest("source=t a=1", null, null); ActionRequest actionRequest = new ActionRequest() { @Override diff --git a/spark/build.gradle b/spark/build.gradle index 89842e5ea8..2608b88ced 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -13,6 +13,9 @@ repositories { mavenCentral() } +checkstyleTest.ignoreFailures = true +checkstyleMain.ignoreFailures = true + dependencies { api project(':core') implementation project(':datasources') diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java index 1e2475c196..1a3304994b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/EmrClientImpl.java @@ -36,12 +36,16 @@ public class EmrClientImpl implements SparkClient { /** * Constructor for EMR Client Implementation. * - * @param emr EMR helper - * @param flint Opensearch args for flint integration jar + * @param emr EMR helper + * @param flint Opensearch args for flint integration jar * @param sparkResponse Response object to help with retrieving results from Opensearch index */ - public EmrClientImpl(AmazonElasticMapReduce emr, String emrCluster, FlintHelper flint, - SparkResponse sparkResponse, String sparkApplicationJar) { + public EmrClientImpl( + AmazonElasticMapReduce emr, + String emrCluster, + FlintHelper flint, + SparkResponse sparkResponse, + String sparkApplicationJar) { this.emr = emr; this.emrCluster = emrCluster; this.flint = flint; @@ -59,38 +63,39 @@ public JSONObject sql(String query) throws IOException { @VisibleForTesting void runEmrApplication(String query) { - HadoopJarStepConfig stepConfig = new HadoopJarStepConfig() - .withJar("command-runner.jar") - .withArgs("spark-submit", - "--class","org.opensearch.sql.SQLJob", - "--jars", - flint.getFlintIntegrationJar(), - sparkApplicationJar, - query, - SPARK_INDEX_NAME, - flint.getFlintHost(), - flint.getFlintPort(), - flint.getFlintScheme(), - flint.getFlintAuth(), - flint.getFlintRegion() - ); + HadoopJarStepConfig stepConfig = + new HadoopJarStepConfig() + .withJar("command-runner.jar") + .withArgs( + "spark-submit", + "--class", + "org.opensearch.sql.SQLJob", + "--jars", + flint.getFlintIntegrationJar(), + sparkApplicationJar, + query, + SPARK_INDEX_NAME, + flint.getFlintHost(), + flint.getFlintPort(), + flint.getFlintScheme(), + flint.getFlintAuth(), + flint.getFlintRegion()); - StepConfig emrstep = new StepConfig() - .withName("Spark Application") - .withActionOnFailure(ActionOnFailure.CONTINUE) - .withHadoopJarStep(stepConfig); + StepConfig emrstep = + new StepConfig() + .withName("Spark Application") + .withActionOnFailure(ActionOnFailure.CONTINUE) + .withHadoopJarStep(stepConfig); - AddJobFlowStepsRequest request = new AddJobFlowStepsRequest() - .withJobFlowId(emrCluster) - .withSteps(emrstep); + AddJobFlowStepsRequest request = + new AddJobFlowStepsRequest().withJobFlowId(emrCluster).withSteps(emrstep); AddJobFlowStepsResult result = emr.addJobFlowSteps(request); logger.info("EMR step ID: " + result.getStepIds()); String stepId = result.getStepIds().get(0); - DescribeStepRequest stepRequest = new DescribeStepRequest() - .withClusterId(emrCluster) - .withStepId(stepId); + DescribeStepRequest stepRequest = + new DescribeStepRequest().withClusterId(emrCluster).withStepId(stepId); waitForStepExecution(stepRequest); sparkResponse.setValue(stepId); @@ -117,5 +122,4 @@ private void waitForStepExecution(DescribeStepRequest stepRequest) { } } } - } diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java index 99d8600dd0..b38f04680b 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java @@ -8,15 +8,13 @@ import java.io.IOException; import org.json.JSONObject; -/** - * Interface class for Spark Client. - */ +/** Interface class for Spark Client. */ public interface SparkClient { /** * This method executes spark sql query. * * @param query spark sql query - * @return spark query response + * @return spark query response */ JSONObject sql(String query) throws IOException; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java index 1936c266de..914aa80085 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java @@ -24,9 +24,7 @@ import org.opensearch.sql.spark.storage.SparkTable; import org.opensearch.sql.storage.Table; -/** - * Spark SQL function implementation. - */ +/** Spark SQL function implementation. */ public class SparkSqlFunctionImplementation extends FunctionExpression implements TableFunctionImplementation { @@ -38,8 +36,8 @@ public class SparkSqlFunctionImplementation extends FunctionExpression * Constructor for spark sql function. * * @param functionName name of the function - * @param arguments a list of expressions - * @param sparkClient spark client + * @param arguments a list of expressions + * @param sparkClient spark client */ public SparkSqlFunctionImplementation( FunctionName functionName, List arguments, SparkClient sparkClient) { @@ -51,9 +49,11 @@ public SparkSqlFunctionImplementation( @Override public ExprValue valueOf(Environment valueEnv) { - throw new UnsupportedOperationException(String.format( - "Spark defined function [%s] is only " - + "supported in SOURCE clause with spark connector catalog", functionName)); + throw new UnsupportedOperationException( + String.format( + "Spark defined function [%s] is only " + + "supported in SOURCE clause with spark connector catalog", + functionName)); } @Override @@ -63,11 +63,15 @@ public ExprType type() { @Override public String toString() { - List args = arguments.stream() - .map(arg -> String.format("%s=%s", - ((NamedArgumentExpression) arg).getArgName(), - ((NamedArgumentExpression) arg).getValue().toString())) - .collect(Collectors.toList()); + List args = + arguments.stream() + .map( + arg -> + String.format( + "%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); return String.format("%s(%s)", functionName, String.join(", ", args)); } @@ -80,23 +84,23 @@ public Table applyArguments() { * This method builds a spark query request. * * @param arguments spark sql function arguments - * @return spark query request + * @return spark query request */ private SparkQueryRequest buildQueryFromSqlFunction(List arguments) { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - arguments.forEach(arg -> { - String argName = ((NamedArgumentExpression) arg).getArgName(); - Expression argValue = ((NamedArgumentExpression) arg).getValue(); - ExprValue literalValue = argValue.valueOf(); - if (argName.equals(QUERY)) { - sparkQueryRequest.setSql((String) literalValue.value()); - } else { - throw new ExpressionEvaluationException( - String.format("Invalid Function Argument:%s", argName)); - } - }); + arguments.forEach( + arg -> { + String argName = ((NamedArgumentExpression) arg).getArgName(); + Expression argValue = ((NamedArgumentExpression) arg).getValue(); + ExprValue literalValue = argValue.valueOf(); + if (argName.equals(QUERY)) { + sparkQueryRequest.setSql((String) literalValue.value()); + } else { + throw new ExpressionEvaluationException( + String.format("Invalid Function Argument:%s", argName)); + } + }); return sparkQueryRequest; } - } diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java index 624600e1a8..a4f2a6c0fe 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java @@ -22,9 +22,7 @@ import org.opensearch.sql.spark.client.SparkClient; import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; -/** - * Function resolver for sql function of spark connector. - */ +/** Function resolver for sql function of spark connector. */ @RequiredArgsConstructor public class SparkSqlTableFunctionResolver implements FunctionResolver { private final SparkClient sparkClient; @@ -35,35 +33,44 @@ public class SparkSqlTableFunctionResolver implements FunctionResolver { @Override public Pair resolve(FunctionSignature unresolvedSignature) { FunctionName functionName = FunctionName.of(SQL); - FunctionSignature functionSignature = - new FunctionSignature(functionName, List.of(STRING)); + FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); final List argumentNames = List.of(QUERY); - FunctionBuilder functionBuilder = (functionProperties, arguments) -> { - Boolean argumentsPassedByName = arguments.stream() - .noneMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); - Boolean argumentsPassedByPosition = arguments.stream() - .allMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); - if (!(argumentsPassedByName || argumentsPassedByPosition)) { - throw new SemanticCheckException("Arguments should be either passed by name or position"); - } + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> { + Boolean argumentsPassedByName = + arguments.stream() + .noneMatch( + arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + Boolean argumentsPassedByPosition = + arguments.stream() + .allMatch( + arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + if (!(argumentsPassedByName || argumentsPassedByPosition)) { + throw new SemanticCheckException( + "Arguments should be either passed by name or position"); + } - if (arguments.size() != argumentNames.size()) { - throw new SemanticCheckException( - String.format("Missing arguments:[%s]", - String.join(",", argumentNames.subList(arguments.size(), argumentNames.size())))); - } + if (arguments.size() != argumentNames.size()) { + throw new SemanticCheckException( + String.format( + "Missing arguments:[%s]", + String.join( + ",", argumentNames.subList(arguments.size(), argumentNames.size())))); + } - if (argumentsPassedByPosition) { - List namedArguments = new ArrayList<>(); - for (int i = 0; i < arguments.size(); i++) { - namedArguments.add(new NamedArgumentExpression(argumentNames.get(i), - ((NamedArgumentExpression) arguments.get(i)).getValue())); - } - return new SparkSqlFunctionImplementation(functionName, namedArguments, sparkClient); - } - return new SparkSqlFunctionImplementation(functionName, arguments, sparkClient); - }; + if (argumentsPassedByPosition) { + List namedArguments = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { + namedArguments.add( + new NamedArgumentExpression( + argumentNames.get(i), + ((NamedArgumentExpression) arguments.get(i)).getValue())); + } + return new SparkSqlFunctionImplementation(functionName, namedArguments, sparkClient); + } + return new SparkSqlFunctionImplementation(functionName, arguments, sparkClient); + }; return Pair.of(functionSignature, functionBuilder); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java index cb2b31ddc1..823ad2da29 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -29,9 +29,7 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.executor.ExecutionEngine; -/** - * Default implementation of SparkSqlFunctionResponseHandle. - */ +/** Default implementation of SparkSqlFunctionResponseHandle. */ public class DefaultSparkSqlFunctionResponseHandle implements SparkSqlFunctionResponseHandle { private Iterator responseIterator; private ExecutionEngine.Schema schema; @@ -54,8 +52,8 @@ private void constructIteratorAndSchema(JSONObject responseObject) { logger.info("Spark Application ID: " + items.getString("applicationId")); columnList = getColumnList(items.getJSONArray("schema")); for (int i = 0; i < items.getJSONArray("result").length(); i++) { - JSONObject row = new JSONObject( - items.getJSONArray("result").get(i).toString().replace("'", "\"")); + JSONObject row = + new JSONObject(items.getJSONArray("result").get(i).toString().replace("'", "\"")); LinkedHashMap linkedHashMap = extractRow(row, columnList); result.add(new ExprTupleValue(linkedHashMap)); } @@ -85,8 +83,8 @@ private static LinkedHashMap extractRow( } else if (type == ExprCoreType.DATE) { linkedHashMap.put(column.getName(), new ExprDateValue(row.getString(column.getName()))); } else if (type == ExprCoreType.TIMESTAMP) { - linkedHashMap.put(column.getName(), - new ExprTimestampValue(row.getString(column.getName()))); + linkedHashMap.put( + column.getName(), new ExprTimestampValue(row.getString(column.getName()))); } else if (type == ExprCoreType.STRING) { linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); } else { @@ -101,10 +99,11 @@ private List getColumnList(JSONArray schema) { List columnList = new ArrayList<>(); for (int i = 0; i < schema.length(); i++) { JSONObject column = new JSONObject(schema.get(i).toString().replace("'", "\"")); - columnList.add(new ExecutionEngine.Schema.Column( - column.get("column_name").toString(), - column.get("column_name").toString(), - getDataType(column.get("data_type").toString()))); + columnList.add( + new ExecutionEngine.Schema.Column( + column.get("column_name").toString(), + column.get("column_name").toString(), + getDataType(column.get("data_type").toString()))); } return columnList; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java index da68b591eb..a9be484712 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/SparkSqlFunctionResponseHandle.java @@ -8,24 +8,18 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionEngine; -/** - * Handle Spark response. - */ +/** Handle Spark response. */ public interface SparkSqlFunctionResponseHandle { - /** - * Return true if Spark response has more result. - */ + /** Return true if Spark response has more result. */ boolean hasNext(); /** - * Return Spark response as {@link ExprValue}. Attention, the method must been called when - * hasNext return true. + * Return Spark response as {@link ExprValue}. Attention, the method must been called when hasNext + * return true. */ ExprValue next(); - /** - * Return ExecutionEngine.Schema of the Spark response. - */ + /** Return ExecutionEngine.Schema of the Spark response. */ ExecutionEngine.Schema schema(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java index 28ce7dd19a..aea8f72f36 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java @@ -12,9 +12,7 @@ import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.read.TableScanBuilder; -/** - * TableScanBuilder for sql function of spark connector. - */ +/** TableScanBuilder for sql function of spark connector. */ @AllArgsConstructor public class SparkSqlFunctionTableScanBuilder extends TableScanBuilder { diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java index 85e854e422..a2e44affd5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanOperator.java @@ -21,9 +21,7 @@ import org.opensearch.sql.spark.request.SparkQueryRequest; import org.opensearch.sql.storage.TableScanOperator; -/** - * This a table scan operator to handle sql table function. - */ +/** This a table scan operator to handle sql table function. */ @RequiredArgsConstructor public class SparkSqlFunctionTableScanOperator extends TableScanOperator { private final SparkClient sparkClient; @@ -34,17 +32,19 @@ public class SparkSqlFunctionTableScanOperator extends TableScanOperator { @Override public void open() { super.open(); - this.sparkResponseHandle = AccessController.doPrivileged( - (PrivilegedAction) () -> { - try { - JSONObject responseObject = sparkClient.sql(request.getSql()); - return new DefaultSparkSqlFunctionResponseHandle(responseObject); - } catch (IOException e) { - LOG.error(e.getMessage()); - throw new RuntimeException( - String.format("Error fetching data from spark server: %s", e.getMessage())); - } - }); + this.sparkResponseHandle = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + JSONObject responseObject = sparkClient.sql(request.getSql()); + return new DefaultSparkSqlFunctionResponseHandle(responseObject); + } catch (IOException e) { + LOG.error(e.getMessage()); + throw new RuntimeException( + String.format("Error fetching data from spark server: %s", e.getMessage())); + } + }); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java index b3c3c0871a..10d880187f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java +++ b/spark/src/main/java/org/opensearch/sql/spark/helper/FlintHelper.java @@ -15,25 +15,20 @@ import lombok.Getter; public class FlintHelper { - @Getter - private final String flintIntegrationJar; - @Getter - private final String flintHost; - @Getter - private final String flintPort; - @Getter - private final String flintScheme; - @Getter - private final String flintAuth; - @Getter - private final String flintRegion; + @Getter private final String flintIntegrationJar; + @Getter private final String flintHost; + @Getter private final String flintPort; + @Getter private final String flintScheme; + @Getter private final String flintAuth; + @Getter private final String flintRegion; - /** Arguments required to write data to opensearch index using flint integration. + /** + * Arguments required to write data to opensearch index using flint integration. * - * @param flintHost Opensearch host for flint - * @param flintPort Opensearch port for flint integration + * @param flintHost Opensearch host for flint + * @param flintPort Opensearch port for flint integration * @param flintScheme Opensearch scheme for flint integration - * @param flintAuth Opensearch auth for flint integration + * @param flintAuth Opensearch auth for flint integration * @param flintRegion Opensearch region for flint integration */ public FlintHelper( diff --git a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java index bc0944a784..94c9795161 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java @@ -7,15 +7,10 @@ import lombok.Data; -/** - * Spark query request. - */ +/** Spark query request. */ @Data public class SparkQueryRequest { - /** - * SQL. - */ + /** SQL. */ private String sql; - } diff --git a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java index f30072eb3f..3edb541384 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java +++ b/spark/src/main/java/org/opensearch/sql/spark/response/SparkResponse.java @@ -36,8 +36,8 @@ public class SparkResponse { * Response for spark sql query. * * @param client Opensearch client - * @param value Identifier field value - * @param field Identifier field name + * @param value Identifier field value + * @param field Identifier field name */ public SparkResponse(Client client, String value, String field) { this.client = client; @@ -64,8 +64,10 @@ private JSONObject searchInSparkIndex(QueryBuilder query) { SearchResponse searchResponse = searchResponseActionFuture.actionGet(); if (searchResponse.status().getStatus() != 200) { throw new RuntimeException( - "Fetching result from " + SPARK_INDEX_NAME + " index failed with status : " - + searchResponse.status()); + "Fetching result from " + + SPARK_INDEX_NAME + + " index failed with status : " + + searchResponse.status()); } else { JSONObject data = new JSONObject(); for (SearchHit searchHit : searchResponse.getHits().getHits()) { @@ -90,11 +92,11 @@ void deleteInSparkIndex(String id) { if (deleteResponse.getResult().equals(DocWriteResponse.Result.DELETED)) { LOG.debug("Spark result successfully deleted ", id); } else if (deleteResponse.getResult().equals(DocWriteResponse.Result.NOT_FOUND)) { - throw new ResourceNotFoundException("Spark result with id " - + id + " doesn't exist"); + throw new ResourceNotFoundException("Spark result with id " + id + " doesn't exist"); } else { - throw new RuntimeException("Deleting spark result information failed with : " - + deleteResponse.getResult().getLowercase()); + throw new RuntimeException( + "Deleting spark result information failed with : " + + deleteResponse.getResult().getLowercase()); } } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java index 3897e8690e..395e1685a6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkScan.java @@ -14,21 +14,14 @@ import org.opensearch.sql.spark.request.SparkQueryRequest; import org.opensearch.sql.storage.TableScanOperator; -/** - * Spark scan operator. - */ +/** Spark scan operator. */ @EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) @ToString(onlyExplicitlyIncluded = true) public class SparkScan extends TableScanOperator { private final SparkClient sparkClient; - @EqualsAndHashCode.Include - @Getter - @Setter - @ToString.Include - private SparkQueryRequest request; - + @EqualsAndHashCode.Include @Getter @Setter @ToString.Include private SparkQueryRequest request; /** * Constructor. @@ -54,5 +47,4 @@ public ExprValue next() { public String explain() { return getRequest().toString(); } - } diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java index a5e35ecc4c..84c9c05e79 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java @@ -15,17 +15,14 @@ import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; -/** - * Spark storage engine implementation. - */ +/** Spark storage engine implementation. */ @RequiredArgsConstructor public class SparkStorageEngine implements StorageEngine { private final SparkClient sparkClient; @Override public Collection getFunctions() { - return Collections.singletonList( - new SparkSqlTableFunctionResolver(sparkClient)); + return Collections.singletonList(new SparkSqlTableFunctionResolver(sparkClient)); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java index 937679b50e..467bacbaea 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -30,9 +30,7 @@ import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.sql.storage.StorageEngine; -/** - * Storage factory implementation for spark connector. - */ +/** Storage factory implementation for spark connector. */ @RequiredArgsConstructor public class SparkStorageFactory implements DataSourceFactory { private final Client client; @@ -66,9 +64,7 @@ public DataSourceType getDataSourceType() { @Override public DataSource createDataSource(DataSourceMetadata metadata) { return new DataSource( - metadata.getName(), - DataSourceType.SPARK, - getStorageEngine(metadata.getProperties())); + metadata.getName(), DataSourceType.SPARK, getStorageEngine(metadata.getProperties())); } /** @@ -81,24 +77,26 @@ StorageEngine getStorageEngine(Map requiredConfig) { SparkClient sparkClient; if (requiredConfig.get(CONNECTOR_TYPE).equals(EMR)) { sparkClient = - AccessController.doPrivileged((PrivilegedAction) () -> { - validateEMRConfigProperties(requiredConfig); - return new EmrClientImpl( - getEMRClient( - requiredConfig.get(EMR_ACCESS_KEY), - requiredConfig.get(EMR_SECRET_KEY), - requiredConfig.get(EMR_REGION)), - requiredConfig.get(EMR_CLUSTER), - new FlintHelper( - requiredConfig.get(FLINT_INTEGRATION), - requiredConfig.get(FLINT_HOST), - requiredConfig.get(FLINT_PORT), - requiredConfig.get(FLINT_SCHEME), - requiredConfig.get(FLINT_AUTH), - requiredConfig.get(FLINT_REGION)), - new SparkResponse(client, null, STEP_ID_FIELD), - requiredConfig.get(SPARK_SQL_APPLICATION)); - }); + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + validateEMRConfigProperties(requiredConfig); + return new EmrClientImpl( + getEMRClient( + requiredConfig.get(EMR_ACCESS_KEY), + requiredConfig.get(EMR_SECRET_KEY), + requiredConfig.get(EMR_REGION)), + requiredConfig.get(EMR_CLUSTER), + new FlintHelper( + requiredConfig.get(FLINT_INTEGRATION), + requiredConfig.get(FLINT_HOST), + requiredConfig.get(FLINT_PORT), + requiredConfig.get(FLINT_SCHEME), + requiredConfig.get(FLINT_AUTH), + requiredConfig.get(FLINT_REGION)), + new SparkResponse(client, null, STEP_ID_FIELD), + requiredConfig.get(SPARK_SQL_APPLICATION)); + }); } else { throw new InvalidParameterException("Spark connector type is invalid."); } @@ -110,12 +108,14 @@ private void validateEMRConfigProperties(Map dataSourceMetadataC if (dataSourceMetadataConfig.get(EMR_CLUSTER) == null || dataSourceMetadataConfig.get(EMR_AUTH_TYPE) == null) { throw new IllegalArgumentException("EMR config properties are missing."); - } else if (dataSourceMetadataConfig.get(EMR_AUTH_TYPE) - .equals(AuthenticationType.AWSSIGV4AUTH.getName()) + } else if (dataSourceMetadataConfig + .get(EMR_AUTH_TYPE) + .equals(AuthenticationType.AWSSIGV4AUTH.getName()) && (dataSourceMetadataConfig.get(EMR_ACCESS_KEY) == null - || dataSourceMetadataConfig.get(EMR_SECRET_KEY) == null)) { + || dataSourceMetadataConfig.get(EMR_SECRET_KEY) == null)) { throw new IllegalArgumentException("EMR auth keys are missing."); - } else if (!dataSourceMetadataConfig.get(EMR_AUTH_TYPE) + } else if (!dataSourceMetadataConfig + .get(EMR_AUTH_TYPE) .equals(AuthenticationType.AWSSIGV4AUTH.getName())) { throw new IllegalArgumentException("Invalid auth type."); } @@ -124,8 +124,8 @@ private void validateEMRConfigProperties(Map dataSourceMetadataC private AmazonElasticMapReduce getEMRClient( String emrAccessKey, String emrSecretKey, String emrRegion) { return AmazonElasticMapReduceClientBuilder.standard() - .withCredentials(new AWSStaticCredentialsProvider( - new BasicAWSCredentials(emrAccessKey, emrSecretKey))) + .withCredentials( + new AWSStaticCredentialsProvider(new BasicAWSCredentials(emrAccessKey, emrSecretKey))) .withRegion(emrRegion) .build(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java index 5151405db9..731c3df672 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java @@ -18,20 +18,14 @@ import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.read.TableScanBuilder; -/** - * Spark table implementation. - * This can be constructed from SparkQueryRequest. - */ +/** Spark table implementation. This can be constructed from SparkQueryRequest. */ public class SparkTable implements Table { private final SparkClient sparkClient; - @Getter - private final SparkQueryRequest sparkQueryRequest; + @Getter private final SparkQueryRequest sparkQueryRequest; - /** - * Constructor for entire Sql Request. - */ + /** Constructor for entire Sql Request. */ public SparkTable(SparkClient sparkService, SparkQueryRequest sparkQueryRequest) { this.sparkClient = sparkService; this.sparkQueryRequest = sparkQueryRequest; @@ -56,8 +50,7 @@ public Map getFieldTypes() { @Override public PhysicalPlan implement(LogicalPlan plan) { - SparkScan metricScan = - new SparkScan(sparkClient); + SparkScan metricScan = new SparkScan(sparkClient); metricScan.setRequest(sparkQueryRequest); return plan.accept(new DefaultImplementor(), metricScan); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java index a94ac01f2f..93dc0d6bc8 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/client/EmrClientImplTest.java @@ -29,12 +29,9 @@ @ExtendWith(MockitoExtension.class) public class EmrClientImplTest { - @Mock - private AmazonElasticMapReduce emr; - @Mock - private FlintHelper flint; - @Mock - private SparkResponse sparkResponse; + @Mock private AmazonElasticMapReduce emr; + @Mock private FlintHelper flint; + @Mock private SparkResponse sparkResponse; @Test @SneakyThrows @@ -50,8 +47,8 @@ void testRunEmrApplication() { describeStepResult.setStep(step); when(emr.describeStep(any())).thenReturn(describeStepResult); - EmrClientImpl emrClientImpl = new EmrClientImpl( - emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); emrClientImpl.runEmrApplication(QUERY); } @@ -69,12 +66,12 @@ void testRunEmrApplicationFailed() { describeStepResult.setStep(step); when(emr.describeStep(any())).thenReturn(describeStepResult); - EmrClientImpl emrClientImpl = new EmrClientImpl( - emr, EMR_CLUSTER_ID, flint, sparkResponse, null); - RuntimeException exception = Assertions.assertThrows(RuntimeException.class, - () -> emrClientImpl.runEmrApplication(QUERY)); - Assertions.assertEquals("Spark SQL application failed.", - exception.getMessage()); + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); } @Test @@ -91,12 +88,12 @@ void testRunEmrApplicationCancelled() { describeStepResult.setStep(step); when(emr.describeStep(any())).thenReturn(describeStepResult); - EmrClientImpl emrClientImpl = new EmrClientImpl( - emr, EMR_CLUSTER_ID, flint, sparkResponse, null); - RuntimeException exception = Assertions.assertThrows(RuntimeException.class, - () -> emrClientImpl.runEmrApplication(QUERY)); - Assertions.assertEquals("Spark SQL application failed.", - exception.getMessage()); + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, () -> emrClientImpl.runEmrApplication(QUERY)); + Assertions.assertEquals("Spark SQL application failed.", exception.getMessage()); } @Test @@ -119,11 +116,12 @@ void testRunEmrApplicationRunnning() { DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); completedDescribeStepResult.setStep(completedStep); - when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + when(emr.describeStep(any())) + .thenReturn(runningDescribeStepResult) .thenReturn(completedDescribeStepResult); - EmrClientImpl emrClientImpl = new EmrClientImpl( - emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); emrClientImpl.runEmrApplication(QUERY); } @@ -147,14 +145,14 @@ void testSql() { DescribeStepResult completedDescribeStepResult = new DescribeStepResult(); completedDescribeStepResult.setStep(completedStep); - when(emr.describeStep(any())).thenReturn(runningDescribeStepResult) + when(emr.describeStep(any())) + .thenReturn(runningDescribeStepResult) .thenReturn(completedDescribeStepResult); when(sparkResponse.getResultFromOpensearchIndex()) .thenReturn(new JSONObject(getJson("select_query_response.json"))); - EmrClientImpl emrClientImpl = new EmrClientImpl( - emr, EMR_CLUSTER_ID, flint, sparkResponse, null); + EmrClientImpl emrClientImpl = + new EmrClientImpl(emr, EMR_CLUSTER_ID, flint, sparkResponse, null); emrClientImpl.sql(QUERY); - } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java index 18db5b9471..120747e0d3 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -27,51 +27,52 @@ @ExtendWith(MockitoExtension.class) public class SparkSqlFunctionImplementationTest { - @Mock - private SparkClient client; + @Mock private SparkClient client; @Test void testValueOfAndTypeToString() { FunctionName functionName = new FunctionName("sql"); - List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); - SparkSqlFunctionImplementation sparkSqlFunctionImplementation - = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); - UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, - () -> sparkSqlFunctionImplementation.valueOf()); - assertEquals("Spark defined function [sql] is only " - + "supported in SOURCE clause with spark connector catalog", exception.getMessage()); - assertEquals("sql(query=\"select 1\")", - sparkSqlFunctionImplementation.toString()); + List namedArgumentExpressionList = + List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + UnsupportedOperationException exception = + assertThrows( + UnsupportedOperationException.class, () -> sparkSqlFunctionImplementation.valueOf()); + assertEquals( + "Spark defined function [sql] is only " + + "supported in SOURCE clause with spark connector catalog", + exception.getMessage()); + assertEquals("sql(query=\"select 1\")", sparkSqlFunctionImplementation.toString()); assertEquals(ExprCoreType.STRUCT, sparkSqlFunctionImplementation.type()); } @Test void testApplyArguments() { FunctionName functionName = new FunctionName("sql"); - List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); - SparkSqlFunctionImplementation sparkSqlFunctionImplementation - = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); - SparkTable sparkTable - = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); + List namedArgumentExpressionList = + List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + SparkTable sparkTable = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); - SparkQueryRequest sparkQueryRequest - = sparkTable.getSparkQueryRequest(); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test void testApplyArgumentsException() { FunctionName functionName = new FunctionName("sql"); - List namedArgumentExpressionList - = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), - DSL.namedArgument("tmp", DSL.literal(12345))); - SparkSqlFunctionImplementation sparkSqlFunctionImplementation - = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); - ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, - () -> sparkSqlFunctionImplementation.applyArguments()); + List namedArgumentExpressionList = + List.of( + DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument("tmp", DSL.literal(12345))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation = + new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> sparkSqlFunctionImplementation.applyArguments()); assertEquals("Invalid Function Argument:tmp", exception.getMessage()); } - } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java index 94c87602b7..212056eb15 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -18,23 +18,20 @@ import org.opensearch.sql.storage.TableScanOperator; public class SparkSqlFunctionTableScanBuilderTest { - @Mock - private SparkClient sparkClient; + @Mock private SparkClient sparkClient; - @Mock - private LogicalProject logicalProject; + @Mock private LogicalProject logicalProject; @Test void testBuild() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder - = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); - TableScanOperator sqlFunctionTableScanOperator - = sparkSqlFunctionTableScanBuilder.build(); - Assertions.assertTrue(sqlFunctionTableScanOperator - instanceof SparkSqlFunctionTableScanOperator); + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = + new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + TableScanOperator sqlFunctionTableScanOperator = sparkSqlFunctionTableScanBuilder.build(); + Assertions.assertTrue( + sqlFunctionTableScanOperator instanceof SparkSqlFunctionTableScanOperator); } @Test @@ -42,8 +39,8 @@ void testPushProject() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder - = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder = + new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); Assertions.assertTrue(sparkSqlFunctionTableScanBuilder.pushDownProject(logicalProject)); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java index f6807f9913..586f0ef2d8 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -43,8 +43,7 @@ @ExtendWith(MockitoExtension.class) public class SparkSqlFunctionTableScanOperatorTest { - @Mock - private SparkClient sparkClient; + @Mock private SparkClient sparkClient; @Test @SneakyThrows @@ -52,15 +51,14 @@ void testEmptyQueryWithException() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())) - .thenThrow(new IOException("Error Message")); - RuntimeException runtimeException - = assertThrows(RuntimeException.class, sparkSqlFunctionTableScanOperator::open); - assertEquals("Error fetching data from spark server: Error Message", - runtimeException.getMessage()); + when(sparkClient.sql(any())).thenThrow(new IOException("Error Message")); + RuntimeException runtimeException = + assertThrows(RuntimeException.class, sparkSqlFunctionTableScanOperator::open); + assertEquals( + "Error fetching data from spark server: Error Message", runtimeException.getMessage()); } @Test @@ -69,8 +67,8 @@ void testClose() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); sparkSqlFunctionTableScanOperator.close(); } @@ -80,11 +78,10 @@ void testExplain() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - Assertions.assertEquals("sql(select 1)", - sparkSqlFunctionTableScanOperator.explain()); + Assertions.assertEquals("sql(select 1)", sparkSqlFunctionTableScanOperator.explain()); } @Test @@ -93,18 +90,19 @@ void testQueryResponseIterator() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())) - .thenReturn(new JSONObject(getJson("select_query_response.json"))); + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("select_query_response.json"))); sparkSqlFunctionTableScanOperator.open(); assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); - ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { - { - put("1", new ExprIntegerValue(1)); - } - }); + ExprTupleValue firstRow = + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("1", new ExprIntegerValue(1)); + } + }); assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); } @@ -115,28 +113,29 @@ void testQueryResponseAllTypes() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())) - .thenReturn(new JSONObject(getJson("all_data_type.json"))); + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("all_data_type.json"))); sparkSqlFunctionTableScanOperator.open(); assertTrue(sparkSqlFunctionTableScanOperator.hasNext()); - ExprTupleValue firstRow = new ExprTupleValue(new LinkedHashMap<>() { - { - put("boolean", ExprBooleanValue.of(true)); - put("long", new ExprLongValue(922337203)); - put("integer", new ExprIntegerValue(2147483647)); - put("short", new ExprShortValue(32767)); - put("byte", new ExprByteValue(127)); - put("double", new ExprDoubleValue(9223372036854.775807)); - put("float", new ExprFloatValue(21474.83647)); - put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); - put("date", new ExprTimestampValue("2023-07-01 10:31:30")); - put("string", new ExprStringValue("ABC")); - put("char", new ExprStringValue("A")); - } - }); + ExprTupleValue firstRow = + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("boolean", ExprBooleanValue.of(true)); + put("long", new ExprLongValue(922337203)); + put("integer", new ExprIntegerValue(2147483647)); + put("short", new ExprShortValue(32767)); + put("byte", new ExprByteValue(127)); + put("double", new ExprDoubleValue(9223372036854.775807)); + put("float", new ExprFloatValue(21474.83647)); + put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); + put("date", new ExprTimestampValue("2023-07-01 10:31:30")); + put("string", new ExprStringValue("ABC")); + put("char", new ExprStringValue("A")); + } + }); assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); } @@ -147,16 +146,15 @@ void testQueryResponseInvalidDataType() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())) - .thenReturn(new JSONObject(getJson("invalid_data_type.json"))); + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("invalid_data_type.json"))); - RuntimeException exception = Assertions.assertThrows(RuntimeException.class, - () -> sparkSqlFunctionTableScanOperator.open()); - Assertions.assertEquals("Result contains invalid data type", - exception.getMessage()); + RuntimeException exception = + Assertions.assertThrows( + RuntimeException.class, () -> sparkSqlFunctionTableScanOperator.open()); + Assertions.assertEquals("Result contains invalid data type", exception.getMessage()); } @Test @@ -165,17 +163,14 @@ void testQuerySchema() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator - = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())) - .thenReturn( - new JSONObject(getJson("select_query_response.json"))); + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("select_query_response.json"))); sparkSqlFunctionTableScanOperator.open(); ArrayList columns = new ArrayList<>(); columns.add(new ExecutionEngine.Schema.Column("1", "1", ExprCoreType.INTEGER)); ExecutionEngine.Schema expectedSchema = new ExecutionEngine.Schema(columns); assertEquals(expectedSchema, sparkSqlFunctionTableScanOperator.schema()); } - } diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java index e18fac36de..a828ac76c4 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -35,107 +35,106 @@ @ExtendWith(MockitoExtension.class) public class SparkSqlTableFunctionResolverTest { - @Mock - private SparkClient client; + @Mock private SparkClient client; - @Mock - private FunctionProperties functionProperties; + @Mock private FunctionProperties functionProperties; @Test void testResolve() { - SparkSqlTableFunctionResolver sqlTableFunctionResolver - = new SparkSqlTableFunctionResolver(client); + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions - = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); - FunctionSignature functionSignature = new FunctionSignature(functionName, expressions - .stream().map(Expression::type).collect(Collectors.toList())); - Pair resolution - = sqlTableFunctionResolver.resolve(functionSignature); + List expressions = List.of(DSL.namedArgument("query", DSL.literal(QUERY))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); assertEquals(functionName, resolution.getKey().getFunctionName()); assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); - TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + TableFunctionImplementation functionImplementation = + (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); - SparkTable sparkTable - = (SparkTable) functionImplementation.applyArguments(); + SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); - SparkQueryRequest sparkQueryRequest = - sparkTable.getSparkQueryRequest(); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test void testArgumentsPassedByPosition() { - SparkSqlTableFunctionResolver sqlTableFunctionResolver - = new SparkSqlTableFunctionResolver(client); + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions - = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); - FunctionSignature functionSignature = new FunctionSignature(functionName, expressions - .stream().map(Expression::type).collect(Collectors.toList())); + List expressions = List.of(DSL.namedArgument(null, DSL.literal(QUERY))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); - Pair resolution - = sqlTableFunctionResolver.resolve(functionSignature); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); assertEquals(functionName, resolution.getKey().getFunctionName()); assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); - TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + TableFunctionImplementation functionImplementation = + (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); - SparkTable sparkTable - = (SparkTable) functionImplementation.applyArguments(); + SparkTable sparkTable = (SparkTable) functionImplementation.applyArguments(); assertNotNull(sparkTable.getSparkQueryRequest()); - SparkQueryRequest sparkQueryRequest = - sparkTable.getSparkQueryRequest(); + SparkQueryRequest sparkQueryRequest = sparkTable.getSparkQueryRequest(); assertEquals(QUERY, sparkQueryRequest.getSql()); } @Test void testMixedArgumentTypes() { - SparkSqlTableFunctionResolver sqlTableFunctionResolver - = new SparkSqlTableFunctionResolver(client); + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions - = List.of(DSL.namedArgument("query", DSL.literal(QUERY)), - DSL.namedArgument(null, DSL.literal(12345))); - FunctionSignature functionSignature = new FunctionSignature(functionName, expressions - .stream().map(Expression::type).collect(Collectors.toList())); - Pair resolution - = sqlTableFunctionResolver.resolve(functionSignature); + List expressions = + List.of( + DSL.namedArgument("query", DSL.literal(QUERY)), + DSL.namedArgument(null, DSL.literal(12345))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); assertEquals(functionName, resolution.getKey().getFunctionName()); assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); - SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(functionProperties, expressions)); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Arguments should be either passed by name or position", exception.getMessage()); } @Test void testWrongArgumentsSizeWhenPassedByName() { - SparkSqlTableFunctionResolver sqlTableFunctionResolver - = new SparkSqlTableFunctionResolver(client); + SparkSqlTableFunctionResolver sqlTableFunctionResolver = + new SparkSqlTableFunctionResolver(client); FunctionName functionName = FunctionName.of("sql"); - List expressions - = List.of(); - FunctionSignature functionSignature = new FunctionSignature(functionName, expressions - .stream().map(Expression::type).collect(Collectors.toList())); - Pair resolution - = sqlTableFunctionResolver.resolve(functionSignature); + List expressions = List.of(); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution = + sqlTableFunctionResolver.resolve(functionSignature); assertEquals(functionName, resolution.getKey().getFunctionName()); assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); - SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(functionProperties, expressions)); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Missing arguments:[query]", exception.getMessage()); } - } diff --git a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java index abc4c81626..211561ac72 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/response/SparkResponseTest.java @@ -32,18 +32,12 @@ @ExtendWith(MockitoExtension.class) public class SparkResponseTest { - @Mock - private Client client; - @Mock - private SearchResponse searchResponse; - @Mock - private DeleteResponse deleteResponse; - @Mock - private SearchHit searchHit; - @Mock - private ActionFuture searchResponseActionFuture; - @Mock - private ActionFuture deleteResponseActionFuture; + @Mock private Client client; + @Mock private SearchResponse searchResponse; + @Mock private DeleteResponse deleteResponse; + @Mock private SearchHit searchHit; + @Mock private ActionFuture searchResponseActionFuture; + @Mock private ActionFuture deleteResponseActionFuture; @Test public void testGetResultFromOpensearchIndex() { @@ -53,12 +47,8 @@ public void testGetResultFromOpensearchIndex() { when(searchResponse.getHits()) .thenReturn( new SearchHits( - new SearchHit[] {searchHit}, - new TotalHits(1, TotalHits.Relation.EQUAL_TO), - 1.0F)); - Mockito.when(searchHit.getSourceAsMap()) - .thenReturn(Map.of("stepId", EMR_CLUSTER_ID)); - + new SearchHit[] {searchHit}, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); + Mockito.when(searchHit.getSourceAsMap()).thenReturn(Map.of("stepId", EMR_CLUSTER_ID)); when(client.delete(any())).thenReturn(deleteResponseActionFuture); when(deleteResponseActionFuture.actionGet()).thenReturn(deleteResponse); @@ -75,11 +65,13 @@ public void testInvalidSearchResponse() { when(searchResponse.status()).thenReturn(RestStatus.NO_CONTENT); SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - RuntimeException exception = assertThrows(RuntimeException.class, - () -> sparkResponse.getResultFromOpensearchIndex()); + RuntimeException exception = + assertThrows(RuntimeException.class, () -> sparkResponse.getResultFromOpensearchIndex()); Assertions.assertEquals( - "Fetching result from " + SPARK_INDEX_NAME - + " index failed with status : " + RestStatus.NO_CONTENT, + "Fetching result from " + + SPARK_INDEX_NAME + + " index failed with status : " + + RestStatus.NO_CONTENT, exception.getMessage()); } @@ -104,8 +96,9 @@ public void testNotFoundDeleteResponse() { when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOT_FOUND); SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - RuntimeException exception = assertThrows(ResourceNotFoundException.class, - () -> sparkResponse.deleteInSparkIndex("123")); + RuntimeException exception = + assertThrows( + ResourceNotFoundException.class, () -> sparkResponse.deleteInSparkIndex("123")); Assertions.assertEquals("Spark result with id 123 doesn't exist", exception.getMessage()); } @@ -116,8 +109,8 @@ public void testInvalidDeleteResponse() { when(deleteResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); SparkResponse sparkResponse = new SparkResponse(client, EMR_CLUSTER_ID, "stepId"); - RuntimeException exception = assertThrows(RuntimeException.class, - () -> sparkResponse.deleteInSparkIndex("123")); + RuntimeException exception = + assertThrows(RuntimeException.class, () -> sparkResponse.deleteInSparkIndex("123")); Assertions.assertEquals( "Deleting spark result information failed with : noop", exception.getMessage()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java index c57142f580..971db3c33c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkScanTest.java @@ -19,8 +19,7 @@ @ExtendWith(MockitoExtension.class) public class SparkScanTest { - @Mock - private SparkClient sparkClient; + @Mock private SparkClient sparkClient; @Test @SneakyThrows @@ -36,8 +35,6 @@ void testQueryResponseIteratorForQueryRangeFunction() { void testExplain() { SparkScan sparkScan = new SparkScan(sparkClient); sparkScan.getRequest().setSql(QUERY); - assertEquals( - "SparkQueryRequest(sql=select 1)", - sparkScan.explain()); + assertEquals("SparkQueryRequest(sql=select 1)", sparkScan.explain()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java index d42e123678..5e7ec76cdb 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java @@ -22,14 +22,12 @@ @ExtendWith(MockitoExtension.class) public class SparkStorageEngineTest { - @Mock - private SparkClient client; + @Mock private SparkClient client; @Test public void getFunctions() { SparkStorageEngine engine = new SparkStorageEngine(client); - Collection functionResolverCollection - = engine.getFunctions(); + Collection functionResolverCollection = engine.getFunctions(); assertNotNull(functionResolverCollection); assertEquals(1, functionResolverCollection.size()); assertTrue( @@ -39,8 +37,10 @@ public void getFunctions() { @Test public void getTable() { SparkStorageEngine engine = new SparkStorageEngine(client); - RuntimeException exception = assertThrows(RuntimeException.class, - () -> engine.getTable(new DataSourceSchemaName("spark", "default"), "")); + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> engine.getTable(new DataSourceSchemaName("spark", "default"), "")); assertEquals("Unable to get table from storage engine.", exception.getMessage()); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java index c68adf2039..eb93cdabfe 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java @@ -24,17 +24,14 @@ @ExtendWith(MockitoExtension.class) public class SparkStorageFactoryTest { - @Mock - private Settings settings; + @Mock private Settings settings; - @Mock - private Client client; + @Mock private Client client; @Test void testGetConnectorType() { SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - Assertions.assertEquals( - DataSourceType.SPARK, sparkStorageFactory.getDataSourceType()); + Assertions.assertEquals(DataSourceType.SPARK, sparkStorageFactory.getDataSourceType()); } @Test @@ -48,8 +45,7 @@ void testGetStorageEngine() { properties.put("emr.auth.secret_key", "secret_key"); properties.put("emr.auth.region", "region"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - StorageEngine storageEngine - = sparkStorageFactory.getStorageEngine(properties); + StorageEngine storageEngine = sparkStorageFactory.getStorageEngine(properties); Assertions.assertTrue(storageEngine instanceof SparkStorageEngine); } @@ -59,10 +55,11 @@ void testInvalidConnectorType() { HashMap properties = new HashMap<>(); properties.put("spark.connector", "random"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - InvalidParameterException exception = Assertions.assertThrows(InvalidParameterException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("Spark connector type is invalid.", - exception.getMessage()); + InvalidParameterException exception = + Assertions.assertThrows( + InvalidParameterException.class, + () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Spark connector type is invalid.", exception.getMessage()); } @Test @@ -72,10 +69,10 @@ void testMissingAuth() { properties.put("spark.connector", "emr"); properties.put("emr.cluster", EMR_CLUSTER_ID); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("EMR config properties are missing.", - exception.getMessage()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", exception.getMessage()); } @Test @@ -86,10 +83,10 @@ void testUnsupportedEmrAuth() { properties.put("emr.cluster", EMR_CLUSTER_ID); properties.put("emr.auth.type", "basic"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("Invalid auth type.", - exception.getMessage()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("Invalid auth type.", exception.getMessage()); } @Test @@ -99,10 +96,10 @@ void testMissingCluster() { properties.put("spark.connector", "emr"); properties.put("emr.auth.type", "awssigv4"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("EMR config properties are missing.", - exception.getMessage()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR config properties are missing.", exception.getMessage()); } @Test @@ -113,10 +110,10 @@ void testMissingAuthKeys() { properties.put("emr.cluster", EMR_CLUSTER_ID); properties.put("emr.auth.type", "awssigv4"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("EMR auth keys are missing.", - exception.getMessage()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", exception.getMessage()); } @Test @@ -128,10 +125,10 @@ void testMissingAuthSecretKey() { properties.put("emr.auth.type", "awssigv4"); properties.put("emr.auth.access_key", "test"); SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); - IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, - () -> sparkStorageFactory.getStorageEngine(properties)); - Assertions.assertEquals("EMR auth keys are missing.", - exception.getMessage()); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> sparkStorageFactory.getStorageEngine(properties)); + Assertions.assertEquals("EMR auth keys are missing.", exception.getMessage()); } @Test @@ -178,5 +175,4 @@ void testSetSparkJars() { DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); } - } diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java index 39bd2eb199..a70d4ba69e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -31,26 +31,23 @@ @ExtendWith(MockitoExtension.class) public class SparkTableTest { - @Mock - private SparkClient client; + @Mock private SparkClient client; @Test void testUnsupportedOperation() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); - SparkTable sparkTable = - new SparkTable(client, sparkQueryRequest); + SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); assertThrows(UnsupportedOperationException.class, sparkTable::exists); - assertThrows(UnsupportedOperationException.class, - () -> sparkTable.create(Collections.emptyMap())); + assertThrows( + UnsupportedOperationException.class, () -> sparkTable.create(Collections.emptyMap())); } @Test void testCreateScanBuilderWithSqlTableFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkTable sparkTable = - new SparkTable(client, sparkQueryRequest); + SparkTable sparkTable = new SparkTable(client, sparkQueryRequest); TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); Assertions.assertNotNull(tableScanBuilder); Assertions.assertTrue(tableScanBuilder instanceof SparkSqlFunctionTableScanBuilder); @@ -59,8 +56,7 @@ void testCreateScanBuilderWithSqlTableFunction() { @Test @SneakyThrows void testGetFieldTypesFromSparkQueryRequest() { - SparkTable sparkTable - = new SparkTable(client, new SparkQueryRequest()); + SparkTable sparkTable = new SparkTable(client, new SparkQueryRequest()); Map expectedFieldTypes = new HashMap<>(); Map fieldTypes = sparkTable.getFieldTypes(); @@ -73,10 +69,9 @@ void testGetFieldTypesFromSparkQueryRequest() { void testImplementWithSqlFunction() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); - SparkTable sparkMetricTable = - new SparkTable(client, sparkQueryRequest); - PhysicalPlan plan = sparkMetricTable.implement( - new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); + SparkTable sparkMetricTable = new SparkTable(client, sparkQueryRequest); + PhysicalPlan plan = + sparkMetricTable.implement(new SparkSqlFunctionTableScanBuilder(client, sparkQueryRequest)); assertTrue(plan instanceof SparkSqlFunctionTableScanOperator); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java index b480e6d9d9..ca77006d9c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java +++ b/spark/src/test/java/org/opensearch/sql/spark/utils/TestUtils.java @@ -12,6 +12,7 @@ public class TestUtils { /** * Get Json document from the files in resources folder. + * * @param filename filename. * @return String. * @throws IOException IOException. @@ -21,5 +22,4 @@ public static String getJson(String filename) throws IOException { return new String( Objects.requireNonNull(classLoader.getResourceAsStream(filename)).readAllBytes()); } - }