Skip to content

Commit

Permalink
Add IR visitor and boolean simplification pass (#2211)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice committed Mar 31, 2024
1 parent bb665f6 commit 35c9e93
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/kvrocks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ jobs:
compiler: gcc
with_openssl: -DENABLE_OPENSSL=ON
- name: Ubuntu Clang with OpenSSL
os: ubuntu-20.04
os: ubuntu-22.04
compiler: clang
with_openssl: -DENABLE_OPENSSL=ON
- name: Ubuntu GCC without luaJIT
Expand All @@ -191,7 +191,7 @@ jobs:
compiler: gcc
new_encoding: -DENABLE_NEW_ENCODING=TRUE
- name: Ubuntu Clang with new encoding
os: ubuntu-20.04
os: ubuntu-22.04
compiler: clang
new_encoding: -DENABLE_NEW_ENCODING=TRUE
- name: Ubuntu GCC with speedb enabled
Expand Down
11 changes: 6 additions & 5 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ struct Node {
return std::unique_ptr<U>(new T(std::forward<Args>(args)...));
}

template <typename T>
static std::unique_ptr<T> MustAs(std::unique_ptr<Node> &&original) {
template <typename T, typename U>
static std::unique_ptr<T> MustAs(std::unique_ptr<U> &&original) {
auto casted = As<T>(std::move(original));
CHECK(casted != nullptr);
return casted;
}

template <typename T>
static std::unique_ptr<T> As(std::unique_ptr<Node> &&original) {
auto casted = dynamic_cast<T *>(original.release());
template <typename T, typename U>
static std::unique_ptr<T> As(std::unique_ptr<U> &&original) {
auto casted = dynamic_cast<T *>(original.get());
if (casted) original.release();
return std::unique_ptr<T>(casted);
}
};
Expand Down
144 changes: 144 additions & 0 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include "ir.h"

namespace kqir {

struct Pass {
virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
};

struct Visitor : Pass {
std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
if (auto v = Node::As<SearchStmt>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<SelectExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<IndexRef>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<SortBy>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<AndExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<OrExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<NotExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<NumericCompareExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<NumericLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<FieldRef>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<StringLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
return Visit(std::move(v));
}

__builtin_unreachable();
}

template <typename T>
std::unique_ptr<T> VisitAs(std::unique_ptr<T> n) {
return Node::MustAs<T>(Visit(std::move(n)));
}

template <typename T>
std::unique_ptr<T> TransformAs(std::unique_ptr<Node> n) {
return Node::MustAs<T>(Transform(std::move(n)));
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
node->index = VisitAs<IndexRef>(std::move(node->index));
node->select_expr = VisitAs<SelectExpr>(std::move(node->select_expr));
if (node->query_expr) node->query_expr = TransformAs<QueryExpr>(std::move(node->query_expr));
if (node->sort_by) node->sort_by = VisitAs<SortBy>(std::move(node->sort_by));
if (node->limit) node->limit = VisitAs<Limit>(std::move(node->limit));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<SelectExpr> node) {
for (auto &n : node->fields) {
n = VisitAs<FieldRef>(std::move(n));
}

return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<IndexRef> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<FieldRef> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<BoolLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<StringLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->num = VisitAs<NumericLiteral>(std::move(node->num));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagContainExpr> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->tag = VisitAs<StringLiteral>(std::move(node->tag));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
}

return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
}

return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) {
node->inner = TransformAs<QueryExpr>(std::move(node->inner));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<SortBy> node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
return node;
}
};

} // namespace kqir
92 changes: 92 additions & 0 deletions src/search/passes/simplify_boolean.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <iostream>
#include <memory>

#include "search/ir.h"
#include "search/ir_pass.h"

namespace kqir {

struct SimplifyBoolean : Visitor {
std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override {
node = Node::MustAs<OrExpr>(Visitor::Visit(std::move(node)));

for (auto iter = node->inners.begin(); iter != node->inners.end();) {
if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
if (!v->val) {
iter = node->inners.erase(iter);
} else {
return v;
}
} else {
++iter;
}
}

if (node->inners.size() == 0) {
return std::make_unique<BoolLiteral>(false);
} else if (node->inners.size() == 1) {
return std::move(node->inners[0]);
}

return node;
}

std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
node = Node::MustAs<AndExpr>(Visitor::Visit(std::move(node)));

for (auto iter = node->inners.begin(); iter != node->inners.end();) {
if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
if (v->val) {
iter = node->inners.erase(iter);
} else {
return v;
}
} else {
++iter;
}
}

if (node->inners.size() == 0) {
return std::make_unique<BoolLiteral>(true);
} else if (node->inners.size() == 1) {
return std::move(node->inners[0]);
}

return node;
}

std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) override {
node = Node::MustAs<NotExpr>(Visitor::Visit(std::move(node)));

if (auto v = Node::As<BoolLiteral>(std::move(node->inner))) {
v->val = !v->val;
return v;
}

return node;
}
};

} // namespace kqir
64 changes: 64 additions & 0 deletions tests/cppunit/ir_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#include "search/ir_pass.h"

#include "gtest/gtest.h"
#include "search/passes/simplify_boolean.h"
#include "search/sql_transformer.h"

using namespace kqir;

static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); }

TEST(IRPassTest, Simple) {
auto ir = *Parse("select a from b where not c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10");

auto original = ir->Dump();

Visitor visitor;
auto ir2 = visitor.Transform(std::move(ir));
ASSERT_EQ(original, ir2->Dump());
}

TEST(IRPassTest, SimplifyBoolean) {
SimplifyBoolean sb;
ASSERT_EQ(sb.Transform(*Parse("select a from b where not false"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where not not false"))->Dump(), "select a from b where false");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false"))->Dump(), "select a from b where false");
ASSERT_EQ(sb.Transform(*Parse("select a from b where false and true"))->Dump(), "select a from b where false");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false and true"))->Dump(),
"select a from b where false");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true and true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and false"))->Dump(), "select a from b where false");
ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true"))->Dump(), "select a from b where x > 1");
ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true and y < 10"))->Dump(),
"select a from b where (and x > 1, y < 10)");
ASSERT_EQ(sb.Transform(*Parse("select a from b where not (false and (not true))"))->Dump(),
"select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true or true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where false or true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true");
ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true) and (y > 2 and true))"))->Dump(),
"select a from b where not y > 2");
}

0 comments on commit 35c9e93

Please sign in to comment.