Skip to content

Commit

Permalink
Add plan lowering pass for KQIR (#2247)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice committed Apr 14, 2024
1 parent a238053 commit 91509c7
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 24 deletions.
2 changes: 2 additions & 0 deletions src/search/index_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct FieldInfo {

FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata> &&metadata)
: name(std::move(name)), metadata(std::move(metadata)) {}

bool IsSortable() const { return dynamic_cast<redis::SearchSortableFieldMetadata *>(metadata.get()) != nullptr; }
};

struct IndexInfo {
Expand Down
17 changes: 11 additions & 6 deletions src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ struct Node {

virtual std::unique_ptr<Node> Clone() const = 0;

template <typename T>
std::unique_ptr<T> CloneAs() const {
return Node::MustAs<T>(Clone());
}

virtual ~Node() = default;

template <typename T, typename U = Node, typename... Args>
Expand Down Expand Up @@ -361,14 +366,14 @@ struct IndexRef : Ref {
std::unique_ptr<Node> Clone() const override { return std::make_unique<IndexRef>(*this); }
};

struct SearchStmt : Node {
struct SearchExpr : Node {
std::unique_ptr<SelectClause> select;
std::unique_ptr<IndexRef> index;
std::unique_ptr<QueryExpr> query_expr;
std::unique_ptr<LimitClause> limit; // optional
std::unique_ptr<SortByClause> sort_by; // optional

SearchStmt(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> &&query_expr,
SearchExpr(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> &&query_expr,
std::unique_ptr<LimitClause> &&limit, std::unique_ptr<SortByClause> &&sort_by,
std::unique_ptr<SelectClause> &&select)
: select(std::move(select)),
Expand All @@ -386,16 +391,16 @@ struct SearchStmt : Node {
}

static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
NodeIterator::MemFn<&SearchStmt::select>, NodeIterator::MemFn<&SearchStmt::index>,
NodeIterator::MemFn<&SearchStmt::query_expr>, NodeIterator::MemFn<&SearchStmt::limit>,
NodeIterator::MemFn<&SearchStmt::sort_by>,
NodeIterator::MemFn<&SearchExpr::select>, NodeIterator::MemFn<&SearchExpr::index>,
NodeIterator::MemFn<&SearchExpr::query_expr>, NodeIterator::MemFn<&SearchExpr::limit>,
NodeIterator::MemFn<&SearchExpr::sort_by>,
};

NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); };
NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); };

