Skip to content

Commit

Permalink
Optimize numeric comparison via interval analysis in KQIR (#2257)
Browse files Browse the repository at this point in the history
Co-authored-by: hulk <[email protected]>
  • Loading branch information
PragmaTwice and git-hulk committed Apr 21, 2024
1 parent 0f2de7d commit 899216d
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct FieldRef : Ref {
const FieldInfo *info = nullptr;

explicit FieldRef(std::string name) : name(std::move(name)) {}
FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)), info(info) {}

std::string_view Name() const override { return "FieldRef"; }
std::string Dump() const override { return name; }
Expand Down
2 changes: 2 additions & 0 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace kqir {
struct Pass {
virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;

virtual void Reset() {}

virtual ~Pass() = default;
};

Expand Down
115 changes: 97 additions & 18 deletions src/search/passes/interval_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#pragma once

#include <algorithm>
#include <cmath>
#include <memory>
#include <set>
#include <type_traits>
Expand All @@ -33,61 +34,139 @@
namespace kqir {

struct IntervalAnalysis : Visitor {
std::map<Node *, std::pair<std::string, IntervalSet>> result;
struct IntervalInfo {
std::string field_name;
const FieldInfo *field_info;
IntervalSet intervals;
};

std::map<Node *, IntervalInfo> result;

void Reset() override { result.clear(); }

template <typename T>
std::unique_ptr<Node> VisitImpl(std::unique_ptr<T> node) {
node = Node::MustAs<T>(Visitor::Visit(std::move(node)));

std::map<std::string, std::pair<IntervalSet, std::set<Node *>>> interval_map;
struct LocalIntervalInfo {
IntervalSet intervals;
std::set<Node *> nodes;
const FieldInfo *field;
};

std::map<std::string, LocalIntervalInfo> interval_map;
for (const auto &n : node->inners) {
IntervalSet new_interval;
const FieldInfo *new_field_info = nullptr;
std::string new_field;

if (auto v = dynamic_cast<NumericCompareExpr *>(n.get())) {
new_interval = IntervalSet(v->op, v->num->val);
new_field = v->field->name;
new_field_info = v->field->info;
} else if (auto iter = result.find(n.get()); iter != result.end()) {
new_interval = iter->second.second;
new_field = iter->second.first;
new_interval = iter->second.intervals;
new_field = iter->second.field_name;
new_field_info = iter->second.field_info;
} else {
continue;
}

if (auto iter = interval_map.find(new_field); iter != interval_map.end()) {
if constexpr (std::is_same_v<T, OrExpr>) {
iter->second.first = iter->second.first | new_interval;
iter->second.intervals = iter->second.intervals | new_interval;
} else if constexpr (std::is_same_v<T, AndExpr>) {
iter->second.first = iter->second.first & new_interval;
iter->second.intervals = iter->second.intervals & new_interval;
} else {
static_assert(AlwaysFalse<T>);
}
iter->second.second.emplace(n.get());
iter->second.nodes.emplace(n.get());
iter->second.field = new_field_info;
} else {
interval_map.emplace(new_field, std::make_pair(new_interval, std::set<Node *>{n.get()}));
interval_map.emplace(new_field, LocalIntervalInfo{new_interval, std::set<Node *>{n.get()}, new_field_info});
}
}

if (interval_map.size() == 1) {
const auto &elem = *interval_map.begin();
result.emplace(node.get(), std::make_pair(elem.first, elem.second.first));
result.emplace(node.get(), IntervalInfo{elem.first, elem.second.field, elem.second.intervals});
}

for (const auto &[field, info] : interval_map) {
if (info.first.IsEmpty() || info.first.IsFull()) {
auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
[&info = info](const auto &n) { return info.second.count(n.get()) == 1; });
node->inners.erase(iter, node->inners.end());
auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
[&info = info](const auto &n) { return info.nodes.count(n.get()) == 1; });
node->inners.erase(iter, node->inners.end());

auto field_node = std::make_unique<FieldRef>(field, info.field);
node->inners.emplace_back(GenerateFromInterval(info.intervals, field_node.get()));
}

return node;
}

static std::unique_ptr<QueryExpr> GenerateFromInterval(const IntervalSet &intervals, FieldRef *field) {
if (intervals.IsEmpty()) {
return std::make_unique<BoolLiteral>(false);
}

if (intervals.IsFull()) {
return std::make_unique<BoolLiteral>(true);
}

std::vector<std::unique_ptr<QueryExpr>> exprs;

if (intervals.intervals.size() > 1 && std::isinf(intervals.intervals.front().first) &&
std::isinf(intervals.intervals.back().second)) {
bool is_all_ne = true;
auto iter = intervals.intervals.begin();
auto last = iter->second;
++iter;
while (iter != intervals.intervals.end()) {
if (iter->first != IntervalSet::NextNum(last)) {
is_all_ne = false;
break;
}

last = iter->second;
++iter;
}

if (info.first.IsEmpty()) {
node->inners.emplace_back(std::make_unique<BoolLiteral>(false));
} else if (info.first.IsFull()) {
node->inners.emplace_back(std::make_unique<BoolLiteral>(true));
if (is_all_ne) {
for (auto i = intervals.intervals.begin(); i != intervals.intervals.end() && !std::isinf(i->second); ++i) {
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::NE, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(i->second)));
}

return std::make_unique<AndExpr>(std::move(exprs));
}
}

return node;
for (auto [l, r] : intervals.intervals) {
if (std::isinf(l)) {
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(r)));
} else if (std::isinf(r)) {
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(l)));
} else if (r == IntervalSet::NextNum(l)) {
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::EQ, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(l)));
} else {
std::vector<std::unique_ptr<QueryExpr>> sub_expr;
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(l)));
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT, field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(r)));

