diff --git a/src/classify.rs b/src/classify.rs index 3e41192a4c..7b1bbc324a 100644 --- a/src/classify.rs +++ b/src/classify.rs @@ -6,8 +6,15 @@ use crate::ty::{ReturnType, Type}; use proc_macro2::{Delimiter, TokenStream, TokenTree}; use std::ops::ControlFlow; -pub(crate) fn requires_terminator(expr: &Expr) -> bool { - // see https://github.com/rust-lang/rust/blob/9a19e7604/compiler/rustc_ast/src/util/classify.rs#L7-L26 +#[cfg(feature = "parsing")] +pub(crate) fn requires_semi_to_be_stmt(expr: &Expr) -> bool { + match expr { + Expr::Macro(expr) => !expr.mac.delimiter.is_brace(), + _ => requires_comma_to_be_match_arm(expr), + } +} + +pub(crate) fn requires_comma_to_be_match_arm(expr: &Expr) -> bool { match expr { Expr::If(_) | Expr::Match(_) diff --git a/src/expr.rs b/src/expr.rs index 7d5fa4c238..980a740599 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -2873,7 +2873,7 @@ pub(crate) mod parsing { fat_arrow_token: input.parse()?, body: { let body = Expr::parse_with_earlier_boundary_rule(input)?; - requires_comma = classify::requires_terminator(&body); + requires_comma = classify::requires_comma_to_be_match_arm(&body); Box::new(body) }, comma: { @@ -3312,7 +3312,10 @@ pub(crate) mod printing { // Ensure that we have a comma after a non-block arm, except // for the last one. let is_last = i == self.arms.len() - 1; - if !is_last && classify::requires_terminator(&arm.body) && arm.comma.is_none() { + if !is_last + && classify::requires_comma_to_be_match_arm(&arm.body) + && arm.comma.is_none() + { ::default().to_tokens(tokens); } } diff --git a/src/stmt.rs b/src/stmt.rs index 735947af2f..fa6cf022d4 100644 --- a/src/stmt.rs +++ b/src/stmt.rs @@ -159,7 +159,7 @@ pub(crate) mod parsing { } let stmt = parse_stmt(input, AllowNoSemi(true))?; let requires_semicolon = match &stmt { - Stmt::Expr(stmt, None) => classify::requires_terminator(stmt), + Stmt::Expr(stmt, None) => classify::requires_semi_to_be_stmt(stmt), Stmt::Macro(stmt) => { stmt.semi_token.is_none() && !stmt.mac.delimiter.is_brace() } @@ -401,7 +401,7 @@ pub(crate) mod parsing { if semi_token.is_some() { Ok(Stmt::Expr(e, semi_token)) - } else if allow_nosemi.0 || !classify::requires_terminator(&e) { + } else if allow_nosemi.0 || !classify::requires_semi_to_be_stmt(&e) { Ok(Stmt::Expr(e, None)) } else { Err(input.error("expected semicolon"))