Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions src/optimizer/join_order/relation_manager.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "duckdb/optimizer/join_order/relation_manager.hpp"

#include "duckdb/common/enums/join_type.hpp"
#include "duckdb/common/enums/logical_operator_type.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/optimizer/join_order/join_order_optimizer.hpp"
#include "duckdb/optimizer/join_order/relation_statistics_helper.hpp"
Expand Down Expand Up @@ -46,30 +47,49 @@ void RelationManager::AddAggregateOrWindowRelation(LogicalOperator &op, optional
void RelationManager::AddRelation(LogicalOperator &op, optional_ptr<LogicalOperator> parent,
const RelationStats &stats) {

// if parent is null, then this is a root relation
// if parent is not null, it should have multiple children
D_ASSERT(!parent || parent->children.size() >= 2);
// // if parent is null, then this is a root relation
// // if parent is not null, it should have multiple children
// D_ASSERT(!parent || parent->children.size() >= 2);
// auto relation = make_uniq<SingleJoinRelation>(op, parent, stats);
// auto relation_id = relations.size();
//
// auto table_indexes = op.GetTableIndex();
// if (table_indexes.empty()) {
// // relation represents a non-reorderable relation, most likely a join relation
// // Get the tables referenced in the non-reorderable relation and add them to the relation mapping
// // This should return all table references, even if there are nested non-reorderable joins.
// unordered_set<idx_t> table_references;
// LogicalJoin::GetTableReferences(op, table_references);
// D_ASSERT(table_references.size() > 0);
// for (auto &reference : table_references) {
// D_ASSERT(relation_mapping.find(reference) == relation_mapping.end());
// relation_mapping[reference] = relation_id;
// }
// } else if (op.type == LogicalOperatorType::LOGICAL_UNNEST) {
// // logical unnest has a logical_unnest index, but other bindings can refer to
// // columns that are not unnested.
// auto bindings = op.GetColumnBindings();
// for (auto &binding : bindings) {
// relation_mapping[binding.table_index] = relation_id;
// }
// } else {
// // Relations should never return more than 1 table index
// D_ASSERT(table_indexes.size() == 1);
// idx_t table_index = table_indexes.at(0);
// D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end());
// relation_mapping[table_index] = relation_id;
// }
// relations.push_back(std::move(relation));
// op.estimated_cardinality = stats.cardinality;
// op.has_estimated_cardinality = true;
auto relation = make_uniq<SingleJoinRelation>(op, parent, stats);
auto relation_id = relations.size();

auto table_indexes = op.GetTableIndex();
if (table_indexes.empty()) {
// relation represents a non-reorderable relation, most likely a join relation
// Get the tables referenced in the non-reorderable relation and add them to the relation mapping
// This should all table references, even if there are nested non-reorderable joins.
unordered_set<idx_t> table_references;
LogicalJoin::GetTableReferences(op, table_references);
D_ASSERT(table_references.size() > 0);
for (auto &reference : table_references) {
D_ASSERT(relation_mapping.find(reference) == relation_mapping.end());
relation_mapping[reference] = relation_id;
auto op_bindings = op.GetColumnBindings();
for (auto &binding : op_bindings) {
if (relation_mapping.find(binding.table_index) == relation_mapping.end()) {
relation_mapping[binding.table_index] = relation_id;
}
} else {
// Relations should never return more than 1 table index
D_ASSERT(table_indexes.size() == 1);
idx_t table_index = table_indexes.at(0);
D_ASSERT(relation_mapping.find(table_index) == relation_mapping.end());
relation_mapping[table_index] = relation_id;
}
relations.push_back(std::move(relation));
op.estimated_cardinality = stats.cardinality;
Expand Down Expand Up @@ -319,8 +339,8 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica
return can_reorder_left && can_reorder_right;
}
case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: {
bool can_reorder_right = ExtractJoinRelations(optimizer, *op->children[1], filter_operators, op);
bool can_reorder_left = ExtractJoinRelations(optimizer, *op->children[0], filter_operators, op);
bool can_reorder_right = ExtractJoinRelations(optimizer, *op->children[1], filter_operators, op);
return can_reorder_left && can_reorder_right;
}
case LogicalOperatorType::LOGICAL_DUMMY_SCAN: {
Expand Down
98 changes: 73 additions & 25 deletions test/optimizer/join_reorder_optimizer.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# description: Make sure we can emit a vaild join order by DPhyp if hypergraph is connected
# group: [optimizer]

statement ok
pragma enable_verification

statement ok
CREATE TABLE t1(c1 int, c2 int, c3 int, c4 int)

Expand Down Expand Up @@ -33,35 +36,80 @@ statement ok
PRAGMA debug_force_no_cross_product=true

statement ok
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
t3.c3 = t4.c3

statement ok
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
t3.c3 = t4.c3 AND
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
t3.c3 = t4.c3 AND
t4.c4 = t1.c4

statement ok
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
EXPLAIN
SELECT
COUNT(*)
FROM
t1, t2, t3, t4
WHERE
t1.c1 = t2.c1 AND
t2.c2 = t3.c2 AND
t1.c1 + t2.c2 + t3.c3= 3 * t4.c4

statement ok
PRAGMA debug_force_no_cross_product=false

statement ok
with
grid as (
from (values ('ABC'), ('DEF')) as v(data)
select
unnest(split(data, '')) as letter,
row_number() over () as row_id,
generate_subscripts(split(data, ''), 1) AS col_id,
),
search(row_i, col_i, letter_to_match) as (
values (0, 0, 'A'), (0, 1, 'B'),
)
from (from grid cross join search) as grid_searches
select exists(
from grid as grid_to_search
where 1=1
and grid_searches.row_id = grid_to_search.row_id + grid_searches.row_i
and grid_searches.col_id = grid_to_search.col_id + grid_searches.col_i
and grid_searches.letter_to_match = grid_to_search.letter
)

statement ok
with
grid as (
from (values ('ABC', 39), ('DEF', 50)) as v(data, row_id)
select
unnest(split(data, '')) as letter,
row_id,
generate_subscripts(split(data, ''), 1) AS col_id,
),
search(row_i, col_i, letter_to_match) as (
values (0, 0, 'A'), (0, 1, 'B'),
)
from (from grid cross join search) as grid_searches
select exists(
from grid as grid_to_search
where 1=1
and grid_searches.row_id = grid_to_search.row_id + grid_searches.row_i
and grid_searches.col_id = grid_to_search.col_id + grid_searches.col_i
and grid_searches.letter_to_match = grid_to_search.letter
)
Loading