From 91509c724270bbff4742d66fc72d20a0aec0be74 Mon Sep 17 00:00:00 2001 From: Twice Date: Sun, 14 Apr 2024 18:59:33 +0900 Subject: [PATCH] Add plan lowering pass for KQIR (#2247) --- src/search/index_info.h | 2 + src/search/ir.h | 17 ++++-- src/search/ir_pass.h | 72 +++++++++++++++++++++++- src/search/ir_plan.h | 91 ++++++++++++++++++++++++++----- src/search/ir_sema_checker.h | 5 +- src/search/passes/lower_to_plan.h | 51 +++++++++++++++++ src/search/search_encoding.h | 4 +- src/search/sql_transformer.h | 2 +- tests/cppunit/ir_pass_test.cc | 14 +++++ 9 files changed, 234 insertions(+), 24 deletions(-) create mode 100644 src/search/passes/lower_to_plan.h diff --git a/src/search/index_info.h b/src/search/index_info.h index e6e95b72348..df059918774 100644 --- a/src/search/index_info.h +++ b/src/search/index_info.h @@ -37,6 +37,8 @@ struct FieldInfo { FieldInfo(std::string name, std::unique_ptr &&metadata) : name(std::move(name)), metadata(std::move(metadata)) {} + + bool IsSortable() const { return dynamic_cast(metadata.get()) != nullptr; } }; struct IndexInfo { diff --git a/src/search/ir.h b/src/search/ir.h index 2b85b458034..d7da716ac50 100644 --- a/src/search/ir.h +++ b/src/search/ir.h @@ -51,6 +51,11 @@ struct Node { virtual std::unique_ptr Clone() const = 0; + template + std::unique_ptr CloneAs() const { + return Node::MustAs(Clone()); + } + virtual ~Node() = default; template @@ -361,14 +366,14 @@ struct IndexRef : Ref { std::unique_ptr Clone() const override { return std::make_unique(*this); } }; -struct SearchStmt : Node { +struct SearchExpr : Node { std::unique_ptr select; std::unique_ptr index; std::unique_ptr query_expr; std::unique_ptr limit; // optional std::unique_ptr sort_by; // optional - SearchStmt(std::unique_ptr &&index, std::unique_ptr &&query_expr, + SearchExpr(std::unique_ptr &&index, std::unique_ptr &&query_expr, std::unique_ptr &&limit, std::unique_ptr &&sort_by, std::unique_ptr &&select) : select(std::move(select)), @@ -386,16 +391,16 @@ struct SearchStmt : Node { } static inline const std::vector> ChildMap = { - NodeIterator::MemFn<&SearchStmt::select>, NodeIterator::MemFn<&SearchStmt::index>, - NodeIterator::MemFn<&SearchStmt::query_expr>, NodeIterator::MemFn<&SearchStmt::limit>, - NodeIterator::MemFn<&SearchStmt::sort_by>, + NodeIterator::MemFn<&SearchExpr::select>, NodeIterator::MemFn<&SearchExpr::index>, + NodeIterator::MemFn<&SearchExpr::query_expr>, NodeIterator::MemFn<&SearchExpr::limit>, + NodeIterator::MemFn<&SearchExpr::sort_by>, }; NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); }; NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); }; std::unique_ptr Clone() const override { - return std::make_unique( + return std::make_unique( Node::MustAs(index->Clone()), Node::MustAs(query_expr->Clone()), Node::MustAs(limit->Clone()), Node::MustAs(sort_by->Clone()), Node::MustAs(select->Clone())); diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h index a252216dd35..924e8c752eb 100644 --- a/src/search/ir_pass.h +++ b/src/search/ir_pass.h @@ -21,6 +21,7 @@ #pragma once #include "ir.h" +#include "search/ir_plan.h" namespace kqir { @@ -30,7 +31,7 @@ struct Pass { struct Visitor : Pass { std::unique_ptr Transform(std::unique_ptr node) override { - if (auto v = Node::As(std::move(node))) { + if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); @@ -58,6 +59,26 @@ struct Visitor : Pass { return Visit(std::move(v)); } else if (auto v = Node::As(std::move(node))) { return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); + } else if (auto v = Node::As(std::move(node))) { + return Visit(std::move(v)); } __builtin_unreachable(); @@ -73,7 +94,7 @@ struct Visitor : Pass { return Node::MustAs(Transform(std::move(n))); } - virtual std::unique_ptr Visit(std::unique_ptr node) { + virtual std::unique_ptr Visit(std::unique_ptr node) { node->index = VisitAs(std::move(node->index)); node->select = VisitAs(std::move(node->select)); node->query_expr = TransformAs(std::move(node->query_expr)); @@ -139,6 +160,53 @@ struct Visitor : Pass { node->field = VisitAs(std::move(node->field)); return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { return node; } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->source = TransformAs(std::move(node->source)); + node->filter_expr = TransformAs(std::move(node->filter_expr)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->limit = VisitAs(std::move(node->limit)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->order = VisitAs(std::move(node->order)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->op = TransformAs(std::move(node->op)); + node->limit = VisitAs(std::move(node->limit)); + node->order = VisitAs(std::move(node->order)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + node->source = TransformAs(std::move(node->source)); + node->select = VisitAs(std::move(node->select)); + return node; + } + + virtual std::unique_ptr Visit(std::unique_ptr node) { + for (auto &n : node->ops) { + n = TransformAs(std::move(n)); + } + + return node; + } }; } // namespace kqir diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h index 9d40d948b39..da8058464d8 100644 --- a/src/search/ir_plan.h +++ b/src/search/ir_plan.h @@ -31,22 +31,30 @@ namespace kqir { struct PlanOperator : Node {}; +struct Noop : PlanOperator { + std::string_view Name() const override { return "Noop"; }; + std::string Dump() const override { return "noop"; } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + struct FullIndexScan : PlanOperator { - IndexInfo *index; + std::unique_ptr index; - explicit FullIndexScan(IndexInfo *index) : index(index) {} + explicit FullIndexScan(std::unique_ptr index) : index(std::move(index)) {} std::string_view Name() const override { return "FullIndexScan"; }; - std::string Content() const override { return index->name; }; - std::string Dump() const override { return fmt::format("full-scan {}", Content()); } + std::string Dump() const override { return fmt::format("full-scan {}", index->name); } - std::unique_ptr Clone() const override { return std::make_unique(*this); } + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(index->Clone())); + } }; struct FieldScan : PlanOperator { - FieldInfo *field; + std::unique_ptr field; - explicit FieldScan(FieldInfo *field) : field(field) {} + explicit FieldScan(std::unique_ptr field) : field(std::move(field)) {} }; struct Interval { @@ -60,25 +68,29 @@ struct Interval { struct NumericFieldScan : FieldScan { Interval range; - NumericFieldScan(FieldInfo *field, Interval range) : FieldScan(field), range(range) {} + NumericFieldScan(std::unique_ptr field, Interval range) : FieldScan(std::move(field)), range(range) {} std::string_view Name() const override { return "NumericFieldScan"; }; std::string Content() const override { return fmt::format("{}, {}", field->name, range.ToString()); }; std::string Dump() const override { return fmt::format("numeric-scan {}", Content()); } - std::unique_ptr Clone() const override { return std::make_unique(*this); } + std::unique_ptr Clone() const override { + return std::make_unique(field->CloneAs(), range); + } }; struct TagFieldScan : FieldScan { std::string tag; - TagFieldScan(FieldInfo *field, std::string tag) : FieldScan(field), tag(std::move(tag)) {} + TagFieldScan(std::unique_ptr field, std::string tag) : FieldScan(std::move(field)), tag(std::move(tag)) {} std::string_view Name() const override { return "TagFieldScan"; }; std::string Content() const override { return fmt::format("{}, {}", field->name, tag); }; std::string Dump() const override { return fmt::format("tag-scan {}", Content()); } - std::unique_ptr Clone() const override { return std::make_unique(*this); } + std::unique_ptr Clone() const override { + return std::make_unique(field->CloneAs(), tag); + } }; struct Filter : PlanOperator { @@ -89,7 +101,7 @@ struct Filter : PlanOperator { : source(std::move(source)), filter_expr(std::move(filter_expr)) {} std::string_view Name() const override { return "Filter"; }; - std::string Dump() const override { return fmt::format("(filter {}, {})", source->Dump(), Content()); } + std::string Dump() const override { return fmt::format("(filter {}: {})", filter_expr->Dump(), source->Dump()); } NodeIterator ChildBegin() override { return {source.get(), filter_expr.get()}; } NodeIterator ChildEnd() override { return {}; } @@ -143,6 +155,55 @@ struct Limit : PlanOperator { } }; +struct Sort : PlanOperator { + std::unique_ptr op; + std::unique_ptr order; + + Sort(std::unique_ptr &&op, std::unique_ptr &&order) + : op(std::move(op)), order(std::move(order)) {} + + std::string_view Name() const override { return "Sort"; }; + std::string Dump() const override { + return fmt::format("(sort {}, {}: {})", order->field->Dump(), order->OrderToString(order->order), op->Dump()); + } + + NodeIterator ChildBegin() override { return NodeIterator{op.get(), order.get()}; } + NodeIterator ChildEnd() override { return {}; } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(op->Clone()), Node::MustAs(order->Clone())); + } +}; + +// operator fusion: Sort + Limit +struct TopNSort : PlanOperator { + std::unique_ptr op; + std::unique_ptr order; + std::unique_ptr limit; + + TopNSort(std::unique_ptr &&op, std::unique_ptr &&order, + std::unique_ptr &&limit) + : op(std::move(op)), order(std::move(order)), limit(std::move(limit)) {} + + std::string_view Name() const override { return "TopNSort"; }; + std::string Dump() const override { + return fmt::format("(top-n sort {}, {}, {}, {}: {})", order->field->Dump(), order->OrderToString(order->order), + limit->offset, limit->count, op->Dump()); + } + + static inline const std::vector> ChildMap = { + NodeIterator::MemFn<&TopNSort::op>, NodeIterator::MemFn<&TopNSort::order>, NodeIterator::MemFn<&TopNSort::limit>}; + + NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); } + NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); } + + std::unique_ptr Clone() const override { + return std::make_unique(Node::MustAs(op->Clone()), + Node::MustAs(order->Clone()), + Node::MustAs(limit->Clone())); + } +}; + struct Projection : PlanOperator { std::unique_ptr source; std::unique_ptr select; @@ -151,7 +212,11 @@ struct Projection : PlanOperator { : source(std::move(source)), select(std::move(select)) {} std::string_view Name() const override { return "Projection"; }; - std::string Dump() const override { return fmt::format("(project {}: {})", select, source); } + std::string Dump() const override { + auto select_str = + select->fields.empty() ? "*" : util::StringJoin(select->fields, [](const auto &v) { return v->Dump(); }); + return fmt::format("project {}: {}", select_str, source->Dump()); + } NodeIterator ChildBegin() override { return {source.get(), select.get()}; } NodeIterator ChildEnd() override { return {}; } diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h index 471c3a3f714..40f316cba0b 100644 --- a/src/search/ir_sema_checker.h +++ b/src/search/ir_sema_checker.h @@ -23,6 +23,7 @@ #include #include +#include "fmt/core.h" #include "index_info.h" #include "ir.h" #include "search_encoding.h" @@ -38,7 +39,7 @@ struct SemaChecker { explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {} Status Check(Node *node) { - if (auto v = dynamic_cast(node)) { + if (auto v = dynamic_cast(node)) { auto index_name = v->index->name; if (auto iter = index_map.find(index_name); iter != index_map.end()) { current_index = &iter->second; @@ -56,6 +57,8 @@ struct SemaChecker { } else if (auto v = dynamic_cast(node)) { if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) { return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)}; + } else if (!iter->second.IsSortable()) { + return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)}; } else { v->field->info = &iter->second; } diff --git a/src/search/passes/lower_to_plan.h b/src/search/passes/lower_to_plan.h new file mode 100644 index 00000000000..dad1db39c23 --- /dev/null +++ b/src/search/passes/lower_to_plan.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include + +#include "search/ir.h" +#include "search/ir_pass.h" +#include "search/ir_plan.h" + +namespace kqir { + +struct LowerToPlan : Visitor { + std::unique_ptr Visit(std::unique_ptr node) override { + auto scan = std::make_unique(node->index->CloneAs()); + auto filter = std::make_unique(std::move(scan), std::move(node->query_expr)); + + std::unique_ptr op = std::move(filter); + + // order is important here, since limit(sort(op)) is different from sort(limit(op)) + if (node->sort_by) { + op = std::make_unique(std::move(op), std::move(node->sort_by)); + } + + if (node->limit) { + op = std::make_unique(std::move(op), std::move(node->limit)); + } + + return std::make_unique(std::move(op), std::move(node->select)); + } +}; + +} // namespace kqir diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index 1637a504c28..14bf2923911 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -125,7 +125,9 @@ inline std::string ConstructNumericFieldMetadataSubkey(std::string_view field_na return res; } -struct SearchNumericFieldMetadata : SearchFieldMetadata {}; +struct SearchSortableFieldMetadata : SearchFieldMetadata {}; + +struct SearchNumericFieldMetadata : SearchSortableFieldMetadata {}; inline std::string ConstructTagFieldSubkey(std::string_view field_name, std::string_view tag, std::string_view key) { std::string res = {(char)SearchSubkeyType::TAG_FIELD}; diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h index 871949fdcce..8386a66e021 100644 --- a/src/search/sql_transformer.h +++ b/src/search/sql_transformer.h @@ -168,7 +168,7 @@ struct Transformer : ir::TreeTransformer { query_expr = std::make_unique(true); } - return Node::Create(std::move(index), std::move(query_expr), std::move(limit), std::move(sort_by), + return Node::Create(std::move(index), std::move(query_expr), std::move(limit), std::move(sort_by), std::move(select)); } else if (IsRoot(node)) { CHECK(node->children.size() == 1); diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc index 70d39f96aca..0f0952bfea0 100644 --- a/tests/cppunit/ir_pass_test.cc +++ b/tests/cppunit/ir_pass_test.cc @@ -21,6 +21,7 @@ #include "search/ir_pass.h" #include "gtest/gtest.h" +#include "search/passes/lower_to_plan.h" #include "search/passes/manager.h" #include "search/passes/push_down_not_expr.h" #include "search/passes/simplify_and_or_expr.h" @@ -103,3 +104,16 @@ TEST(IRPassTest, Manager) { PassManager::Default(*Parse("select * from a where not (x > 1 or (y < 2 or z = 3)) and (true or x = 1)"))->Dump(), "select * from a where (and x <= 1, y >= 2, z != 3)"); } + +TEST(IRPassTest, LowerToPlan) { + LowerToPlan ltp; + + ASSERT_EQ(ltp.Transform(*Parse("select * from a"))->Dump(), "project *: (filter true: full-scan a)"); + ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(), "project *: (filter b > 1: full-scan a)"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d"))->Dump(), + "project a: (sort d, asc: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 1"))->Dump(), + "project a: (limit 0, 1: (filter c = 1: full-scan b))"); + ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 1"))->Dump(), + "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan b)))"); +}