From 0c0d04e68b2a54ecf4a415f71b8936e58e49c31d Mon Sep 17 00:00:00 2001 From: Forest Vey Date: Wed, 28 Jun 2023 14:22:29 -0700 Subject: [PATCH] Add Support for Field Star in Nested Function (#1773) * Add Support for Field Star in Nested Function. Signed-off-by: forestmvey * Removing toString for NestedAllTupleFields. Signed-off-by: forestmvey * Adding IT test for nested all fields in invalid clause of SQL statement. Signed-off-by: forestmvey * Use utility function for checking is nested in NestedAnalyzer. Signed-off-by: forestmvey * Formatting fixes. Signed-off-by: forestmvey --------- Signed-off-by: forestmvey (cherry picked from commit fa840e0cc5f764b124ccfe58cb190f1f89b2650a) Signed-off-by: forestmvey --- .../org/opensearch/sql/analysis/Analyzer.java | 10 +- .../sql/analysis/ExpressionAnalyzer.java | 11 +- .../sql/analysis/NestedAnalyzer.java | 49 +++- .../analysis/SelectExpressionAnalyzer.java | 32 +++ .../sql/analysis/TypeEnvironment.java | 11 + .../sql/analysis/symbol/SymbolTable.java | 15 ++ .../sql/ast/AbstractNodeVisitor.java | 5 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 5 + .../ast/expression/NestedAllTupleFields.java | 41 ++++ .../opensearch/sql/analysis/AnalyzerTest.java | 220 ++++++++++++++++-- docs/user/dql/functions.rst | 11 + .../sql/legacy/PrettyFormatResponseIT.java | 5 +- .../java/org/opensearch/sql/sql/NestedIT.java | 158 ++++++++++++- sql/src/main/antlr/OpenSearchSQLParser.g4 | 7 +- .../sql/sql/parser/AstExpressionBuilder.java | 12 +- .../sql/sql/parser/AstBuilderTest.java | 14 ++ 16 files changed, 577 insertions(+), 29 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 6418f92686..2c4647004c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -29,6 +29,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.analysis.function.Exp; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; @@ -469,8 +470,13 @@ public LogicalPlan visitSort(Sort node, AnalysisContext context) { node.getSortList().stream() .map( sortField -> { - Expression expression = optimizer.optimize( - expressionAnalyzer.analyze(sortField.getField(), context), context); + var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); + if (analyzed == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", sortField.getField()) + ); + } + Expression expression = optimizer.optimize(analyzed, context); return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); }) .collect(Collectors.toList()); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 43155a868a..601e3e00cc 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -186,7 +186,16 @@ public Expression visitFunction(Function node, AnalysisContext context) { FunctionName functionName = FunctionName.of(node.getFuncName()); List arguments = node.getFuncArgs().stream() - .map(unresolvedExpression -> analyze(unresolvedExpression, context)) + .map(unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression) + ); + } else { + return ret; + } + }) .collect(Collectors.toList()); return (Expression) repository.compile(context.getFunctionProperties(), functionName, arguments); diff --git a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java index 4e3939bb14..f050824557 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/NestedAnalyzer.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.Expression; @@ -45,6 +46,28 @@ public LogicalPlan visitAlias(Alias node, AnalysisContext context) { return node.getDelegated().accept(this, context); } + @Override + public LogicalPlan visitNestedAllTupleFields(NestedAllTupleFields node, AnalysisContext context) { + List> args = new ArrayList<>(); + for (NamedExpression namedExpr : namedExpressions) { + if (isNestedFunction(namedExpr.getDelegated())) { + ReferenceExpression field = + (ReferenceExpression) ((FunctionExpression) namedExpr.getDelegated()) + .getArguments().get(0); + + // If path is same as NestedAllTupleFields path + if (field.getAttr().substring(0, field.getAttr().lastIndexOf(".")) + .equalsIgnoreCase(node.getPath())) { + args.add(Map.of( + "field", field, + "path", new ReferenceExpression(node.getPath(), STRING))); + } + } + } + + return mergeChildIfLogicalNested(args); + } + @Override public LogicalPlan visitFunction(Function node, AnalysisContext context) { if (node.getFuncName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) { @@ -54,6 +77,8 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) { ReferenceExpression nestedField = (ReferenceExpression)expressionAnalyzer.analyze(expressions.get(0), context); Map args; + + // Path parameter is supplied if (expressions.size() == 2) { args = Map.of( "field", nestedField, @@ -65,16 +90,28 @@ public LogicalPlan visitFunction(Function node, AnalysisContext context) { "path", generatePath(nestedField.toString()) ); } - if (child instanceof LogicalNested) { - ((LogicalNested)child).addFields(args); - return child; - } else { - return new LogicalNested(child, new ArrayList<>(Arrays.asList(args)), namedExpressions); - } + + return mergeChildIfLogicalNested(new ArrayList<>(Arrays.asList(args))); } return null; } + /** + * NestedAnalyzer visits all functions in SELECT clause, creates logical plans for each and + * merges them. This is to avoid another merge rule in LogicalPlanOptimizer:create(). + * @param args field and path params to add to logical plan. + * @return child of logical nested with added args, or new LogicalNested. + */ + private LogicalPlan mergeChildIfLogicalNested(List> args) { + if (child instanceof LogicalNested) { + for (var arg : args) { + ((LogicalNested) child).addFields(arg); + } + return child; + } + return new LogicalNested(child, args, namedExpressions); + } + /** * Validate each parameter used in nested function in SELECT clause. Any supplied parameter * for a nested function in a SELECT statement must be a valid qualified name, and the field diff --git a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java index 3593488f46..734f37378b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/SelectExpressionAnalyzer.java @@ -10,13 +10,17 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; import org.opensearch.sql.analysis.symbol.Namespace; +import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.data.type.ExprType; @@ -58,6 +62,11 @@ public List visitField(Field node, AnalysisContext context) { @Override public List visitAlias(Alias node, AnalysisContext context) { + // Expand all nested fields if used in SELECT clause + if (node.getDelegated() instanceof NestedAllTupleFields) { + return node.getDelegated().accept(this, context); + } + Expression expr = referenceIfSymbolDefined(node, context); return Collections.singletonList(DSL.named( unqualifiedNameIfFieldOnly(node, context), @@ -100,6 +109,29 @@ public List visitAllFields(AllFields node, new ReferenceExpression(entry.getKey(), entry.getValue()))).collect(Collectors.toList()); } + @Override + public List visitNestedAllTupleFields(NestedAllTupleFields node, + AnalysisContext context) { + TypeEnvironment environment = context.peek(); + Map lookupAllTupleFields = + environment.lookupAllTupleFields(Namespace.FIELD_NAME); + environment.resolve(new Symbol(Namespace.FIELD_NAME, node.getPath())); + + // Match all fields with same path as used in nested function. + Pattern p = Pattern.compile(node.getPath() + "\\.[^\\.]+$"); + return lookupAllTupleFields.entrySet().stream() + .filter(field -> p.matcher(field.getKey()).find()) + .map(entry -> { + Expression nestedFunc = new Function( + "nested", + List.of( + new QualifiedName(List.of(entry.getKey().split("\\.")))) + ).accept(expressionAnalyzer, context); + return DSL.named("nested(" + entry.getKey() + ")", nestedFunc); + }) + .collect(Collectors.toList()); + } + /** * Get unqualified name if select item is just a field. For example, suppose an index * named "accounts", return "age" for "SELECT accounts.age". But do nothing for expression diff --git a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java index c9fd8030e0..17d203f66f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java +++ b/core/src/main/java/org/opensearch/sql/analysis/TypeEnvironment.java @@ -85,6 +85,17 @@ public Map lookupAllFields(Namespace namespace) { return result; } + /** + * Resolve all fields in the current environment. + * @param namespace a namespace + * @return all symbols in the namespace + */ + public Map lookupAllTupleFields(Namespace namespace) { + Map result = new LinkedHashMap<>(); + symbolTable.lookupAllTupleFields(namespace).forEach(result::putIfAbsent); + return result; + } + /** * Define symbol with the type. * diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java index 45f77915f2..be7435c288 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java @@ -128,6 +128,21 @@ public Map lookupAllFields(Namespace namespace) { return results; } + /** + * Look up all top level symbols in the namespace. + * + * @param namespace a namespace + * @return all symbols in the namespace map + */ + public Map lookupAllTupleFields(Namespace namespace) { + final LinkedHashMap allSymbols = + orderedTable.getOrDefault(namespace, new LinkedHashMap<>()); + final LinkedHashMap result = new LinkedHashMap<>(); + allSymbols.entrySet().stream() + .forEach(entry -> result.put(entry.getKey(), entry.getValue())); + return result; + } + /** * Check if namespace map in empty (none definition). * diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 3e81509fae..f02bc07ccc 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -25,6 +25,7 @@ import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -238,6 +239,10 @@ public T visitAllFields(AllFields node, C context) { return visitChildren(node, context); } + public T visitNestedAllTupleFields(NestedAllTupleFields node, C context) { + return visitChildren(node, context); + } + public T visitInterval(Interval node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index de2ab5404a..d5f10fcfd4 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -30,6 +30,7 @@ import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; @@ -377,6 +378,10 @@ public Alias alias(String name, UnresolvedExpression expr, String alias) { return new Alias(name, expr, alias); } + public NestedAllTupleFields nestedAllTupleFields(String path) { + return new NestedAllTupleFields(path); + } + public static List exprList(UnresolvedExpression... exprList) { return Arrays.asList(exprList); } diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java b/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java new file mode 100644 index 0000000000..adf2025e6c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/NestedAllTupleFields.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ast.expression; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** + * Represents all tuple fields used in nested function. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class NestedAllTupleFields extends UnresolvedExpression { + @Getter + private final String path; + + @Override + public List getChild() { + return Collections.emptyList(); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitNestedAllTupleFields(this, context); + } + + @Override + public String toString() { + return String.format("nested(%s.*)", path); + } +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 6d83ee53a8..100cfd67af 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -24,6 +24,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.nestedAllTupleFields; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.span; @@ -556,7 +557,7 @@ public void project_nested_field_arg() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null) ); @@ -567,13 +568,13 @@ public void project_nested_field_arg() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function("nested", qualifiedName("message", "info")), null) ) ); @@ -583,6 +584,195 @@ public void project_nested_field_arg() { assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message"))))); } + @Test + public void sort_with_nested_all_tuple_fields_throws_exception() { + assertThrows(UnsupportedOperationException.class, () -> analyze( + AstDSL.project( + AstDSL.sort( + AstDSL.relation("schema"), + field(nestedAllTupleFields("message")) + ), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + )); + } + + @Test + public void filter_with_nested_all_tuple_fields_throws_exception() { + assertThrows(UnsupportedOperationException.class, () -> analyze( + AstDSL.project( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function("=", nestedAllTupleFields("message"), AstDSL.intLiteral(1))), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + )); + } + + + @Test + public void project_nested_field_star_arg() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_another_nested_function() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ), + Map.of( + "field", new ReferenceExpression("comment.data", STRING), + "path", new ReferenceExpression("comment", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression("nested(comment.data)", + DSL.nested(DSL.ref("comment.data", STRING))) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("nested(comment.data)", + DSL.nested(DSL.ref("comment.data", STRING))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("nested(comment.*)", + nestedAllTupleFields("comment")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_another_field() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression("comment.data", + DSL.ref("comment.data", STRING)) + ); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.relation("schema", table), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("comment.data", + DSL.ref("comment.data", STRING)) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("comment.data", + field("comment.data")) + ) + ); + } + + @Test + public void project_nested_field_star_arg_with_highlight() { + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING) + ) + ); + + List projectList = + List.of( + new NamedExpression("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", + new HighlightExpression(DSL.literal("fieldA"))) + ); + + Map highlightArgs = new HashMap<>(); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.nested( + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), + DSL.literal("fieldA"), highlightArgs), + nestedArgs, + projectList), + DSL.named("nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", + new HighlightExpression(DSL.literal("fieldA"))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("nested(message.*)", + nestedAllTupleFields("message")), + AstDSL.alias("highlight(fieldA)", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), highlightArgs)) + ) + ); + } + @Test public void project_nested_field_and_path_args() { List> nestedArgs = @@ -596,7 +786,7 @@ public void project_nested_field_and_path_args() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), null) ); @@ -607,13 +797,13 @@ public void project_nested_field_and_path_args() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function( "nested", qualifiedName("message", "info"), @@ -638,7 +828,7 @@ public void project_nested_deep_field_arg() { List projectList = List.of( new NamedExpression( - "message.info.id", + "nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING)), null) ); @@ -649,13 +839,13 @@ public void project_nested_deep_field_arg() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info.id", + DSL.named("nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info.id", + AstDSL.alias("nested(message.info.id)", function("nested", qualifiedName("message", "info", "id")), null) ) ); @@ -678,11 +868,11 @@ public void project_multiple_nested() { List projectList = List.of( new NamedExpression( - "message.info", + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null), new NamedExpression( - "comment.data", + "nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)), null) ); @@ -693,17 +883,17 @@ public void project_multiple_nested() { LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), - DSL.named("message.info", + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("comment.data", + DSL.named("nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("message.info", + AstDSL.alias("nested(message.info)", function("nested", qualifiedName("message", "info")), null), - AstDSL.alias("comment.data", + AstDSL.alias("nested(comment.data)", function("nested", qualifiedName("comment", "data")), null) ) ); diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index cef87624a5..19260e8bea 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -4458,6 +4458,17 @@ Example with ``field`` and ``path`` parameters:: | b | +---------------------------------+ +Example with ``field.*`` used in SELECT clause:: + + os> SELECT nested(message.*) FROM nested; + fetched rows / total rows = 2/2 + +--------------------------+-----------------------------+------------------------+ + | nested(message.author) | nested(message.dayOfWeek) | nested(message.info) | + |--------------------------+-----------------------------+------------------------| + | e | 1 | a | + | f | 2 | b | + +--------------------------+-----------------------------+------------------------+ + Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause:: diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java index 200c300f3b..ef80098df6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java @@ -53,6 +53,9 @@ public class PrettyFormatResponseIT extends SQLIntegTestCase { private static final Set messageFields = Sets.newHashSet( "message.dayOfWeek", "message.info", "message.author"); + private static final Set messageFieldsWithNestedFunction = Sets.newHashSet( + "nested(message.dayOfWeek)", "nested(message.info)", "nested(message.author)"); + private static final Set commentFields = Sets.newHashSet("comment.data", "comment.likes"); private static final List nameFields = Arrays.asList("firstname", "lastname"); @@ -211,7 +214,7 @@ public void selectNestedFieldWithWildcard() throws IOException { String.format(Locale.ROOT, "SELECT nested(message.*) FROM %s", TestsConstants.TEST_INDEX_NESTED_TYPE)); - assertContainsColumnsInAnyOrder(getSchema(response), messageFields); + assertContainsColumnsInAnyOrder(getSchema(response), messageFieldsWithNestedFunction); assertContainsData(getDataRows(response), messageFields); } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 69b54cfc4f..d3230188b7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -89,8 +89,7 @@ public void nested_function_with_arrays_in_an_aggregate_function_in_select_test( verifyDataRows(result, rows(19)); } - // TODO not currently supported by legacy, should we add implementation in AstBuilder? - @Disabled + @Test public void nested_function_in_a_function_in_select_test() { String query = "SELECT upper(nested(message.info)) FROM " + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; @@ -104,6 +103,41 @@ public void nested_function_in_a_function_in_select_test() { rows("ZZ")); } + @Test + public void nested_all_function_in_a_function_in_select_test() { + String query = "SELECT nested(message.*) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS + " WHERE nested(message.info) = 'a'"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("e", 1, "a")); + } + + @Test + public void invalid_multiple_nested_all_function_in_a_function_in_select_test() { + String query = "SELECT nested(message.*), nested(message.info) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; + RuntimeException result = assertThrows( + RuntimeException.class, + () -> executeJdbcRequest(query) + ); + assertTrue( + result.getMessage().contains("IllegalArgumentException") + && result.getMessage().contains("Multiple entries with same key") + ); + } + + @Test + public void nested_all_function_with_limit_test() { + String query = "SELECT nested(message.*) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS + " LIMIT 3"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("e", 1, "a"), + rows("f", 2, "b"), + rows("g", 1, "c") + ); + } + + @Test public void nested_function_with_array_of_multi_nested_field_test() { String query = "SELECT nested(message.author.name) FROM " + TEST_INDEX_MULTI_NESTED_TYPE; @@ -403,6 +437,107 @@ public void test_nested_in_where_as_predicate_expression_with_relevance_query() verifyDataRows(result, rows(10, "a")); } + @Test + public void nested_function_all_subfields() { + String query = "SELECT nested(message.*) FROM " + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword")); + verifyDataRows(result, + rows("e", 1, "a"), + rows("f", 2, "b"), + rows("g", 1, "c"), + rows("h", 4, "c"), + rows("i", 5, "a"), + rows("zz", 6, "zz")); + } + + @Test + public void nested_function_all_subfields_and_specified_subfield() { + String query = "SELECT nested(message.*), nested(comment.data) FROM " + + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("nested(comment.data)", null, "keyword")); + verifyDataRows(result, + rows("e", 1, "a", "ab"), + rows("f", 2, "b", "aa"), + rows("g", 1, "c", "aa"), + rows("h", 4, "c", "ab"), + rows("i", 5, "a", "ab"), + rows("zz", 6, "zz", new JSONArray(List.of("aa", "bb")))); + } + + @Test + public void nested_function_all_deep_nested_subfields() { + String query = "SELECT nested(message.author.address.*) FROM " + + TEST_INDEX_MULTI_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author.address.number)", null, "integer"), + schema("nested(message.author.address.street)", null, "keyword")); + verifyDataRows(result, + rows(1, "bc"), + rows(2, "ab"), + rows(3, "sk"), + rows(4, "mb"), + rows(5, "on"), + rows(6, "qc")); + } + + @Test + public void nested_function_all_subfields_for_two_nested_fields() { + String query = "SELECT nested(message.*), nested(comment.*) FROM " + + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("nested(comment.data)", null, "keyword"), + schema("nested(comment.likes)", null, "long")); + verifyDataRows(result, + rows("e", 1, "a", "ab", 3), + rows("f", 2, "b", "aa", 2), + rows("g", 1, "c", "aa", 3), + rows("h", 4, "c", "ab", 1), + rows("i", 5, "a", "ab", 1), + rows("zz", 6, "zz", new JSONArray(List.of("aa", "bb")), 10)); + } + + @Test + public void nested_function_all_subfields_and_non_nested_field() { + String query = "SELECT nested(message.*), myNum FROM " + TEST_INDEX_NESTED_TYPE; + JSONObject result = executeJdbcRequest(query); + + assertEquals(6, result.getInt("total")); + verifySchema(result, + schema("nested(message.author)", null, "keyword"), + schema("nested(message.dayOfWeek)", null, "long"), + schema("nested(message.info)", null, "keyword"), + schema("myNum", null, "long")); + verifyDataRows(result, + rows("e", 1, "a", 1), + rows("f", 2, "b", 2), + rows("g", 1, "c", 3), + rows("h", 4, "c", 4), + rows("i", 5, "a", 4), + rows("zz", 6, "zz", new JSONArray(List.of(3, 4)))); + } + @Test public void nested_function_with_date_types_as_object_arrays_within_arrays_test() { String query = "SELECT nested(address.moveInDate) FROM " + TEST_INDEX_NESTED_SIMPLE; @@ -435,4 +570,23 @@ public void nested_function_with_date_types_as_object_arrays_within_arrays_test( ) ); } + + @Test + public void nested_function_all_subfields_in_wrong_clause() { + String query = "SELECT * FROM " + TEST_INDEX_NESTED_TYPE + " ORDER BY nested(message.*)"; + + Exception exception = assertThrows(RuntimeException.class, () -> + executeJdbcRequest(query)); + + assertTrue(exception.getMessage().contains("" + + "{\n" + + " \"error\": {\n" + + " \"reason\": \"There was internal problem at backend\",\n" + + " \"details\": \"Invalid use of expression nested(message.*)\",\n" + + " \"type\": \"UnsupportedOperationException\"\n" + + " },\n" + + " \"status\": 503\n" + + "}" + )); + } } diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index a0727ec01a..20df7a62b9 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -333,7 +333,8 @@ nullNotnull ; functionCall - : scalarFunctionName LR_BRACKET functionArgs RR_BRACKET #scalarFunctionCall + : nestedFunctionName LR_BRACKET allTupleFields RR_BRACKET #nestedAllFunctionCall + | scalarFunctionName LR_BRACKET functionArgs RR_BRACKET #scalarFunctionCall | specificFunction #specificFunctionCall | windowFunctionClause #windowFunctionCall | aggregateFunction #aggregateFunctionCall @@ -818,6 +819,10 @@ columnName : qualifiedName ; +allTupleFields + : path=qualifiedName DOT STAR + ; + alias : ident ; diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index dda2aa592c..81e8b910dd 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.between; import static org.opensearch.sql.ast.dsl.AstDSL.not; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; -import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE; @@ -41,6 +40,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MultiFieldRelevanceFunctionContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NestedAllFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NoFieldRelevanceFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NotExpressionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NullLiteralContext; @@ -90,6 +90,7 @@ import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.QualifiedName; @@ -102,6 +103,7 @@ import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchQueryContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; @@ -150,6 +152,14 @@ public UnresolvedExpression visitNestedExpressionAtom(NestedExpressionAtomContex return visit(ctx.expression()); // Discard parenthesis around } + @Override + public UnresolvedExpression visitNestedAllFunctionCall( + NestedAllFunctionCallContext ctx) { + return new NestedAllTupleFields( + visitQualifiedName(ctx.allTupleFields().path).toString() + ); + } + @Override public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ctx) { return buildFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs().functionArg()); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 4fee3ff414..4b44c0344c 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -39,6 +39,7 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.NestedAllTupleFields; import org.opensearch.sql.common.antlr.SyntaxCheckException; class AstBuilderTest extends AstBuilderTestBase { @@ -86,6 +87,19 @@ public void can_build_select_all_from_index() { assertThrows(SyntaxCheckException.class, () -> buildAST("SELECT *")); } + @Test + public void can_build_nested_select_all() { + assertEquals( + project( + relation("test"), + alias("nested(field.*)", + new NestedAllTupleFields("field") + ) + ), + buildAST("SELECT nested(field.*) FROM test") + ); + } + @Test public void can_build_select_all_and_fields_from_index() { assertEquals(