diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java index 148f43028604..11857d3fcbac 100644 --- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java +++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java @@ -2503,9 +2503,9 @@ private RelBuilder pruneAggregateInputFieldsAndDeduplicateAggCalls( builder.add(project.getRowType().getFieldList().get(i)); } - r = - project.copy(project.getTraitSet().apply(targetMapping), project.getInput(), - newProjects, builder.build()); + r = project.copy(project.getTraitSet().apply(targetMapping), project.getInput(), newProjects, builder.build()); + r = project.copy(project.getTraitSet(), project.getInput(), newProjects, builder.build()); + //r = project.copy(cluster.traitSet(), project.getInput(), newProjects, builder.build()); } else { groupSetAfterPruning = groupSet; groupSetsAfterPruning = groupSets; diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index 24b06aedee62..f6cddadd5555 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -22,10 +22,13 @@ import org.apache.calcite.plan.Contexts; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelDistributions; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Correlate; @@ -548,27 +551,105 @@ private void checkSimplify(UnaryOperator transform, } /** Test case for - * - * [CALCITE-6340] RelBuilder always creates Project with Convention.NONE during aggregate_.. + * [CALCITE-6340] + * RelBuilder always creates Project with Convention.NONE during aggregate_.. + */ + @Test void testPruneProjectInputOfAggregatePreservesConventionAndCollations() { + final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true)); + + RelNode node = builder + .scan("EMP") + .sort(builder.nullsLast(builder.desc(builder.field(0))), + builder.field(1)) + .project(builder.alias(builder.field(0), "a"), + builder.alias(builder.field(1), "b"), + builder.alias(builder.field(0), "c"), + builder.alias(builder.field(1), "d")) + .build(); + System.out.println("TRAITS: " + node.getTraitSet()); + System.out.println(RelOptUtil.toString(node)); + + RelTraitSet desiredTraits = builder.getCluster().traitSet() + .replace(EnumerableConvention.INSTANCE); + + RuleSet prepareRules = + RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE, + EnumerableRules.ENUMERABLE_SORT_RULE, + EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE); + Program program = Programs.of(prepareRules); + node = program.run(node.getCluster().getPlanner(), node, + desiredTraits, ImmutableList.of(), ImmutableList.of()); + System.out.println("TRAITS: " + node.getTraitSet()); + System.out.println(RelOptUtil.toString(node)); + + node = builder.push(node) + .aggregate(builder.groupKey(0), builder.aggregateCall( + SqlStdOperatorTable.SUM, builder.field(0))) + .build(); + + RelTraitSet relTraitSet = node.getInput(0).getTraitSet(); + System.out.println(relTraitSet); + System.out.println(RelOptUtil.toString(node)); + + final RelCollation collation1 = RelCollations.of(new RelFieldCollation(1, + RelFieldCollation.Direction.DESCENDING, RelFieldCollation.NullDirection.LAST), + new RelFieldCollation(0)); + + assertTrue(relTraitSet.contains(EnumerableConvention.INSTANCE)); + assertTrue(relTraitSet.getTrait(1).satisfies(collation1)); + } + + /** Test case for + * [CALCITE-6340] + * RelBuilder always creates Project with Convention.NONE during aggregate_.. + */ + @Test void testPruneProjectInputOfAggregatePreservesConventionAndDistribution() { + final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true)); + + RelNode node = builder + .scan("EMP") + .project(builder.alias(builder.field(0), "a"), + builder.alias(builder.field(0), "b"), + builder.alias(builder.field(1), "c")) + .build(); + RelTraitSet desiredTraits = builder.getCluster().traitSet() + .replace(EnumerableConvention.INSTANCE); + + RuleSet prepareRules = + RuleSets.ofList(EnumerableRules.ENUMERABLE_PROJECT_RULE, + EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE); + Program program = Programs.of(prepareRules); + node = program.run(node.getCluster().getPlanner(), node, + desiredTraits, ImmutableList.of(), ImmutableList.of()); + + node = node.copy(desiredTraits.plus(RelDistributions.BROADCAST_DISTRIBUTED), node.getInputs()); + + node = builder.push(node) + .aggregate(builder.groupKey(0), builder.aggregateCall( + SqlStdOperatorTable.SUM, builder.field(0))) + .build(); + + RelTraitSet relTraitSet = node.getInput(0).getTraitSet(); + assertTrue(relTraitSet.contains(EnumerableConvention.INSTANCE)); + assertTrue(relTraitSet.contains(RelDistributions.BROADCAST_DISTRIBUTED)); + } + + /** Test case for + * + * [CALCITE-6340] RelBuilder drops set conventions when aggregating over duplicate + * projected fields.. */ @Test void testPruneProjectInputOfAggregatePreservesTraitSet() { final RelBuilder builder = createBuilder(config -> config.withPruneInputOfAggregate(true)); - // This issue only occurs when projecting more columns than there are fields and putting - // an aggregate over that projection. - RelNode root = - builder.scan("DEPT") - .adoptConvention(EnumerableConvention.INSTANCE) - .project(builder.alias(builder.field(0), "a"), - builder.alias(builder.field(1), "b"), - builder.alias(builder.field(2), "c"), - builder.alias(builder.field(1), "d")) - .aggregate(builder.groupKey(0, 1, 2, 1), - builder.aggregateCall(SqlStdOperatorTable.SUM, - builder.field(0))) - .build(); + final RelNode root = builder.scan("DEPT") + .adoptConvention(EnumerableConvention.INSTANCE) + .project(builder.alias(builder.field(0), "a"), + builder.alias(builder.field(0), "b")) + .aggregate( + builder.groupKey(0), builder.aggregateCall( + SqlStdOperatorTable.SUM, builder.field(0))).build(); - // Verify that the project under the aggregate kept the EnumerableConvention.INSTANCE trait. assertTrue(root.getInput(0).getTraitSet().contains(EnumerableConvention.INSTANCE)); }