From c65d1e856040953b7c629b209eeca5ce8c81977c Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" Date: Fri, 17 Jan 2025 00:03:23 -0800 Subject: [PATCH] Support distribution type hint to allow broadcast join --- .../calcite/rel/hint/PinotHintOptions.java | 65 ++++++++- .../PinotJoinExchangeNodeInsertRule.java | 130 ++++++++++++++++-- .../PinotJoinToDynamicBroadcastRule.java | 3 +- .../pinot/query/routing/WorkerManager.java | 29 ++-- .../src/test/resources/queries/JoinPlans.json | 55 ++++++++ .../test/resources/queries/QueryHints.json | 8 ++ 6 files changed, 266 insertions(+), 24 deletions(-) diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java index 4463b1fff176..0cc8eac4f8b4 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/hint/PinotHintOptions.java @@ -18,6 +18,9 @@ */ package org.apache.pinot.calcite.rel.hint; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.calcite.rel.RelDistribution; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.hint.RelHint; @@ -83,6 +86,9 @@ public static class JoinHintOptions { // "lookup" can be used when the right table is a dimension table replicated to all workers public static final String LOOKUP_JOIN_STRATEGY = "lookup"; + public static final String LEFT_DISTRIBUTION_TYPE = "left_distribution_type"; + public static final String RIGHT_DISTRIBUTION_TYPE = "right_distribution_type"; + /** * Max rows allowed to build the right table hash collection. */ @@ -105,11 +111,64 @@ public static class JoinHintOptions { */ public static final String APPEND_DISTINCT_TO_SEMI_JOIN_PROJECT = "append_distinct_to_semi_join_project"; + @Nullable + public static Map getJoinHintOptions(Join join) { + return PinotHintStrategyTable.getHintOptions(join.getHints(), JOIN_HINT_OPTIONS); + } + + @Nullable + public static String getJoinStrategyHint(Join join) { + return PinotHintStrategyTable.getHintOption(join.getHints(), JOIN_HINT_OPTIONS, JOIN_STRATEGY); + } + // TODO: Consider adding a Join implementation with join strategy. public static boolean useLookupJoinStrategy(Join join) { - return LOOKUP_JOIN_STRATEGY.equalsIgnoreCase( - PinotHintStrategyTable.getHintOption(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS, - PinotHintOptions.JoinHintOptions.JOIN_STRATEGY)); + return LOOKUP_JOIN_STRATEGY.equalsIgnoreCase(getJoinStrategyHint(join)); + } + + @Nullable + public static DistributionType getLeftDistributionType(Map joinHintOptions) { + return DistributionType.fromHint(joinHintOptions.get(LEFT_DISTRIBUTION_TYPE)); + } + + @Nullable + public static DistributionType getRightDistributionType(Map joinHintOptions) { + return DistributionType.fromHint(joinHintOptions.get(RIGHT_DISTRIBUTION_TYPE)); + } + } + + /** + * Similar to {@link RelDistribution.Type}, it contains the distribution types to be used to shuffle data. + */ + public enum DistributionType { + LOCAL, // Distribute data locally without ser/de + HASH, // Distribute data by hash partitioning + BROADCAST, // Distribute data by broadcasting the data to all workers + RANDOM; // Distribute data randomly + + public static final String LOCAL_HINT = "local"; + public static final String HASH_HINT = "hash"; + public static final String BROADCAST_HINT = "broadcast"; + public static final String RANDOM_HINT = "random"; + + @Nullable + public static DistributionType fromHint(@Nullable String hint) { + if (hint == null) { + return null; + } + if (hint.equalsIgnoreCase(LOCAL_HINT)) { + return LOCAL; + } + if (hint.equalsIgnoreCase(HASH_HINT)) { + return HASH; + } + if (hint.equalsIgnoreCase(BROADCAST_HINT)) { + return BROADCAST; + } + if (hint.equalsIgnoreCase(RANDOM_HINT)) { + return RANDOM; + } + throw new IllegalArgumentException("Unsupported distribution type hint: " + hint); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java index 5df0b20c54d4..517edf18011a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java @@ -18,6 +18,8 @@ */ package org.apache.pinot.calcite.rel.rules; +import com.google.common.base.Preconditions; +import java.util.Map; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelDistributions; @@ -52,22 +54,132 @@ public void onMatch(RelOptRuleCall call) { RelNode left = PinotRuleUtils.unboxRel(join.getInput(0)); RelNode right = PinotRuleUtils.unboxRel(join.getInput(1)); JoinInfo joinInfo = join.analyzeCondition(); + Map joinHintOptions = PinotHintOptions.JoinHintOptions.getJoinHintOptions(join); + PinotHintOptions.DistributionType leftDistributionType; + PinotHintOptions.DistributionType rightDistributionType; + if (joinHintOptions != null) { + leftDistributionType = PinotHintOptions.JoinHintOptions.getLeftDistributionType(joinHintOptions); + rightDistributionType = PinotHintOptions.JoinHintOptions.getRightDistributionType(joinHintOptions); + } else { + leftDistributionType = null; + rightDistributionType = null; + } RelNode newLeft; RelNode newRight; if (PinotHintOptions.JoinHintOptions.useLookupJoinStrategy(join)) { - // Lookup join - add local exchange on the left side - newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON); + // Lookup join + if (leftDistributionType == null) { + // By default, use local distribution for the left side + newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON); + } else { + switch (leftDistributionType) { + case LOCAL: + newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON); + break; + case HASH: + Preconditions.checkArgument(!joinInfo.leftKeys.isEmpty(), "Hash distribution requires join keys"); + newLeft = PinotLogicalExchange.create(left, RelDistributions.hash(joinInfo.leftKeys)); + break; + case RANDOM: + newLeft = PinotLogicalExchange.create(left, RelDistributions.RANDOM_DISTRIBUTED); + break; + default: + throw new IllegalArgumentException( + "Unsupported left distribution type: " + leftDistributionType + " for lookup join"); + } + } + Preconditions.checkArgument(rightDistributionType == null, + "Right distribution type hint is not supported for lookup join"); newRight = right; } else { - // Regular join - add exchange on both sides + // Hash join + // TODO: Validate if the configured distribution types are valid if (joinInfo.leftKeys.isEmpty()) { - // Broadcast the right side if there is no join key - newLeft = PinotLogicalExchange.create(left, RelDistributions.RANDOM_DISTRIBUTED); - newRight = PinotLogicalExchange.create(right, RelDistributions.BROADCAST_DISTRIBUTED); + // No join key, cannot use hash distribution + if (leftDistributionType == null) { + // By default, randomly distribute the left side + newLeft = PinotLogicalExchange.create(left, RelDistributions.RANDOM_DISTRIBUTED); + } else { + switch (leftDistributionType) { + case LOCAL: + newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON); + break; + case RANDOM: + newLeft = PinotLogicalExchange.create(left, RelDistributions.RANDOM_DISTRIBUTED); + break; + case BROADCAST: + newLeft = PinotLogicalExchange.create(left, RelDistributions.BROADCAST_DISTRIBUTED); + break; + default: + throw new IllegalArgumentException( + "Unsupported left distribution type: " + leftDistributionType + " for hash join without join keys"); + } + } + if (rightDistributionType == null) { + // By default, broadcast the right side + newRight = PinotLogicalExchange.create(right, RelDistributions.BROADCAST_DISTRIBUTED); + } else { + switch (rightDistributionType) { + case LOCAL: + newRight = PinotLogicalExchange.create(right, RelDistributions.SINGLETON); + break; + case RANDOM: + newRight = PinotLogicalExchange.create(right, RelDistributions.RANDOM_DISTRIBUTED); + break; + case BROADCAST: + newRight = PinotLogicalExchange.create(right, RelDistributions.BROADCAST_DISTRIBUTED); + break; + default: + throw new IllegalStateException( + "Unsupported right distribution type: " + rightDistributionType + " for hash join without join keys"); + } + } } else { - // Use hash exchange when there are join keys - newLeft = PinotLogicalExchange.create(left, RelDistributions.hash(joinInfo.leftKeys)); - newRight = PinotLogicalExchange.create(right, RelDistributions.hash(joinInfo.rightKeys)); + // There are join keys, hash distribution is supported + if (leftDistributionType == null) { + // By default, hash distribute the left side + newLeft = PinotLogicalExchange.create(left, RelDistributions.hash(joinInfo.leftKeys)); + } else { + switch (leftDistributionType) { + case LOCAL: + newLeft = PinotLogicalExchange.create(left, RelDistributions.SINGLETON); + break; + case HASH: + newLeft = PinotLogicalExchange.create(left, RelDistributions.hash(joinInfo.leftKeys)); + break; + case BROADCAST: + newLeft = PinotLogicalExchange.create(left, RelDistributions.BROADCAST_DISTRIBUTED); + break; + case RANDOM: + newLeft = PinotLogicalExchange.create(left, RelDistributions.RANDOM_DISTRIBUTED); + break; + default: + throw new IllegalArgumentException( + "Unsupported left distribution type: " + leftDistributionType + " for hash join with join keys"); + } + } + if (rightDistributionType == null) { + // By default, hash distribute the right side + newRight = PinotLogicalExchange.create(right, RelDistributions.hash(joinInfo.rightKeys)); + } else { + switch (rightDistributionType) { + case LOCAL: + newRight = PinotLogicalExchange.create(right, RelDistributions.SINGLETON); + break; + case HASH: + newRight = PinotLogicalExchange.create(right, RelDistributions.hash(joinInfo.rightKeys)); + break; + case BROADCAST: + newRight = PinotLogicalExchange.create(right, RelDistributions.BROADCAST_DISTRIBUTED); + break; + case RANDOM: + newRight = PinotLogicalExchange.create(right, RelDistributions.RANDOM_DISTRIBUTED); + break; + default: + throw new IllegalStateException( + "Unsupported right distribution type: " + rightDistributionType + " for hash join with join keys"); + } + } } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinToDynamicBroadcastRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinToDynamicBroadcastRule.java index c1924eb1d5b5..d8c7abf39e71 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinToDynamicBroadcastRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotJoinToDynamicBroadcastRule.java @@ -124,8 +124,7 @@ public boolean matches(RelOptRuleCall call) { Join join = call.rel(0); // Do not apply this rule if join strategy is explicitly set to something other than dynamic broadcast - String joinStrategy = PinotHintStrategyTable.getHintOption(join.getHints(), PinotHintOptions.JOIN_HINT_OPTIONS, - PinotHintOptions.JoinHintOptions.JOIN_STRATEGY); + String joinStrategy = PinotHintOptions.JoinHintOptions.getJoinStrategyHint(join); if (joinStrategy != null && !joinStrategy.equals( PinotHintOptions.JoinHintOptions.DYNAMIC_BROADCAST_JOIN_STRATEGY)) { return false; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java index 6af598bb3da7..7556993876a6 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java @@ -97,11 +97,14 @@ private void assignWorkersToNonRootFragment(PlanFragment fragment, DispatchableP Map metadataMap = context.getDispatchablePlanMetadataMap(); DispatchablePlanMetadata metadata = metadataMap.get(fragment.getFragmentId()); boolean leafPlan = isLeafPlan(metadata); - if (isLocalExchange(children)) { - // If it is a local exchange (single child with SINGLETON distribution), use the same worker assignment to avoid + int childIdWithLocalExchange = findLocalExchange(children); + if (childIdWithLocalExchange >= 0) { + // If there is a local exchange (child with SINGLETON distribution), use the same worker assignment to avoid // shuffling data. - // TODO: Support partition parallelism - DispatchablePlanMetadata childMetadata = metadataMap.get(children.get(0).getFragmentId()); + // TODO: + // 1. Support partition parallelism + // 2. Check if there are conflicts (multiple children with different local exchange) + DispatchablePlanMetadata childMetadata = metadataMap.get(children.get(childIdWithLocalExchange).getFragmentId()); metadata.setWorkerIdToServerInstanceMap(childMetadata.getWorkerIdToServerInstanceMap()); metadata.setPartitionFunction(childMetadata.getPartitionFunction()); if (leafPlan) { @@ -121,13 +124,19 @@ private void assignWorkersToNonRootFragment(PlanFragment fragment, DispatchableP } } - private boolean isLocalExchange(List children) { - if (children.size() != 1) { - return false; + /** + * Returns the index of the child fragment that has a local exchange (SINGLETON distribution), or -1 if none exists. + */ + private int findLocalExchange(List children) { + int numChildren = children.size(); + for (int i = 0; i < numChildren; i++) { + PlanNode childPlanNode = children.get(i).getFragmentRoot(); + if (childPlanNode instanceof MailboxSendNode + && ((MailboxSendNode) childPlanNode).getDistributionType() == RelDistribution.Type.SINGLETON) { + return i; + } } - PlanNode childPlanNode = children.get(0).getFragmentRoot(); - return childPlanNode instanceof MailboxSendNode - && ((MailboxSendNode) childPlanNode).getDistributionType() == RelDistribution.Type.SINGLETON; + return -1; } private static boolean isLeafPlan(DispatchablePlanMetadata metadata) { diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json index f275eca72f4c..70c5ccfded5b 100644 --- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json +++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json @@ -758,6 +758,61 @@ } ] }, + "broadcast_join_planning_tests": { + "queries": [ + { + "description": "Simple broadcast join", + "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(left_distribution_type = 'local', right_distribution_type = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], col2=[$2])", + "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])", + "\n PinotLogicalExchange(distribution=[single])", + "\n LogicalProject(col1=[$0])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast])", + "\n LogicalProject(col1=[$0], col2=[$1])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Broadcast join with filter on both left and right table", + "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(left_distribution_type = 'local', right_distribution_type = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON a.col1 = b.col1 WHERE a.col2 = 'foo' AND b.col2 = 'bar'", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], col2=[$2])", + "\n LogicalJoin(condition=[=($0, $1)], joinType=[inner])", + "\n PinotLogicalExchange(distribution=[single])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[=($1, _UTF-8'foo')])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast])", + "\n LogicalProject(col1=[$0], col2=[$1])", + "\n LogicalFilter(condition=[=($1, _UTF-8'bar')])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + }, + { + "description": "Broadcast join with transformation on both left and right table joined key", + "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(left_distribution_type = 'local', right_distribution_type = 'broadcast') */ a.col1, b.col2 FROM a JOIN b ON upper(a.col1) = upper(b.col1)", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], col2=[$2])", + "\n LogicalJoin(condition=[=($1, $3)], joinType=[inner])", + "\n PinotLogicalExchange(distribution=[single])", + "\n LogicalProject(col1=[$0], $f8=[UPPER($0)])", + "\n LogicalTableScan(table=[[default, a]])", + "\n PinotLogicalExchange(distribution=[broadcast])", + "\n LogicalProject(col2=[$1], $f8=[UPPER($0)])", + "\n LogicalTableScan(table=[[default, b]])", + "\n" + ] + } + ] + + }, "exception_throwing_join_planning_tests": { "queries": [ { diff --git a/pinot-query-runtime/src/test/resources/queries/QueryHints.json b/pinot-query-runtime/src/test/resources/queries/QueryHints.json index e8d30ed40905..e4645fa1747f 100644 --- a/pinot-query-runtime/src/test/resources/queries/QueryHints.json +++ b/pinot-query-runtime/src/test/resources/queries/QueryHints.json @@ -125,6 +125,14 @@ "description": "Colocated JOIN with partition column and group by non-partitioned column with stage parallelism", "sql": "SET stageParallelism=2; SELECT {tbl1}.name, SUM({tbl2}.num) FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ ON {tbl1}.num = {tbl2}.num GROUP BY {tbl1}.name" }, + { + "description": "Broadcast JOIN without partition hint", + "sql": "SELECT /*+ joinOptions(left_distribution_type = 'local', right_distribution_type = 'broadcast') */ {tbl1}.num, {tbl1}.name, {tbl2}.num, {tbl2}.val FROM {tbl1} JOIN {tbl2} ON {tbl1}.num = {tbl2}.num" + }, + { + "description": "Broadcast JOIN with partition hint", + "sql": "SELECT /*+ joinOptions(left_distribution_type = 'local', right_distribution_type = 'broadcast') */ {tbl1}.num, {tbl1}.name, {tbl2}.num, {tbl2}.val FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ JOIN {tbl2} ON {tbl1}.num = {tbl2}.num" + }, { "description": "Colocated, Dynamic broadcast SEMI-JOIN with partition column", "sql": "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ {tbl1}.num, {tbl1}.name FROM {tbl1} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ WHERE {tbl1}.num IN (SELECT {tbl2}.num FROM {tbl2} /*+ tableOptions(partition_function='hashcode', partition_key='num', partition_size='4') */ WHERE {tbl2}.val IN ('xxx', 'yyy'))"