exprs.emplace_back(std::make_unique<AndExpr>(std::move(sub_expr)));
}
}

if (exprs.size() == 1) {
return std::move(exprs.front());
} else {
return std::make_unique<OrExpr>(std::move(exprs));
}
}

std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override { return VisitImpl(std::move(node)); }
Expand Down
1 change: 1 addition & 0 deletions src/search/passes/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using PassSequence = std::vector<std::unique_ptr<Pass>>;
struct PassManager {
static std::unique_ptr<Node> Execute(const PassSequence &seq, std::unique_ptr<Node> node) {
for (auto &pass : seq) {
pass->Reset();
node = pass->Transform(std::move(node));
}
return node;
Expand Down
32 changes: 32 additions & 0 deletions tests/cppunit/ir_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "search/ir_pass.h"

#include "gtest/gtest.h"
#include "search/interval.h"
#include "search/ir_sema_checker.h"
#include "search/passes/interval_analysis.h"
#include "search/passes/lower_to_plan.h"
#include "search/passes/manager.h"
Expand Down Expand Up @@ -130,4 +132,34 @@ TEST(IRPassTest, IntervalAnalysis) {
"select * from a where false");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a > 3 or a < 1) and a = 2"))->Dump(),
"select * from a where false");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where b = 1 and (a = 1 or a != 1)"))->Dump(),
"select * from a where b = 1");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or b = 1 or a != 1"))->Dump(),
"select * from a where true");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a < 3 or a > 1) and b >= 1"))->Dump(),
"select * from a where b >= 1");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 or a != 2"))->Dump(),
"select * from a where true");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a = 2"))->Dump(),
"select * from a where false");

ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 and a < 3"))->Dump(),
"select * from a where a < 1");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 1 or a < 3"))->Dump(),
"select * from a where a < 3");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 and a < 3"))->Dump(),
"select * from a where a = 1");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a < 3"))->Dump(),
"select * from a where a < 3");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 1 or a = 3"))->Dump(),
"select * from a where (or a = 1, a = 3)");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1"))->Dump(),
"select * from a where a != 1");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and a != 2"))->Dump(),
"select * from a where (and a != 1, a != 2)");
ASSERT_EQ(
PassManager::Execute(ia_passes, *Parse("select * from a where a >= 0 and a >= 1 and a < 4 and a != 2"))->Dump(),
fmt::format("select * from a where (or (and a >= 1, a < 2), (and a >= {}, a < 4))", IntervalSet::NextNum(2)));
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 1 and b > 1 and b = 2"))->Dump(),
"select * from a where (and a != 1, b = 2)");
}

0 comments on commit 899216d

Please sign in to comment.