diff --git a/derive/examples/hook.pest b/derive/examples/hook.pest new file mode 100644 index 00000000..48695d2c --- /dev/null +++ b/derive/examples/hook.pest @@ -0,0 +1,5 @@ +WHITESPACE = _{ " " | "\t" | NEWLINE } +int = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT+ | ASCII_DIGIT) } +__HOOK_INT = _{ int } +ints = { SOI ~ __HOOK_INT* ~ EOI } +ints2 = { SOI ~ (__HOOK_INT ~ "yes" | int ~ "no")* ~ EOI } diff --git a/derive/examples/hook.rs b/derive/examples/hook.rs new file mode 100644 index 00000000..42f32cc8 --- /dev/null +++ b/derive/examples/hook.rs @@ -0,0 +1,101 @@ +use pest::StateParser; + +mod parser { + use pest::{Span, StateCheckpoint}; + use pest_derive::Parser; + + #[derive(Parser)] + #[grammar = "../examples/hook.pest"] + #[custom_state(crate::parser::CustomState)] + pub struct Parser; + + pub struct CustomState { + pub max_int_visited: usize, + is_snapshot_cleared: bool, + } + + impl StateCheckpoint for CustomState { + fn snapshot(&mut self) {} + fn clear_snapshot(&mut self) { + self.is_snapshot_cleared = true; + } + fn restore(&mut self) {} + } + + impl CustomState { + pub fn create() -> Self { + Self { + max_int_visited: 0, + is_snapshot_cleared: true, + } + } + } + + impl Parser { + #[allow(non_snake_case)] + fn hook__HOOK_INT<'a>(state: &mut CustomState, span: Span<'a>) -> bool { + if !state.is_snapshot_cleared { + println!("this state cannot operate with snapshot, please check your grammar to avoid hook in unexpected location",); + return false; + } + let val: usize = span.as_str().parse().unwrap(); + println!("hook called with val={}", val); + if val >= state.max_int_visited { + state.is_snapshot_cleared = false; + state.max_int_visited = val; + true + } else { + false + } + } + } +} + +fn main() { + // parser::Rule::ints parses a non-decreasing sequence of integers. + + println!("parser::Rule::ints"); + + // should parse successfully. + let (state, _) = parser::Parser::parse_with_state( + parser::Rule::ints, + "1\n2\n3\n4\n", + parser::CustomState::create(), + ) + .unwrap(); + assert_eq!(state.max_int_visited, 4); + + println!("parser::Rule::ints"); + + // custom state hook will reject this case + assert!(parser::Parser::parse_with_state( + parser::Rule::ints, + "1\n2\n2\n0\n", + parser::CustomState::create() + ) + .is_err()); + + // parser::Rule::ints2 passes a non-decreasing sequence of integers to HOOK_INT, while allowing + // other numbers in the sequence. + + println!("parser::Rule::ints2"); + + // should parse successfully. + let (state, _) = parser::Parser::parse_with_state( + parser::Rule::ints2, + "1 yes\n2 yes\n3 yes\n4 yes\n", + parser::CustomState::create(), + ) + .unwrap(); + assert_eq!(state.max_int_visited, 4); + + println!("parser::Rule::ints2"); + + // custom state hook will still be called with val = 3, but it will be restored. + assert!(parser::Parser::parse_with_state( + parser::Rule::ints2, + "1 yes\n2 yes\n3 no\n4 yes\n", + parser::CustomState::create() + ) + .is_err()); +} diff --git a/derive/src/lib.rs b/derive/src/lib.rs index a908897f..f490a994 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -319,7 +319,7 @@ use proc_macro::TokenStream; /// The main method that's called by the proc macro /// (a wrapper around `pest_generator::derive_parser`) -#[proc_macro_derive(Parser, attributes(grammar, grammar_inline))] +#[proc_macro_derive(Parser, attributes(grammar, grammar_inline, custom_state))] pub fn derive_parser(input: TokenStream) -> TokenStream { pest_generator::derive_parser(input.into(), true).into() } diff --git a/generator/src/generator.rs b/generator/src/generator.rs index 0dbcaa31..ee28a09c 100644 --- a/generator/src/generator.rs +++ b/generator/src/generator.rs @@ -27,10 +27,17 @@ pub(crate) fn generate( defaults: Vec<&str>, doc_comment: &DocComment, include_grammar: bool, + custom_state: Option, ) -> TokenStream { let uses_eoi = defaults.iter().any(|name| *name == "EOI"); - let builtins = generate_builtin_rules(); + let custom_state = if let Some(custom_state) = custom_state { + quote! { #custom_state } + } else { + quote! { () } + }; + + let builtins = generate_builtin_rules(custom_state.clone()); let include_fix = if include_grammar { generate_include(&name, paths) } else { @@ -38,9 +45,12 @@ pub(crate) fn generate( }; let rule_enum = generate_enum(&rules, doc_comment, uses_eoi); let patterns = generate_patterns(&rules, uses_eoi); - let skip = generate_skip(&rules); + let skip = generate_skip(&rules, custom_state.clone()); - let mut rules: Vec<_> = rules.into_iter().map(generate_rule).collect(); + let mut rules: Vec<_> = rules + .into_iter() + .map(|rule| generate_rule(&name, custom_state.clone(), rule)) + .collect(); rules.extend(builtins.into_iter().filter_map(|(builtin, tokens)| { if defaults.contains(&builtin) { Some(tokens) @@ -55,12 +65,13 @@ pub(crate) fn generate( let parser_impl = quote! { #[allow(clippy::all)] - impl #impl_generics ::pest::Parser for #name #ty_generics #where_clause { - fn parse<'i>( + impl #impl_generics ::pest::StateParser for #name #ty_generics #where_clause { + fn parse_with_state<'i>( rule: Rule, - input: &'i str + input: &'i str, + state: #custom_state ) -> #result< - ::pest::iterators::Pairs<'i, Rule>, + (#custom_state, ::pest::iterators::Pairs<'i, Rule>), ::pest::error::Error > { mod rules { @@ -78,11 +89,11 @@ pub(crate) fn generate( pub use self::visible::*; } - ::pest::state(input, |state| { + ::pest::state_custom(input, |state| { match rule { #patterns } - }) + }, state) } } }; @@ -96,42 +107,75 @@ pub(crate) fn generate( // Note: All builtin rules should be validated as pest builtins in meta/src/validator.rs. // Some should also be keywords. -fn generate_builtin_rules() -> Vec<(&'static str, TokenStream)> { +fn generate_builtin_rules(custom_state: TokenStream) -> Vec<(&'static str, TokenStream)> { let mut builtins = Vec::new(); - insert_builtin!(builtins, ANY, state.skip(1)); + insert_builtin!(builtins, ANY, state.skip(1), custom_state); insert_builtin!( builtins, EOI, - state.rule(Rule::EOI, |state| state.end_of_input()) + state.rule(Rule::EOI, |state| state.end_of_input()), + custom_state + ); + insert_builtin!(builtins, SOI, state.start_of_input(), custom_state); + insert_builtin!(builtins, PEEK, state.stack_peek(), custom_state); + insert_builtin!(builtins, PEEK_ALL, state.stack_match_peek(), custom_state); + insert_builtin!(builtins, POP, state.stack_pop(), custom_state); + insert_builtin!(builtins, POP_ALL, state.stack_match_pop(), custom_state); + insert_builtin!(builtins, DROP, state.stack_drop(), custom_state); + + insert_builtin!( + builtins, + ASCII_DIGIT, + state.match_range('0'..'9'), + custom_state + ); + insert_builtin!( + builtins, + ASCII_NONZERO_DIGIT, + state.match_range('1'..'9'), + custom_state + ); + insert_builtin!( + builtins, + ASCII_BIN_DIGIT, + state.match_range('0'..'1'), + custom_state + ); + insert_builtin!( + builtins, + ASCII_OCT_DIGIT, + state.match_range('0'..'7'), + custom_state ); - insert_builtin!(builtins, SOI, state.start_of_input()); - insert_builtin!(builtins, PEEK, state.stack_peek()); - insert_builtin!(builtins, PEEK_ALL, state.stack_match_peek()); - insert_builtin!(builtins, POP, state.stack_pop()); - insert_builtin!(builtins, POP_ALL, state.stack_match_pop()); - insert_builtin!(builtins, DROP, state.stack_drop()); - - insert_builtin!(builtins, ASCII_DIGIT, state.match_range('0'..'9')); - insert_builtin!(builtins, ASCII_NONZERO_DIGIT, state.match_range('1'..'9')); - insert_builtin!(builtins, ASCII_BIN_DIGIT, state.match_range('0'..'1')); - insert_builtin!(builtins, ASCII_OCT_DIGIT, state.match_range('0'..'7')); insert_builtin!( builtins, ASCII_HEX_DIGIT, state .match_range('0'..'9') .or_else(|state| state.match_range('a'..'f')) - .or_else(|state| state.match_range('A'..'F')) + .or_else(|state| state.match_range('A'..'F')), + custom_state + ); + insert_builtin!( + builtins, + ASCII_ALPHA_LOWER, + state.match_range('a'..'z'), + custom_state + ); + insert_builtin!( + builtins, + ASCII_ALPHA_UPPER, + state.match_range('A'..'Z'), + custom_state ); - insert_builtin!(builtins, ASCII_ALPHA_LOWER, state.match_range('a'..'z')); - insert_builtin!(builtins, ASCII_ALPHA_UPPER, state.match_range('A'..'Z')); insert_builtin!( builtins, ASCII_ALPHA, state .match_range('a'..'z') - .or_else(|state| state.match_range('A'..'Z')) + .or_else(|state| state.match_range('A'..'Z')), + custom_state ); insert_builtin!( builtins, @@ -139,16 +183,23 @@ fn generate_builtin_rules() -> Vec<(&'static str, TokenStream)> { state .match_range('a'..'z') .or_else(|state| state.match_range('A'..'Z')) - .or_else(|state| state.match_range('0'..'9')) + .or_else(|state| state.match_range('0'..'9')), + custom_state + ); + insert_builtin!( + builtins, + ASCII, + state.match_range('\x00'..'\x7f'), + custom_state ); - insert_builtin!(builtins, ASCII, state.match_range('\x00'..'\x7f')); insert_builtin!( builtins, NEWLINE, state .match_string("\n") .or_else(|state| state.match_string("\r\n")) - .or_else(|state| state.match_string("\r")) + .or_else(|state| state.match_string("\r")), + custom_state ); let box_ty = box_type(); @@ -159,7 +210,7 @@ fn generate_builtin_rules() -> Vec<(&'static str, TokenStream)> { builtins.push((property, quote! { #[inline] #[allow(dead_code, non_snake_case, unused_variables)] - fn #property_ident(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + fn #property_ident(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { state.match_char_by(::pest::unicode::#property_ident) } })); @@ -258,7 +309,31 @@ fn generate_patterns(rules: &[OptimizedRule], uses_eoi: bool) -> TokenStream { } } -fn generate_rule(rule: OptimizedRule) -> TokenStream { +fn generate_expr_hooked(parser_name: &Ident, name: Ident, expr: OptimizedExpr) -> TokenStream { + if let OptimizedExpr::Ident(ident) = expr { + let name = format_ident!("r#hook{}", name); + let ident = format_ident!("r#{}", ident); + quote! { + let start = *state.position(); + let mut state = self::#ident(state)?; + let end = *state.position(); + let span = start.span(&end); + if super::super::#parser_name::#name(state.state_mut(), span) { + Ok(state) + } else { + Err(state) + } + } + } else { + unreachable!("__HOOK can be only applied in grammars with a single ident"); + } +} + +fn generate_rule( + parser_name: &Ident, + custom_state: TokenStream, + rule: OptimizedRule, +) -> TokenStream { let name = format_ident!("r#{}", rule.name); let expr = if rule.ty == RuleType::Atomic || rule.ty == RuleType::CompoundAtomic { generate_expr_atomic(rule.expr) @@ -270,6 +345,8 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { #atomic }) } + } else if rule.name.starts_with("__HOOK") { + generate_expr_hooked(parser_name, name.clone(), rule.expr) } else { generate_expr(rule.expr) }; @@ -280,7 +357,7 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { RuleType::Normal => quote! { #[inline] #[allow(non_snake_case, unused_variables)] - pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { state.rule(Rule::#name, |state| { #expr }) @@ -289,14 +366,14 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { RuleType::Silent => quote! { #[inline] #[allow(non_snake_case, unused_variables)] - pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { #expr } }, RuleType::Atomic => quote! { #[inline] #[allow(non_snake_case, unused_variables)] - pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { state.rule(Rule::#name, |state| { state.atomic(::pest::Atomicity::Atomic, |state| { #expr @@ -307,7 +384,7 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { RuleType::CompoundAtomic => quote! { #[inline] #[allow(non_snake_case, unused_variables)] - pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { state.atomic(::pest::Atomicity::CompoundAtomic, |state| { state.rule(Rule::#name, |state| { #expr @@ -318,7 +395,7 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { RuleType::NonAtomic => quote! { #[inline] #[allow(non_snake_case, unused_variables)] - pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule>>> { + pub fn #name(state: #box_ty<::pest::ParserState<'_, Rule, #custom_state>>) -> ::pest::ParseResult<#box_ty<::pest::ParserState<'_, Rule, #custom_state>>> { state.atomic(::pest::Atomicity::NonAtomic, |state| { state.rule(Rule::#name, |state| { #expr @@ -329,19 +406,20 @@ fn generate_rule(rule: OptimizedRule) -> TokenStream { } } -fn generate_skip(rules: &[OptimizedRule]) -> TokenStream { +fn generate_skip(rules: &[OptimizedRule], custom_state: TokenStream) -> TokenStream { let whitespace = rules.iter().any(|rule| rule.name == "WHITESPACE"); let comment = rules.iter().any(|rule| rule.name == "COMMENT"); match (whitespace, comment) { - (false, false) => generate_rule!(skip, Ok(state)), + (false, false) => generate_rule!(skip, Ok(state), custom_state), (true, false) => generate_rule!( skip, if state.atomicity() == ::pest::Atomicity::NonAtomic { state.repeat(|state| super::visible::WHITESPACE(state)) } else { Ok(state) - } + }, + custom_state ), (false, true) => generate_rule!( skip, @@ -349,7 +427,8 @@ fn generate_skip(rules: &[OptimizedRule]) -> TokenStream { state.repeat(|state| super::visible::COMMENT(state)) } else { Ok(state) - } + }, + custom_state ), (true, true) => generate_rule!( skip, @@ -369,7 +448,8 @@ fn generate_skip(rules: &[OptimizedRule]) -> TokenStream { }) } else { Ok(state) - } + }, + custom_state ), } } @@ -1035,7 +1115,7 @@ mod tests { let test_path = current_dir.join("test.pest").to_str().unwrap().to_string(); assert_eq!( - generate(name, &generics, vec![PathBuf::from("base.pest"), PathBuf::from("test.pest")], rules, defaults, doc_comment, true).to_string(), + generate(name, &generics, vec![PathBuf::from("base.pest"), PathBuf::from("test.pest")], rules, defaults, doc_comment, true, None).to_string(), quote! { #[allow(non_upper_case_globals)] const _PEST_GRAMMAR_MyParser: [&'static str; 2usize] = [include_str!(#base_path), include_str!(#test_path)]; @@ -1050,7 +1130,7 @@ mod tests { } #[allow(clippy::all)] - impl ::pest::Parser for MyParser { + impl ::pest::StateParser for MyParser { fn parse<'i>( rule: Rule, input: &'i str diff --git a/generator/src/lib.rs b/generator/src/lib.rs index 41129ef6..4fb3ab86 100644 --- a/generator/src/lib.rs +++ b/generator/src/lib.rs @@ -27,6 +27,7 @@ use std::io::{self, Read}; use std::path::Path; use proc_macro2::TokenStream; +use quote::ToTokens; use syn::{Attribute, DeriveInput, Generics, Ident, Lit, Meta}; #[macro_use] @@ -42,7 +43,7 @@ use pest_meta::{optimizer, unwrap_or_report, validator}; /// "include_str" statement (done in pest_derive, but turned off in the local bootstrap). pub fn derive_parser(input: TokenStream, include_grammar: bool) -> TokenStream { let ast: DeriveInput = syn::parse2(input).unwrap(); - let (name, generics, contents) = parse_derive(ast); + let (name, generics, contents, custom_state) = parse_derive(ast); let mut data = String::new(); let mut paths = vec![]; @@ -105,6 +106,7 @@ pub fn derive_parser(input: TokenStream, include_grammar: bool) -> TokenStream { defaults, &doc_comment, include_grammar, + custom_state, ) } @@ -121,7 +123,9 @@ enum GrammarSource { Inline(String), } -fn parse_derive(ast: DeriveInput) -> (Ident, Generics, Vec) { +fn parse_derive( + mut ast: DeriveInput, +) -> (Ident, Generics, Vec, Option) { let name = ast.ident; let generics = ast.generics; @@ -145,7 +149,27 @@ fn parse_derive(ast: DeriveInput) -> (Ident, Generics, Vec) { grammar_sources.push(get_attribute(attr)) } - (name, generics, grammar_sources) + let custom_state = { + let attrs: Vec<(usize, &Attribute)> = ast + .attrs + .iter() + .enumerate() + .filter(|(_, attr)| match attr.parse_meta() { + Ok(Meta::List(list)) => list.path.is_ident("custom_state"), + _ => false, + }) + .collect(); + if attrs.len() == 1 { + let (id, attr) = attrs[0]; + let attr = get_attribute_custom_state(attr); + ast.attrs.remove(id); + Some(attr) + } else { + None + } + }; + + (name, generics, grammar_sources, custom_state) } fn get_attribute(attr: &Attribute) -> GrammarSource { @@ -164,6 +188,18 @@ fn get_attribute(attr: &Attribute) -> GrammarSource { } } +fn get_attribute_custom_state(attr: &Attribute) -> TokenStream { + match attr.parse_meta() { + Ok(Meta::List(list)) => match list.nested.first() { + Some(x) => { + x.to_token_stream() + } + _ => panic!(), + }, + _ => panic!(), + } +} + #[cfg(test)] mod tests { use super::parse_derive; @@ -177,7 +213,7 @@ mod tests { pub struct MyParser<'a, T>; "; let ast = syn::parse_str(definition).unwrap(); - let (_, _, filenames) = parse_derive(ast); + let (_, _, filenames, _) = parse_derive(ast); assert_eq!(filenames, [GrammarSource::Inline("GRAMMAR".to_string())]); } @@ -189,7 +225,7 @@ mod tests { pub struct MyParser<'a, T>; "; let ast = syn::parse_str(definition).unwrap(); - let (_, _, filenames) = parse_derive(ast); + let (_, _, filenames, _) = parse_derive(ast); assert_eq!(filenames, [GrammarSource::File("myfile.pest".to_string())]); } @@ -202,7 +238,7 @@ mod tests { pub struct MyParser<'a, T>; "; let ast = syn::parse_str(definition).unwrap(); - let (_, _, filenames) = parse_derive(ast); + let (_, _, filenames, _) = parse_derive(ast); assert_eq!( filenames, [ diff --git a/generator/src/macros.rs b/generator/src/macros.rs index 377f66e6..18cb5cf1 100644 --- a/generator/src/macros.rs +++ b/generator/src/macros.rs @@ -8,18 +8,21 @@ // modified, or distributed except according to those terms. macro_rules! insert_builtin { - ($builtin: expr, $name: ident, $pattern: expr) => { - $builtin.push((stringify!($name), generate_rule!($name, $pattern))); + ($builtin: expr, $name: ident, $pattern: expr, $custom_state: ident) => { + $builtin.push(( + stringify!($name), + generate_rule!($name, $pattern, $custom_state), + )); }; } #[cfg(feature = "std")] macro_rules! generate_rule { - ($name: ident, $pattern: expr) => { + ($name: ident, $pattern: expr, $custom_state: ident) => { quote! { #[inline] #[allow(dead_code, non_snake_case, unused_variables)] - pub fn $name(state: ::std::boxed::Box<::pest::ParserState<'_, Rule>>) -> ::pest::ParseResult<::std::boxed::Box<::pest::ParserState<'_, Rule>>> { + pub fn $name(state: ::std::boxed::Box<::pest::ParserState<'_, Rule, #$custom_state>>) -> ::pest::ParseResult<::std::boxed::Box<::pest::ParserState<'_, Rule, #$custom_state>>> { $pattern } } diff --git a/meta/src/optimizer/restorer.rs b/meta/src/optimizer/restorer.rs index e128e03f..22636643 100644 --- a/meta/src/optimizer/restorer.rs +++ b/meta/src/optimizer/restorer.rs @@ -62,6 +62,7 @@ fn child_modifies_state( ) -> bool { expr.iter_top_down().any(|expr| match expr { OptimizedExpr::Push(_) => true, + OptimizedExpr::Ident(ref name) if name.starts_with("__HOOK") => true, OptimizedExpr::Ident(ref name) if name == "DROP" => true, OptimizedExpr::Ident(ref name) if name == "POP" => true, OptimizedExpr::Ident(ref name) => match cache.get(name).cloned() { diff --git a/pest/src/lib.rs b/pest/src/lib.rs index fa4df200..5b7681f7 100644 --- a/pest/src/lib.rs +++ b/pest/src/lib.rs @@ -334,9 +334,10 @@ extern crate alloc; #[cfg(feature = "std")] extern crate std; -pub use crate::parser::Parser; +pub use crate::parser::{Parser, StateParser}; pub use crate::parser_state::{ - set_call_limit, state, Atomicity, Lookahead, MatchDir, ParseResult, ParserState, + set_call_limit, state, state_custom, Atomicity, Lookahead, MatchDir, ParseResult, ParserState, + StateCheckpoint, }; pub use crate::position::Position; pub use crate::span::{Lines, LinesSpan, Span}; diff --git a/pest/src/parser.rs b/pest/src/parser.rs index 1c83a066..42da15ab 100644 --- a/pest/src/parser.rs +++ b/pest/src/parser.rs @@ -9,11 +9,25 @@ use crate::error::Error; use crate::iterators::Pairs; -use crate::RuleType; +use crate::{RuleType, StateCheckpoint}; /// A trait with a single method that parses strings. -pub trait Parser { +pub trait Parser { /// Parses a `&str` starting from `rule`. #[allow(clippy::perf)] fn parse(rule: R, input: &str) -> Result, Error>; } + +/// A trait with a single method that parses strings. +pub trait StateParser { + /// Parses a `&str` starting from `rule`. + #[allow(clippy::perf)] + fn parse_with_state(rule: R, input: &str, state: S) -> Result<(S, Pairs<'_, R>), Error>; +} + +impl> Parser for T { + fn parse(rule: R, input: &str) -> Result, Error> { + let (_, pairs) = Self::parse_with_state(rule, input, S::default())?; + Ok(pairs) + } +} diff --git a/pest/src/parser_state.rs b/pest/src/parser_state.rs index f58de00c..7edfe115 100644 --- a/pest/src/parser_state.rs +++ b/pest/src/parser_state.rs @@ -126,7 +126,10 @@ impl CallLimitTracker { /// /// [`Parser`]: trait.Parser.html #[derive(Debug)] -pub struct ParserState<'i, R: RuleType> { +pub struct ParserState<'i, R: RuleType, S = ()> +where + S: StateCheckpoint, +{ position: Position<'i>, queue: Vec>, lookahead: Lookahead, @@ -136,6 +139,23 @@ pub struct ParserState<'i, R: RuleType> { atomicity: Atomicity, stack: Stack>, call_tracker: CallLimitTracker, + custom_state: Box, +} + +/// Trait for custom state that can be stored in a [`ParserState`]. +pub trait StateCheckpoint { + /// Saves the current state of the parser. + fn snapshot(&mut self); + /// Clears the saved state of the parser. + fn clear_snapshot(&mut self); + /// Restores the saved state of the parser. + fn restore(&mut self); +} + +impl StateCheckpoint for () { + fn snapshot(&mut self) {} + fn clear_snapshot(&mut self) {} + fn restore(&mut self) {} } /// Creates a `ParserState` from a `&str`, supplying it to a closure `f`. @@ -150,14 +170,40 @@ pub struct ParserState<'i, R: RuleType> { #[allow(clippy::perf)] pub fn state<'i, R: RuleType, F>(input: &'i str, f: F) -> Result, Error> where - F: FnOnce(Box>) -> ParseResult>>, + F: FnOnce(Box>) -> ParseResult>>, +{ + let (_, pairs) = state_custom(input, f, ())?; + Ok(pairs) +} + +/// Creates a `ParserState` from a `&str`, supplying it to a closure `f`. +/// +/// # Examples +/// +/// ``` +/// # use pest; +/// let input = ""; +/// pest::state_custom::<(), _, ()>(input, |s| Ok(s)).unwrap(); +/// ``` +#[allow(clippy::perf)] +pub fn state_custom<'i, R: RuleType, F, S>( + input: &'i str, + f: F, + custom_state: S, +) -> Result<(S, pairs::Pairs<'i, R>), Error> +where + F: FnOnce(Box>) -> ParseResult>>, + S: StateCheckpoint, { - let state = ParserState::new(input); + let state = ParserState::new_with_state(input, custom_state); match f(state) { Ok(state) => { let len = state.queue.len(); - Ok(pairs::new(Rc::new(state.queue), input, None, 0, len)) + Ok(( + *state.custom_state, + pairs::new(Rc::new(state.queue), input, None, 0, len), + )) } Err(mut state) => { let variant = if state.reached_call_limit() { @@ -184,7 +230,10 @@ where } } -impl<'i, R: RuleType> ParserState<'i, R> { +impl<'i, R: RuleType, S> ParserState<'i, R, S> +where + S: Default + StateCheckpoint, +{ /// Allocates a fresh `ParserState` object to the heap and returns the owned `Box`. This `Box` /// will be passed from closure to closure based on the needs of the specified `Parser`. /// @@ -196,6 +245,22 @@ impl<'i, R: RuleType> ParserState<'i, R> { /// let state: Box> = pest::ParserState::new(input); /// ``` pub fn new(input: &'i str) -> Box { + Self::new_with_state(input, S::default()) + } +} + +impl<'i, R: RuleType, S: StateCheckpoint> ParserState<'i, R, S> { + /// Allocates a fresh `ParserState` object to the heap and returns the owned `Box`. This `Box` + /// will be passed from closure to closure based on the needs of the specified `Parser`. + /// + /// # Examples + /// + /// ``` + /// # use pest; + /// let input = ""; + /// let state: Box> = pest::ParserState::new(input); + /// ``` + pub fn new_with_state(input: &'i str, custom_state: S) -> Box { Box::new(ParserState { position: Position::from_start(input), queue: vec![], @@ -206,9 +271,20 @@ impl<'i, R: RuleType> ParserState<'i, R> { atomicity: Atomicity::NonAtomic, stack: Stack::new(), call_tracker: Default::default(), + custom_state: Box::new(custom_state), }) } + /// Get the current custom state. + pub fn state_mut(&mut self) -> &mut S { + &mut self.custom_state + } + + /// Get the current custom state. + pub fn state(&self) -> &S { + &self.custom_state + } + /// Returns a reference to the current `Position` of the `ParserState`. /// /// # Examples @@ -1206,6 +1282,7 @@ impl<'i, R: RuleType> ParserState<'i, R> { #[inline] pub(crate) fn checkpoint(mut self: Box) -> Box { self.stack.snapshot(); + self.custom_state.snapshot(); self } @@ -1214,6 +1291,7 @@ impl<'i, R: RuleType> ParserState<'i, R> { #[inline] pub(crate) fn checkpoint_ok(mut self: Box) -> Box { self.stack.clear_snapshot(); + self.custom_state.clear_snapshot(); self } @@ -1221,6 +1299,7 @@ impl<'i, R: RuleType> ParserState<'i, R> { #[inline] pub(crate) fn restore(mut self: Box) -> Box { self.stack.restore(); + self.custom_state.restore(); self } }