Skip to content

[Feature] Cover JSON schema string format #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/grammar_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ class GrammarBuilder {
/*! \brief Default constructor. Creates a new grammar object. */
GrammarBuilder() : grammar_(std::make_shared<Grammar::Impl>()) {}

/*! \brief Constructor. Creates a new grammar object from an existing grammar. */
GrammarBuilder(const Grammar& grammar)
: grammar_(std::make_shared<Grammar::Impl>(*grammar.operator->())) {
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
auto rule = grammar->GetRule(i);
rule_name_to_id_[rule.name] = i;
}
}

/*!
* \brief Get the result grammar. This function will also set the root rule to the rule with the
* specified name. The rule should be already added to the grammar.
Expand Down
2 changes: 1 addition & 1 deletion cpp/grammar_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ bool GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(

// Find all positions that can come to and end. Then check if the suffix from that position
// can be accepted by the lookahead assertion.
for (int i = static_cast<int>(can_reach_end_stack.size()); i >= 0; --i) {
for (int i = static_cast<int>(can_reach_end_stack.size()) - 1; i >= 0; --i) {
if (!can_reach_end_stack[i]) {
continue;
}
Expand Down
286 changes: 280 additions & 6 deletions cpp/grammar_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <algorithm>
#include <queue>
#include <set>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -284,7 +285,17 @@ class NestedRuleUnwrapper : public GrammarMutator {
}
};

class ByteStringFuser : public GrammarMutator {
class StructureNormalizerImpl : public GrammarMutator {
public:
using GrammarMutator::Apply;
using GrammarMutator::GrammarMutator;

Grammar Apply(const Grammar& grammar) final {
return NestedRuleUnwrapper().Apply(SingleElementExprEliminator().Apply(grammar));
}
};

class ByteStringFuserImpl : public GrammarMutator {
public:
using GrammarMutator::Apply;
using GrammarMutator::GrammarMutator;
Expand Down Expand Up @@ -316,6 +327,249 @@ class ByteStringFuser : public GrammarMutator {
return builder_.AddSequence(new_sequence_ids);
}
};

class RuleInlinerImpl : public GrammarMutator {
public:
using GrammarMutator::Apply;
using GrammarMutator::GrammarMutator;

private:
int32_t VisitChoices(const RuleExpr& rule_expr) final {
std::vector<int32_t> new_choice_ids;
for (int i : rule_expr) {
auto choice_expr = base_grammar_->GetRuleExpr(i);
if (choice_expr.type == RuleExprType::kEmptyStr) {
new_choice_ids.push_back(VisitExpr(i));
continue;
}
XGRAMMAR_ICHECK(choice_expr.type == RuleExprType::kSequence);
auto first_element = base_grammar_->GetRuleExpr(choice_expr[0]);
if (first_element.type != RuleExprType::kRuleRef) {
new_choice_ids.push_back(VisitExpr(choice_expr));
continue;
}
auto rule_ref_id = first_element[0];
if (can_rule_be_inlined_.count(rule_ref_id) == 0) {
can_rule_be_inlined_[rule_ref_id] = CheckIfRuleCanBeInlined(rule_ref_id);
}
if (!can_rule_be_inlined_[rule_ref_id]) {
new_choice_ids.push_back(VisitExpr(choice_expr));
continue;
}

// Do inlining
std::vector<int32_t> other_elements;
for (int i = 1; i < choice_expr.size(); ++i) {
other_elements.push_back(VisitExpr(choice_expr[i]));
}

auto ref_rule = base_grammar_->GetRule(rule_ref_id);
auto ref_rule_expr = base_grammar_->GetRuleExpr(ref_rule.body_expr_id);

for (auto ref_choice_id : ref_rule_expr) {
auto ref_choice_expr = base_grammar_->GetRuleExpr(ref_choice_id);
XGRAMMAR_ICHECK(ref_choice_expr.type == RuleExprType::kSequence);
std::vector<int32_t> choice_to_add;
for (auto ref_element_id : ref_choice_expr) {
choice_to_add.push_back(VisitExpr(ref_element_id));
}
choice_to_add.insert(choice_to_add.end(), other_elements.begin(), other_elements.end());
new_choice_ids.push_back(builder_.AddSequence(choice_to_add));
}
}
return builder_.AddChoices(new_choice_ids);
}

/**
* The rule should be: a sequence of choices, cannot be empty, cannot refer to other rules
*/
bool CheckIfRuleCanBeInlined(int32_t rule_id) {
auto rule = base_grammar_->GetRule(rule_id);
auto rule_expr = base_grammar_->GetRuleExpr(rule.body_expr_id);
if (rule_expr.type != RuleExprType::kChoices) {
return false;
}
if (rule_expr.size() == 0) {
return false;
}
for (auto choice_id : rule_expr) {
auto choice_expr = base_grammar_->GetRuleExpr(choice_id);
if (choice_expr.type == RuleExprType::kEmptyStr) {
return false;
}
XGRAMMAR_ICHECK(choice_expr.type == RuleExprType::kSequence);
for (auto element_id : choice_expr) {
auto element_expr = base_grammar_->GetRuleExpr(element_id);
if (element_expr.type == RuleExprType::kRuleRef) {
return false;
}
}
}
return true;
}

