diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 68b175c27..3b06ede72 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory("Examples") add_subdirectory("GenJSON") add_subdirectory("Index") +add_subdirectory("Query") diff --git a/bin/Index/BuildAST.cpp b/bin/Index/BuildAST.cpp new file mode 100644 index 000000000..e5df35cb2 --- /dev/null +++ b/bin/Index/BuildAST.cpp @@ -0,0 +1,188 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include "BuildAST.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace indexer { + +static void SerializeAST(mx::Fragment fragment, ServerContext &ctx) { + auto &ast = ctx->ast; + std::unordered_map ctx_to_node_id; + + for (mx::Token tok : mx::Token::in(fragment)) { + // Skip whitespaces + switch (tok.kind()) { + case mx::TokenKind::UNKNOWN: + case mx::TokenKind::WHITESPACE: + case mx::TokenKind::COMMENT: + continue; + default: + if (tok.data().empty()) { + continue; + } + break; + } + ctx->spelling_to_token_kind.Set(tok.data(), tok.kind()); + + // Start with the token node + mx::ASTNode node{}; + node.kind = mx::SyntexNodeKind{tok.kind()}.Serialize(); + node.entity = tok.id(); + node.spelling = std::string(tok.data().data(), tok.data().size()); + node.prev = ast.GetNodeInIndex(fragment.id(), node.kind); + std::optional node_id = ast.AddNode(node); + ast.SetNodeInIndex(fragment.id(), node.kind, *node_id); + + for (auto ctx = mx::TokenContext::of(tok); ctx; ctx = ctx->parent()) { + auto it = ctx_to_node_id.find(ctx->id()); + + // Add to parent node's children if it already exists + + if (it != ctx_to_node_id.end()) { + ast.AddChild(it->second, *node_id); + node_id = std::nullopt; + break; + } + + // Otherwise we need to create a new parent node + + if (auto decl = mx::Decl::from(*ctx)) { + mx::ASTNode parent{}; + parent.kind = mx::SyntexNodeKind{decl->kind()}.Serialize(); + parent.entity = decl->id(); + parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); + auto parent_id = ast.AddNode(parent); + // Add it to the index + ast.SetNodeInIndex(fragment.id(), parent.kind, parent_id); + ctx_to_node_id[ctx->id()] = parent_id; + ast.AddChild(parent_id, *node_id); + node_id = parent_id; + continue; + } + + if (auto stmt = mx::Stmt::from(*ctx)) { + mx::ASTNode parent{}; + parent.kind = mx::SyntexNodeKind{stmt->kind()}.Serialize(); + parent.entity = stmt->id(); + parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); + auto parent_id = ast.AddNode(parent); + // Add it to the index + ast.SetNodeInIndex(fragment.id(), parent.kind, parent_id); + ctx_to_node_id[ctx->id()] = parent_id; + ast.AddChild(parent_id, *node_id); + node_id = parent_id; + continue; + } + } + + // If we didn't add the token to a pre-existing parent, add it to the root + + if (node_id.has_value()) { + ast.AddNodeToRoot(fragment.id(), *node_id); + } + } +} + +static void ImportGrammar(mx::Fragment fragment, ServerContext& ctx) { + auto &ast = ctx->ast; + auto &grammar = ctx->grammar; + auto nodes = ast.Root(fragment.id()); + + // Make a production rule for every node and its children. + while (!nodes.empty()) { + auto node_id = nodes.back(); + nodes.pop_back(); + + auto node = ast.GetNode(node_id); + auto node_kind = mx::SyntexNodeKind::Deserialize(node.kind); + + if (!node_kind.IsToken()) { + // This is an internal or root node. E.g. given the following: + // + // A + // / | \ + // B C D + // + // We want to make a rule of the form `B C D A`, i.e. if you match `B C D` + // then you have matched an `A`. This "backward" syntax enables us to prefix + // scan for left corners (`B` in this case) and find all rules starting with + // `B`. + + auto child_vector = ast.GetChildren(node_id); + assert(child_vector.size() >= 1); + + // FIXME: do something else with long grammar rules. PHP has + // some generated initializer lists with 100s of elements that + // blows up our stack when serializing a grammar. + if (child_vector.size() > 100) { + continue; + } + + // Add the child nodes to the work list. + nodes.insert(nodes.end(), child_vector.begin(), child_vector.end()); + + // Walk the trie + std::uint64_t leaves_id = 0; + for (auto child_id : child_vector) { + auto child = ast.GetNode(child_id); + leaves_id = grammar.GetChild(leaves_id, child.kind); + } + // Save pointer to rule head + auto head_id = grammar.GetChild(leaves_id, node.kind); + + // Avoid creating cyclic CFGs + bool allow_production = true; + + if (child_vector.size() == 1) { + std::vector queue = { node.kind }; + while (!queue.empty()) { + auto nt = queue.back(); + queue.pop_back(); + + // Check if we can reach our own left corner + auto child = ast.GetNode(child_vector[0]); + if (nt == child.kind) { + allow_production = false; + break; + } + + // Queue result of matching trivial productions + for(auto [left, rest] : grammar.GetChildLeaves(0, nt)) { + auto node = grammar.GetNode(rest); + if(node.is_production) { + queue.push_back(left); + } + } + } + } + + // Mark the head as a production if appropriate + grammar.UpdateNode(head_id, {allow_production}); + } + } +} + +void BuildAST(mx::Index index, ServerContext &context) { + for(auto file : mx::File::in(index)) { + for(auto fragment : mx::Fragment::in(file)) { + sqlite::Transaction tx(context.db); + std::scoped_lock lock(tx); + SerializeAST(fragment, context); + ImportGrammar(fragment, context); + } + } +} +} // namespace indexer \ No newline at end of file diff --git a/bin/Index/BuildAST.h b/bin/Index/BuildAST.h new file mode 100644 index 000000000..2cdabcf9e --- /dev/null +++ b/bin/Index/BuildAST.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include "Context.h" + +namespace indexer { + +void BuildAST(mx::Index index, ServerContext& context); + +} // namespace indexer diff --git a/bin/Index/CMakeLists.txt b/bin/Index/CMakeLists.txt index 46be997d7..d156e3adc 100644 --- a/bin/Index/CMakeLists.txt +++ b/bin/Index/CMakeLists.txt @@ -9,6 +9,8 @@ set(exe_name "mx-index") add_executable("${exe_name}" + "BuildAST.cpp" + "BuildAST.h" "BuildPendingFragment.cpp" "Compress.cpp" "Compress.h" @@ -57,6 +59,7 @@ target_link_libraries("${exe_name}" PRIVATE ${MX_BEGIN_FORCE_LOAD_GROUP} "mx-util" + "mx-api" "concurrentqueue" ${MX_BEGIN_FORCE_LOAD_LIB} pasta::pasta ${MX_END_FORCE_LOAD_LIB} ${MX_END_FORCE_LOAD_GROUP} diff --git a/bin/Index/Main.cpp b/bin/Index/Main.cpp index 5f6b23180..89a0f1d53 100644 --- a/bin/Index/Main.cpp +++ b/bin/Index/Main.cpp @@ -24,6 +24,7 @@ #include "Context.h" #include "Parser.h" #include "Importer.h" +#include "BuildAST.h" // Should we show a help message? DECLARE_bool(help); @@ -158,5 +159,8 @@ extern "C" int main(int argc, char *argv[]) { executor.Start(); executor.Wait(); + auto index = mx::Index(mx::EntityProvider::from_database(FLAGS_db)); + indexer::BuildAST(index, ic->server_context[0]); + return EXIT_SUCCESS; } diff --git a/bin/Query/CMakeLists.txt b/bin/Query/CMakeLists.txt new file mode 100644 index 000000000..353e157c6 --- /dev/null +++ b/bin/Query/CMakeLists.txt @@ -0,0 +1,35 @@ +# +# Copyright (c) 2022-present, Trail of Bits, Inc. +# All rights reserved. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. +# + +add_executable("syntex-query" "SyntexQuery.cpp") + +target_link_libraries("syntex-query" + PRIVATE + gflags + glog::glog + "mx-api" +) + +install( + TARGETS + "syntex-query" + EXPORT + "${PROJECT_NAME}Targets" + RUNTIME + DESTINATION + "${CMAKE_INSTALL_BINDIR}" +) + +add_executable("predicate-example" "PredicateExample.cpp") + +target_link_libraries("predicate-example" + PRIVATE + gflags + glog::glog + "mx-api" +) diff --git a/bin/Query/PredicateExample.cpp b/bin/Query/PredicateExample.cpp new file mode 100644 index 000000000..24662df7d --- /dev/null +++ b/bin/Query/PredicateExample.cpp @@ -0,0 +1,185 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +// +// Example utility that uses syntex predicates to locate float to integer casts +// + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define ANSI_RED "\x1b[1;31m" +#define ANSI_RESET "\x1b[1;0m" + +DECLARE_bool(help); +DEFINE_string(db, "", "Path to Multiplier database."); +DEFINE_bool(long_is_32_bits, false, "Is 'long' a 32-bit type?"); + +static std::optional IntegralTypeWidth(const mx::Type &type) { + auto builtin_type = mx::BuiltinType::from(type); + if (!builtin_type) { + return std::nullopt; + } + + switch (builtin_type->builtin_kind()) { + case mx::BuiltinTypeKind::S_CHAR: + case mx::BuiltinTypeKind::CHARACTER_U: + case mx::BuiltinTypeKind::CHARACTER_S: + case mx::BuiltinTypeKind::BOOLEAN: + case mx::BuiltinTypeKind::CHAR8: + case mx::BuiltinTypeKind::U_CHAR: + return 8u; + case mx::BuiltinTypeKind::W_CHAR_S: + case mx::BuiltinTypeKind::W_CHAR_U: + case mx::BuiltinTypeKind::CHAR16: + case mx::BuiltinTypeKind::SHORT: + case mx::BuiltinTypeKind::U_SHORT: + return 16u; + case mx::BuiltinTypeKind::CHAR32: + case mx::BuiltinTypeKind::INT: + case mx::BuiltinTypeKind::U_INT: + return 32u; + case mx::BuiltinTypeKind::U_LONG: + case mx::BuiltinTypeKind::LONG: + return FLAGS_long_is_32_bits ? 32u : 64u; + case mx::BuiltinTypeKind::U_LONG_LONG: + case mx::BuiltinTypeKind::LONG_LONG: + return 64u; + case mx::BuiltinTypeKind::U_INT128: + case mx::BuiltinTypeKind::INT128: + return 128u; + default: + return std::nullopt; + } +} + +static std::optional IntegralTypeWidth(const mx::ValueDecl &decl) { + return IntegralTypeWidth(decl.type()); +} + +static void HighlightMatch(std::ostream &os, mx::Index index, mx::SyntexMatch m) { + auto stmt = std::get(index.entity(m.MetavarMatch(0).Entity())); + auto ref = mx::DeclRefExpr::from(stmt); + if (!ref) { + return; + } + + auto var = mx::VarDecl::from(ref->referenced_declaration()); + if (!var) { + return; + } + + auto type_size = IntegralTypeWidth(var.value()); + if (!type_size) { + return; + } + + auto lit = mx::IntegerLiteral::from( + std::get(index.entity(m.MetavarMatch(1).Entity()))); + if (!lit) { + return; + } + + auto lit_data = lit->token().data(); + + // Strip off the suffixes of things like `0ull`. + while (lit_data.ends_with('u') || lit_data.ends_with('U') || + lit_data.ends_with('l') || lit_data.ends_with('L') || + lit_data.ends_with(' ')) { + lit_data = lit_data.substr(0, lit_data.size() - 1u); + } + + int64_t lit_val{-1}; + + std::stringstream ss; + ss << lit_data; + ss >> lit_val; + + if (0 >= lit_val) { + return; + } + + if (type_size.value() > static_cast(lit_val)) { + return; + } + + auto entity = index.entity(m.Entity()); + auto fragment = index.fragment_containing(m.Entity()); + auto builtin_type = mx::BuiltinType::from(var->type()); + mx::TokenRange tok_range; + if(std::holds_alternative(entity)) { + tok_range = std::get(entity); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } + os << "File ID: " << mx::File::containing(*fragment).id() << '\n' + << "Fragment ID: " << fragment->id().Pack() << '\n' + << "Literal value: " << lit_val << '\n' + << "Type size: " << type_size.value() << '\n' + << "Type kind: " << mx::EnumeratorName(builtin_type->builtin_kind()) + << "\nExpression:"; + + for (mx::Token tok : tok_range) { + os << ' ' << tok.data(); + } + + os << "\n\n"; +} + +extern "C" int main(int argc, char *argv[]) { + std::stringstream ss; + ss + << "Usage: " << argv[0] + << " [--db DATABASE]\n"; + + google::SetUsageMessage(ss.str()); + google::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + if (FLAGS_help) { + std::cerr << google::ProgramUsage() << std::endl; + return EXIT_FAILURE; + } + + if (FLAGS_db.empty()) { + std::cerr << "Need to specify a database using --db" << std::endl; + return EXIT_FAILURE; + } + + // Setup index and grammar + + mx::Index index = mx::EntityProvider::in_memory_cache( + mx::EntityProvider::from_database(FLAGS_db)); + + // Setup query + + auto res = index.parse_syntex_query("$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); + if (!res.IsValid()) { + return EXIT_FAILURE; + } + + // Match fragments + res.ForEachMatch([&](auto match) { + HighlightMatch(std::cout, index, std::move(match)); + return true; + }); + + return EXIT_SUCCESS; +} diff --git a/bin/Query/SyntexQuery.cpp b/bin/Query/SyntexQuery.cpp new file mode 100644 index 000000000..ed4eae29e --- /dev/null +++ b/bin/Query/SyntexQuery.cpp @@ -0,0 +1,103 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +#define ANSI_RED "\x1b[1;31m" +#define ANSI_RESET "\x1b[1;0m" + +DECLARE_bool(help); +DEFINE_string(db, "", "Path to Multiplier database."); +DEFINE_string(query, "", "Use argument value as query"); +DEFINE_uint64(threads, 0, "Use this number of threads"); +DEFINE_bool(suppress_output, false, "Don't print matches to stdout"); + +static void PrintMatch(mx::Index index, const mx::SyntexMatch &match) +{ + if (FLAGS_suppress_output) { + return; + } + + auto entity = index.entity(match.Entity()); + auto fragment = *index.fragment_containing(match.Entity()); + mx::TokenRange tok_range; + if(std::holds_alternative(entity)) { + tok_range = std::get(entity); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } + + // Print matching fragment ID + std::cout << "Match in " << fragment.id() << ":\n"; + for (auto token : fragment.parsed_tokens()) { + if (token.id() == tok_range.front().id()) { + // Switch to ANSI red for the first matching token + std::cout << ANSI_RED; + } + + std::cout << token.data() << " "; + + if (token.id() == tok_range.back().id()) { + // Reset color after last matching token + std::cout << ANSI_RESET; + } + } + + std::cout << "\n"; + + for (auto &metavar : match.MetavarMatches()) { + std::cout << "Matching metavar " << metavar.Name() << "\n"; + } +} + +extern "C" int main(int argc, char *argv[]) { + std::stringstream ss; + ss + << "Usage: " << argv[0] + << " [--db DATABASE]\n"; + + google::SetUsageMessage(ss.str()); + google::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + if (FLAGS_help) { + std::cerr << google::ProgramUsage() << std::endl; + return EXIT_FAILURE; + } + + if (FLAGS_db.empty()) { + std::cerr << "Need to specify a database using --db" << std::endl; + return EXIT_FAILURE; + } + + // Setup index and grammar + + mx::Index index = mx::EntityProvider::from_database(FLAGS_db); + + // Parse query + + auto res = index.parse_syntex_query(FLAGS_query); + + if (!res.IsValid()) { + std::cerr << "Query `" << FLAGS_query << "` has no valid parses\n"; + return EXIT_FAILURE; + } + + res.ForEachMatch([&](auto match) { + PrintMatch(index, std::move(match)); + return true; + }); + + return EXIT_SUCCESS; +} diff --git a/include/multiplier/Index.h b/include/multiplier/Index.h index b046a56b8..ecc79b82f 100644 --- a/include/multiplier/Index.h +++ b/include/multiplier/Index.h @@ -43,6 +43,14 @@ class WeggliQueryMatch; class WeggliQueryResultIterator; class WeggliQueryResult; class WeggliQueryResultImpl; +struct ASTNode; + +class SyntexNodeKind; +class SyntexGrammarNode; +class SyntexQuery; +class SyntexQueryImpl; + +using SyntexGrammarLeaves = std::unordered_map; using DeclUse = Use; using StmtUse = Use; @@ -102,6 +110,8 @@ class EntityProvider { friend class UseIteratorImpl; friend class WeggliQueryResultImpl; friend class WeggliQueryResultIterator; + friend class SyntexQuery; + friend class SyntexQueryImpl; protected: @@ -172,6 +182,17 @@ class EntityProvider { virtual void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) = 0; + + virtual std::optional + TokenKindOf(std::string_view spelling) = 0; + + virtual void LoadGrammarRoot(SyntexGrammarLeaves &root) = 0; + + virtual std::vector GetFragmentsInAST(void) = 0; + virtual ASTNode GetASTNode(std::uint64_t id) = 0; + virtual std::vector GetASTNodeChildren(std::uint64_t id) = 0; + virtual std::vector GetASTNodesInFragment(RawEntityId frag) = 0; + virtual std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) = 0; }; using VariantEntity = std::variant prev; + unsigned short kind; + RawEntityId entity; + std::optional spelling; +}; + +class PersistentAST final { + struct Impl; + std::unique_ptr impl; + + public: + PersistentAST(sqlite::Connection &db); + + std::vector Root(RawEntityId fragment); + + std::uint64_t AddNode(const ASTNode& node); + + void AddNodeToRoot(RawEntityId fragment, std::uint64_t node_id); + + ASTNode GetNode(std::uint64_t node_id); + + std::optional GetNodeInIndex( + RawEntityId fragment, + unsigned short kind); + + void SetNodeInIndex( + RawEntityId fragment, + unsigned short kind, + std::uint64_t node_id); + + std::vector GetFragments(); + + std::vector GetChildren(std::uint64_t parent); + + void AddChild(std::uint64_t parent, std::uint64_t child); +}; + +struct GrammarNode { + bool is_production; +}; + +class PersistentGrammar final { + struct Impl; + std::unique_ptr impl; + + public: + PersistentGrammar(sqlite::Connection &db); + + std::vector> GetChildren(std::uint64_t parent); + + std::uint64_t GetChild(std::uint64_t parent, unsigned short kind); + + std::vector> GetChildLeaves(std::uint64_t parent, unsigned short kind); + + void UpdateNode(std::uint64_t id, const GrammarNode &node); + + GrammarNode GetNode(std::uint64_t id); +}; + class IndexStorage final { sqlite::Connection &db; @@ -160,6 +221,13 @@ class IndexStorage final { mx::PersistentSet entity_id_reference; + mx::PersistentMap + spelling_to_token_kind; + + PersistentAST ast; + + PersistentGrammar grammar; + // SQLite database. Used for things like symbol searches. SymbolDatabase database; diff --git a/include/multiplier/NodeKind.h b/include/multiplier/NodeKind.h new file mode 100644 index 000000000..0df85d852 --- /dev/null +++ b/include/multiplier/NodeKind.h @@ -0,0 +1,131 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace mx { + +// +// SyntexNodeKind: Core class of Syntex, represents the following things: +// - An entry in a grammar rule +// - Kind of node in a multiplier AST +// - Kind of node in a query AST +// + +class SyntexNodeKind { +private: + unsigned short val; + + SyntexNodeKind(unsigned short val_) : val(val_) {} + +public: + static SyntexNodeKind Any() { + return SyntexNodeKind(UpperLimit()); + } + + SyntexNodeKind(mx::DeclKind kind) + : val(static_cast(kind)) {} + + SyntexNodeKind(mx::StmtKind kind) + : val(static_cast(kind) + + mx::NumEnumerators(mx::DeclKind{})) {} + + SyntexNodeKind(mx::TokenKind kind) + : val(static_cast(kind) + + mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{})) {} + + template + auto Visit(T visitor) const { + if (val < mx::NumEnumerators(mx::DeclKind{})) { + return visitor(static_cast(val)); + } else if (val < mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{})) { + return visitor(static_cast(val + - mx::NumEnumerators(mx::DeclKind{}))); + } else if (val < UpperLimit()) { + return visitor(static_cast(val + - mx::NumEnumerators(mx::DeclKind{}) + - mx::NumEnumerators(mx::StmtKind{}))); + } else { + return visitor(); + } + } + + bool IsToken() const { + return val >= mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{}); + } + + mx::TokenKind AsToken() const { + assert(IsToken()); + return static_cast(val + - mx::NumEnumerators(mx::DeclKind{}) + - mx::NumEnumerators(mx::StmtKind{})); + } + + bool operator==(const SyntexNodeKind &other) const { + return val == other.val; + } + + static SyntexNodeKind Deserialize(unsigned short val) { + return val; + } + + unsigned short Serialize() const { + return val; + } + + static constexpr unsigned short UpperLimit() { + return mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{}) + + mx::NumEnumerators(mx::TokenKind{}); + } +}; + +// +// This template (and deduction guide) allows for the easy generation of nice +// looking visitors from a set of lambdas. See the operator<< implementation +// below as an example usecase. +// + +template +struct Visitor : F ... { + using F::operator() ...; +}; + +template Visitor(F...) -> Visitor; + +// +// Pretty print a NodeKind to an output stream +// + +inline std::ostream& operator<<(std::ostream &os, const SyntexNodeKind &kind) { + kind.Visit(Visitor { + [&] (mx::DeclKind kind) { os << "DeclKind::" << EnumeratorName(kind); }, + [&] (mx::StmtKind kind) { os << "StmtKind::" << EnumeratorName(kind); }, + [&] (mx::TokenKind kind) { os << "TokenKind::" << EnumeratorName(kind); }, + [&] () { os << "NodeKind::Any"; }, + }); + return os; +} + +} // namespace mx + +namespace std { + +template<> +struct hash { + size_t operator()(const mx::SyntexNodeKind &kind) const { + return kind.Serialize(); + } +}; + +} // namespace std \ No newline at end of file diff --git a/include/multiplier/PersistentMap.h b/include/multiplier/PersistentMap.h index e50c87293..79fc6ee7f 100644 --- a/include/multiplier/PersistentMap.h +++ b/include/multiplier/PersistentMap.h @@ -34,6 +34,7 @@ static constexpr const char* table_names[] = { "'mx::MangledNameToEntityId'", "'mx::EntityIdUseToFragmentId'", "'mx::EntityIdReference'", + "'mx::syntex::Tokens'", }; template @@ -106,7 +107,8 @@ class PersistentSet { db.Execute(ss.str()); ss = {}; - ss << "INSERT OR IGNORE INTO " << table_names[kId] << '(' << table_desc.str() << ") VALUES("; + ss << "INSERT OR IGNORE INTO " << table_names[kId] + << '(' << table_desc.str() << ") VALUES("; for(size_t i = 0; i < sizeof...(Keys); ++i) { ss << "?" << (i + 1); if(i != sizeof...(Keys) - 1) { @@ -139,7 +141,8 @@ class PersistentSet { for(size_t i = 0; i < sizeof...(Keys); ++i) { ss = {}; - ss << "SELECT " << table_desc.str() << " FROM " << table_names[kId] << " WHERE "; + ss << "SELECT " << table_desc.str() + << " FROM " << table_names[kId] << " WHERE "; for(size_t j = 0; j <= i; j++) { ss << "key" << j << " = ?" << (j + 1); if(j != i) { @@ -193,26 +196,37 @@ template class PersistentMap { private: sqlite::Connection &db; - std::shared_ptr set_stmt, get_stmt, get_or_set_stmt; + std::shared_ptr set_stmt; + std::shared_ptr get_stmt; + std::shared_ptr get_or_set_stmt; + std::shared_ptr enum_stmt; public: PersistentMap(sqlite::Connection &db) : db(db) { std::stringstream ss; - ss << "CREATE TABLE IF NOT EXISTS " << table_names[kId] << "(key, value, PRIMARY KEY(key))"; + ss << "CREATE TABLE IF NOT EXISTS " << table_names[kId] + << "(key, value, PRIMARY KEY(key))"; db.Execute(ss.str()); ss = {}; - ss << "INSERT OR REPLACE INTO " << table_names[kId] << "(key, value) VALUES (?1, ?2)"; + ss << "INSERT OR REPLACE INTO " << table_names[kId] + << "(key, value) VALUES (?1, ?2)"; set_stmt = db.Prepare(ss.str()); ss = {}; - ss << "SELECT key, value FROM " << table_names[kId] << " WHERE key = ?1"; + ss << "SELECT key, value FROM " << table_names[kId] + << " WHERE key = ?1"; get_stmt = db.Prepare(ss.str()); ss = {}; ss << "INSERT INTO " << table_names[kId] - << "(key, value) VALUES(?1, ?2) ON CONFLICT DO UPDATE SET value=value RETURNING key, value"; + << "(key, value) VALUES(?1, ?2) " + << "ON CONFLICT DO UPDATE SET value=value RETURNING key, value"; get_or_set_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key, value FROM " << table_names[kId]; + enum_stmt = db.Prepare(ss.str()); } V GetOrSet(K key, V val) const { diff --git a/include/multiplier/Query.h b/include/multiplier/Query.h index 01fad218c..7e9514819 100644 --- a/include/multiplier/Query.h +++ b/include/multiplier/Query.h @@ -37,6 +37,10 @@ class WeggliQueryMatch; class WeggliQueryResultIterator; class WeggliQueryResult; class WeggliQueryResultImpl; +class SyntexQuery; +class SyntexQueryImpl; +class SyntexMatch; +class SyntexMetavarMatch; // The range of tokens of a match. class WeggliQueryMatch : public TokenRange { @@ -325,4 +329,25 @@ class RegexQueryResult { } }; +class SyntexQuery { + private: + std::shared_ptr impl; + SyntexQuery(void) = delete; + + public: + explicit SyntexQuery(std::shared_ptr ep, std::string_view query); + + bool IsValid() const; + + bool AddMetavarPredicate(const std::string_view &name, + std::function predicate); + + void ForEachMatch(mx::RawEntityId frag_id, + std::function pred) const; + void ForEachMatch(std::function pred) const; + + std::vector Find(mx::RawEntityId frag_id) const; + std::vector Find(void) const; +}; + } // namespace mx diff --git a/include/multiplier/SQLiteStore.h b/include/multiplier/SQLiteStore.h index 27655a1e7..92af9525a 100644 --- a/include/multiplier/SQLiteStore.h +++ b/include/multiplier/SQLiteStore.h @@ -34,6 +34,39 @@ class Error : public std::runtime_error { }; class QueryResult { + private: + void column_dispatcher(int& idx, std::string& arg) { + arg = getText(idx); + idx++; + } + + void column_dispatcher(int& idx, std::string_view& arg) { + arg = getBlob(idx); + idx++; + } + + void column_dispatcher(int& idx, std::nullopt_t& arg) { + idx++; + } + + template || std::is_enum_v>> + void column_dispatcher(int& idx, T& arg) { + arg = static_cast(getInt64(idx)); + idx++; + } + + template + void column_dispatcher(int& idx, std::optional& arg) { + if(isNull(idx)) { + arg = {}; + idx++; + return; + } + T value; + column_dispatcher(idx, value); + arg = value; + } + public: ~QueryResult() = default; @@ -50,23 +83,7 @@ class QueryResult { } int idx = 0; - auto column_dispatcher = [this, &idx] (auto &&arg) { - using arg_t = std::decay_t; - if constexpr (std::is_integral_v) { - arg = static_cast(getInt64(idx)); - } else if (std::is_same_v) { - arg = getText(idx); - } else if (std::is_same_v) { - arg = getBlob(idx); - } else if constexpr (std::is_same_v) { - ; - } else { - throw Error("Can't read column data; Type not supported!"); - } - idx++; - }; - - (column_dispatcher(std::forward(args)), ...); + (column_dispatcher(idx, std::forward(args)), ...); } private: @@ -81,6 +98,7 @@ class QueryResult { std::string getText(int32_t idx); std::string_view getBlob(int32_t idx); + bool isNull(int32_t idx); std::shared_ptr stmt; }; @@ -145,6 +163,20 @@ class Statement : public std::enable_shared_from_this { void bind(const size_t i, const std::string_view &value); + template + void bind(const size_t i, const std::optional &value) { + if(value.has_value()) { + bind(i, value.value()); + } else { + bind(i, nullptr); + } + } + + template>> + void bind(const size_t i, const T &value) { + bind(i, static_cast(value)); + } + void reset(); template diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h new file mode 100644 index 000000000..a846786cc --- /dev/null +++ b/include/multiplier/Syntex.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include "Entities/Attr.h" +#include "Entities/Decl.h" +#include "Entities/Designator.h" +#include "Entities/Stmt.h" +#include "Entities/Type.h" +#include "File.h" +#include "Index.h" +#include "Token.h" + +namespace mx { + +class SyntexQuery; +class SyntexQueryImpl; + +// +// Chunk of a fragment (potentially) matching a metavariable +// + +class SyntexMetavarMatch { +private: + std::string name; + mx::EntityId entity; + +public: + SyntexMetavarMatch(const std::string &name_, mx::EntityId entity_) + : name(std::move(name_)), + entity(std::move(entity_)) {} + + const std::string &Name(void) const { + return name; + } + + mx::EntityId Entity(void) const { + return entity; + } +}; + +// +// Chunk of a ParsedQuery that matched against a part of a fragment +// + +class SyntexMatch { +private: + friend class SyntexQuery; + + mx::EntityId entity; + + std::vector metavars; + +public: + SyntexMatch(mx::EntityId entity_, std::vector matevars_) + : entity(std::move(entity_)), + metavars(std::move(matevars_)) {} + + const mx::EntityId &Entity(void) const { + return entity; + } + + const std::vector &MetavarMatches(void) const { + return metavars; + } + + const SyntexMetavarMatch &MetavarMatch(size_t i) const { + return metavars[i]; + } +}; + +} // namespace mx diff --git a/lib/API/CMakeLists.txt b/lib/API/CMakeLists.txt index b0dfd7566..80f821325 100644 --- a/lib/API/CMakeLists.txt +++ b/lib/API/CMakeLists.txt @@ -47,12 +47,15 @@ add_library("mx-api" "Fragment.cpp" "FragmentImpl.cpp" "Fragment.h" + "Grammar.h" "Index.cpp" "InvalidEntityProvider.cpp" "InvalidEntityProvider.h" "PackedFileImpl.cpp" "PackedFragmentImpl.cpp" "PackedReaderState.cpp" + "Query.cpp" + "Query.h" "Re2.cpp" "Re2.h" "SQLiteEntityProvider.cpp" diff --git a/lib/API/CachingEntityProvider.cpp b/lib/API/CachingEntityProvider.cpp index 47865f110..ceb11af1c 100644 --- a/lib/API/CachingEntityProvider.cpp +++ b/lib/API/CachingEntityProvider.cpp @@ -4,8 +4,12 @@ // This source code is licensed in accordance with the terms specified in // the LICENSE file found in the root directory of this source tree. +#include +#include "Grammar.h" #include "CachingEntityProvider.h" +#include + #include #include #include @@ -40,6 +44,13 @@ void CachingEntityProvider::ClearCacheLocked(unsigned new_version_number) { references.clear(); has_file_list = false; version_number = new_version_number; + spelling_to_token_kind.clear(); + grammar_root.clear(); + fragments_in_ast.clear(); + node_contents.clear(); + node_children.clear(); + fragment_nodes.clear(); + node_index.clear(); next->VersionNumberChanged(new_version_number); } @@ -260,4 +271,79 @@ EntityProvider::Ptr EntityProvider::in_memory_cache( return ret; } +std::optional +CachingEntityProvider::TokenKindOf(std::string_view spelling) { + std::string str{spelling.data(), spelling.size()}; + auto it = spelling_to_token_kind.find(str); + if(it == spelling_to_token_kind.end()) { + auto kind = next->TokenKindOf(spelling); + if(kind) { + spelling_to_token_kind[str] = *kind; + } + return kind; + } + return it->second; +} + +void CachingEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves& root) { + if(grammar_root.empty()) { + next->LoadGrammarRoot(grammar_root); + } + root = grammar_root; +} + +std::vector CachingEntityProvider::GetFragmentsInAST(void) { + if(fragments_in_ast.empty()) { + fragments_in_ast = next->GetFragmentsInAST(); + } + return fragments_in_ast; +} + +ASTNode CachingEntityProvider::GetASTNode(std::uint64_t id) { + auto it = node_contents.find(id); + if(it == node_contents.end()) { + node_contents[id] = next->GetASTNode(id); + } + return node_contents[id]; +} + +std::vector CachingEntityProvider::GetASTNodeChildren(std::uint64_t id) { + if(node_children.find(id) == node_children.end()) { + for(auto child : next->GetASTNodeChildren(id)) { + node_children.insert({id, child}); + } + } + std::vector res; + for(auto [it, end] = node_children.equal_range(id); it != end; ++it) { + res.push_back(it->second); + } + return res; +} + +std::vector CachingEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + if(fragment_nodes.find(frag) == fragment_nodes.end()) { + for(auto child : next->GetASTNodesInFragment(frag)) { + fragment_nodes.insert({frag, child}); + } + } + std::vector res; + for(auto [it, end] = fragment_nodes.equal_range(frag); it != end; ++it) { + res.push_back(it->second); + } + return res; +} + +std::optional CachingEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + auto it = node_index.find({frag, kind}); + if(it == node_index.end()) { + auto value = next->GetASTNodeWithKind(frag, kind); + if(value.has_value()) { + node_index[{frag, kind}] = value.value(); + return value; + } + return {}; + } + return it->second; +} + } // namespace mx diff --git a/lib/API/CachingEntityProvider.h b/lib/API/CachingEntityProvider.h index 59439eb2a..194263750 100644 --- a/lib/API/CachingEntityProvider.h +++ b/lib/API/CachingEntityProvider.h @@ -14,6 +14,23 @@ #include #include +template +inline void hash_combine(size_t &h, const T& v) +{ + std::hash hasher; + h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); +} + +template<> +struct std::hash> { + size_t operator()(const std::pair &self) const { + size_t hash = 0; + hash_combine(hash, self.first); + hash_combine(hash, self.second); + return hash; + } +}; + namespace mx { class CachingEntityProvider final : public EntityProvider { @@ -50,6 +67,15 @@ class CachingEntityProvider final : public EntityProvider { std::unordered_map>> references; + std::unordered_map spelling_to_token_kind; + SyntexGrammarLeaves grammar_root; + + std::vector fragments_in_ast; + std::unordered_map node_contents; + std::unordered_multimap node_children; + std::unordered_multimap fragment_nodes; + std::unordered_map, std::uint64_t> node_index; + void ClearCacheLocked(unsigned new_version_number); inline CachingEntityProvider(Ptr next_) @@ -98,6 +124,16 @@ class CachingEntityProvider final : public EntityProvider { mx::DeclCategory category, std::vector &ids_out) final; + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/API/Grammar.h b/lib/API/Grammar.h new file mode 100644 index 000000000..a49b6b0d0 --- /dev/null +++ b/lib/API/Grammar.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include + +namespace mx { + +struct SyntexGrammarNode; + +// +// One set of grammar leaves +// +using SyntexGrammarLeaves = std::unordered_map; + +// +// Node in the grammar tree +// + +struct SyntexGrammarNode { + // Does this node correspond to the head of a production + bool is_production; + // Further leaves + SyntexGrammarLeaves leaves; +}; + +} // namespace mx \ No newline at end of file diff --git a/lib/API/Index.cpp b/lib/API/Index.cpp index a2160ca73..463942942 100644 --- a/lib/API/Index.cpp +++ b/lib/API/Index.cpp @@ -6,6 +6,7 @@ #include "File.h" #include "Fragment.h" +#include "Query.h" #include #include #include @@ -329,4 +330,8 @@ NamedDeclList Index::query_entities( return decls; } +SyntexQuery Index::parse_syntex_query(std::string_view query) { + return SyntexQuery(impl, query); +} + } // namespace mx diff --git a/lib/API/InvalidEntityProvider.cpp b/lib/API/InvalidEntityProvider.cpp index 6764432b8..185ab7f04 100644 --- a/lib/API/InvalidEntityProvider.cpp +++ b/lib/API/InvalidEntityProvider.cpp @@ -6,6 +6,10 @@ #include "InvalidEntityProvider.h" +#include +#include "Grammar.h" +#include + namespace mx { InvalidEntityProvider::~InvalidEntityProvider(void) noexcept {} @@ -76,6 +80,33 @@ void InvalidEntityProvider::FindSymbol( ids_out.clear(); } +std::optional +InvalidEntityProvider::TokenKindOf(std::string_view spelling) { + return {}; +} + +void InvalidEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves &) {} + +std::vector InvalidEntityProvider::GetFragmentsInAST(void) { + return {}; +} + +ASTNode InvalidEntityProvider::GetASTNode(std::uint64_t id) { + return {}; +} + +std::vector InvalidEntityProvider::GetASTNodeChildren(std::uint64_t id) { + return {}; +} + +std::vector InvalidEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + return {}; +} + +std::optional InvalidEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + return {}; +} + Index::Index(void) : impl(std::make_shared()) {} diff --git a/lib/API/InvalidEntityProvider.h b/lib/API/InvalidEntityProvider.h index 37a3087d8..76cbfd329 100644 --- a/lib/API/InvalidEntityProvider.h +++ b/lib/API/InvalidEntityProvider.h @@ -57,6 +57,17 @@ class InvalidEntityProvider final : public EntityProvider { void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) final; + + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/API/Query.cpp b/lib/API/Query.cpp new file mode 100644 index 000000000..76bcb2203 --- /dev/null +++ b/lib/API/Query.cpp @@ -0,0 +1,877 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include "Query.h" +#include +#include +#include + +#include +#include +#include +#include + +namespace mx { + +template +void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, + VarargCallback vararg_callback, std::string_view input, size_t index) { + size_t end = index; + + auto Look = [&] (size_t i) -> int { + if (end + i < input.size()) + return input[end + i]; + else + return -1; + }; + + auto Eat = [&] (size_t cnt) { + end += cnt; + }; + + auto Get = [&] () { + int ch = Look(0); + if (ch != -1) + Eat(1); + return ch; + }; + + auto Match = [&] (char ch) { + if (Look(0) == ch) { + Eat(1); + return true; + } + return false; + }; + + auto MatchSpace = [&] () { + switch (Look(0)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchIdent = [&] () { + switch (Look(0)) { + case '_': + case 'a' ... 'z': + case 'A' ... 'Z': + case '0' ... '9': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchDecimal = [&] () { + switch (Look(0)) { + case '0' ... '9': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchHex = [&] () { + switch (Look(0)) { + case '0' ... '9': + case 'a' ... 'f': + case 'A' ... 'F': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchOct = [&] () { + switch (Look(0)) { + case '0' ... '7': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchDecimalExponent = [&] () { + if (Match('e') || Match('E')) { + Match('+') || Match('-'); + return true; + } + return false; + }; + + auto MatchHexExponent = [&] () { + if (Match('p') || Match('P')) { + Match('+') || Match('-'); + return true; + } + return false; + }; + + auto MatchIntegerSuffix = [&] () { + if (Match('l') || Match('L')) { + if (Match('l') || Match('L')) { + if (Match('u') || Match('U')) { // llu + // unsigned long long + } else { // ll + // long long + } + } else if (Match('u') || Match('U')) { // lu + // unsigned long + } else { // l + // long + } + } else if (Match('u') || Match('U')) { + if (Match('l') || Match('L')) { + if (Match('l') || Match('L')) { // ull + // unsigned long long + } else { // ul + // unsigned long + } + } else { // u + // unsigned int + } + } else { // + // int + } + }; + + auto MatchFloatingSuffix = [&] () { + if (Match('f') || Match('F')) { // f + // float + } else if (Match('l') || Match('L')) { // l + // long double + } else { // + // double + } + }; + + // Skip all whitespaces that might preceed the token + while (MatchSpace()) + ; + + // The spelling starts after skipping preceding whitespaces + size_t begin = end; + + // Add a token to the output + auto AddToken = [&] (size_t len, mx::TokenKind kind) { + size_t next = begin + len; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + token_callback(kind, input.substr(begin, len), next); + return; + } + }; + + auto AddMetavar = [&] (std::string_view name, SyntexNodeKind filter) { + size_t next = end; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + metavar_callback(name, filter, next); + return; + } + }; + + auto AddVararg = [&] () { + size_t next = end; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + vararg_callback(next); + return; + } + }; + + switch (Get()) { + // End of input + case -1: + break; + + // + // For identifiers and constants, the longest match is always consumed + // + + // Metavariable + case '$': + { + // Check for variable argument + if (Look(0) == '.' && Look(1) == '.' && Look(2) == '.') { + Eat(3); + AddVararg(); + break; + } + + // Skip over the name + while (MatchIdent()) + ; + + auto name = input.substr(begin + 1, end - begin - 1); + auto filter = SyntexNodeKind::Any(); + + // Skip over filter if present + if (Match(':')) { + size_t filter_begin = end; + while (MatchIdent()) + ; + + auto filter_str = input.substr(filter_begin, end - filter_begin); + + // Try to parse it as a DeclKind + // FIXME: this should probably be done with some kind of LUT over this + // slow mess + for (int i = 0; i < NumEnumerators(mx::DeclKind::TYPE); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = SyntexNodeKind(kind); + goto done_filters; + } + } + for (int i = 0; i < NumEnumerators(mx::StmtKind::NULL_STMT); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = SyntexNodeKind(kind); + goto done_filters; + } + } + for (int i = 0; i < NumEnumerators(mx::TokenKind::UNKNOWN); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = SyntexNodeKind(kind); + goto done_filters; + } + } + + assert("FIXME: return proper error for invalid filter" && false); + +done_filters:; + } + + AddMetavar(name, filter); + break; + } + + // Identifiers + case '_': + case 'a' ... 'z': + case 'A' ... 'Z': + while (MatchIdent()) + ; + + AddToken(end - begin, mx::TokenKind::IDENTIFIER); + break; + + // Numeric constants + case '0': + if (Match('.')) { + while (MatchDecimal()) + ; + + if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + } + + MatchFloatingSuffix(); + } else if (Match('x') || Match('X')) { + while (MatchHex()) + ; + + if (Match('.')) { + while (MatchHex()) + ; + + if (MatchHexExponent()) { + while(MatchHex()) + ; + } + MatchFloatingSuffix(); + } else if (MatchHexExponent()) { + while (MatchHex()) + ; + + MatchFloatingSuffix(); + } else { + MatchIntegerSuffix(); + } + } else { + while (MatchOct()) + ; + + MatchIntegerSuffix(); + } + + AddToken(end - begin, mx::TokenKind::NUMERIC_CONSTANT); + break; + + case '1' ... '9': + while (MatchDecimal()) + ; + + if (Match('.')) { +FractionalConstant: + while (MatchDecimal()) + ; + + if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + } + + MatchFloatingSuffix(); + } else if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + MatchFloatingSuffix(); + } else { + MatchIntegerSuffix(); + } + + AddToken(end - begin, mx::TokenKind::NUMERIC_CONSTANT); + break; + + // Character constants + case '\'': + for (;;) { + auto ch = Get(); + if (ch == '\\') + Get(); + else if (ch == -1 || ch == '\'') + break; + } + + AddToken(end - begin, mx::TokenKind::CHARACTER_CONSTANT); + break; + + // String literals + case '"': + for (;;) { + auto ch = Get(); + if (ch == '\\') + Get(); + else if (ch == -1 || ch == '"') + break; + } + AddToken(end - begin, mx::TokenKind::STRING_LITERAL); + break; + + // + // For punctuators only the first character is consumed, and all possible + // matches at the current position are added + // + + case '[': + AddToken(1, mx::TokenKind::L_SQUARE); + break; + case ']': + AddToken(1, mx::TokenKind::R_SQUARE); + break; + case '(': + AddToken(1, mx::TokenKind::L_PARENTHESIS); + break; + case ')': + AddToken(1, mx::TokenKind::R_PARENTHESIS); + break; + case '{': + AddToken(1, mx::TokenKind::L_BRACE_TOKEN); + break; + case '}': + AddToken(1, mx::TokenKind::R_BRACE_TOKEN); + break; + case '~': + AddToken(1, mx::TokenKind::TILDE); + break; + case '?': + AddToken(1, mx::TokenKind::QUESTION); + break; + case ':': + AddToken(1, mx::TokenKind::COLON); + break; + case ';': + AddToken(1, mx::TokenKind::SEMI); + break; + case ',': + AddToken(1, mx::TokenKind::COMMA); + break; + case '.': + if (MatchDecimal()) { + goto FractionalConstant; + } else if (Look(0) == '.' && Look(1) == '.') { + AddToken(3, mx::TokenKind::ELLIPSIS); + } else { + AddToken(1, mx::TokenKind::PERIOD); + } + break; + case '-': + AddToken(1, mx::TokenKind::MINUS); + if (Look(0) == '>') + AddToken(2, mx::TokenKind::ARROW); + else if (Look(0) == '-') + AddToken(2, mx::TokenKind::MINUS_MINUS); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::MINUS_EQUAL); + break; + case '+': + AddToken(1, mx::TokenKind::PLUS); + if (Look(0) == '+') + AddToken(2, mx::TokenKind::PLUS_PLUS); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::PLUS_EQUAL); + break; + case '&': + AddToken(1, mx::TokenKind::AMP); + if (Look(0) == '&') + AddToken(2, mx::TokenKind::AMP_AMP); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::AMP_EQUAL); + break; + case '*': + AddToken(1, mx::TokenKind::STAR); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::STAR_EQUAL); + break; + case '!': + AddToken(1, mx::TokenKind::EXCLAIM); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::EXCLAIM_EQUAL); + break; + case '/': + AddToken(1, mx::TokenKind::SLASH); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::SLASH_EQUAL); + break; + case '%': + AddToken(1, mx::TokenKind::PERCENT); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::PERCENT_EQUAL); + break; + case '<': + AddToken(1, mx::TokenKind::LESS); + if (Look(0) == '<') { + AddToken(2, mx::TokenKind::LESS_LESS); + if (Look(1) == '=') + AddToken(3, mx::TokenKind::LESS_LESS_EQUAL); + } else if (Look(0) == '=') { + AddToken(1, mx::TokenKind::LESS_EQUAL); + } + break; + case '>': + AddToken(1, mx::TokenKind::GREATER); + if (Look(0) == '>') { + AddToken(2, mx::TokenKind::GREATER_GREATER); + if (Look(1) == '=') + AddToken(3, mx::TokenKind::GREATER_GREATER_EQUAL); + } else if (Look(0) == '=') { + AddToken(2, mx::TokenKind::GREATER_EQUAL); + } + break; + case '=': + AddToken(1, mx::TokenKind::EQUAL); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::EQUAL_EQUAL); + break; + case '^': + AddToken(1, mx::TokenKind::CARET); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::CARET_EQUAL); + break; + case '|': + AddToken(1, mx::TokenKind::PIPE); + if (Look(0) == '|') + AddToken(2, mx::TokenKind::PIPE_PIPE); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::PIPE_EQUAL); + break; + case '#': + AddToken(1, mx::TokenKind::HASH); + if (Look(0) == '#') + AddToken(2, mx::TokenKind::HASH_HASH); + break; + default: + AddToken(1, mx::TokenKind::UNKNOWN); + } +} + +SyntexQueryImpl::SyntexQueryImpl(std::shared_ptr ep, std::string_view input) + : m_ep(ep), m_input(input) { + ep->LoadGrammarRoot(grammar_root); +} + +void SyntexQueryImpl::MatchGlob(TableEntry &result, + const std::unordered_set &follow, + Item &item, + size_t next) { + + for (auto &[left, rest] : *item.m_leaves) { + // If: + // a) we reach the end of a production and have an empty follow set + // b) or the next non-terminal is contained in the follow set + // we should continue parsing normally. + if ((rest.is_production && follow.empty()) || follow.contains(left)) { + MatchRule(result, item, next); + } + + // If the next entry is a terminator, we don't need to glob further + // NOTE: for most usecases of $... this makes sense and improves performance + // but if some weird case doesn't match it might be necessary to remove this + if (left == mx::TokenKind::R_PARENTHESIS || + left == mx::TokenKind::R_BRACE_TOKEN) { + continue; + } + + // Otherwise the rest of the grammar rule is a candiate for more globbing + if (rest.leaves.size() > 0) { + const SyntexGrammarLeaves *old_leaves = item.m_leaves; + item.m_leaves = &rest.leaves; + item.m_children.emplace_back(SyntexNodeKind::Any(), next, Glob::YES); + MatchGlob(result, follow, item, next); + item.m_leaves = old_leaves; + item.m_children.pop_back(); + } + } +} + +void SyntexQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { + // Iterate shifts + for (auto &[key, _] : ParsesAtIndex(next)) { + SyntexNodeKind kind = key.first; + size_t next = key.second; + item.IterateShifts(kind, next, Glob::NO, [&] (Item &item) { + MatchRule(result, item, next); + }); + } + + // Iterate glob shifts + if (auto it = m_globs.find(next); it != m_globs.end()) { + // Compute set of node kinds that can follow $... + std::unordered_set follow; + for (auto &[key, _] : ParsesAtIndex(it->second)) { + follow.insert(key.first); + } + MatchGlob(result, follow, item, it->second); + } + + // Iterate reductions + item.IterateReductions([&] (SyntexNodeKind kind, const auto &children) { + result[{kind, next}].emplace(children); + MatchPrefix(result, kind, next); + }); +} + +void SyntexQueryImpl::MatchPrefix(TableEntry &result, SyntexNodeKind kind, size_t next) { + Item(&grammar_root).IterateShifts(kind, next, Glob::NO, [&] (Item &item) { + MatchRule(result, item, next); + }); +} + +const SyntexQueryImpl::TableEntry &SyntexQueryImpl::ParsesAtIndex(size_t index) { + // Lookup memoized parses at this index + auto it = m_parses.find(index); + if (it != m_parses.end()) { + return it->second; + } + + // And only do computation if the lookup found nothing + auto &result = m_parses[index]; + + auto TokenCallback = [&] (mx::TokenKind lex_kind, std::string_view spelling, size_t next) { + if (auto grm_kind = m_ep->TokenKindOf(spelling)) { + result[{*grm_kind, next}].emplace(spelling); + MatchPrefix(result, *grm_kind, next); + } else { + std::cerr << "Warning: token `" << spelling << "` not present in grammar\n"; + } + }; + + auto MetavarCallback = [&] (std::string_view name, SyntexNodeKind filter, size_t next) { + if (name == "") { + result[{filter, next}].emplace(nullptr); + } else { + auto [it, added] = m_metavars.emplace(name, Metavar(name, {})); + if (!added) { + std::cerr << "Error: duplicate metavariable name `" << name << "`\n"; + abort(); + } + result[{filter, next}].emplace(&it->second); + } + MatchPrefix(result, filter, next); + }; + + auto VarargCallback = [&] (size_t next) { + m_globs[index] = next; + }; + + Tokenize(TokenCallback, MetavarCallback, VarargCallback, m_input, index); + + return result; +} + +std::pair> SyntexQueryImpl::MatchMarker( + const TableEntry &entry, const SyntexParseMarker &marker, std::uint64_t node_id) { + + std::vector metavar_matches; + auto node = m_ep->GetASTNode(node_id); + auto kind = SyntexNodeKind::Deserialize(node.kind); + auto children = m_ep->GetASTNodeChildren(node_id); + + switch (marker.m_kind) { + case SyntexParseMarker::METAVAR: + if (marker.m_metavar) { + SyntexMetavarMatch mv_match( + {marker.m_metavar->m_name.data(), marker.m_metavar->m_name.size()}, + node.entity); + if (auto &predicate = marker.m_metavar->m_predicate) { + if (!(*predicate)(mv_match)) { + return {false, {}}; + } + } + metavar_matches.push_back(std::move(mv_match)); + } + return {true, metavar_matches}; + case SyntexParseMarker::TERMINAL: + return {kind.IsToken() && node.spelling == marker.m_spelling, {}}; + case SyntexParseMarker::NONTERMINAL: + if (kind.IsToken() || + children.size() != marker.m_children.size()) { + return {false, {}}; + } + + auto child_entry = &entry; + auto child_it = marker.m_children.begin(); + + for (std::uint64_t child_node_id : children) { + auto &[kind, next, glob] = *child_it; + auto child_node = m_ep->GetASTNode(child_node_id); + auto child_node_kind = SyntexNodeKind::Deserialize(child_node.kind); + + if (kind != SyntexNodeKind::Any() && kind != child_node_kind) { + return {false, {}}; + } + + if (glob == Glob::NO) { + auto markers = child_entry->find({ kind, next }); + if (markers == child_entry->end()) { + return {false, metavar_matches}; + } + for (auto &marker : markers->second) { + auto [child_ok, child_metavar_matches] = + MatchMarker(*child_entry, marker, child_node_id); + if (child_ok) { + metavar_matches.insert( + metavar_matches.end(), + child_metavar_matches.begin(), + child_metavar_matches.end()); + goto ok; + } + } + return {false, {}}; + } + ok: + child_entry = &ParsesAtIndex(std::get<1>(*child_it)); + ++child_it; + } + return {true, metavar_matches}; + } +} + +void SyntexQueryImpl::DebugParseTable(std::ostream &os) { + // Make sure the DP table was actually filled in + ParsesAtIndex(0); + + // Find all possible indices then sort them + std::vector indices; + for (auto &[index, _] : m_parses) { + indices.push_back(index); + } + std::sort(indices.begin(), indices.end()); + + // Then print all the parses at every index in the table + for (size_t index : indices) { + os << "Parses at " << index << ":\n"; + for (auto &[key, markers] : m_parses.at(index)) { + for (auto &marker : markers) { + // Print head + std::stringstream ss; + ss << " (" << key.first << ", " << key.second << ")"; + + std::cout << std::left << std::setw(40) << std::setfill(' ') + << ss.str() << " <- "; + + // Print body + switch (marker.m_kind) { + case SyntexParseMarker::METAVAR: + std::cout << "$" << (marker.m_metavar ? marker.m_metavar->m_name : ""); + break; + case SyntexParseMarker::TERMINAL: + std::cout << "`" << marker.m_spelling << "`"; + break; + case SyntexParseMarker::NONTERMINAL: + for (auto &[kind, next, glob] : marker.m_children) { + if (glob == Glob::YES) { + std::cout << "(" << kind << ", " << next << ", ..." << ") "; + } else { + std::cout << "(" << kind << ", " << next << ") "; + } + } + break; + } + + std::cout << "\n"; + } + } + } +} + +SyntexQuery::SyntexQuery(std::shared_ptr ep, std::string_view query) + : impl(std::make_shared(ep, query)) {} + +bool SyntexQuery::IsValid(void) const { + for (auto &[key, markers] : impl->ParsesAtIndex(0)) { + if (key.second == impl->m_input.size()) { + return true; + } + } + return false; +} + +bool SyntexQuery::AddMetavarPredicate( + const std::string_view &name, + std::function predicate) { + + // Find metavariable name + auto it = impl->m_metavars.find(name); + if (it == impl->m_metavars.end()) { + return false; + } + + // Overwrite predicate + if (it->second.m_predicate) { + it->second.m_predicate = + [old_pred = std::move(it->second.m_predicate.value()), + new_pred = std::move(predicate)] (const SyntexMetavarMatch &mvm) -> bool { + return old_pred(mvm) && new_pred(mvm); + }; + + } else { + it->second.m_predicate = std::move(predicate); + } + + return true; +} + +void SyntexQuery::ForEachMatch(std::function pred) const { + bool done = false; + auto real_pred = [sub_pred = std::move(pred), &done] (SyntexMatch m) -> bool { + if (sub_pred(std::move(m))) { + return true; + } else { + done = true; + return false; + } + }; + for(auto frag_id : impl->m_ep->GetFragmentsInAST()) { + ForEachMatch(frag_id, real_pred); + if (done) { + break; + } + } +} + +std::vector SyntexQuery::Find(mx::RawEntityId frag) const { + std::vector ret; + ForEachMatch(frag, [&ret] (SyntexMatch m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); + return ret; +} + +std::vector SyntexQuery::Find(void) const { + std::vector ret; + for (auto frag_id : impl->m_ep->GetFragmentsInAST()) { + ForEachMatch(frag_id, [&ret] (SyntexMatch m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); + } + return ret; +} + +void SyntexQuery::ForEachMatch(mx::RawEntityId frag_id, + std::function pred) const { + auto frag = impl->m_ep->FragmentFor(impl->m_ep, frag_id); + + // Find matching AST node + auto &entry = impl->ParsesAtIndex(0); + for (auto &[key, markers] : entry) { + if (key.second != impl->m_input.size()) { + continue; + } + if (key.first == SyntexNodeKind::Any()) { + for (auto ast_node_id : impl->m_ep->GetASTNodesInFragment(frag_id)) { + auto ast_node = impl->m_ep->GetASTNode(ast_node_id); + for (auto &marker : markers) { + auto [ok, metavar_matches] = impl->MatchMarker( + entry, marker, ast_node_id); + if (ok && !pred(SyntexMatch(ast_node.entity, metavar_matches))) { + return; + } + } + } + } else { + auto ast_node_id = impl->m_ep->GetASTNodeWithKind(frag_id, key.first.Serialize()); + while (ast_node_id.has_value()) { + auto ast_node = impl->m_ep->GetASTNode(*ast_node_id); + for (auto &marker : markers) { + auto [ok, metavar_matches] = impl->MatchMarker(entry, marker, *ast_node_id); + if (ok && !pred(SyntexMatch(ast_node.entity, metavar_matches))) { + return; + } + } + ast_node_id = ast_node.prev; + } + } + } +} + +} // namespace mx diff --git a/lib/API/Query.h b/lib/API/Query.h new file mode 100644 index 000000000..50cb1d7d5 --- /dev/null +++ b/lib/API/Query.h @@ -0,0 +1,250 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include "Grammar.h" +#include +#include +#include +#include +#include + +template +inline void hash_combine(size_t &h, const T& v) +{ + std::hash hasher; + h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); +} + +template<> +struct std::hash> { + size_t operator()(const std::pair &self) const { + size_t hash = 0; + hash_combine(hash, self.first); + hash_combine(hash, self.second); + return hash; + } +}; + +namespace mx { + +struct Metavar { + std::string_view m_name; + std::optional> m_predicate; + + explicit Metavar(std::string_view name, + std::optional> predicate) + : m_name(name), m_predicate(std::move(predicate)) {} +}; + +enum class Glob { + NO, + YES +}; + +struct SyntexParseMarker { + + // Node category + enum { + METAVAR, + TERMINAL, + NONTERMINAL, + } m_kind; + + // Associated data + union { + Metavar *m_metavar; + std::string_view m_spelling; + std::vector> m_children; + }; + + explicit SyntexParseMarker(Metavar *metavar) + : m_kind(METAVAR), m_metavar(metavar) {} + + explicit SyntexParseMarker(std::string_view spelling) + : m_kind(TERMINAL), m_spelling(spelling) {} + + explicit SyntexParseMarker(const std::vector> &children) + : m_kind(NONTERMINAL), m_children(children) {} + + SyntexParseMarker(SyntexParseMarker &&other) + : m_kind(other.m_kind) + { + switch (m_kind) { + case METAVAR: + m_metavar = other.m_metavar; + break; + case TERMINAL: + new (&m_spelling) std::string_view(other.m_spelling); + break; + case NONTERMINAL: + new (&m_children) std::vector>(std::move(other.m_children)); + break; + } + } + + ~SyntexParseMarker() { + switch (m_kind) { + case METAVAR: + break; + case TERMINAL: + m_spelling.std::string_view::~string_view(); + break; + case NONTERMINAL: + m_children.std::vector>::~vector(); + break; + } + } + + bool operator==(const SyntexParseMarker &other) const { + if (m_kind != other.m_kind) { + return false; + } + + switch (m_kind) { + case METAVAR: + // NOTE: it's impossible to have two metavariables + // at the same input location, thus this is never called + assert(false); + abort(); + case TERMINAL: + return m_spelling == other.m_spelling; + case NONTERMINAL: + return m_children == other.m_children; + } + } +}; + +} // namespace mx + +template<> +struct std::hash { + size_t operator()(const mx::SyntexParseMarker &self) const { + size_t hash = 0; + hash_combine(hash, self.m_kind); + switch (self.m_kind) { + case mx::SyntexParseMarker::METAVAR: + break; + case mx::SyntexParseMarker::TERMINAL: + hash_combine(hash, self.m_spelling); + break; + case mx::SyntexParseMarker::NONTERMINAL: + for (auto &[kind, next, glob] : self.m_children) { + hash_combine(hash, kind); + hash_combine(hash, next); + hash_combine(hash, glob); + } + break; + } + return hash; + } +}; + +namespace mx { + +// +// Parser state (e.g. a pointer into the grammar trie) +// + +struct Item { + const SyntexGrammarLeaves *m_leaves; + std::vector> m_children; + + explicit Item(const SyntexGrammarLeaves *leaves) + : m_leaves(leaves) + {} + + template + void IterateShifts(SyntexNodeKind kind, size_t next, Glob glob, F cb) { + if (kind == SyntexNodeKind::Any()) { + const SyntexGrammarLeaves *old_leaves = m_leaves; + m_children.emplace_back(kind, next, glob); + + for (auto &[kind, rest] : *m_leaves) { + if (rest.leaves.empty()) { + continue; + } + + m_leaves = &rest.leaves; + cb(*this); + } + + m_leaves = old_leaves; + m_children.pop_back(); + } else { + auto it = m_leaves->find(kind); + if (it == m_leaves->end() || it->second.leaves.empty()) { + return; + } + + // Morph ourselves into the shifted state + const SyntexGrammarLeaves *old_leaves = m_leaves; + m_leaves = &it->second.leaves; + m_children.emplace_back(kind, next, glob); + + // Fire callback with morphed item + cb(*this); + + // Restore item + m_leaves = old_leaves; + m_children.pop_back(); + } + } + + template + void IterateReductions(F cb) const { + for (auto &[left, rest] : *m_leaves) { + if (rest.is_production) { + cb(left, m_children); + } + } + } +}; + +// +// Wrapper around parsing functions +// + +struct SyntexQueryImpl { + std::shared_ptr m_ep; + + // Input string + std::string_view m_input; + + SyntexGrammarLeaves grammar_root; + + // Main DP parse table + using TableEntry = std::unordered_map, + std::unordered_set>; + + std::unordered_map m_parses; + + // Metavariables + std::unordered_map m_metavars; + + // Globs + std::unordered_map m_globs; + + void MatchGlob(TableEntry &result, const std::unordered_set &follow, + Item &item, size_t next); + + void MatchRule(TableEntry &result, Item &item, size_t next); + + void MatchPrefix(TableEntry &result, SyntexNodeKind kind, size_t next); + + const TableEntry &ParsesAtIndex(size_t index); + + explicit SyntexQueryImpl(std::shared_ptr ep, std::string_view input); + + void DebugParseTable(std::ostream &os); + + std::pair> MatchMarker( + const TableEntry &entry, const SyntexParseMarker &marker, std::uint64_t node_id); +}; + +} // namespace mx \ No newline at end of file diff --git a/lib/API/SQLiteEntityProvider.cpp b/lib/API/SQLiteEntityProvider.cpp index 88db18ee4..4983c3652 100644 --- a/lib/API/SQLiteEntityProvider.cpp +++ b/lib/API/SQLiteEntityProvider.cpp @@ -5,9 +5,11 @@ // the LICENSE file found in the root directory of this source tree. #include "SQLiteEntityProvider.h" +#include #include "API.h" #include "Compress.h" #include "Re2.h" +#include "Grammar.h" #include #include #include @@ -334,6 +336,64 @@ void SQLiteEntityProvider::FindSymbol(const Ptr &, std::string symbol, }); } +std::optional +SQLiteEntityProvider::TokenKindOf(std::string_view spelling) { + auto &storage = d->GetStorage(); + return storage.spelling_to_token_kind.TryGet(spelling); +} + +void SQLiteEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves &root) { + auto &storage = d->GetStorage(); + auto &grammar = storage.grammar; + std::vector> to_load; + + for(auto [kind, id] : grammar.GetChildren(0)) { + auto data = grammar.GetNode(id); + auto &node = root[SyntexNodeKind::Deserialize(kind)]; + node.is_production = data.is_production; + to_load.emplace_back(id, &node); + } + + while(!to_load.empty()) { + auto pair = to_load.back(); + to_load.pop_back(); + auto id = std::get<0>(pair); + auto &node = *std::get<1>(pair); + + for(auto [kind, child_id] : grammar.GetChildren(id)) { + auto data = grammar.GetNode(child_id); + auto &child_node = node.leaves[SyntexNodeKind::Deserialize(kind)]; + child_node.is_production = data.is_production; + to_load.emplace_back(child_id, &child_node); + } + } +} + +std::vector SQLiteEntityProvider::GetFragmentsInAST(void) { + auto &storage = d->GetStorage(); + return storage.ast.GetFragments(); +} + +ASTNode SQLiteEntityProvider::GetASTNode(std::uint64_t id) { + auto &storage = d->GetStorage(); + return storage.ast.GetNode(id); +} + +std::vector SQLiteEntityProvider::GetASTNodeChildren(std::uint64_t id) { + auto &storage = d->GetStorage(); + return storage.ast.GetChildren(id); +} + +std::vector SQLiteEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + auto &storage = d->GetStorage(); + return storage.ast.Root(frag); +} + +std::optional SQLiteEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + auto &storage = d->GetStorage(); + return storage.ast.GetNodeInIndex(frag, kind); +} + EntityProvider::Ptr EntityProvider::from_database(std::filesystem::path path) { return std::make_shared(path); } diff --git a/lib/API/SQLiteEntityProvider.h b/lib/API/SQLiteEntityProvider.h index 89cd5fbb2..3ae1e5346 100644 --- a/lib/API/SQLiteEntityProvider.h +++ b/lib/API/SQLiteEntityProvider.h @@ -64,6 +64,17 @@ class SQLiteEntityProvider final : public EntityProvider { void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) final; + + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/Common/IndexStorage.cpp b/lib/Common/IndexStorage.cpp index 20e1cdce3..9db662369 100644 --- a/lib/Common/IndexStorage.cpp +++ b/lib/Common/IndexStorage.cpp @@ -12,6 +12,248 @@ #include namespace mx { +struct PersistentAST::Impl { + sqlite::Connection &db; + std::shared_ptr get_root_stmt; + std::shared_ptr create_node_stmt; + std::shared_ptr add_root_stmt; + std::shared_ptr get_node_stmt; + std::shared_ptr get_index_stmt; + std::shared_ptr set_index_stmt; + std::shared_ptr get_fragments_stmt; + std::shared_ptr get_children_stmt; + std::shared_ptr add_child_stmt; + + Impl(sqlite::Connection &db); +}; + +PersistentAST::Impl::Impl(sqlite::Connection &db) : db(db) { + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTNode'(prev, kind, entity, spelling)"); + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTChildren'(parent, child, PRIMARY KEY(parent, child))"); + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTIndex'(fragment, kind, node, PRIMARY KEY(fragment, kind))"); + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::ASTRoot'(fragment, node)"); + + get_root_stmt = db.Prepare( + "SELECT node FROM 'mx::syntex::ASTRoot' WHERE fragment = ?1"); + create_node_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTNode'(prev, kind, entity, spelling) " + "VALUES (?1, ?2, ?3, ?4) RETURNING rowid"); + add_root_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTRoot'(fragment, node) VALUES (?1, ?2)"); + get_node_stmt = db.Prepare( + "SELECT prev, kind, entity, spelling " + "FROM 'mx::syntex::ASTNode' WHERE rowid = ?1"); + get_index_stmt = db.Prepare( + "SELECT node FROM 'mx::syntex::ASTIndex' " + "WHERE fragment = ?1 AND kind = ?2"); + set_index_stmt = db.Prepare( + "INSERT OR REPLACE INTO 'mx::syntex::ASTIndex'(fragment, kind, node) " + "VALUES(?1, ?2, ?3)" + ); + get_fragments_stmt = db.Prepare( + "SELECT DISTINCT fragment FROM 'mx::syntex::ASTRoot'" + ); + get_children_stmt = db.Prepare( + "SELECT child FROM 'mx::syntex::ASTChildren' WHERE parent = ?1" + ); + add_child_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTChildren'(parent, child) VALUES (?1, ?2)" + ); +} + +PersistentAST::PersistentAST(sqlite::Connection &db) + : impl(std::make_unique(db)) {} + +std::vector PersistentAST::Root(RawEntityId fragment) { + std::vector results; + impl->get_root_stmt->BindValues(fragment); + while(impl->get_root_stmt->ExecuteStep()) { + impl->get_root_stmt->GetResult().Columns(results.emplace_back()); + } + impl->get_root_stmt->Reset(); + return results; +} + +std::uint64_t PersistentAST::AddNode(const ASTNode& node) { + impl->create_node_stmt->BindValues(node.prev, node.kind, + node.entity, node.spelling); + impl->create_node_stmt->ExecuteStep(); + std::uint64_t rowid; + impl->create_node_stmt->GetResult().Columns(rowid); + impl->create_node_stmt->Reset(); + return rowid; +} + +void PersistentAST::AddNodeToRoot(RawEntityId fragment, std::uint64_t node_id) { + impl->add_root_stmt->BindValues(fragment, node_id); + impl->add_root_stmt->Execute(); +} + +ASTNode PersistentAST::GetNode(std::uint64_t node_id) { + ASTNode node; + impl->get_node_stmt->BindValues(node_id); + impl->get_node_stmt->ExecuteStep(); + impl->get_node_stmt->GetResult().Columns(node.prev, node.kind, + node.entity, node.spelling); + impl->get_node_stmt->Reset(); + return node; +} + +std::optional PersistentAST::GetNodeInIndex( + RawEntityId fragment, + unsigned short kind) { + impl->get_index_stmt->BindValues(fragment, kind); + if(impl->get_index_stmt->ExecuteStep()) { + std::uint64_t rowid; + impl->get_index_stmt->GetResult().Columns(rowid); + impl->get_root_stmt->Reset(); + return rowid; + } + return {}; +} + +void PersistentAST::SetNodeInIndex( + RawEntityId fragment, + unsigned short kind, + std::uint64_t node_id) { + impl->set_index_stmt->BindValues(fragment, kind, node_id); + impl->set_index_stmt->Execute(); +} + +std::vector PersistentAST::GetFragments() { + std::vector fragments; + while(impl->get_fragments_stmt->ExecuteStep()) { + impl->get_fragments_stmt->GetResult().Columns(fragments.emplace_back()); + } + return fragments; +} + +std::vector PersistentAST::GetChildren(std::uint64_t parent) { + std::vector children; + impl->get_children_stmt->BindValues(parent); + while(impl->get_children_stmt->ExecuteStep()) { + impl->get_children_stmt->GetResult().Columns(children.emplace_back()); + } + return children; +} + +void PersistentAST::AddChild(std::uint64_t parent, std::uint64_t child) { + impl->add_child_stmt->BindValues(parent, child); + impl->add_child_stmt->Execute(); +} + +struct PersistentGrammar::Impl { + sqlite::Connection &db; + std::shared_ptr get_children_stmt; + std::shared_ptr get_child_stmt; + std::shared_ptr get_child_leaves_stmt; + std::shared_ptr update_node_stmt; + std::shared_ptr get_node_stmt; + std::shared_ptr add_node_stmt; + std::shared_ptr add_child_stmt; + + Impl(sqlite::Connection &db); +}; + +PersistentGrammar::Impl::Impl(sqlite::Connection &db) + : db(db) { + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::GrammarNodes'(is_production)" + ); + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::GrammarChildren'(parent, kind, child, PRIMARY KEY(parent, kind))" + ); + + get_children_stmt = db.Prepare( + "SELECT kind, child FROM 'mx::syntex::GrammarChildren' WHERE parent = ?1" + ); + get_child_stmt = db.Prepare( + "SELECT child FROM 'mx::syntex::GrammarChildren' WHERE parent = ?1 AND kind = ?2" + ); + get_child_leaves_stmt = db.Prepare( + "SELECT child_node.kind, child_node.child" + " FROM 'mx::syntex::GrammarChildren' AS parent_node," + " 'mx::syntex::GrammarChildren' AS child_node" + " WHERE parent_node.parent = ?1" + " AND parent_node.kind = ?2" + " AND child_node.parent = parent_node.child" + ); + update_node_stmt = db.Prepare( + "UPDATE 'mx::syntex::GrammarNodes' SET is_production = ?2 WHERE rowid = ?1" + ); + get_node_stmt = db.Prepare( + "SELECT is_production FROM 'mx::syntex::GrammarNodes' WHERE rowid = ?1" + ); + add_node_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::GrammarNodes'(is_production) VALUES (?1) RETURNING rowid" + ); + add_child_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::GrammarChildren'(parent, kind, child) VALUES (?1, ?2, ?3)" + ); +} + +PersistentGrammar::PersistentGrammar(sqlite::Connection &db) : impl(std::make_unique(db)) {} + +std::vector> +PersistentGrammar::GetChildren(std::uint64_t parent) { + std::vector> result; + impl->get_children_stmt->BindValues(parent); + while(impl->get_children_stmt->ExecuteStep()) { + auto &[kind, id] = result.emplace_back(); + impl->get_children_stmt->GetResult().Columns(kind, id); + } + return result; +} + +std::uint64_t PersistentGrammar::GetChild(std::uint64_t parent, unsigned short kind) { + impl->get_child_stmt->BindValues(parent, kind); + std::uint64_t child; + if(impl->get_child_stmt->ExecuteStep()) { + impl->get_child_stmt->GetResult().Columns(child); + impl->get_child_stmt->Reset(); + return child; + } + impl->add_node_stmt->BindValues(0); + impl->add_node_stmt->ExecuteStep(); + impl->add_node_stmt->GetResult().Columns(child); + impl->add_node_stmt->Reset(); + impl->add_child_stmt->BindValues(parent, kind, child); + impl->add_child_stmt->Execute(); + return child; +} + +std::vector> +PersistentGrammar::GetChildLeaves(std::uint64_t parent, unsigned short kind) { + std::vector> result; + impl->get_child_leaves_stmt->BindValues(parent, kind); + while(impl->get_child_leaves_stmt->ExecuteStep()) { + auto &[kind, id] = result.emplace_back(); + impl->get_child_leaves_stmt->GetResult().Columns(kind, id); + } + return result; +} + +void PersistentGrammar::UpdateNode(std::uint64_t id, const GrammarNode &node) { + impl->update_node_stmt->BindValues(id, node.is_production); + impl->update_node_stmt->Execute(); +} + +GrammarNode PersistentGrammar::GetNode(std::uint64_t id) { + GrammarNode node; + impl->get_node_stmt->BindValues(id); + impl->get_node_stmt->ExecuteStep(); + impl->get_node_stmt->GetResult().Columns(node.is_production); + impl->get_node_stmt->Reset(); + return node; +} + IndexStorage::IndexStorage(sqlite::Connection& db) : db(db) , version_number(db) @@ -30,6 +272,9 @@ IndexStorage::IndexStorage(sqlite::Connection& db) , mangled_name_to_entity_id(db) , entity_id_use_to_fragment_id(db) , entity_id_reference(db) + , spelling_to_token_kind(db) + , ast(db) + , grammar(db) , database(db) {} IndexStorage::~IndexStorage() {} diff --git a/lib/Common/SQLiteStore.cpp b/lib/Common/SQLiteStore.cpp index ca7d638ce..5b4cb7de6 100644 --- a/lib/Common/SQLiteStore.cpp +++ b/lib/Common/SQLiteStore.cpp @@ -78,6 +78,11 @@ std::string_view QueryResult::getBlob(int32_t idx) { return std::string_view(ptr, len); } +bool QueryResult::isNull(int32_t idx) { + auto prepared_stmt = stmt->prepareStatement(); + return sqlite3_column_type(prepared_stmt, idx) == SQLITE_NULL; +} + Statement::Statement(Connection &conn, const std::string &stmt) : db(conn), query(stmt) { @@ -90,8 +95,7 @@ Statement::Statement(Connection &conn, const std::string &stmt) static_cast(query.size()), &stmt, const_cast(&tail)); if (SQLITE_OK != ret) { - assert(0); - throw Error("Failed to prepare statement"); + throw Error("Failed to prepare statement", db.GetHandler()); } return std::shared_ptr(