std::unique_ptr<Node> Clone() const override {
return std::make_unique<SearchStmt>(
return std::make_unique<SearchExpr>(
Node::MustAs<IndexRef>(index->Clone()), Node::MustAs<QueryExpr>(query_expr->Clone()),
Node::MustAs<LimitClause>(limit->Clone()), Node::MustAs<SortByClause>(sort_by->Clone()),
Node::MustAs<SelectClause>(select->Clone()));
Expand Down
72 changes: 70 additions & 2 deletions src/search/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#pragma once

#include "ir.h"
#include "search/ir_plan.h"

namespace kqir {

Expand All @@ -30,7 +31,7 @@ struct Pass {

struct Visitor : Pass {
std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
if (auto v = Node::As<SearchStmt>(std::move(node))) {
if (auto v = Node::As<SearchExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<SelectClause>(std::move(node))) {
return Visit(std::move(v));
Expand Down Expand Up @@ -58,6 +59,26 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<FullIndexScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<NumericFieldScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Filter>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Merge>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Sort>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<TopNSort>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Projection>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Noop>(std::move(node))) {
return Visit(std::move(v));
}

__builtin_unreachable();
Expand All @@ -73,7 +94,7 @@ struct Visitor : Pass {
return Node::MustAs<T>(Transform(std::move(n)));
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) {
node->index = VisitAs<IndexRef>(std::move(node->index));
node->select = VisitAs<SelectClause>(std::move(node->select));
node->query_expr = TransformAs<QueryExpr>(std::move(node->query_expr));
Expand Down Expand Up @@ -139,6 +160,53 @@ struct Visitor : Pass {
node->field = VisitAs<FieldRef>(std::move(node->field));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Noop> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<FullIndexScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericFieldScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { return node; }

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) {
node->op = TransformAs<PlanOperator>(std::move(node->op));
node->limit = VisitAs<LimitClause>(std::move(node->limit));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Sort> node) {
node->op = TransformAs<PlanOperator>(std::move(node->op));
node->order = VisitAs<SortByClause>(std::move(node->order));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<TopNSort> node) {
node->op = TransformAs<PlanOperator>(std::move(node->op));
node->limit = VisitAs<LimitClause>(std::move(node->limit));
node->order = VisitAs<SortByClause>(std::move(node->order));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Projection> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->select = VisitAs<SelectClause>(std::move(node->select));
return node;
}

virtual std::unique_ptr<Node> Visit(std::unique_ptr<Merge> node) {
for (auto &n : node->ops) {
n = TransformAs<PlanOperator>(std::move(n));
}

return node;
}
};

} // namespace kqir
91 changes: 78 additions & 13 deletions src/search/ir_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,30 @@ namespace kqir {

struct PlanOperator : Node {};

struct Noop : PlanOperator {
std::string_view Name() const override { return "Noop"; };
std::string Dump() const override { return "noop"; }

std::unique_ptr<Node> Clone() const override { return std::make_unique<Noop>(*this); }
};

struct FullIndexScan : PlanOperator {
IndexInfo *index;
std::unique_ptr<IndexRef> index;

explicit FullIndexScan(IndexInfo *index) : index(index) {}
explicit FullIndexScan(std::unique_ptr<IndexRef> index) : index(std::move(index)) {}

std::string_view Name() const override { return "FullIndexScan"; };
std::string Content() const override { return index->name; };
std::string Dump() const override { return fmt::format("full-scan {}", Content()); }
std::string Dump() const override { return fmt::format("full-scan {}", index->name); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<FullIndexScan>(*this); }
std::unique_ptr<Node> Clone() const override {
return std::make_unique<FullIndexScan>(Node::MustAs<IndexRef>(index->Clone()));
}
};

struct FieldScan : PlanOperator {
FieldInfo *field;
std::unique_ptr<FieldRef> field;

explicit FieldScan(FieldInfo *field) : field(field) {}
explicit FieldScan(std::unique_ptr<FieldRef> field) : field(std::move(field)) {}
};

struct Interval {
Expand All @@ -60,25 +68,29 @@ struct Interval {
struct NumericFieldScan : FieldScan {
Interval range;

NumericFieldScan(FieldInfo *field, Interval range) : FieldScan(field), range(range) {}
NumericFieldScan(std::unique_ptr<FieldRef> field, Interval range) : FieldScan(std::move(field)), range(range) {}

std::string_view Name() const override { return "NumericFieldScan"; };
std::string Content() const override { return fmt::format("{}, {}", field->name, range.ToString()); };
std::string Dump() const override { return fmt::format("numeric-scan {}", Content()); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<NumericFieldScan>(*this); }
std::unique_ptr<Node> Clone() const override {
return std::make_unique<NumericFieldScan>(field->CloneAs<FieldRef>(), range);
}
};

struct TagFieldScan : FieldScan {
std::string tag;

TagFieldScan(FieldInfo *field, std::string tag) : FieldScan(field), tag(std::move(tag)) {}
TagFieldScan(std::unique_ptr<FieldRef> field, std::string tag) : FieldScan(std::move(field)), tag(std::move(tag)) {}

std::string_view Name() const override { return "TagFieldScan"; };
std::string Content() const override { return fmt::format("{}, {}", field->name, tag); };
std::string Dump() const override { return fmt::format("tag-scan {}", Content()); }

std::unique_ptr<Node> Clone() const override { return std::make_unique<TagFieldScan>(*this); }
std::unique_ptr<Node> Clone() const override {
return std::make_unique<TagFieldScan>(field->CloneAs<FieldRef>(), tag);
}
};

struct Filter : PlanOperator {
Expand All @@ -89,7 +101,7 @@ struct Filter : PlanOperator {
: source(std::move(source)), filter_expr(std::move(filter_expr)) {}

std::string_view Name() const override { return "Filter"; };
std::string Dump() const override { return fmt::format("(filter {}, {})", source->Dump(), Content()); }
std::string Dump() const override { return fmt::format("(filter {}: {})", filter_expr->Dump(), source->Dump()); }

NodeIterator ChildBegin() override { return {source.get(), filter_expr.get()}; }
NodeIterator ChildEnd() override { return {}; }
Expand Down Expand Up @@ -143,6 +155,55 @@ struct Limit : PlanOperator {
}
};

struct Sort : PlanOperator {
std::unique_ptr<PlanOperator> op;
std::unique_ptr<SortByClause> order;

Sort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause> &&order)
: op(std::move(op)), order(std::move(order)) {}

std::string_view Name() const override { return "Sort"; };
std::string Dump() const override {
return fmt::format("(sort {}, {}: {})", order->field->Dump(), order->OrderToString(order->order), op->Dump());
}

NodeIterator ChildBegin() override { return NodeIterator{op.get(), order.get()}; }
NodeIterator ChildEnd() override { return {}; }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<Sort>(Node::MustAs<PlanOperator>(op->Clone()), Node::MustAs<SortByClause>(order->Clone()));
}
};

// operator fusion: Sort + Limit
struct TopNSort : PlanOperator {
std::unique_ptr<PlanOperator> op;
std::unique_ptr<SortByClause> order;
std::unique_ptr<LimitClause> limit;

TopNSort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause> &&order,
std::unique_ptr<LimitClause> &&limit)
: op(std::move(op)), order(std::move(order)), limit(std::move(limit)) {}

std::string_view Name() const override { return "TopNSort"; };
std::string Dump() const override {
return fmt::format("(top-n sort {}, {}, {}, {}: {})", order->field->Dump(), order->OrderToString(order->order),
limit->offset, limit->count, op->Dump());
}

static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
NodeIterator::MemFn<&TopNSort::op>, NodeIterator::MemFn<&TopNSort::order>, NodeIterator::MemFn<&TopNSort::limit>};

NodeIterator ChildBegin() override { return NodeIterator(this, ChildMap.begin()); }
NodeIterator ChildEnd() override { return NodeIterator(this, ChildMap.end()); }

std::unique_ptr<Node> Clone() const override {
return std::make_unique<TopNSort>(Node::MustAs<PlanOperator>(op->Clone()),
Node::MustAs<SortByClause>(order->Clone()),
Node::MustAs<LimitClause>(limit->Clone()));
}
};

struct Projection : PlanOperator {
std::unique_ptr<PlanOperator> source;
std::unique_ptr<SelectClause> select;
Expand All @@ -151,7 +212,11 @@ struct Projection : PlanOperator {
: source(std::move(source)), select(std::move(select)) {}

std::string_view Name() const override { return "Projection"; };
std::string Dump() const override { return fmt::format("(project {}: {})", select, source); }
std::string Dump() const override {
auto select_str =
select->fields.empty() ? "*" : util::StringJoin(select->fields, [](const auto &v) { return v->Dump(); });
return fmt::format("project {}: {}", select_str, source->Dump());
}

NodeIterator ChildBegin() override { return {source.get(), select.get()}; }
NodeIterator ChildEnd() override { return {}; }
Expand Down
5 changes: 4 additions & 1 deletion src/search/ir_sema_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <map>
#include <memory>

#include "fmt/core.h"
#include "index_info.h"
#include "ir.h"
#include "search_encoding.h"
Expand All @@ -38,7 +39,7 @@ struct SemaChecker {
explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {}

Status Check(Node *node) {
if (auto v = dynamic_cast<SearchStmt *>(node)) {
if (auto v = dynamic_cast<SearchExpr *>(node)) {
auto index_name = v->index->name;
if (auto iter = index_map.find(index_name); iter != index_map.end()) {
current_index = &iter->second;
Expand All @@ -56,6 +57,8 @@ struct SemaChecker {
} else if (auto v = dynamic_cast<SortByClause *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter == current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index `{}`", v->field->name, current_index->name)};
} else if (!iter->second.IsSortable()) {
return {Status::NotOK, fmt::format("field `{}` is not sortable", v->field->name)};
} else {
v->field->info = &iter->second;
}
Expand Down
51 changes: 51 additions & 0 deletions src/search/passes/lower_to_plan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#pragma once

#include <memory>

#include "search/ir.h"
#include "search/ir_pass.h"
#include "search/ir_plan.h"

namespace kqir {

struct LowerToPlan : Visitor {
std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
auto scan = std::make_unique<FullIndexScan>(node->index->CloneAs<IndexRef>());
auto filter = std::make_unique<Filter>(std::move(scan), std::move(node->query_expr));

std::unique_ptr<PlanOperator> op = std::move(filter);

// order is important here, since limit(sort(op)) is different from sort(limit(op))
if (node->sort_by) {
op = std::make_unique<Sort>(std::move(op), std::move(node->sort_by));
}

if (node->limit) {
op = std::make_unique<Limit>(std::move(op), std::move(node->limit));
}

return std::make_unique<Projection>(std::move(op), std::move(node->select));
}
};

} // namespace kqir
Loading

0 comments on commit 91509c7

Please sign in to comment.