diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 45cb180e76..213562b833 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -15,6 +15,9 @@ mod resource; #[cfg(feature = "typed-routing")] mod typed; +#[cfg(feature = "typed-routing")] +use crate::routing::typed::TypedMethod; + pub use self::resource::Resource; #[cfg(feature = "typed-routing")] @@ -79,6 +82,14 @@ macro_rules! vpath { /// Extension trait that adds additional methods to [`Router`]. #[allow(clippy::return_self_not_must_use)] pub trait RouterExt: sealed::Sealed { + // TODO: comments + #[cfg(feature = "typed-routing")] + fn typed(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedMethod; + /// Add a typed `GET` route to the router. /// /// The path will be inferred from the first argument to the handler function which must @@ -240,6 +251,16 @@ impl RouterExt for Router where S: Clone + Send + Sync + 'static, { + #[cfg(feature = "typed-routing")] + fn typed(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedMethod, + { + self.route(P::PATH, axum::routing::on(P::METHOD, handler)) + } + #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index c0659c0340..bcca1093fa 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -4,6 +4,11 @@ use super::sealed::Sealed; use http::Uri; use serde_core::Serialize; +// TODO: comments +pub trait TypedMethod: TypedPath { + const METHOD: axum::routing::MethodFilter; +} + /// A type safe path. /// /// This is used to statically connect a path to its corresponding handler using diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 0d143d79ba..6ac5071ccf 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -17,6 +17,7 @@ mod axum_test; mod debug_handler; mod from_ref; mod from_request; +mod typed_method; mod typed_path; mod with_position; @@ -655,6 +656,16 @@ pub fn __private_axum_test(_attr: TokenStream, input: TokenStream) -> TokenStrea expand_attr_with(_attr, input, axum_test::expand) } +/// Derive an implementation of [`axum_extra::routing::TypedMethod`]. +/// +/// See that trait for more details. +/// +/// [`axum_extra::routing::TypedMethod`]: https://docs.rs/axum-extra/latest/axum_extra/routing/trait.TypedMethod.html +#[proc_macro_derive(TypedMethod, attributes(typed_method))] +pub fn derive_typed_method(input: TokenStream) -> TokenStream { + expand_with(input, |item_struct| typed_method::expand(&item_struct)) +} + /// Derive an implementation of [`axum_extra::routing::TypedPath`]. /// /// See that trait for more details. diff --git a/axum-macros/src/typed_method.rs b/axum-macros/src/typed_method.rs new file mode 100644 index 0000000000..36c191efc7 --- /dev/null +++ b/axum-macros/src/typed_method.rs @@ -0,0 +1,74 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::{parse::Parse, spanned::Spanned, ItemStruct}; + +use super::attr_parsing::Combine; + +pub(crate) fn expand(item_struct: &ItemStruct) -> syn::Result { + let ItemStruct { + attrs, + ident, + generics, + .. + } = &item_struct; + + if !generics.params.is_empty() || generics.where_clause.is_some() { + return Err(syn::Error::new_spanned( + generics, + "`#[derive(TypedMethod)]` doesn't support generics", + )); + } + + let Attrs { method_filter } = super::attr_parsing::parse_attrs("typed_method", attrs)?; + + let method_filter = method_filter.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "Missing method filter: `#[typed_method(\"GET\")]`", + ) + })?; + + let typed_path_impl = quote_spanned! {method_filter.span()=> + #[automatically_derived] + impl ::axum_typed_method::TypedMethod for #ident { + const METHOD: ::axum::routing::MethodFilter = #method_filter; + } + }; + + Ok(quote! (#typed_path_impl)) +} + +#[derive(Default)] +struct Attrs { + method_filter: Option, +} + +impl Parse for Attrs { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + Ok(Self { + method_filter: Some(input.parse()?), + }) + } +} + +impl Combine for Attrs { + fn combine(mut self, other: Self) -> syn::Result { + let Self { method_filter } = other; + if let Some(method_filter) = method_filter { + if self.method_filter.is_some() { + return Err(syn::Error::new_spanned( + method_filter, + "method filter specified more than once", + )); + } + self.method_filter = Some(method_filter); + } + + Ok(self) + } +} + +#[test] +fn ui() { + crate::run_ui_tests("typed_method"); +}