From 7e063030519b1b5a5367a5921a0a5ee2b389d0b7 Mon Sep 17 00:00:00 2001 From: Twice Date: Fri, 29 Mar 2024 18:22:47 +0900 Subject: [PATCH] Add support for dumping DOT graphs from KQIR (#2205) --- src/common/type_util.h | 8 +++ src/search/ir.h | 79 ++++++++++++++++++++++-- src/search/ir_dot_dumper.h | 54 ++++++++++++++++ src/search/ir_iterator.h | 92 ++++++++++++++++++++++++++++ src/search/redis_query_transformer.h | 30 ++++----- src/search/sql_transformer.h | 25 ++++---- tests/cppunit/ir_dot_dumper_test.cc | 55 +++++++++++++++++ 7 files changed, 311 insertions(+), 32 deletions(-) create mode 100644 src/search/ir_dot_dumper.h create mode 100644 src/search/ir_iterator.h create mode 100644 tests/cppunit/ir_dot_dumper_test.cc diff --git a/src/common/type_util.h b/src/common/type_util.h index d55019f77f7..98439243622 100644 --- a/src/common/type_util.h +++ b/src/common/type_util.h @@ -39,3 +39,11 @@ using RemoveCVRef = typename std::remove_cv_t constexpr bool AlwaysFalse = false; + +template +struct GetClassFromMember; + +template +struct GetClassFromMember { + using type = C; // NOLINT +}; diff --git a/src/search/ir.h b/src/search/ir.h index 87ed6488466..02f3766ab07 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -22,22 +22,32 @@ #include +#include #include #include #include #include +#include #include #include #include #include "fmt/core.h" +#include "ir_iterator.h" #include "string_util.h" +#include "type_util.h" // kqir stands for Kvorcks Query Intermediate Representation namespace kqir { struct Node { virtual std::string Dump() const = 0; + virtual std::string_view Name() const = 0; + virtual std::string Content() const { return {}; } + + virtual NodeIterator ChildBegin() { return {}; }; + virtual NodeIterator ChildEnd() { return {}; }; + virtual ~Node() = default; template @@ -45,10 +55,16 @@ struct Node { return std::unique_ptr(new T(std::forward(args)...)); } + template + static std::unique_ptr MustAs(std::unique_ptr &&original) { + auto casted = As(std::move(original)); + CHECK(casted != nullptr); + return casted; + } + template static std::unique_ptr As(std::unique_ptr &&original) { auto casted = dynamic_cast(original.release()); - CHECK(casted); return std::unique_ptr(casted); } }; @@ -58,7 +74,9 @@ struct FieldRef : Node { explicit FieldRef(std::string name) : name(std::move(name)) {} + std::string_view Name() const override { return "FieldRef"; } std::string Dump() const override { return name; } + std::string Content() const override { return Dump(); } }; struct StringLiteral : Node { @@ -66,7 +84,9 @@ struct StringLiteral : Node { explicit StringLiteral(std::string val) : val(std::move(val)) {} + std::string_view Name() const override { return "StringLiteral"; } std::string Dump() const override { return fmt::format("\"{}\"", util::EscapeString(val)); } + std::string Content() const override { return Dump(); } }; struct QueryExpr : Node {}; @@ -80,7 +100,11 @@ struct TagContainExpr : BoolAtomExpr { TagContainExpr(std::unique_ptr &&field, std::unique_ptr &&tag) : field(std::move(field)), tag(std::move(tag)) {} + std::string_view Name() const override { return "TagContainExpr"; } std::string Dump() const override { return fmt::format("{} hastag {}", field->Dump(), tag->Dump()); } + + NodeIterator ChildBegin() override { return {field.get(), tag.get()}; }; + NodeIterator ChildEnd() override { return {}; }; }; struct NumericLiteral : Node { @@ -88,7 +112,9 @@ struct NumericLiteral : Node { explicit NumericLiteral(double val) : val(val) {} + std::string_view Name() const override { return "NumericLiteral"; } std::string Dump() const override { return fmt::format("{}", val); } + std::string Content() const override { return Dump(); } }; // NOLINTNEXTLINE @@ -156,7 +182,12 @@ struct NumericCompareExpr : BoolAtomExpr { __builtin_unreachable(); } + std::string_view Name() const override { return "NumericCompareExpr"; } std::string Dump() const override { return fmt::format("{} {} {}", field->Dump(), ToOperator(op), num->Dump()); }; + std::string Content() const override { return ToOperator(op); } + + NodeIterator ChildBegin() override { return {field.get(), num.get()}; }; + NodeIterator ChildEnd() override { return {}; }; }; struct BoolLiteral : BoolAtomExpr { @@ -164,7 +195,9 @@ struct BoolLiteral : BoolAtomExpr { explicit BoolLiteral(bool val) : val(val) {} + std::string_view Name() const override { return "BoolLiteral"; } std::string Dump() const override { return val ? "true" : "false"; } + std::string Content() const override { return Dump(); } }; struct QueryExpr; @@ -174,7 +207,11 @@ struct NotExpr : QueryExpr { explicit NotExpr(std::unique_ptr &&inner) : inner(std::move(inner)) {} + std::string_view Name() const override { return "NotExpr"; } std::string Dump() const override { return fmt::format("not {}", inner->Dump()); } + + NodeIterator ChildBegin() override { return NodeIterator{inner.get()}; }; + NodeIterator ChildEnd() override { return {}; }; }; struct AndExpr : QueryExpr { @@ -182,9 +219,13 @@ struct AndExpr : QueryExpr { explicit AndExpr(std::vector> &&inners) : inners(std::move(inners)) {} + std::string_view Name() const override { return "AndExpr"; } std::string Dump() const override { return fmt::format("(and {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); } + + NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; }; struct OrExpr : QueryExpr { @@ -192,9 +233,13 @@ struct OrExpr : QueryExpr { explicit OrExpr(std::vector> &&inners) : inners(std::move(inners)) {} + std::string_view Name() const override { return "OrExpr"; } std::string Dump() const override { return fmt::format("(or {})", util::StringJoin(inners, [](const auto &v) { return v->Dump(); })); } + + NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(inners.end()); }; }; struct Limit : Node { @@ -203,7 +248,9 @@ struct Limit : Node { Limit(size_t offset, size_t count) : offset(offset), count(count) {} + std::string_view Name() const override { return "Limit"; } std::string Dump() const override { return fmt::format("limit {}, {}", offset, count); } + std::string Content() const override { return fmt::format("{}, {}", offset, count); } }; struct SortBy : Node { @@ -213,7 +260,13 @@ struct SortBy : Node { SortBy(Order order, std::unique_ptr &&field) : order(order), field(std::move(field)) {} static constexpr const char *OrderToString(Order order) { return order == ASC ? "asc" : "desc"; } + + std::string_view Name() const override { return "SortBy"; } std::string Dump() const override { return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order)); } + std::string Content() const override { return OrderToString(order); } + + NodeIterator ChildBegin() override { return NodeIterator(field.get()); }; + NodeIterator ChildEnd() override { return {}; }; }; struct SelectExpr : Node { @@ -221,10 +274,14 @@ struct SelectExpr : Node { explicit SelectExpr(std::vector> &&fields) : fields(std::move(fields)) {} + std::string_view Name() const override { return "SelectExpr"; } std::string Dump() const override { if (fields.empty()) return "select *"; return fmt::format("select {}", util::StringJoin(fields, [](const auto &v) { return v->Dump(); })); } + + NodeIterator ChildBegin() override { return NodeIterator(fields.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(fields.end()); }; }; struct IndexRef : Node { @@ -232,24 +289,27 @@ struct IndexRef : Node { explicit IndexRef(std::string name) : name(std::move(name)) {} + std::string_view Name() const override { return "IndexRef"; } std::string Dump() const override { return name; } + std::string Content() const override { return Dump(); } }; struct SearchStmt : Node { + std::unique_ptr select_expr; std::unique_ptr index; std::unique_ptr query_expr; // optional std::unique_ptr limit; // optional std::unique_ptr sort_by; // optional - std::unique_ptr select_expr; SearchStmt(std::unique_ptr &&index, std::unique_ptr &&query_expr, std::unique_ptr &&limit, std::unique_ptr &&sort_by, std::unique_ptr &&select_expr) - : index(std::move(index)), + : select_expr(std::move(select_expr)), + index(std::move(index)), query_expr(std::move(query_expr)), limit(std::move(limit)), - sort_by(std::move(sort_by)), - select_expr(std::move(select_expr)) {} + sort_by(std::move(sort_by)) {} + std::string_view Name() const override { return "SearchStmt"; } std::string Dump() const override { std::string opt; if (query_expr) opt += " where " + query_expr->Dump(); @@ -257,6 +317,15 @@ struct SearchStmt : Node { if (limit) opt += " " + limit->Dump(); return fmt::format("{} from {}{}", select_expr->Dump(), index->Dump(), opt); } + + static inline const std::vector> ChildMap = { + NodeIterator::MemFn<&SearchStmt::select_expr>, NodeIterator::MemFn<&SearchStmt::index>, + NodeIterator::MemFn<&SearchStmt::query_expr>, NodeIterator::MemFn<&SearchStmt::limit>, + NodeIterator::MemFn<&SearchStmt::sort_by>, + }; + + NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); }; + NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); }; }; } // namespace kqir diff --git a/src/search/ir_dot_dumper.h b/src/search/ir_dot_dumper.h new file mode 100644 index 00000000000..5bb6dc7b728 --- /dev/null +++ b/src/search/ir_dot_dumper.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "ir.h" +#include "string_util.h" + +namespace kqir { + +struct DotDumper { + std::ostream &os; + + void Dump(Node *node) { + os << "digraph {\n"; + dump(node); + os << "}\n"; + } + + private: + static std::string nodeId(Node *node) { return fmt::format("x{:x}", (uint64_t)node); } + + void dump(Node *node) { + os << " " << nodeId(node) << " [ label = \"" << node->Name(); + if (auto content = node->Content(); !content.empty()) { + os << " (" << util::EscapeString(content) << ")\" ];\n"; + } else { + os << "\" ];\n"; + } + for (auto i = node->ChildBegin(); i != node->ChildEnd(); ++i) { + os << " " << nodeId(node) << " -> " << nodeId(*i) << ";\n"; + dump(*i); + } + } +}; + +} // namespace kqir diff --git a/src/search/ir_iterator.h b/src/search/ir_iterator.h new file mode 100644 index 00000000000..2ead473c190 --- /dev/null +++ b/src/search/ir_iterator.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include + +#include "type_util.h" + +namespace kqir { + +struct Node; + +struct NodeIterator { + std::variant, + std::pair>::const_iterator>, + std::vector>::iterator> + val; + + NodeIterator() : val(nullptr) {} + explicit NodeIterator(Node *node) : val(node) {} + NodeIterator(Node *n1, Node *n2) : val(std::array{n1, n2}) {} + explicit NodeIterator(Node *parent, std::vector>::const_iterator iter) + : val(std::make_pair(parent, iter)) {} + template , int> = 0> + explicit NodeIterator(Iterator iter) : val(*CastToNodeIter(&iter)) {} + + template + static auto CastToNodeIter(Iterator *iter) { + auto res __attribute__((__may_alias__)) = reinterpret_cast>::iterator *>(iter); + return res; + } + + template + static Node *MemFn(Node *parent) { + return (reinterpret_cast::type *>(parent)->*F).get(); + } + + friend bool operator==(NodeIterator l, NodeIterator r) { return l.val == r.val; } + + friend bool operator!=(NodeIterator l, NodeIterator r) { return l.val != r.val; } + + Node *operator*() { + if (val.index() == 0) { + return std::get<0>(val); + } else if (val.index() == 1) { + return std::get<1>(val)[0]; + } else if (val.index() == 2) { + auto &[parent, iter] = std::get<2>(val); + return (*iter)(parent); + } else { + return std::get<3>(val)->get(); + } + } + + NodeIterator &operator++() { + if (val.index() == 0) { + val = nullptr; + } else if (val.index() == 1) { + val = std::get<1>(val)[1]; + } else if (val.index() == 2) { + ++std::get<2>(val).second; + } else { + ++std::get<3>(val); + } + + return *this; + } +}; + +} // namespace kqir diff --git a/src/search/redis_query_transformer.h b/src/search/redis_query_transformer.h index ebd051430c3..45734b3182a 100644 --- a/src/search/redis_query_transformer.h +++ b/src/search/redis_query_transformer.h @@ -83,13 +83,13 @@ struct Transformer : ir::TreeTransformer { const auto& rhs = query->children[1]; if (Is(lhs)) { - exprs.push_back( - std::make_unique(NumericCompareExpr::GT, std::make_unique(field), - Node::As(GET_OR_RET(Transform(lhs->children[0]))))); + exprs.push_back(std::make_unique( + NumericCompareExpr::GT, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(lhs->children[0]))))); } else if (Is(lhs)) { - exprs.push_back(std::make_unique(NumericCompareExpr::GET, - std::make_unique(field), - Node::As(GET_OR_RET(Transform(lhs))))); + exprs.push_back( + std::make_unique(NumericCompareExpr::GET, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(lhs))))); } else { // Inf if (lhs->string_view() == "+inf") { return {Status::NotOK, "it's not allowed to set the lower bound as positive infinity"}; @@ -97,13 +97,13 @@ struct Transformer : ir::TreeTransformer { } if (Is(rhs)) { - exprs.push_back( - std::make_unique(NumericCompareExpr::LT, std::make_unique(field), - Node::As(GET_OR_RET(Transform(rhs->children[0]))))); + exprs.push_back(std::make_unique( + NumericCompareExpr::LT, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(rhs->children[0]))))); } else if (Is(rhs)) { - exprs.push_back(std::make_unique(NumericCompareExpr::LET, - std::make_unique(field), - Node::As(GET_OR_RET(Transform(rhs))))); + exprs.push_back( + std::make_unique(NumericCompareExpr::LET, std::make_unique(field), + Node::MustAs(GET_OR_RET(Transform(rhs))))); } else { // Inf if (rhs->string_view() == "-inf") { return {Status::NotOK, "it's not allowed to set the upper bound as negative infinity"}; @@ -121,12 +121,12 @@ struct Transformer : ir::TreeTransformer { } else if (Is(node)) { CHECK(node->children.size() == 1); - return Node::Create(Node::As(GET_OR_RET(Transform(node->children[0])))); + return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); } else if (Is(node)) { std::vector> exprs; for (const auto& child : node->children) { - exprs.push_back(Node::As(GET_OR_RET(Transform(child)))); + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); } return Node::Create(std::move(exprs)); @@ -134,7 +134,7 @@ struct Transformer : ir::TreeTransformer { std::vector> exprs; for (const auto& child : node->children) { - exprs.push_back(Node::As(GET_OR_RET(Transform(child)))); + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); } return Node::Create(std::move(exprs)); diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index 5c300654c28..63397e403d0 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -63,8 +63,9 @@ struct Transformer : ir::TreeTransformer { } else if (Is(node)) { CHECK(node->children.size() == 2); - return Node::Create(std::make_unique(node->children[0]->string()), - Node::As(GET_OR_RET(Transform(node->children[1])))); + return Node::Create( + std::make_unique(node->children[0]->string()), + Node::MustAs(GET_OR_RET(Transform(node->children[1])))); } else if (Is(node)) { CHECK(node->children.size() == 3); @@ -74,23 +75,23 @@ struct Transformer : ir::TreeTransformer { auto op = ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value(); if (Is(lhs) && Is(rhs)) { return Node::Create(op, std::make_unique(lhs->string()), - Node::As(GET_OR_RET(Transform(rhs)))); + Node::MustAs(GET_OR_RET(Transform(rhs)))); } else if (Is(lhs) && Is(rhs)) { return Node::Create(ir::NumericCompareExpr::Flip(op), std::make_unique(rhs->string()), - Node::As(GET_OR_RET(Transform(lhs)))); + Node::MustAs(GET_OR_RET(Transform(lhs)))); } else { return {Status::NotOK, "the left and right side of numeric comparison should be an identifier and a number"}; } } else if (Is(node)) { CHECK(node->children.size() == 1); - return Node::Create(Node::As(GET_OR_RET(Transform(node->children[0])))); + return Node::Create(Node::MustAs(GET_OR_RET(Transform(node->children[0])))); } else if (Is(node)) { std::vector> exprs; for (const auto& child : node->children) { - exprs.push_back(Node::As(GET_OR_RET(Transform(child)))); + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); } return Node::Create(std::move(exprs)); @@ -98,7 +99,7 @@ struct Transformer : ir::TreeTransformer { std::vector> exprs; for (const auto& child : node->children) { - exprs.push_back(Node::As(GET_OR_RET(Transform(child)))); + exprs.push_back(Node::MustAs(GET_OR_RET(Transform(child)))); } return Node::Create(std::move(exprs)); @@ -146,8 +147,8 @@ struct Transformer : ir::TreeTransformer { } else if (Is(node)) { // root node CHECK(node->children.size() >= 2 && node->children.size() <= 5); - auto index = Node::As(GET_OR_RET(Transform(node->children[1]))); - auto select = Node::As(GET_OR_RET(Transform(node->children[0]))); + auto index = Node::MustAs(GET_OR_RET(Transform(node->children[1]))); + auto select = Node::MustAs(GET_OR_RET(Transform(node->children[0]))); std::unique_ptr query_expr; std::unique_ptr limit; @@ -155,11 +156,11 @@ struct Transformer : ir::TreeTransformer { for (size_t i = 2; i < node->children.size(); ++i) { if (Is(node->children[i])) { - query_expr = Node::As(GET_OR_RET(Transform(node->children[i]))); + query_expr = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); } else if (Is(node->children[i])) { - limit = Node::As(GET_OR_RET(Transform(node->children[i]))); + limit = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); } else if (Is(node->children[i])) { - sort_by = Node::As(GET_OR_RET(Transform(node->children[i]))); + sort_by = Node::MustAs(GET_OR_RET(Transform(node->children[i]))); } } diff --git a/tests/cppunit/ir_dot_dumper_test.cc b/tests/cppunit/ir_dot_dumper_test.cc new file mode 100644 index 00000000000..0310eb212e0 --- /dev/null +++ b/tests/cppunit/ir_dot_dumper_test.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "search/ir_dot_dumper.h" + +#include +#include + +#include "gtest/gtest.h" +#include "search/sql_transformer.h" + +using namespace kqir; + +static auto Parse(const std::string& in) { return sql::ParseToIR(peg::string_input(in, "test")); } + +TEST(DotDumperTest, Simple) { + auto ir = *Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e order by e asc limit 0, 10"); + + std::stringstream ss; + DotDumper dumper{ss}; + + dumper.Dump(ir.get()); + + std::string dot = ss.str(); + std::smatch matches; + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "SearchStmt)")); + auto search_stmt = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "OrExpr)")); + auto or_expr = matches[1].str(); + + std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "AndExpr)")); + auto and_expr = matches[1].str(); + + ASSERT_NE(dot.find(fmt::format("{} -> {}", search_stmt, or_expr)), std::string::npos); + ASSERT_NE(dot.find(fmt::format("{} -> {}", or_expr, and_expr)), std::string::npos); +}