diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index 528f8e1ab0..db0294688b 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -73,6 +73,23 @@ async fn handler(Path(path): Path) -> String { Note that the leading slash is not included, i.e. for the route `/foo/{*rest}` and the path `/foo/bar/baz` the value of `rest` will be `bar/baz`. +The captured segments can also be extracted as a sequence: + +```rust +use axum::{ + Router, + routing::get, + extract::Path, +}; + +let app: Router = Router::new().route("/files/{*path}", get(handler)); + +async fn handler(Path(segments): Path>) -> String { + segments.join(", ") +} +``` +For the path `/files/foo/bar/baz`, `segments` will be `["foo", "bar", "baz"]`. + # Accepting multiple methods To accept multiple methods for the same route you can add all handlers at the @@ -121,7 +138,8 @@ let app = Router::new() .route("/users", get(list_users).post(create_user)) .route("/users/{id}", get(show_user)) .route("/api/{version}/users/{id}/action", delete(do_users_action)) - .route("/assets/{*path}", get(serve_asset)); + .route("/assets/{*path}", get(serve_asset)) + .route("/batch/{*ids}", get(batch_process)); async fn root() {} @@ -134,6 +152,8 @@ async fn show_user(Path(id): Path) {} async fn do_users_action(Path((version, id)): Path<(String, u64)>) {} async fn serve_asset(Path(path): Path) {} + +async fn batch_process(Path(ids): Path>) {} # let _: Router = app; ``` diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 929f0c1f53..b8418d2860 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -4,7 +4,7 @@ use serde_core::{ de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, forward_to_deserialize_any, Deserializer, }; -use std::{any::type_name, sync::Arc}; +use std::{any::type_name, str::Split, sync::Arc}; macro_rules! unsupported_type { ($trait_fn:ident) => { @@ -144,6 +144,12 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { where V: Visitor<'de>, { + if let [(_, s)] = self.url_params { + return visitor.visit_seq(ValueSeqDeserializer { + split: s.split('/'), + }); + } + visitor.visit_seq(SeqDeserializer { params: self.url_params, idx: 0, @@ -486,13 +492,13 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { } } - fn deserialize_seq(self, _visitor: V) -> Result + fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de>, { - Err(PathDeserializationError::unsupported_type(type_name::< - V::Value, - >())) + visitor.visit_seq(ValueSeqDeserializer { + split: self.value.split('/'), + }) } fn deserialize_tuple_struct( @@ -630,6 +636,124 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { } } +macro_rules! parse_raw_str { + ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let v = self.value.parse().map_err(|_| { + PathDeserializationError::new(ErrorKind::ParseError { + value: self.value.to_owned(), + expected_type: $ty, + }) + })?; + visitor.$visit_fn(v) + } + }; +} + +#[derive(Debug)] +struct RawStrValueDeserializer<'de> { + value: &'de str, +} + +impl<'de> Deserializer<'de> for RawStrValueDeserializer<'de> { + type Error = PathDeserializationError; + + parse_raw_str!(deserialize_bool, visit_bool, "bool"); + parse_raw_str!(deserialize_i8, visit_i8, "i8"); + parse_raw_str!(deserialize_i16, visit_i16, "i16"); + parse_raw_str!(deserialize_i32, visit_i32, "i32"); + parse_raw_str!(deserialize_i64, visit_i64, "i64"); + parse_raw_str!(deserialize_i128, visit_i128, "i128"); + parse_raw_str!(deserialize_u8, visit_u8, "u8"); + parse_raw_str!(deserialize_u16, visit_u16, "u16"); + parse_raw_str!(deserialize_u32, visit_u32, "u32"); + parse_raw_str!(deserialize_u64, visit_u64, "u64"); + parse_raw_str!(deserialize_u128, visit_u128, "u128"); + parse_raw_str!(deserialize_f32, visit_f32, "f32"); + parse_raw_str!(deserialize_f64, visit_f64, "f64"); + parse_raw_str!(deserialize_string, visit_string, "String"); + parse_raw_str!(deserialize_byte_buf, visit_string, "String"); + parse_raw_str!(deserialize_char, visit_char, "char"); + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.value.as_bytes()) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(EnumDeserializer { value: self.value }) + } + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::unsupported_type(type_name::< + V::Value, + >())) + } + + forward_to_deserialize_any! { + option unit unit_struct seq tuple identifier + tuple_struct map struct ignored_any + } +} + +struct ValueSeqDeserializer<'de> { + split: Split<'de, char>, +} + +impl<'de> SeqAccess<'de> for ValueSeqDeserializer<'de> { + type Error = PathDeserializationError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + for s in self.split.by_ref() { + // skip empty segments from trailing or consecutive slashes + if !s.is_empty() { + return Ok(Some( + seed.deserialize(RawStrValueDeserializer { value: s })?, + )); + } + } + + Ok(None) + } +} + #[derive(Debug, Clone)] enum KeyOrIdx<'de> { Key(&'de str), @@ -821,6 +945,63 @@ mod tests { ); } + #[test] + fn test_parse_seq_wildcard() { + let url_params = create_url_params(vec![("a", "x/y/z")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ["x".to_owned(), "y".to_owned(), "z".to_owned()] + ); + + let url_params = create_url_params(vec![("a", "1/-2/3")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + [1, -2, 3] + ); + } + + #[test] + fn test_parse_seq_wildcard_empty() { + let url_params = create_url_params(vec![("a", "x")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ["x".to_owned()] + ); + + let url_params = create_url_params(vec![("a", "x/")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ["x".to_owned()] + ); + + let url_params = create_url_params(vec![("a", "x///y")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ["x".to_owned(), "y".to_owned()] + ); + } + + #[test] + fn test_parse_seq_wildcard_multiple_segments() { + let url_params = create_url_params(vec![("a", "test"), ("b", "x")]); + assert_eq!( + <(String, Vec)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ("test".to_owned(), vec!["x".to_owned()]) + ); + + let url_params = create_url_params(vec![("a", "test"), ("b", "x/")]); + assert_eq!( + <(String, Vec)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ("test".to_owned(), vec!["x".to_owned()]) + ); + + let url_params = create_url_params(vec![("a", "test"), ("b", "x/y")]); + assert_eq!( + <(String, Vec)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), + ("test".to_owned(), vec!["x".to_owned(), "y".to_owned()]) + ); + } + macro_rules! test_parse_error { ( $params:expr, @@ -937,6 +1118,14 @@ mod tests { test_parse_error!( vec![("a", "false")], Vec<(u32, String)>, + ErrorKind::UnsupportedType { + name: "(u32, alloc::string::String)" + } + ); + + test_parse_error!( + vec![("a", "false"), ("b", "true")], + Vec<(u32, String)>, ErrorKind::Message("Unexpected key type".to_owned()) ); } @@ -975,4 +1164,28 @@ mod tests { } ); } + + #[test] + fn test_parse_seq_wildcard_error() { + test_parse_error!( + vec![("a", "1/notanumber/3")], + Vec, + ErrorKind::ParseError { + value: "notanumber".to_owned(), + expected_type: "i32", + } + ); + } + + #[test] + fn test_parse_seq_wildcard_tuple_error() { + test_parse_error!( + vec![("a", "test"), ("b", "x/y")], + (String, Vec), + ErrorKind::ParseError { + value: "x".to_owned(), + expected_type: "i32", + } + ); + } }