diff --git a/src/include/duckdb/optimizer/join_order/join_relation.hpp b/src/include/duckdb/optimizer/join_order/join_relation.hpp index 7b040c1b5ef4..d5a898fc05d2 100644 --- a/src/include/duckdb/optimizer/join_order/join_relation.hpp +++ b/src/include/duckdb/optimizer/join_order/join_relation.hpp @@ -14,46 +14,47 @@ namespace duckdb { -//! Set of relations, used in the join graph. struct JoinRelationSet { - JoinRelationSet(unsafe_unique_array relations, idx_t count) : relations(std::move(relations)), count(count) { + JoinRelationSet() { } + JoinRelationSet(unsafe_unique_array &relations_, idx_t count) { + for (idx_t i = 0; i < count; i++) { + relations[relations_[i]] = true; + } + } + static void EnumerateRelations(std::bitset<12> relations, const std::function &callback); string ToString() const; - - unsafe_unique_array relations; - idx_t count; + idx_t Count() const; + idx_t NextNeighbor(idx_t i); + std::bitset<12> relations; static bool IsSubset(JoinRelationSet &super, JoinRelationSet &sub); + JoinRelationSet Copy() const; }; //! The JoinRelationTree is a structure holding all the created JoinRelationSet objects and allowing fast lookup on to //! them class JoinRelationSetManager { -public: - //! Contains a node with a JoinRelationSet and child relations - // FIXME: this structure is inefficient, could use a bitmap for lookup instead (todo: profile) - struct JoinRelationTreeNode { - unique_ptr relation; - unordered_map> children; - }; public: //! Create or get a JoinRelationSet from a single node with the given index - JoinRelationSet &GetJoinRelation(idx_t index); + reference GetJoinRelation(idx_t index); //! Create or get a JoinRelationSet from a set of relation bindings - JoinRelationSet &GetJoinRelation(const unordered_set &bindings); + reference GetJoinRelation(const unordered_set &bindings); //! Create or get a JoinRelationSet from a (sorted, duplicate-free!) list of relations - JoinRelationSet &GetJoinRelation(unsafe_unique_array relations, idx_t count); + reference GetJoinRelation(unsafe_unique_array relations, idx_t count); + //! Create or get a JoinRelationSet from another JoinRelation Set + reference GetJoinRelation(unique_ptr set); //! Union two sets of relations together and create a new relation set - JoinRelationSet &Union(JoinRelationSet &left, JoinRelationSet &right); + reference Union(JoinRelationSet &left, JoinRelationSet &right); // //! Create the set difference of left \ right (i.e. all elements in left that are not in right) // JoinRelationSet *Difference(JoinRelationSet *left, JoinRelationSet *right); string ToString() const; void Print(); private: - JoinRelationTreeNode root; + unordered_map, unique_ptr> active_relation_sets; }; } // namespace duckdb diff --git a/src/include/duckdb/optimizer/join_order/query_graph.hpp b/src/include/duckdb/optimizer/join_order/query_graph.hpp index e5ac67138978..671604f392dd 100644 --- a/src/include/duckdb/optimizer/join_order/query_graph.hpp +++ b/src/include/duckdb/optimizer/join_order/query_graph.hpp @@ -37,6 +37,10 @@ struct NeighborInfo { class QueryGraphEdges { public: //! Contains a node with info about neighboring relations and child edge infos + //! The root is a top level QueryEdge with no neighbors, then each child represents a single + //! relation node. Neighbors with these single nodes are in the neighbors vector. + //! If the edge is complex (like a+b = c), then the children structure is used to capture + //! the presence of [a, b]. struct QueryEdge { vector> neighbors; unordered_map> children; diff --git a/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp b/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp index e98868115af3..e523281e3be0 100644 --- a/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp +++ b/src/include/duckdb/optimizer/join_order/query_graph_manager.hpp @@ -97,7 +97,7 @@ class QueryGraphManager { //! A map to store the optimal join plan found for a specific JoinRelationSet* optional_ptr>> plans; -private: +// private: vector> filter_operators; //! Filter information including the column_bindings that join filters diff --git a/src/optimizer/join_order/cardinality_estimator.cpp b/src/optimizer/join_order/cardinality_estimator.cpp index 07cbf1dcd466..578c85725e35 100644 --- a/src/optimizer/join_order/cardinality_estimator.cpp +++ b/src/optimizer/join_order/cardinality_estimator.cpp @@ -21,7 +21,7 @@ bool CardinalityEstimator::EmptyFilter(FilterInfo &filter_info) { } void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { - D_ASSERT(filter_info.set.get().count >= 1); + D_ASSERT(filter_info.set.get().Count() >= 1); for (const RelationsToTDom &r2tdom : relations_to_tdoms) { auto &i_set = r2tdom.equivalent_relations; if (i_set.find(filter_info.left_binding) != i_set.end()) { @@ -37,7 +37,7 @@ void CardinalityEstimator::AddRelationTdom(FilterInfo &filter_info) { } bool CardinalityEstimator::SingleColumnFilter(duckdb::FilterInfo &filter_info) { - if (filter_info.left_set && filter_info.right_set && filter_info.set.get().count > 1) { + if (filter_info.left_set && filter_info.right_set && filter_info.set.get().Count() > 1) { // Both set and are from different relations return false; } @@ -111,8 +111,8 @@ void CardinalityEstimator::InitEquivalentRelations(const vectorleft_set->count >= 1); - D_ASSERT(filter->right_set->count >= 1); + D_ASSERT(filter->left_set->Count() >= 1); + D_ASSERT(filter->right_set->Count() >= 1); auto matching_equivalent_sets = DetermineMatchingEquivalentSets(filter.get()); AddToEquivalenceSets(filter.get(), matching_equivalent_sets); @@ -128,9 +128,9 @@ void CardinalityEstimator::RemoveEmptyTotalDomains() { double CardinalityEstimator::GetNumerator(JoinRelationSet &set) { double numerator = 1; - for (idx_t i = 0; i < set.count; i++) { - auto &single_node_set = set_manager.GetJoinRelation(set.relations[i]); - auto card_helper = relation_set_2_cardinality[single_node_set.ToString()]; + for (idx_t i = 0; i < set.Count(); i++) { + auto single_node_set = set_manager.GetJoinRelation(set.relations[i]); + auto card_helper = relation_set_2_cardinality[single_node_set.get().ToString()]; numerator *= card_helper.cardinality_before_filters == 0 ? 1 : card_helper.cardinality_before_filters; } return numerator; @@ -333,7 +333,7 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { continue; } left_subgraph->numerator_relations = &UpdateNumeratorRelations(*left_subgraph, right_subgraph, edge); - left_subgraph->relations = &set_manager.Union(*left_subgraph->relations, *right_subgraph.relations); + left_subgraph->relations = set_manager.Union(*left_subgraph->relations, *right_subgraph.relations).get(); left_subgraph->denom = CalculateUpdatedDenom(*left_subgraph, right_subgraph, edge); } else if (subgraph_connections.size() == 2) { // The two subgraphs in the subgraph_connections can be merged by this edge. @@ -341,7 +341,7 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { auto subgraph_to_merge_into = &subgraphs.at(subgraph_connections.at(0)); auto subgraph_to_delete = &subgraphs.at(subgraph_connections.at(1)); subgraph_to_merge_into->relations = - &set_manager.Union(*subgraph_to_merge_into->relations, *subgraph_to_delete->relations); + set_manager.Union(*subgraph_to_merge_into->relations, *subgraph_to_delete->relations).get(); subgraph_to_merge_into->numerator_relations = &UpdateNumeratorRelations(*subgraph_to_merge_into, *subgraph_to_delete, edge); subgraph_to_merge_into->denom = CalculateUpdatedDenom(*subgraph_to_merge_into, *subgraph_to_delete, edge); @@ -361,10 +361,10 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { auto final_subgraph = subgraphs.at(0); for (auto merge_with = subgraphs.begin() + 1; merge_with != subgraphs.end(); merge_with++) { D_ASSERT(final_subgraph.relations && merge_with->relations); - final_subgraph.relations = &set_manager.Union(*final_subgraph.relations, *merge_with->relations); + final_subgraph.relations = set_manager.Union(*final_subgraph.relations, *merge_with->relations).get(); D_ASSERT(final_subgraph.numerator_relations && merge_with->numerator_relations); final_subgraph.numerator_relations = - &set_manager.Union(*final_subgraph.numerator_relations, *merge_with->numerator_relations); + set_manager.Union(*final_subgraph.numerator_relations, *merge_with->numerator_relations).get(); final_subgraph.denom *= merge_with->denom; } } @@ -431,7 +431,7 @@ void CardinalityEstimator::InitCardinalityEstimatorProps(optional_ptr set, RelationStats &stats) { - D_ASSERT(set->count == 1); + D_ASSERT(set->Count() == 1); auto relation_id = set->relations[0]; //! Initialize the distinct count for all columns used in joins with the current relation. // D_ASSERT(stats.column_distinct_count.size() >= 1); diff --git a/src/optimizer/join_order/cost_model.cpp b/src/optimizer/join_order/cost_model.cpp index bfe64412f053..a7ae137c6458 100644 --- a/src/optimizer/join_order/cost_model.cpp +++ b/src/optimizer/join_order/cost_model.cpp @@ -9,7 +9,7 @@ CostModel::CostModel(QueryGraphManager &query_graph_manager) } double CostModel::ComputeCost(DPJoinNode &left, DPJoinNode &right) { - auto &combination = query_graph_manager.set_manager.Union(left.set, right.set); + auto combination = query_graph_manager.set_manager.Union(left.set, right.set); auto join_card = cardinality_estimator.EstimateCardinalityWithSet(combination); auto join_cost = join_card; return join_cost + left.cost + right.cost; diff --git a/src/optimizer/join_order/join_relation_set.cpp b/src/optimizer/join_order/join_relation_set.cpp index aa5767427ae7..4135709b9ce3 100644 --- a/src/optimizer/join_order/join_relation_set.cpp +++ b/src/optimizer/join_order/join_relation_set.cpp @@ -1,64 +1,95 @@ #include "duckdb/optimizer/join_order/join_relation.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" #include namespace duckdb { -using JoinRelationTreeNode = JoinRelationSetManager::JoinRelationTreeNode; - // LCOV_EXCL_START string JoinRelationSet::ToString() const { string result = "["; - result += StringUtil::Join(relations, count, ", ", [](const idx_t &relation) { return to_string(relation); }); + EnumerateRelations(relations, [&](idx_t relation) { result += to_string(relation) + ", "; }); result += "]"; return result; } // LCOV_EXCL_STOP //! Returns true if sub is a subset of super +// bool JoinRelationSetOld::IsSubset(JoinRelationSetOld &super, JoinRelationSetOld &sub) { +// D_ASSERT(sub.count > 0); +// if (sub.count > super.count) { +// return false; +// } +// idx_t j = 0; +// for (idx_t i = 0; i < super.count; i++) { +// if (sub.relations[j] == super.relations[i]) { +// j++; +// if (j == sub.count) { +// return true; +// } +// } +// } +// return false; +// } + bool JoinRelationSet::IsSubset(JoinRelationSet &super, JoinRelationSet &sub) { - D_ASSERT(sub.count > 0); - if (sub.count > super.count) { - return false; + std::bitset<12> sub_copy = sub.relations; + sub_copy &= super.relations; + return sub_copy == sub.relations; +} + +void JoinRelationSet::EnumerateRelations(std::bitset<12> relations, + const std::function &callback) { + for (idx_t i = 0; i < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; i++) { + if (relations[i]) { + callback(i); + } } - idx_t j = 0; - for (idx_t i = 0; i < super.count; i++) { - if (sub.relations[j] == super.relations[i]) { - j++; - if (j == sub.count) { - return true; - } +} + +idx_t JoinRelationSet::Count() const { + idx_t count = 0; + for (idx_t i = 0; i < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; i++) { + if (relations[i]) { + count++; } } - return false; + return count; } -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(unsafe_unique_array relations, idx_t count) { - // now look it up in the tree - reference info(root); - for (idx_t i = 0; i < count; i++) { - auto entry = info.get().children.find(relations[i]); - if (entry == info.get().children.end()) { - // node not found, create it - auto insert_it = info.get().children.insert(make_pair(relations[i], make_uniq())); - entry = insert_it.first; +idx_t JoinRelationSet::NextNeighbor(idx_t i) { + for (idx_t j = 0; j < i; j++) { + if (relations[j]) { + return j; } - // move to the next node - info = *entry->second; } - // now check if the JoinRelationSet has already been created - if (!info.get().relation) { - // if it hasn't we need to create it - info.get().relation = make_uniq(std::move(relations), count); + return DConstants::INVALID_INDEX; +} + +JoinRelationSet JoinRelationSet::Copy() const { + JoinRelationSet result; + result.relations = relations; + return result; +} + +reference JoinRelationSetManager::GetJoinRelation(unsafe_unique_array relations, idx_t count) { + auto ret = make_uniq(relations, count); + return GetJoinRelation(std::move(ret)); +} + +reference JoinRelationSetManager::GetJoinRelation(unique_ptr set) { + auto existing = active_relation_sets.find(set->relations); + if (existing == active_relation_sets.end()) { + auto copy = make_uniq(set->Copy()); + active_relation_sets[set->relations] = std::move(set); + set = std::move(copy); } - return *info.get().relation; + auto ret = active_relation_sets.find(set->relations); + auto &wat = *ret->second; + return wat; } //! Create or get a JoinRelationSet from a single node with the given index -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(idx_t index) { +reference JoinRelationSetManager::GetJoinRelation(idx_t index) { // create a sorted vector of the relations auto relations = make_unsafe_uniq_array(1); relations[0] = index; @@ -66,7 +97,7 @@ JoinRelationSet &JoinRelationSetManager::GetJoinRelation(idx_t index) { return GetJoinRelation(std::move(relations), count); } -JoinRelationSet &JoinRelationSetManager::GetJoinRelation(const unordered_set &bindings) { +reference JoinRelationSetManager::GetJoinRelation(const unordered_set &bindings) { // create a sorted vector of the relations unsafe_unique_array relations = bindings.empty() ? nullptr : make_unsafe_uniq_array(bindings.size()); idx_t count = 0; @@ -77,40 +108,11 @@ JoinRelationSet &JoinRelationSetManager::GetJoinRelation(const unordered_set(left.count + right.count); - idx_t count = 0; - // move through the left and right relations, eliminating duplicates - idx_t i = 0, j = 0; - while (true) { - if (i == left.count) { - // exhausted left relation, add remaining of right relation - for (; j < right.count; j++) { - relations[count++] = right.relations[j]; - } - break; - } else if (j == right.count) { - // exhausted right relation, add remaining of left - for (; i < left.count; i++) { - relations[count++] = left.relations[i]; - } - break; - } else if (left.relations[i] < right.relations[j]) { - // left is smaller, progress left and add it to the set - relations[count++] = left.relations[i]; - i++; - } else if (left.relations[i] > right.relations[j]) { - // right is smaller, progress right and add it to the set - relations[count++] = right.relations[j]; - j++; - } else { - D_ASSERT(left.relations[i] == right.relations[j]); - relations[count++] = left.relations[i]; - i++; - j++; - } - } - return GetJoinRelation(std::move(relations), count); +reference JoinRelationSetManager::Union(JoinRelationSet &left, JoinRelationSet &right) { + auto left_copy = make_uniq(left.Copy()); + auto right_copy = right.Copy(); + left_copy->relations |= right_copy.relations; + return GetJoinRelation(std::move(left_copy)); } // JoinRelationSet *JoinRelationSetManager::Difference(JoinRelationSet *left, JoinRelationSet *right) { @@ -144,23 +146,23 @@ JoinRelationSet &JoinRelationSetManager::Union(JoinRelationSet &left, JoinRelati // return GetJoinRelation(std::move(relations), count); // } -static string JoinRelationTreeNodeToString(const JoinRelationTreeNode *node) { - string result = ""; - if (node->relation) { - result += node->relation.get()->ToString() + "\n"; - } - for (auto &child : node->children) { - result += JoinRelationTreeNodeToString(child.second.get()); - } - return result; -} - -string JoinRelationSetManager::ToString() const { - return JoinRelationTreeNodeToString(&root); -} +// static string JoinRelationTreeNodeToString(const JoinRelationTreeNode *node) { +// string result = ""; +// if (node->relation) { +// result += node->relation.get()->ToString() + "\n"; +// } +// for (auto &child : node->children) { +// result += JoinRelationTreeNodeToString(child.second.get()); +// } +// return result; +// } -void JoinRelationSetManager::Print() { - Printer::Print(ToString()); -} +// string JoinRelationSetManagerOld::ToString() const { +// return JoinRelationTreeNodeToString(&root); +// } +// +// void JoinRelationSetManagerOld::Print() { +// Printer::Print(ToString()); +// } } // namespace duckdb diff --git a/src/optimizer/join_order/plan_enumerator.cpp b/src/optimizer/join_order/plan_enumerator.cpp index 04b396b97800..4d4248b37ad4 100644 --- a/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/optimizer/join_order/plan_enumerator.cpp @@ -34,8 +34,10 @@ static vector> AddSuperSets(const vector node, unordered_set &exclusion_set) { - for (idx_t i = 0; i < node->count; i++) { - exclusion_set.insert(node->relations[i]); + for (idx_t i = 0; i < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; i++) { + if (node->relations[i]) { + exclusion_set.insert(i); + } } } @@ -78,12 +80,12 @@ void PlanEnumerator::GenerateCrossProducts() { // generate a set of cross products to combine the currently available plans into a full join plan // we create edges between every relation with a high cost for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { - auto &left = query_graph_manager.set_manager.GetJoinRelation(i); + auto left = query_graph_manager.set_manager.GetJoinRelation(i); for (idx_t j = 0; j < query_graph_manager.relation_manager.NumRelations(); j++) { auto cross_product_allowed = query_graph_manager.relation_manager.CrossProductWithRelationAllowed(i) && query_graph_manager.relation_manager.CrossProductWithRelationAllowed(j); if (i != j && cross_product_allowed) { - auto &right = query_graph_manager.set_manager.GetJoinRelation(j); + auto right = query_graph_manager.set_manager.GetJoinRelation(j); query_graph_manager.CreateQueryGraphCrossProduct(left, right); } } @@ -146,7 +148,7 @@ DPJoinNode &PlanEnumerator::EmitPair(JoinRelationSet &left, JoinRelationSet &rig if (left_plan == plans.end() || right_plan == plans.end()) { throw InternalException("No left or right plan: internal error in join order optimizer"); } - auto &new_set = query_graph_manager.set_manager.Union(left, right); + auto new_set = query_graph_manager.set_manager.Union(left, right); // create the join tree based on combining the two plans auto new_plan = CreateJoinTree(new_set, info, *left_plan->second, *right_plan->second); // check if this plan is the optimal plan we found for this set of relations @@ -182,14 +184,20 @@ bool PlanEnumerator::TryEmitPair(JoinRelationSet &left, JoinRelationSet &right, } bool PlanEnumerator::EmitCSG(JoinRelationSet &node) { - if (node.count == query_graph_manager.relation_manager.NumRelations()) { + if (node.Count() == query_graph_manager.relation_manager.NumRelations()) { return true; } // create the exclusion set as everything inside the subgraph AND anything with members BELOW it unordered_set exclusion_set; - for (idx_t i = 0; i < node.relations[0]; i++) { - exclusion_set.insert(i); + for (idx_t j = 0; j < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; ++j) { + if (node.relations[j]) { + for (idx_t i = 0; i < j; i++) { + exclusion_set.insert(i); + } + break; + } } + UpdateExclusionSet(&node, exclusion_set); // find the neighbors given this exclusion set auto neighbors = query_graph.GetNeighbors(node, exclusion_set); @@ -216,7 +224,7 @@ bool PlanEnumerator::EmitCSG(JoinRelationSet &node) { for (auto neighbor : neighbors) { // since the GetNeighbors only returns the smallest element in a list, the entry might not be connected to // (only!) this neighbor, hence we have to do a connectedness check before we can emit it - auto &neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); + auto neighbor_relation = query_graph_manager.set_manager.GetJoinRelation(neighbor); auto connections = query_graph.GetConnections(node, neighbor_relation); if (!connections.empty()) { if (!TryEmitPair(node, neighbor_relation, connections)) { @@ -245,12 +253,12 @@ bool PlanEnumerator::EnumerateCmpRecursive(JoinRelationSet &left, JoinRelationSe vector> union_sets; union_sets.reserve(all_subset.size()); for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + auto neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); // emit the combinations of this node and its neighbors - auto &combined_set = query_graph_manager.set_manager.Union(right, neighbor); + auto combined_set = query_graph_manager.set_manager.Union(right, neighbor); // If combined_set.count == right.count, This means we found a neighbor that has been present before // This means we didn't set exclusion_set correctly. - D_ASSERT(combined_set.count > right.count); + D_ASSERT(combined_set.get().Count() > right.Count()); if (plans.find(combined_set) != plans.end()) { auto connections = query_graph.GetConnections(left, combined_set); if (!connections.empty()) { @@ -288,10 +296,13 @@ bool PlanEnumerator::EnumerateCSGRecursive(JoinRelationSet &node, unordered_set< vector> union_sets; union_sets.reserve(all_subset.size()); for (const auto &rel_set : all_subset) { - auto &neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); + auto neighbor = query_graph_manager.set_manager.GetJoinRelation(rel_set); // emit the combinations of this node and its neighbors - auto &new_set = query_graph_manager.set_manager.Union(node, neighbor); - D_ASSERT(new_set.count > node.count); + auto new_set = query_graph_manager.set_manager.Union(node, neighbor); + if (new_set.get().Count() <= node.Count()) { + auto break_here = 0; + } + D_ASSERT(new_set.get().Count() > node.Count()); if (plans.find(new_set) != plans.end()) { if (!EmitCSG(new_set)) { return false; @@ -319,15 +330,16 @@ bool PlanEnumerator::SolveJoinOrderExactly() { // now we perform the actual dynamic programming to compute the final result // we enumerate over all the possible pairs in the neighborhood for (idx_t i = query_graph_manager.relation_manager.NumRelations(); i > 0; i--) { + auto relation_id = i - 1; // for every node in the set, we consider it as the start node once - auto &start_node = query_graph_manager.set_manager.GetJoinRelation(i - 1); + auto start_node = query_graph_manager.set_manager.GetJoinRelation(relation_id); // emit the start node if (!EmitCSG(start_node)) { return false; } // initialize the set of exclusion_set as all the nodes with a number below this unordered_set exclusion_set; - for (idx_t j = 0; j < i; j++) { + for (idx_t j = 0; j < relation_id; j++) { exclusion_set.insert(j); } // then we recursively search for neighbors that do not belong to the banned entries @@ -434,8 +446,8 @@ void PlanEnumerator::SolveJoinOrderApproximately() { // important to erase the biggest element first // if we erase the smallest element first the index of the biggest element changes - auto &new_set = query_graph_manager.set_manager.Union(join_relations.at(best_left).get(), - join_relations.at(best_right).get()); + auto new_set = query_graph_manager.set_manager.Union(join_relations.at(best_left).get(), + join_relations.at(best_right).get()); D_ASSERT(best_right > best_left); join_relations.erase(join_relations.begin() + (int64_t)best_right); join_relations.erase(join_relations.begin() + (int64_t)best_left); @@ -456,13 +468,13 @@ void PlanEnumerator::InitLeafPlans() { // then update the total domains based on the cardinalities of each relation. for (idx_t i = 0; i < relation_stats.size(); i++) { auto stats = relation_stats.at(i); - auto &relation_set = query_graph_manager.set_manager.GetJoinRelation(i); + auto relation_set = query_graph_manager.set_manager.GetJoinRelation(i); auto join_node = make_uniq(relation_set); join_node->cost = 0; join_node->cardinality = stats.cardinality; - D_ASSERT(join_node->set.count == 1); + D_ASSERT(join_node->set.Count() == 1); plans[relation_set] = std::move(join_node); - cost_model.cardinality_estimator.InitCardinalityEstimatorProps(&relation_set, stats); + cost_model.cardinality_estimator.InitCardinalityEstimatorProps(relation_set.get(), stats); } } @@ -485,7 +497,7 @@ void PlanEnumerator::SolveJoinOrder() { for (idx_t i = 0; i < query_graph_manager.relation_manager.NumRelations(); i++) { bindings.insert(i); } - auto &total_relation = query_graph_manager.set_manager.GetJoinRelation(bindings); + auto total_relation = query_graph_manager.set_manager.GetJoinRelation(bindings); auto final_plan = plans.find(total_relation); if (final_plan == plans.end()) { // could not find the final plan diff --git a/src/optimizer/join_order/query_graph.cpp b/src/optimizer/join_order/query_graph.cpp index beb9e1521a7b..70fcd3524adb 100644 --- a/src/optimizer/join_order/query_graph.cpp +++ b/src/optimizer/join_order/query_graph.cpp @@ -4,6 +4,8 @@ #include "duckdb/common/string_util.hpp" #include "duckdb/common/assert.hpp" +#include + namespace duckdb { using QueryEdge = QueryGraphEdges::QueryEdge; @@ -37,24 +39,23 @@ void QueryGraphEdges::Print() { // LCOV_EXCL_STOP optional_ptr QueryGraphEdges::GetQueryEdge(JoinRelationSet &left) { - D_ASSERT(left.count > 0); // find the EdgeInfo corresponding to the left set optional_ptr info(&root); - for (idx_t i = 0; i < left.count; i++) { - auto entry = info.get()->children.find(left.relations[i]); + JoinRelationSet::EnumerateRelations(left.relations, [&](idx_t relation_id) { + auto entry = info.get()->children.find(relation_id); if (entry == info.get()->children.end()) { // node not found, create it - auto insert_it = info.get()->children.insert(make_pair(left.relations[i], make_uniq())); + auto insert_it = info.get()->children.insert(make_pair(relation_id, make_uniq())); entry = insert_it.first; } // move to the next node info = entry->second; - } + }); return info; } void QueryGraphEdges::CreateEdge(JoinRelationSet &left, JoinRelationSet &right, optional_ptr filter_info) { - D_ASSERT(left.count > 0 && right.count > 0); + D_ASSERT(left.Count() > 0 && right.Count() > 0); // find the EdgeInfo corresponding to the left set auto info = GetQueryEdge(left); // now insert the edge to the right relation, if it does not exist @@ -86,29 +87,37 @@ void QueryGraphEdges::EnumerateNeighborsDFS(JoinRelationSet &node, reference new_info = *iter->second; - EnumerateNeighborsDFS(node, new_info, node_index + 1, callback); + for (idx_t node_index = index; node_index < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; ++node_index) { + if (node.relations[node_index]) { + auto iter = info.get().children.find(node_index); + if (iter != info.get().children.end()) { + reference new_info = *iter->second; + EnumerateNeighborsDFS(node, new_info, node_index + 1, callback); + } } } } void QueryGraphEdges::EnumerateNeighbors(JoinRelationSet &node, const std::function &callback) const { - for (idx_t j = 0; j < node.count; j++) { - auto iter = root.children.find(node.relations[j]); - if (iter != root.children.end()) { - reference new_info = *iter->second; - EnumerateNeighborsDFS(node, new_info, j + 1, callback); + for (idx_t j = 0; j < PlanEnumerator::THRESHOLD_TO_SWAP_TO_APPROXIMATE; j++) { + if (node.relations[j]) { + auto iter = root.children.find(j); + if (iter != root.children.end()) { + reference new_info = *iter->second; + EnumerateNeighborsDFS(node, new_info, j + 1, callback); + } } } } //! Returns true if a JoinRelationSet is banned by the list of exclusion_set, false otherwise static bool JoinRelationSetIsExcluded(optional_ptr node, unordered_set &exclusion_set) { - return exclusion_set.find(node->relations[0]) != exclusion_set.end(); + bool is_excluded = false; + JoinRelationSet::EnumerateRelations(node->relations, [&](idx_t relation_id) { + is_excluded |= exclusion_set.find(relation_id) != exclusion_set.end(); + }); + return is_excluded; } const vector QueryGraphEdges::GetNeighbors(JoinRelationSet &node, unordered_set &exclusion_set) const { @@ -116,7 +125,11 @@ const vector QueryGraphEdges::GetNeighbors(JoinRelationSet &node, unorder EnumerateNeighbors(node, [&](NeighborInfo &info) -> bool { if (!JoinRelationSetIsExcluded(info.neighbor, exclusion_set)) { // add the smallest node of the neighbor to the set - result.insert(info.neighbor->relations[0]); + JoinRelationSet::EnumerateRelations(info.neighbor->relations, [&](idx_t relation_id) { + if (result.size() == 0) { + result.insert(relation_id); + } + }); } return false; }); diff --git a/src/optimizer/join_order/query_graph_manager.cpp b/src/optimizer/join_order/query_graph_manager.cpp index 3a5214d2c206..9b68ec53c73e 100644 --- a/src/optimizer/join_order/query_graph_manager.cpp +++ b/src/optimizer/join_order/query_graph_manager.cpp @@ -97,10 +97,10 @@ void QueryGraphManager::CreateHyperGraphEdges() { // both the left and the right side have bindings // first create the relation sets, if they do not exist if (!filter_info->left_set) { - filter_info->left_set = &set_manager.GetJoinRelation(left_bindings); + filter_info->left_set = set_manager.GetJoinRelation(left_bindings).get(); } if (!filter_info->right_set) { - filter_info->right_set = &set_manager.GetJoinRelation(right_bindings); + filter_info->right_set = set_manager.GetJoinRelation(right_bindings).get(); } // we can only create a meaningful edge if the sets are not exactly the same if (filter_info->left_set != filter_info->right_set) { @@ -179,7 +179,7 @@ unique_ptr QueryGraphManager::Reconstruct(unique_ptr> extracted_relations; @@ -310,10 +310,10 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorset.count == 1); + D_ASSERT(node->set.Count() == 1); D_ASSERT(extracted_relations[node->set.relations[0]]); result_relation = &node->set; result_operator = std::move(extracted_relations[result_relation->relations[0]]); @@ -333,7 +333,7 @@ GenerateJoinRelation QueryGraphManager::GenerateJoins(vectorfilter) { // now check if the filter is a subset of the current relation // note that infos with an empty relation set are a special case and we do not push them down - if (info.set.get().count > 0 && JoinRelationSet::IsSubset(*result_relation, info.set)) { + if (info.set.get().Count() > 0 && JoinRelationSet::IsSubset(*result_relation, info.set)) { auto &filter_and_binding = filters_and_bindings[info.filter_index]; auto filter = std::move(filter_and_binding->filter); // if it is, we can push the filter diff --git a/src/optimizer/join_order/relation_manager.cpp b/src/optimizer/join_order/relation_manager.cpp index de75659bc377..e0e198a02357 100644 --- a/src/optimizer/join_order/relation_manager.cpp +++ b/src/optimizer/join_order/relation_manager.cpp @@ -518,20 +518,20 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op ExtractBindings(*comp.left, left_bindings); if (!left_set) { - left_set = set_manager.GetJoinRelation(left_bindings); + left_set = set_manager.GetJoinRelation(left_bindings).get(); } else { - left_set = set_manager.Union(set_manager.GetJoinRelation(left_bindings), *left_set); + left_set = set_manager.Union(set_manager.GetJoinRelation(left_bindings), *left_set).get(); } if (!right_set) { - right_set = set_manager.GetJoinRelation(right_bindings); + right_set = set_manager.GetJoinRelation(right_bindings).get(); } else { - right_set = set_manager.Union(set_manager.GetJoinRelation(right_bindings), *right_set); + right_set = set_manager.Union(set_manager.GetJoinRelation(right_bindings), *right_set).get(); } } - full_set = set_manager.Union(*left_set, *right_set); - D_ASSERT(left_set && left_set->count > 0); - D_ASSERT(right_set && right_set->count == 1); - D_ASSERT(full_set && full_set->count > 0); + full_set = set_manager.Union(*left_set, *right_set).get(); + D_ASSERT(left_set && left_set->Count() > 0); + D_ASSERT(right_set && right_set->Count() == 1); + D_ASSERT(full_set && full_set->Count() > 0); // now we push the conjunction expressions // In QueryGraphManager::GenerateJoins we extract each condition again and create a standalone join @@ -551,7 +551,7 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op filter_set.insert(*comparison); unordered_set bindings; ExtractBindings(*comparison, bindings); - auto &set = set_manager.GetJoinRelation(bindings); + auto set = set_manager.GetJoinRelation(bindings); auto filter_info = make_uniq(std::move(comparison), set, filters_and_bindings.size(), join.join_type); filters_and_bindings.push_back(std::move(filter_info)); @@ -572,7 +572,7 @@ vector> RelationManager::ExtractEdges(LogicalOperator &op leftover_expressions.push_back(std::move(expression)); continue; } - auto &set = set_manager.GetJoinRelation(bindings); + auto set = set_manager.GetJoinRelation(bindings); auto filter_info = make_uniq(std::move(expression), set, filters_and_bindings.size()); filters_and_bindings.push_back(std::move(filter_info)); } diff --git a/test/optimizer/joins/test_simple_joins.test b/test/optimizer/joins/test_simple_joins.test new file mode 100644 index 000000000000..fa311e631568 --- /dev/null +++ b/test/optimizer/joins/test_simple_joins.test @@ -0,0 +1,18 @@ +# name: test/optimizer/joins/test_simple_joins.test +# description: just test simple joins +# group: [joins] + +statement ok +create table t1 as select range a from range(10); + +statement ok +create table t2 as select range b from range(100); + +statement ok +create table t3 as select range c from range(1000); + +statement ok +select * from t1, t2, t3 where a = b and b = c; + + +