diff --git a/src/search/ir.h b/src/search/ir.h index d7da716ac50..3acc5ae8b7e 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -85,6 +85,7 @@ struct FieldRef : Ref { const FieldInfo *info = nullptr; explicit FieldRef(std::string name) : name(std::move(name)) {} + FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)), info(info) {} std::string_view Name() const override { return "FieldRef"; } std::string Dump() const override { return name; } diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h index 5fa57b1a6be..2068a45a4f4 100644 --- a/src/search/ir_pass.h +++ b/src/search/ir_pass.h @@ -28,6 +28,8 @@ namespace kqir { struct Pass { virtual std::unique_ptr Transform(std::unique_ptr node) = 0; + virtual void Reset() {} + virtual ~Pass() = default; }; diff --git a/src/search/passes/interval_analysis.h b/src/search/passes/interval_analysis.h index 3010ca73fcb..59959979560 100644 --- a/src/search/passes/interval_analysis.h +++ b/src/search/passes/interval_analysis.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include #include @@ -33,61 +34,139 @@ namespace kqir { struct IntervalAnalysis : Visitor { - std::map> result; + struct IntervalInfo { + std::string field_name; + const FieldInfo *field_info; + IntervalSet intervals; + }; + + std::map result; + + void Reset() override { result.clear(); } template std::unique_ptr VisitImpl(std::unique_ptr node) { node = Node::MustAs(Visitor::Visit(std::move(node))); - std::map>> interval_map; + struct LocalIntervalInfo { + IntervalSet intervals; + std::set nodes; + const FieldInfo *field; + }; + + std::map interval_map; for (const auto &n : node->inners) { IntervalSet new_interval; + const FieldInfo *new_field_info = nullptr; std::string new_field; if (auto v = dynamic_cast(n.get())) { new_interval = IntervalSet(v->op, v->num->val); new_field = v->field->name; + new_field_info = v->field->info; } else if (auto iter = result.find(n.get()); iter != result.end()) { - new_interval = iter->second.second; - new_field = iter->second.first; + new_interval = iter->second.intervals; + new_field = iter->second.field_name; + new_field_info = iter->second.field_info; } else { continue; } if (auto iter = interval_map.find(new_field); iter != interval_map.end()) { if constexpr (std::is_same_v) { - iter->second.first = iter->second.first | new_interval; + iter->second.intervals = iter->second.intervals | new_interval; } else if constexpr (std::is_same_v) { - iter->second.first = iter->second.first & new_interval; + iter->second.intervals = iter->second.intervals & new_interval; } else { static_assert(AlwaysFalse); } - iter->second.second.emplace(n.get()); + iter->second.nodes.emplace(n.get()); + iter->second.field = new_field_info; } else { - interval_map.emplace(new_field, std::make_pair(new_interval, std::set{n.get()})); + interval_map.emplace(new_field, LocalIntervalInfo{new_interval, std::set{n.get()}, new_field_info}); } } if (interval_map.size() == 1) { const auto &elem = *interval_map.begin(); - result.emplace(node.get(), std::make_pair(elem.first, elem.second.first)); + result.emplace(node.get(), IntervalInfo{elem.first, elem.second.field, elem.second.intervals}); } for (const auto &[field, info] : interval_map) { - if (info.first.IsEmpty() || info.first.IsFull()) { - auto iter = std::remove_if(node->inners.begin(), node->inners.end(), - [&info = info](const auto &n) { return info.second.count(n.get()) == 1; }); - node->inners.erase(iter, node->inners.end()); + auto iter = std::remove_if(node->inners.begin(), node->inners.end(), + [&info = info](const auto &n) { return info.nodes.count(n.get()) == 1; }); + node->inners.erase(iter, node->inners.end()); + + auto field_node = std::make_unique(field, info.field); + node->inners.emplace_back(GenerateFromInterval(info.intervals, field_node.get())); + } + + return node; + } + + static std::unique_ptr GenerateFromInterval(const IntervalSet &intervals, FieldRef *field) { + if (intervals.IsEmpty()) { + return std::make_unique(false); + } + + if (intervals.IsFull()) { + return std::make_unique(true); + } + + std::vector> exprs; + + if (intervals.intervals.size() > 1 && std::isinf(intervals.intervals.front().first) && + std::isinf(intervals.intervals.back().second)) { + bool is_all_ne = true; + auto iter = intervals.intervals.begin(); + auto last = iter->second; + ++iter; + while (iter != intervals.intervals.end()) { + if (iter->first != IntervalSet::NextNum(last)) { + is_all_ne = false; + break; + } + + last = iter->second; + ++iter; } - if (info.first.IsEmpty()) { - node->inners.emplace_back(std::make_unique(false)); - } else if (info.first.IsFull()) { - node->inners.emplace_back(std::make_unique(true)); + if (is_all_ne) { + for (auto i = intervals.intervals.begin(); i != intervals.intervals.end() && !std::isinf(i->second); ++i) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::NE, field->CloneAs(), + std::make_unique(i->second))); + } + + return std::make_unique(std::move(exprs)); } } - return node; + for (auto [l, r] : intervals.intervals) { + if (std::isinf(l)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::LT, field->CloneAs(), + std::make_unique(r))); + } else if (std::isinf(r)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::GET, field->CloneAs(), + std::make_unique(l))); + } else if (r == IntervalSet::NextNum(l)) { + exprs.emplace_back(std::make_unique(NumericCompareExpr::EQ, field->CloneAs(), + std::make_unique(l))); + } else { + std::vector> sub_expr; + sub_expr.emplace_back(std::make_unique(NumericCompareExpr::GET, field->CloneAs(), + std::make_unique(l))); + sub_expr.emplace_back(std::make_unique(NumericCompareExpr::LT, field->CloneAs(), + std::make_unique(r))); + + exprs.emplace_back(std::make_unique(std::move(sub_expr))); + } + } + + if (exprs.size() == 1) { + return std::move(exprs.front()); + } else { + return std::make_unique(std::move(exprs)); + } } std::unique_ptr Visit(std::unique_ptr node) override { return VisitImpl(std::move(node)); } diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h index 094faa23a83..480e27a7ae9 100644 --- a/src/search/passes/manager.h +++ b/src/search/passes/manager.h @@ -36,6 +36,7 @@ using PassSequence = std::vector>; struct PassManager { static std::unique_ptr Execute(const PassSequence &seq, std::unique_ptr node) { for (auto &pass : seq) { + pass->Reset(); node = pass->Transform(std::move(node)); } return node; diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc index f40dd8a4139..bfb630a5f2a 100644 --- a/tests/cppunit/ir_pass_test.cc +++ b/tests/cppunit/ir_pass_test.cc @@ -21,6 +21,8 @@ #include "search/ir_pass.h" #include "gtest/gtest.h" +#include "search/interval.h" +#include "search/ir_sema_checker.h" #include "search/passes/interval_analysis.h" #include "search/passes/lower_to_plan.h" #include "search/passes/manager.h" @@ -130,4 +132,34 @@ TEST(IRPassTest, IntervalAnalysis) { "select * from a where false"); ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a > 3 or a < 1) and a = 2"))->Dump(), "select * from a where false"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where b = 1 and (a = 1 or a != 1)"))->Dump(), + "select * from a where b = 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or b = 1 or a != 1"))->Dump(), + "select * from a where true"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a < 3 or a > 1) and b >= 1"))->Dump(), + "select * from a where b >= 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 or a != 2"))->Dump(), + "select * from a where true"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a = 2"))->Dump(), + "select * from a where false"); + + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 and a < 3"))->Dump(), + "select * from a where a < 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 or a < 3"))->Dump(), + "select * from a where a < 3"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a < 3"))->Dump(), + "select * from a where a = 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a < 3"))->Dump(), + "select * from a where a < 3"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a = 3"))->Dump(), + "select * from a where (or a = 1, a = 3)"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1"))->Dump(), + "select * from a where a != 1"); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and a != 2"))->Dump(), + "select * from a where (and a != 1, a != 2)"); + ASSERT_EQ( + PassManager::Execute(ia_passes, *Parse("select * from a where a >= 0 and a >= 1 and a < 4 and a != 2"))->Dump(), + fmt::format("select * from a where (or (and a >= 1, a < 2), (and a >= {}, a < 4))", IntervalSet::NextNum(2))); + ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and b > 1 and b = 2"))->Dump(), + "select * from a where (and a != 1, b = 2)"); }