std::unordered_map<int32_t, bool> can_rule_be_inlined_;
};

/*!
* \brief Analyze all referenced rules or the main rule. Return a list of all referenced rule ids.
* This is useful for dead code elimination.
*/
class UsedRulesAnalyzer : public GrammarVisitor<std::vector<int32_t>> {
public:
UsedRulesAnalyzer() = default;

std::vector<int32_t> Apply(const Grammar& grammar) final {
base_grammar_ = grammar;

std::set<int32_t> visited;

std::queue<int32_t>().swap(visit_queue_);

visit_queue_.push(base_grammar_->GetRootRuleId());
while (!visit_queue_.empty()) {
auto rule_id = visit_queue_.front();
visit_queue_.pop();
if (visited.count(rule_id)) {
continue;
}
visited.insert(rule_id);
auto rule = base_grammar_->GetRule(rule_id);
VisitExpr(rule.body_expr_id);
}

return std::vector<int32_t>(visited.begin(), visited.end());
}

void VisitTagDispatch(const RuleExpr& rule_expr) {
for (int i = 0; i < rule_expr.size(); i += 2) {
visit_queue_.push(rule_expr[i + 1]);
}
}

void VisitRuleRef(const RuleExpr& rule_expr) { visit_queue_.push(rule_expr[0]); }

private:
std::queue<int32_t> visit_queue_;
};

class DeadCodeEliminatorImpl : public GrammarMutator {
public:
using GrammarMutator::Apply;
using GrammarMutator::GrammarMutator;

Grammar Apply(const Grammar& grammar) final {
Init(grammar);
auto used_rules = UsedRulesAnalyzer().Apply(grammar);
rule_id_map_.clear();
for (auto rule_id : used_rules) {
rule_id_map_[rule_id] = builder_.AddEmptyRule(grammar->GetRule(rule_id).name);
}
for (auto rule_id : used_rules) {
auto rule = grammar->GetRule(rule_id);
auto new_body_expr_id = VisitExpr(rule.body_expr_id);
builder_.UpdateRuleBody(rule_id_map_[rule_id], new_body_expr_id);
builder_.AddLookaheadAssertion(
rule_id_map_[rule_id], VisitLookaheadAssertion(rule.lookahead_assertion_id)
);
}
XGRAMMAR_CHECK(rule_id_map_.count(grammar->GetRootRuleId()) > 0);
return builder_.Get(rule_id_map_[grammar->GetRootRuleId()]);
}

int32_t VisitTagDispatch(const RuleExpr& rule_expr) final {
std::vector<std::pair<int32_t, int32_t>> tag_dispatch_list;
for (int i = 0; i < rule_expr.size(); i += 2) {
XGRAMMAR_DCHECK(rule_id_map_.count(rule_expr[i + 1]) > 0);
auto new_rule_id = rule_id_map_[rule_expr[i + 1]];
tag_dispatch_list.push_back({VisitExpr(rule_expr[i]), new_rule_id});
}
return builder_.AddTagDispatch(tag_dispatch_list);
}

int32_t VisitRuleRef(const RuleExpr& rule_expr) final {
XGRAMMAR_DCHECK(rule_id_map_.count(rule_expr[0]) > 0);
auto new_rule_id = rule_id_map_[rule_expr[0]];
return builder_.AddRuleRef(new_rule_id);
}

private:
std::unordered_map<int32_t, int32_t> rule_id_map_;
};

