From 35c9e937341a82e494be678bc56e518ab71fff2e Mon Sep 17 00:00:00 2001 From: Twice Date: Sun, 31 Mar 2024 18:25:01 +0900 Subject: [PATCH] Add IR visitor and boolean simplification pass (#2211) --- .github/workflows/kvrocks.yaml | 4 +- src/search/ir.h | 11 +- src/search/ir_pass.h | 144 +++++++++++++++++++++++++++ src/search/passes/simplify_boolean.h | 92 +++++++++++++++++ tests/cppunit/ir_pass_test.cc | 64 ++++++++++++ 5 files changed, 308 insertions(+), 7 deletions(-) create mode 100644 src/search/ir_pass.h create mode 100644 src/search/passes/simplify_boolean.h create mode 100644 tests/cppunit/ir_pass_test.cc diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml index 675a23b6f25..6c2b51c911a 100644 --- a/.github/workflows/kvrocks.yaml +++ b/.github/workflows/kvrocks.yaml @@ -175,7 +175,7 @@ jobs: compiler: gcc with_openssl: -DENABLE_OPENSSL=ON - name: Ubuntu Clang with OpenSSL - os: ubuntu-20.04 + os: ubuntu-22.04 compiler: clang with_openssl: -DENABLE_OPENSSL=ON - name: Ubuntu GCC without luaJIT @@ -191,7 +191,7 @@ jobs: compiler: gcc new_encoding: -DENABLE_NEW_ENCODING=TRUE - name: Ubuntu Clang with new encoding - os: ubuntu-20.04 + os: ubuntu-22.04 compiler: clang new_encoding: -DENABLE_NEW_ENCODING=TRUE - name: Ubuntu GCC with speedb enabled diff --git a/src/search/ir.h b/src/search/ir.h index 02f3766ab07..6e3dea28150 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -55,16 +55,17 @@ struct Node { return std::unique_ptr(new T(std::forward(args)...)); } - template - static std::unique_ptr MustAs(std::unique_ptr &&original) { + template + static std::unique_ptr MustAs(std::unique_ptr &&original) { auto casted = As(std::move(original)); CHECK(casted != nullptr); return casted; } - template - static std::unique_ptr As(std::unique_ptr &&original) { - auto casted = dynamic_cast(original.release()); + template + static std::unique_ptr As(std::unique_ptr &&original) { + auto casted = dynamic_cast(original.get()); + if (casted) original.release(); return std::unique_ptr(casted); } }; diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h new file mode 100644 index 00000000000..9a67530a488 --- /dev/null +++ b/src/search/ir_pass.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "ir.h" + +namespace kqir { + +struct Pass { + virtual std::unique_ptr Transform(std::unique_ptr node) = 0; +}; + +struct Visitor : Pass { + std::unique_ptr Transform(std::unique_ptr node) override { + if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } + + __builtin_unreachable(); + } + + template + std::unique_ptr VisitAs(std::unique_ptr n) { + return Node::MustAs(Visit(std::move(n))); + } + + template + std::unique_ptr TransformAs(std::unique_ptr n) { + return Node::MustAs(Transform(std::move(n))); + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->index = VisitAs(std::move(node->index)); + node->select_expr = VisitAs(std::move(node->select_expr)); + if (node->query_expr) node->query_expr = TransformAs(std::move(node->query_expr)); + if (node->sort_by) node->sort_by = VisitAs(std::move(node->sort_by)); + if (node->limit) node->limit = VisitAs(std::move(node->limit)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->fields) { + n = VisitAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->num = VisitAs(std::move(node->num)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + node->tag = VisitAs(std::move(node->tag)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->inners) { + n = TransformAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->inners) { + n = TransformAs(std::move(n)); + } + + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->inner = TransformAs(std::move(node->inner)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->field = VisitAs(std::move(node->field)); + return node; + } +}; + +} // namespace kqir diff --git a/src/search/passes/simplify_boolean.h b/src/search/passes/simplify_boolean.h new file mode 100644 index 00000000000..99229c2fb22 --- /dev/null +++ b/src/search/passes/simplify_boolean.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include + +#include "search/ir.h" +#include "search/ir_pass.h" + +namespace kqir { + +struct SimplifyBoolean : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + for (auto iter = node->inners.begin(); iter != node->inners.end();) { + if (auto v = Node::As(std::move(*iter))) { + if (!v->val) { + iter = node->inners.erase(iter); + } else { + return v; + } + } else { + ++iter; + } + } + + if (node->inners.size() == 0) { + return std::make_unique(false); + } else if (node->inners.size() == 1) { + return std::move(node->inners[0]); + } + + return node; + } + + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + for (auto iter = node->inners.begin(); iter != node->inners.end();) { + if (auto v = Node::As(std::move(*iter))) { + if (v->val) { + iter = node->inners.erase(iter); + } else { + return v; + } + } else { + ++iter; + } + } + + if (node->inners.size() == 0) { + return std::make_unique(true); + } else if (node->inners.size() == 1) { + return std::move(node->inners[0]); + } + + return node; + } + + std::unique_ptr Visit(std::unique_ptr node) override { + node = Node::MustAs(Visitor::Visit(std::move(node))); + + if (auto v = Node::As(std::move(node->inner))) { + v->val = !v->val; + return v; + } + + return node; + } +}; + +} // namespace kqir diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc new file mode 100644 index 00000000000..9f9af8f6767 --- /dev/null +++ b/tests/cppunit/ir_pass_test.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/ir_pass.h" + +#include "gtest/gtest.h" +#include "search/passes/simplify_boolean.h" +#include "search/sql_transformer.h" + +using namespace kqir; + +static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); } + +TEST(IRPassTest, Simple) { + auto ir = *Parse("select a from b where not c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10"); + + auto original = ir->Dump(); + + Visitor visitor; + auto ir2 = visitor.Transform(std::move(ir)); + ASSERT_EQ(original, ir2->Dump()); +} + +TEST(IRPassTest, SimplifyBoolean) { + SimplifyBoolean sb; + ASSERT_EQ(sb.Transform(*Parse("select a from b where not false"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not not false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where false and true"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false and true"))->Dump(), + "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true and true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and false"))->Dump(), "select a from b where false"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true"))->Dump(), "select a from b where x > 1"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true and y < 10"))->Dump(), + "select a from b where (and x > 1, y < 10)"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not (false and (not true))"))->Dump(), + "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or true"))->Dump(), "select a from b where true"); + ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true) and (y > 2 and true))"))->Dump(), + "select a from b where not y > 2"); +}