Skip to content

Commit

Permalink
Add Support for Nested Function in Order By Clause (#1789)
Browse files Browse the repository at this point in the history
* Add Support for Nested Function in Order By Clause (#280)

* Adding order by clause support for nested function.

Signed-off-by: forestmvey <[email protected]>

* Adding test coverage for nested in ORDER BY clause.

Signed-off-by: forestmvey <[email protected]>

* Added nested function validation to NestedAnalyzer.

Signed-off-by: forestmvey <[email protected]>

---------

Signed-off-by: forestmvey <[email protected]>

* Adding semantic check for missing arguments in function and unit test.

Signed-off-by: forestmvey <[email protected]>

---------

Signed-off-by: forestmvey <[email protected]>
(cherry picked from commit 3302ec8)
  • Loading branch information
forestmvey committed Jun 27, 2023
1 parent e9a009b commit 078e884
Show file tree
Hide file tree
Showing 12 changed files with 278 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
Expand Down Expand Up @@ -105,7 +107,18 @@ private void validateArgs(List<UnresolvedExpression> args) {
* @param field : Nested field to generate path of.
* @return : Path of field derived from last level of nesting.
*/
private ReferenceExpression generatePath(String field) {
public static ReferenceExpression generatePath(String field) {
return new ReferenceExpression(field.substring(0, field.lastIndexOf(".")), STRING);
}

/**
* Check if supplied expression is a nested function.
* @param expr Expression checking if is nested function.
* @return True if expression is a nested function.
*/
public static Boolean isNestedFunction(Expression expr) {
return (expr instanceof FunctionExpression
&& ((FunctionExpression) expr).getFunctionName().getFunctionName()
.equalsIgnoreCase(BuiltinFunctionName.NESTED.name()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME;
import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;
import static org.opensearch.sql.ast.dsl.AstDSL.aggregate;
import static org.opensearch.sql.ast.dsl.AstDSL.alias;
import static org.opensearch.sql.ast.dsl.AstDSL.argument;
Expand Down Expand Up @@ -39,6 +41,7 @@
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.ALGO;
import static org.opensearch.sql.utils.MLCommonsConstants.ASYNC;
Expand Down Expand Up @@ -574,6 +577,10 @@ public void project_nested_field_arg() {
function("nested", qualifiedName("message", "info")), null)
)
);

assertTrue(isNestedFunction(DSL.nested(DSL.ref("message.info", STRING))));
assertFalse(isNestedFunction(DSL.literal("fieldA")));
assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message")))));
}

@Test
Expand Down
11 changes: 11 additions & 0 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4469,6 +4469,17 @@ Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause::
| b |
+---------------------------------+

Example with ``field`` and ``path`` parameters in the SELECT and ORDER BY clause::

os> SELECT nested(message.info, message) FROM nested ORDER BY nested(message.info, message) DESC;
fetched rows / total rows = 2/2
+---------------------------------+
| nested(message.info, message) |
|---------------------------------|
| b |
| a |
+---------------------------------+


System Functions
================
Expand Down
34 changes: 34 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,40 @@ public void nested_function_with_order_by_clause() {
rows("zz"));
}

@Test
public void nested_function_with_order_by_clause_desc() {
String query =
"SELECT nested(message.info) FROM " + TEST_INDEX_NESTED_TYPE
+ " ORDER BY nested(message.info, message) DESC";
JSONObject result = executeJdbcRequest(query);

assertEquals(6, result.getInt("total"));
verifyDataRows(result,
rows("zz"),
rows("c"),
rows("c"),
rows("a"),
rows("b"),
rows("a"));
}

@Test
public void nested_function_and_field_with_order_by_clause() {
String query =
"SELECT nested(message.info), myNum FROM " + TEST_INDEX_NESTED_TYPE
+ " ORDER BY nested(message.info, message), myNum";
JSONObject result = executeJdbcRequest(query);

assertEquals(6, result.getInt("total"));
verifyDataRows(result,
rows("a", 1),
rows("c", 4),
rows("a", 4),
rows("b", 2),
rows("c", 3),
rows("zz", new JSONArray(List.of(3, 4))));
}

// Nested function in GROUP BY clause is not yet implemented for JDBC format. This test ensures
// that the V2 engine falls back to legacy implementation.
// TODO Fix the test when NESTED is supported in GROUP BY in the V2 engine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.opensearch.storage.scan;

import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import java.util.function.Function;
import lombok.EqualsAndHashCode;
import org.opensearch.sql.expression.ReferenceExpression;
Expand Down Expand Up @@ -113,9 +115,15 @@ public boolean pushDownNested(LogicalNested nested) {
return delegate.pushDownNested(nested);
}

/**
* Valid if sorting is only by fields.
* @param sort Logical sort
* @return True if sorting by fields only
*/
private boolean sortByFieldsOnly(LogicalSort sort) {
return sort.getSortList().stream()
.map(sortItem -> sortItem.getRight() instanceof ReferenceExpression)
.map(sortItem -> sortItem.getRight() instanceof ReferenceExpression
|| isNestedFunction(sortItem.getRight()))
.reduce(true, Boolean::logicalAnd);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package org.opensearch.sql.opensearch.storage.script.filter.lucene;

import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.function.Function;
Expand Down Expand Up @@ -62,10 +64,7 @@ public boolean canSupport(FunctionExpression func) {
* @return return true if function has supported nested function expression.
*/
public boolean isNestedPredicate(FunctionExpression func) {
return ((func.getArguments().get(0) instanceof FunctionExpression
&& ((FunctionExpression)func.getArguments().get(0))
.getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name()))
);
return isNestedFunction(func.getArguments().get(0));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@

package org.opensearch.sql.opensearch.storage.script.sort;

import static org.opensearch.sql.analysis.NestedAnalyzer.generatePath;
import static org.opensearch.sql.analysis.NestedAnalyzer.isNestedFunction;

import com.google.common.collect.ImmutableMap;
import java.util.Map;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.NestedSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;

/**
Expand Down Expand Up @@ -53,11 +59,44 @@ public SortBuilder<?> build(Expression expression, Sort.SortOption option) {
return SortBuilders.scoreSort().order(sortOrderMap.get(option.getSortOrder()));
}
return fieldBuild((ReferenceExpression) expression, option);
} else if (isNestedFunction(expression)) {

validateNestedArgs((FunctionExpression) expression);
String orderByName = ((FunctionExpression)expression).getArguments().get(0).toString();
// Generate path if argument not supplied in function.
ReferenceExpression path = ((FunctionExpression)expression).getArguments().size() == 2
? (ReferenceExpression) ((FunctionExpression)expression).getArguments().get(1)
: generatePath(orderByName);
return SortBuilders.fieldSort(orderByName)
.order(sortOrderMap.get(option.getSortOrder()))
.setNestedSort(new NestedSortBuilder(path.toString()));
} else {
throw new IllegalStateException("unsupported expression " + expression.getClass());
}
}

/**
* Validate semantics for arguments in nested function.
* @param nestedFunc Nested function expression.
*/
private void validateNestedArgs(FunctionExpression nestedFunc) {
if (nestedFunc.getArguments().size() < 1 || nestedFunc.getArguments().size() > 2) {
throw new IllegalArgumentException(
"nested function supports 2 parameters (field, path) or 1 parameter (field)"
);
}

for (Expression arg : nestedFunc.getArguments()) {
if (!(arg instanceof ReferenceExpression)) {
throw new IllegalArgumentException(
String.format("Illegal nested field name: %s",
arg.toString()
)
);
}
}
}

private FieldSortBuilder fieldBuild(ReferenceExpression ref, Sort.SortOption option) {
return SortBuilders.fieldSort(
OpenSearchTextType.convertTextToKeyword(ref.getAttr(), ref.type()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.DSL.literal;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight;
Expand Down Expand Up @@ -58,6 +59,7 @@
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder;
import org.opensearch.search.sort.NestedSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;
Expand Down Expand Up @@ -574,6 +576,76 @@ void only_one_project_should_be_push() {
);
}

@Test
void test_nested_sort_filter_push_down() {
assertEqualsAfterOptimization(
project(
indexScanBuilder(
withFilterPushedDown(QueryBuilders.termQuery("intV", 1)),
withSortPushedDown(
SortBuilders.fieldSort("message.info")
.order(SortOrder.ASC)
.setNestedSort(new NestedSortBuilder("message")))),
DSL.named("intV", DSL.ref("intV", INTEGER))
),
project(
sort(
filter(
relation("schema", table),
DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))
),
Pair.of(
SortOption.DEFAULT_ASC, DSL.nested(DSL.ref("message.info", STRING))
)
),
DSL.named("intV", DSL.ref("intV", INTEGER))
)
);
}

@Test
void test_function_expression_sort_returns_optimized_logical_sort() {
// Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false
assertEqualsAfterOptimization(
sort(
indexScanBuilder(),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.match(DSL.namedArgument("field", literal("message")))
)
),
sort(
relation("schema", table),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.match(DSL.namedArgument("field", literal("message"))
)
)
)
);
}

@Test
void test_non_field_sort_returns_optimized_logical_sort() {
// Invalid use case coverage OpenSearchIndexScanBuilder::sortByFieldsOnly returns false
assertEqualsAfterOptimization(
sort(
indexScanBuilder(),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.literal("field")
)
),
sort(
relation("schema", table),
Pair.of(
SortOption.DEFAULT_ASC,
DSL.literal("field")
)
)
);
}

@Test
void sort_with_expression_cannot_merge_with_relation() {
assertEqualsAfterOptimization(
Expand Down
Loading

0 comments on commit 078e884

Please sign in to comment.