class LookaheadAssertionAnalyzerImpl : public GrammarMutator {
public:
using GrammarMutator::GrammarMutator;

Grammar Apply(const Grammar& grammar) final {
InitWithCopy(grammar);
auto root_rule = grammar->GetRootRule();
auto root_rule_expr = base_grammar_->GetRuleExpr(root_rule.body_expr_id);
if (root_rule_expr.type == RuleExprType::kTagDispatch) {
return grammar;
}
for (int i = 0; i < static_cast<int>(grammar->NumRules()); ++i) {
auto rule = grammar->GetRule(i);
if (i == grammar->GetRootRuleId() || rule.lookahead_assertion_id != -1) {
continue;
}
auto look_head_assertion_id = DetectLookaheadAssertion(i);
if (look_head_assertion_id != -1) {
builder_.AddLookaheadAssertion(i, look_head_assertion_id);
}
}
return builder_.Get(grammar->GetRootRuleId());
}

int32_t DetectLookaheadAssertion(int32_t rule_id) {
std::vector<int32_t> found_sequence; // Element ids
bool found = false;
for (int i = 0; i < static_cast<int>(base_grammar_->NumRules()); ++i) {
auto rule = base_grammar_->GetRule(i);
auto rule_expr = base_grammar_->GetRuleExpr(rule.body_expr_id);
if (rule_expr.type == RuleExprType::kTagDispatch) {
for (int j = 1; j < rule_expr.size(); j += 2) {
if (rule_expr[j] == rule_id) {
return -1;
}
}
continue;
}
XGRAMMAR_DCHECK(rule_expr.type == RuleExprType::kChoices);
for (auto sequence_id : rule_expr) {
auto sequence_expr = base_grammar_->GetRuleExpr(sequence_id);
if (sequence_expr.type != RuleExprType::kSequence) {
continue;
}
auto last_element = base_grammar_->GetRuleExpr(sequence_expr.end()[-1]);
if (last_element.type == RuleExprType::kRuleRef && last_element[0] == rule_id &&
i != rule_id) {
return -1;
}

for (int j = 0; j < sequence_expr.size() - 1; ++j) {
auto element_expr = base_grammar_->GetRuleExpr(sequence_expr[j]);
if (element_expr.type != RuleExprType::kRuleRef || element_expr[0] != rule_id) {
continue;
}
if (found) {
return -1;
}
found = true;
for (int k = j + 1; k < sequence_expr.size(); ++k) {
found_sequence.push_back(sequence_expr[k]);
}
}
}
}

if (!found) {
return -1;
}
return builder_.AddSequence(found_sequence);
}
};

/*!
* \brief A class that normalizes a grammar by applying a series of transformations.
*
Expand All @@ -341,9 +595,11 @@ class GrammarNormalizerImpl : public GrammarMutator {
// Return the list of all normalizers in the class. The normalizers are applied one by one.
std::vector<std::unique_ptr<GrammarMutator>> GetNormalizerList() {
std::vector<std::unique_ptr<GrammarMutator>> normalizer_mutators;
normalizer_mutators.emplace_back(std::make_unique<SingleElementExprEliminator>());
normalizer_mutators.emplace_back(std::make_unique<NestedRuleUnwrapper>());
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuser>());
normalizer_mutators.emplace_back(std::make_unique<StructureNormalizerImpl>());
normalizer_mutators.emplace_back(std::make_unique<ByteStringFuserImpl>());
normalizer_mutators.emplace_back(std::make_unique<RuleInlinerImpl>());
normalizer_mutators.emplace_back(std::make_unique<DeadCodeEliminatorImpl>());
normalizer_mutators.emplace_back(std::make_unique<LookaheadAssertionAnalyzerImpl>());
return normalizer_mutators;
}
};
Expand Down Expand Up @@ -515,8 +771,8 @@ class AllowEmptyRuleAnalyzerImpl : public GrammarVisitor<std::vector<int32_t>> {
std::unordered_set<int32_t> empty_rule_id_set;
FindExplicitEmptyRules(&empty_rule_id_set);

// Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm on
// the rule reference graph.
// Step 2: Find rules that indirectly allow empty string. Using the Bellman-Ford algorithm
// on the rule reference graph.
std::vector<std::vector<int32_t>> rule_ref_graph = RuleRefGraphFinder().Apply(grammar);
FindIndirectEmptyRules(&empty_rule_id_set, rule_ref_graph);

Expand Down Expand Up @@ -703,4 +959,22 @@ Grammar StructuralTagGrammarCreator::Apply(
return StructuralTagGrammarCreatorImpl().Apply(triggers, tag_groups);
}

Grammar RuleInliner::Apply(const Grammar& grammar) { return RuleInlinerImpl().Apply(grammar); }

Grammar ByteStringFuser::Apply(const Grammar& grammar) {
return ByteStringFuserImpl().Apply(grammar);
}

Grammar DeadCodeEliminator::Apply(const Grammar& grammar) {
return DeadCodeEliminatorImpl().Apply(grammar);
}

Grammar StructureNormalizer::Apply(const Grammar& grammar) {
return StructureNormalizerImpl().Apply(grammar);
}

Grammar LookaheadAssertionAnalyzer::Apply(const Grammar& grammar) {
return LookaheadAssertionAnalyzerImpl().Apply(grammar);
}

} // namespace xgrammar
Loading
Loading