diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6ac9d3a34..7eba00c94 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -184,22 +184,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex //add group by fields to context context.getGroupingParseExpressions().addAll(groupExpList); } - - // set sort direction according to command type - List sortDirections = new ArrayList<>(); - if (node instanceof RareAggregation) { - sortDirections.add(Ascending$.MODULE$); - } else if(node instanceof TopAggregation) { - sortDirections.add(Descending$.MODULE$); - } - if (!sortExpList.isEmpty()) { - visitExpressionList(node.getSortExprList(), context); - Seq sortElements = context.retainAllNamedParseExpressions(exp -> - new SortOrder((NamedExpression) exp, sortDirections.get(0) , sortDirections.get(0).defaultNullOrdering(), seq(new ArrayList()))); - context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); - } - UnresolvedExpression span = node.getSpan(); if (!Objects.isNull(span)) { span.accept(this, context); @@ -207,8 +192,24 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); } // build the aggregation logical step - return extractedAggregation(context); -} +// context.apply(p -> extractedAggregation(context)); TODO remove + LogicalPlan logicalPlan = extractedAggregation(context); + + // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) + List sortDirections = new ArrayList<>(); + sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : node instanceof TopAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); + + if (!sortExpList.isEmpty()) { + visitExpressionList(node.getSortExprList(), context); + Seq sortElements = context.retainAllNamedParseExpressions(exp -> + new SortOrder(exp, + sortDirections.get(0), + sortDirections.get(0).defaultNullOrdering(), + seq(new ArrayList()))); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan)); + } + return logicalPlan; + } private static LogicalPlan extractedAggregation(CatalystPlanContext context) { Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p);