diff --git a/tracing-attributes/src/attr.rs b/tracing-attributes/src/attr.rs index 902ca1450b..89dfb9445a 100644 --- a/tracing-attributes/src/attr.rs +++ b/tracing-attributes/src/attr.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; +use syn::parse::discouraged::Speculative; use syn::{punctuated::Punctuated, Expr, Ident, LitInt, LitStr, Path, Token}; use proc_macro2::TokenStream; @@ -268,7 +269,7 @@ pub(crate) struct Fields(pub(crate) Punctuated); #[derive(Clone, Debug)] pub(crate) struct Field { - pub(crate) name: Punctuated, + pub(crate) name: FieldName, pub(crate) value: Option, pub(crate) kind: FieldKind, } @@ -306,7 +307,7 @@ impl Parse for Field { input.parse::()?; kind = FieldKind::Debug; }; - let name = Punctuated::parse_separated_nonempty_with(input, Ident::parse_any)?; + let name = FieldName::parse(input)?; let value = if input.peek(Token![=]) { input.parse::()?; if input.peek(Token![%]) { @@ -357,6 +358,37 @@ impl ToTokens for FieldKind { } } +#[derive(Clone, Debug)] +pub(crate) enum FieldName { + Ident(Punctuated), + Literal(LitStr), +} + +impl Parse for FieldName { + fn parse(input: ParseStream<'_>) -> syn::Result { + let ahead = input.fork(); + if let Ok(ident) = Punctuated::parse_separated_nonempty_with(&ahead, Ident::parse_any) { + input.advance_to(&ahead); + return Ok(Self::Ident(ident)); + } + if let Ok(lit) = input.parse::() { + return Ok(Self::Literal(lit)); + } + Err(ahead.error( + "expected \"field.name\" (string literal) or field.name (punctuated identifier)", + )) + } +} + +impl ToTokens for FieldName { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + FieldName::Ident(ident) => ident.to_tokens(tokens), + FieldName::Literal(lit) => lit.to_tokens(tokens), + } + } +} + #[derive(Clone, Debug)] pub(crate) enum Level { Trace, diff --git a/tracing-attributes/src/expand.rs b/tracing-attributes/src/expand.rs index b52cb12aba..46d4055035 100644 --- a/tracing-attributes/src/expand.rs +++ b/tracing-attributes/src/expand.rs @@ -2,6 +2,7 @@ use std::iter; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::Parse; use syn::visit_mut::VisitMut; use syn::{ punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, @@ -10,7 +11,7 @@ use syn::{ }; use crate::{ - attr::{Field, Fields, FormatMode, InstrumentArgs, Level}, + attr::{Field, FieldName, Fields, FormatMode, InstrumentArgs, Level}, MaybeItemFn, MaybeItemFnRef, }; @@ -189,9 +190,16 @@ fn gen_block( // If any parameters have the same name as a custom field, skip // and allow them to be formatted by the custom field. if let Some(ref fields) = args.fields { - fields.0.iter().all(|Field { ref name, .. }| { - let first = name.first(); - first != name.last() || !first.iter().any(|name| name == ¶m) + fields.0.iter().all(|Field { ref name, .. }| match name { + FieldName::Ident(name) => { + let first = name.first(); + first != name.last() || !first.iter().any(|name| name == ¶m) + } + FieldName::Literal(name) => { + // If the literal string would be a valid ident, apply the same overwrite logic. + let literal_ident = name.parse_with(syn::Ident::parse); + !literal_ident.iter().any(|name| name == param) + } }) } else { true diff --git a/tracing-attributes/tests/instrument.rs b/tracing-attributes/tests/instrument.rs index 957567dcf9..e65d07e831 100644 --- a/tracing-attributes/tests/instrument.rs +++ b/tracing-attributes/tests/instrument.rs @@ -239,3 +239,32 @@ fn impl_trait_return_type() { handle.assert_finished(); } + +#[test] +fn keywords_in_fields() { + #[instrument(fields("d.type" = "test", "x" = ?x))] + fn my_fn(x: u64) { + tracing::event!(Level::TRACE, "r.type" = "test", "event name"); + } + + let span = expect::span().named("my_fn"); + + let (subscriber, handle) = collector::mock() + .new_span( + span.clone().with_fields( + expect::field("d.type") + .with_value(&"test") + .and(expect::field("x").with_value(&tracing::field::display(10))), + ), + ) + .enter(span.clone()) + .event(expect::event().with_fields(expect::field("r.type").with_value(&"test"))) + .exit(span.clone()) + .drop_span(span) + .only() + .run_with_handle(); + + with_default(subscriber, || my_fn(10)); + + handle.assert_finished(); +} diff --git a/tracing-attributes/tests/ui.rs b/tracing-attributes/tests/ui.rs index 73d7fdcef8..3b974cb7fe 100644 --- a/tracing-attributes/tests/ui.rs +++ b/tracing-attributes/tests/ui.rs @@ -12,3 +12,10 @@ fn const_instrument() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/const_instrument.rs"); } + +#[rustversion::stable] +#[test] +fn invalid_keyword_instrument() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/invalid_keyword_instrument.rs"); +} diff --git a/tracing-attributes/tests/ui/invalid_keyword_instrument.rs b/tracing-attributes/tests/ui/invalid_keyword_instrument.rs new file mode 100644 index 0000000000..2f146ab313 --- /dev/null +++ b/tracing-attributes/tests/ui/invalid_keyword_instrument.rs @@ -0,0 +1,6 @@ +#![allow(unreachable_code)] + +#[tracing::instrument(level = "trace", fields(() = "test"))] +fn test_fn() -> &'static str {} + +fn main() {} diff --git a/tracing-attributes/tests/ui/invalid_keyword_instrument.stderr b/tracing-attributes/tests/ui/invalid_keyword_instrument.stderr new file mode 100644 index 0000000000..0bcd89adba --- /dev/null +++ b/tracing-attributes/tests/ui/invalid_keyword_instrument.stderr @@ -0,0 +1,5 @@ +error: expected "field.name" (string literal) or field.name (punctuated identifier) + --> tests/ui/invalid_keyword_instrument.rs:3:47 + | +3 | #[tracing::instrument(level = "trace", fields(() = "test"))] + | ^