diff --git a/crates/core/src/http/request.rs b/crates/core/src/http/request.rs index d13dbbdcf..507fe6f1a 100644 --- a/crates/core/src/http/request.rs +++ b/crates/core/src/http/request.rs @@ -14,7 +14,6 @@ pub use http::request::Parts; use http::uri::{Scheme, Uri}; use http::Extensions; use http_body_util::{BodyExt, Limited}; -use indexmap::IndexMap; use multimap::MultiMap; use parking_lot::RwLock; use serde::de::Deserialize; @@ -25,6 +24,7 @@ use crate::fuse::TransProto; use crate::http::body::ReqBody; use crate::http::form::{FilePart, FormData}; use crate::http::{Mime, ParseError, Version}; +use crate::routing::PathParams; use crate::serde::{from_request, from_str_map, from_str_multi_map, from_str_multi_val, from_str_val}; use crate::Error; @@ -61,7 +61,7 @@ pub struct Request { #[cfg(feature = "cookie")] pub(crate) cookies: CookieJar, - pub(crate) params: IndexMap, + pub(crate) params: PathParams, // accept: Option>, pub(crate) queries: OnceLock>, @@ -110,7 +110,7 @@ impl Request { method: Method::default(), #[cfg(feature = "cookie")] cookies: CookieJar::default(), - params: IndexMap::new(), + params: PathParams::new(), queries: OnceLock::new(), form_data: tokio::sync::OnceCell::new(), payload: tokio::sync::OnceCell::new(), @@ -171,7 +171,7 @@ impl Request { #[cfg(feature = "cookie")] cookies, // accept: None, - params: IndexMap::new(), + params: PathParams::new(), form_data: tokio::sync::OnceCell::new(), payload: tokio::sync::OnceCell::new(), // multipart: OnceLock::new(), @@ -567,12 +567,12 @@ impl Request { } /// Get params reference. #[inline] - pub fn params(&self) -> &IndexMap { + pub fn params(&self) -> &PathParams { &self.params } /// Get params mutable reference. #[inline] - pub fn params_mut(&mut self) -> &mut IndexMap { + pub fn params_mut(&mut self) -> &mut PathParams { &mut self.params } diff --git a/crates/core/src/routing/filters/path.rs b/crates/core/src/routing/filters/path.rs index e4f4aa1e4..6b0559132 100644 --- a/crates/core/src/routing/filters/path.rs +++ b/crates/core/src/routing/filters/path.rs @@ -254,13 +254,13 @@ impl PathWisp for CharsWisp { } if chars.len() == max_width { state.forward(max_width); - state.params.insert(self.name.clone(), chars.into_iter().collect()); + state.params.insert(&self.name, chars.into_iter().collect()); return true; } } if chars.len() >= self.min_width { state.forward(chars.len()); - state.params.insert(self.name.clone(), chars.into_iter().collect()); + state.params.insert(&self.name, chars.into_iter().collect()); true } else { false @@ -274,7 +274,7 @@ impl PathWisp for CharsWisp { } if chars.len() >= self.min_width { state.forward(chars.len()); - state.params.insert(self.name.clone(), chars.into_iter().collect()); + state.params.insert(&self.name, chars.into_iter().collect()); true } else { false @@ -298,7 +298,7 @@ impl CombWisp { impl PathWisp for CombWisp { #[inline] fn detect<'a>(&self, state: &mut PathState) -> bool { - let mut offline = if let Some(part) = state.parts.get_mut(state.cursor.0) { + let mut offline = if let Some(part) = state.parts.get(state.cursor.0) { part.clone() } else { return false; @@ -403,7 +403,7 @@ impl PathWisp for NamedWisp { } if !rest.is_empty() || !self.0.starts_with("*+") { let rest = rest.to_string(); - state.params.insert(self.0.clone(), rest); + state.params.insert(&self.0, rest); state.cursor.0 = state.parts.len(); true } else { @@ -416,7 +416,7 @@ impl PathWisp for NamedWisp { } let picked = picked.expect("picked should not be `None`").to_owned(); state.forward(picked.len()); - state.params.insert(self.0.clone(), picked); + state.params.insert(&self.0, picked); true } } @@ -456,7 +456,7 @@ impl PathWisp for RegexWisp { if let Some(cap) = cap { let cap = cap.as_str().to_owned(); state.forward(cap.len()); - state.params.insert(self.name.clone(), cap); + state.params.insert(&self.name, cap); true } else { false @@ -472,7 +472,7 @@ impl PathWisp for RegexWisp { if let Some(cap) = cap { let cap = cap.as_str().to_owned(); state.forward(cap.len()); - state.params.insert(self.name.clone(), cap); + state.params.insert(&self.name, cap); true } else { false @@ -930,11 +930,6 @@ mod tests { let segments = PathParser::new("/").parse().unwrap(); assert!(segments.is_empty()); } - #[test] - fn test_parse_rest_without_name() { - let segments = PathParser::new("/hello/<**>").parse().unwrap(); - assert_eq!(format!("{:?}", segments), r#"[ConstWisp("hello"), NamedWisp("**")]"#); - } #[test] fn test_parse_single_const() { diff --git a/crates/core/src/routing/mod.rs b/crates/core/src/routing/mod.rs index e5f146102..cbafb092e 100644 --- a/crates/core/src/routing/mod.rs +++ b/crates/core/src/routing/mod.rs @@ -301,7 +301,7 @@ //! //! #[handler] //! fn serve_file(req: &mut Request) { -//! let rest_path = req.param::("**rest_path"); +//! let rest_path = req.param::("rest_path"); //! } //! ``` //! @@ -375,6 +375,7 @@ mod router; pub use router::Router; use std::borrow::Cow; +use std::ops::Deref; use std::sync::Arc; use indexmap::IndexMap; @@ -388,8 +389,57 @@ pub struct DetectMatched { pub goal: Arc, } -#[doc(hidden)] -pub type PathParams = IndexMap; +/// The path parameters. +#[derive(Clone, Default, Debug, Eq, PartialEq)] +pub struct PathParams { + inner: IndexMap, + greedy: bool, +} +impl Deref for PathParams { + type Target = IndexMap; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} +impl PathParams { + /// Create new `PathParams`. + pub fn new() -> Self { + PathParams::default() + } + /// If there is a wildcard param, it's value is `true`. + pub fn greedy(&self) -> bool { + self.greedy + } + /// Get the last param starts with '*', for example: <**rest>, <*?rest>. + pub fn tail(&self) -> Option<&str> { + if self.greedy { + self.inner.last().map(|(_, v)| &**v) + } else { + None + } + } + + /// Insert new param. + pub fn insert(&mut self, name: &str, value: String) { + #[cfg(debug_assertions)] + { + if self.greedy { + panic!("only one wildcard param is allowed and it must be the last one."); + } + } + if name.starts_with("*+") || name.starts_with("*?") || name.starts_with("**") { + self.inner.insert(name[2..].to_owned(), value); + self.greedy = true; + } else if let Some(name) = name.strip_prefix('*') { + self.inner.insert(name.to_owned(), value); + self.greedy = true; + } else { + self.inner.insert(name.to_owned(), value); + } + } +} + #[doc(hidden)] #[derive(Clone, Debug, Eq, PartialEq)] pub struct PathState { diff --git a/crates/core/src/routing/router.rs b/crates/core/src/routing/router.rs index 5915b2bb6..03b4d3495 100644 --- a/crates/core/src/routing/router.rs +++ b/crates/core/src/routing/router.rs @@ -354,11 +354,11 @@ impl fmt::Debug for Router { } else { format!("{prefix}{SYMBOL_TEE}{SYMBOL_RIGHT}{SYMBOL_RIGHT}") }; - let hd = if let Some(goal) = &router.goal { - format!(" -> {}", goal.type_name()) - } else { - "".into() - }; + let hd = router + .goal + .as_ref() + .map(|goal| format!(" -> {}", goal.type_name())) + .unwrap_or_default(); if !others.is_empty() { writeln!(f, "{cp}{path}[{}]{hd}", others.join(","))?; } else { diff --git a/crates/core/src/serde/request.rs b/crates/core/src/serde/request.rs index 79752a7b0..5c18c6fa8 100644 --- a/crates/core/src/serde/request.rs +++ b/crates/core/src/serde/request.rs @@ -206,16 +206,16 @@ impl<'de> RequestDeserializer<'de> { return false; }; - let field_name: Cow<'_, str> = if let Some(rename) = field.rename { - Cow::from(rename) + let field_name = if let Some(rename) = field.rename { + rename } else if let Some(serde_rename) = field.serde_rename { - Cow::from(serde_rename) + serde_rename } else if let Some(rename_all) = self.metadata.rename_all { - rename_all.apply_to_field(field.decl_name).into() + &*rename_all.apply_to_field(field.decl_name) } else if let Some(serde_rename_all) = self.metadata.serde_rename_all { - serde_rename_all.apply_to_field(field.decl_name).into() + &*serde_rename_all.apply_to_field(field.decl_name) } else { - field.decl_name.into() + field.decl_name }; for source in sources { @@ -237,7 +237,7 @@ impl<'de> RequestDeserializer<'de> { } } SourceFrom::Query => { - let mut value = self.queries.get_vec(field_name.as_ref()); + let mut value = self.queries.get_vec(field_name); if value.is_none() { for alias in &field.aliases { value = self.queries.get_vec(*alias); @@ -254,8 +254,8 @@ impl<'de> RequestDeserializer<'de> { } SourceFrom::Header => { let mut value = None; - if self.headers.contains_key(field_name.as_ref()) { - value = Some(self.headers.get_all(field_name.as_ref())) + if self.headers.contains_key(field_name) { + value = Some(self.headers.get_all(field_name)) } else { for alias in &field.aliases { if self.headers.contains_key(*alias) { @@ -301,7 +301,7 @@ impl<'de> RequestDeserializer<'de> { if let Some(payload) = &self.payload { match payload { Payload::FormData(form_data) => { - let mut value = form_data.fields.get(field_name.as_ref()); + let mut value = form_data.fields.get(field_name); if value.is_none() { for alias in &field.aliases { value = form_data.fields.get(*alias); @@ -318,7 +318,7 @@ impl<'de> RequestDeserializer<'de> { return false; } Payload::JsonMap(ref map) => { - let mut value = map.get(field_name.as_ref()); + let mut value = map.get(field_name); if value.is_none() { for alias in &field.aliases { value = map.get(alias); @@ -346,7 +346,7 @@ impl<'de> RequestDeserializer<'de> { } SourceParser::MultiMap => { if let Some(Payload::FormData(form_data)) = self.payload { - let mut value = form_data.fields.get_vec(field_name.as_ref()); + let mut value = form_data.fields.get_vec(field_name); if value.is_none() { for alias in &field.aliases { value = form_data.fields.get_vec(*alias); diff --git a/crates/oapi/src/extract/parameter/path.rs b/crates/oapi/src/extract/parameter/path.rs index f921b14cb..399f5984c 100644 --- a/crates/oapi/src/extract/parameter/path.rs +++ b/crates/oapi/src/extract/parameter/path.rs @@ -158,7 +158,7 @@ mod tests { let req = TestClient::get("http://127.0.0.1:5801").build_hyper(); let schema = req.uri().scheme().cloned().unwrap(); let mut req = Request::from_hyper(req, schema); - req.params_mut().insert("param".to_string(), "param".to_string()); + req.params_mut().insert("param", "param".to_string()); let result = PathParam::::extract_with_arg(&mut req, "param").await; assert_eq!(result.unwrap().0, "param"); } diff --git a/crates/oapi/src/swagger_ui/mod.rs b/crates/oapi/src/swagger_ui/mod.rs index d21950574..fcb9d278b 100644 --- a/crates/oapi/src/swagger_ui/mod.rs +++ b/crates/oapi/src/swagger_ui/mod.rs @@ -233,12 +233,16 @@ pub(crate) fn redirect_to_dir_url(req_uri: &Uri, res: &mut Response) { #[async_trait] impl Handler for SwaggerUi { async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) { - let path = req.params().get("**").map(|s| &**s).unwrap_or_default(); - // Redirect to dir url if path is empty and not end with '/' - if path.is_empty() && !req.uri().path().ends_with('/') { + // Redirect to dir url if path is not end with '/' + if !req.uri().path().ends_with('/') { redirect_to_dir_url(req.uri(), res); return; } + let Some(path) = req.params().tail() else { + res.render(StatusError::not_found().detail("The router params is incorrect. The params should ending with a wildcard.")); + return; + }; + let keywords = self .keywords .as_ref() diff --git a/crates/proxy/src/lib.rs b/crates/proxy/src/lib.rs index 85864ebba..9fb25b868 100644 --- a/crates/proxy/src/lib.rs +++ b/crates/proxy/src/lib.rs @@ -132,14 +132,12 @@ where /// Url part getter. You can use this to get the proxied url path or query. pub type UrlPartGetter = Box Option + Send + Sync + 'static>; -/// Default url path getter. This getter will get the url path from request wildcard param, like `<**rest>`, `<*+rest>`. +/// Default url path getter. +/// +/// This getter will get the last param as the rest url path from request. +/// In most case you should use wildcard param, like `<**rest>`, `<*+rest>`. pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option { - let param = req.params().iter().find(|(key, _)| key.starts_with('*')); - if let Some((_, rest)) = param { - Some(encode_url_path(rest)) - } else { - None - } + req.params().tail().map(encode_url_path) } /// Default url query getter. This getter just return the query string from request uri. pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option { diff --git a/crates/serve-static/src/dir.rs b/crates/serve-static/src/dir.rs index 2b44f5a77..29cbde2be 100644 --- a/crates/serve-static/src/dir.rs +++ b/crates/serve-static/src/dir.rs @@ -287,14 +287,13 @@ impl DirInfo { #[async_trait] impl Handler for StaticDir { async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) { - let param = req.params().iter().find(|(key, _)| key.starts_with('*')); let req_path = req.uri().path(); - let rel_path = if let Some((_, value)) = param { - value.clone() + let rel_path = if let Some(rest) = req.params().tail() { + rest } else { - decode_url_path_safely(req_path) + &*decode_url_path_safely(req_path) }; - let rel_path = format_url_path_safely(&rel_path); + let rel_path = format_url_path_safely(rel_path); let mut files: HashMap = HashMap::new(); let mut dirs: HashMap = HashMap::new(); let is_dot_file = Path::new(&rel_path) diff --git a/crates/serve-static/src/embed.rs b/crates/serve-static/src/embed.rs index a7b96bad7..c7895d679 100644 --- a/crates/serve-static/src/embed.rs +++ b/crates/serve-static/src/embed.rs @@ -114,11 +114,10 @@ where T: RustEmbed + Send + Sync + 'static, { async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response, _ctrl: &mut FlowCtrl) { - let param = req.params().iter().find(|(key, _)| key.starts_with('*')); - let req_path = if let Some((_, value)) = param { - value.clone() + let req_path = if let Some(rest) = req.params().tail() { + rest } else { - decode_url_path_safely(req.uri().path()) + &*decode_url_path_safely(req.uri().path()) }; let req_path = format_url_path_safely(&req_path); let mut key_path = Cow::Borrowed(&*req_path); diff --git a/crates/serve-static/src/lib.rs b/crates/serve-static/src/lib.rs index f9f6cbc55..09b7d8497 100644 --- a/crates/serve-static/src/lib.rs +++ b/crates/serve-static/src/lib.rs @@ -190,21 +190,21 @@ mod tests { let router = Router::new() .push(Router::with_path("test1.txt").get(Assets::get("test1.txt").unwrap().into_handler())) - .push(Router::with_path("files/<*path>").get(serve_file)) + .push(Router::with_path("files/<**path>").get(serve_file)) .push( - Router::with_path("dir/<*path>").get( + Router::with_path("dir/<**path>").get( static_embed::() .defaults("index.html") .fallback("fallback.html"), ), ) - .push(Router::with_path("dir2/<*path>").get(static_embed::())) - .push(Router::with_path("dir3/<*path>").get(static_embed::().fallback("notexist.html"))); + .push(Router::with_path("dir2/<**path>").get(static_embed::())) + .push(Router::with_path("dir3/<**path>").get(static_embed::().fallback("notexist.html"))); let service = Service::new(router); #[handler] async fn serve_file(req: &mut Request, res: &mut Response) { - let path = req.param::("*path").unwrap(); + let path = req.param::("path").unwrap(); if let Some(file) = Assets::get(&path) { file.render(req, res); } diff --git a/examples/static-embed-file/src/main.rs b/examples/static-embed-file/src/main.rs index 5ecdbf0c5..6c58bc91a 100644 --- a/examples/static-embed-file/src/main.rs +++ b/examples/static-embed-file/src/main.rs @@ -10,7 +10,7 @@ struct Assets; async fn main() { tracing_subscriber::fmt().init(); - let router = Router::with_path("<*path>").get(serve_file); + let router = Router::with_path("<**rest>").get(serve_file); let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; Server::new(acceptor).serve(router).await; @@ -18,7 +18,7 @@ async fn main() { #[handler] async fn serve_file(req: &mut Request, res: &mut Response) { - let path = req.param::("*path").unwrap(); + let path = req.param::("rest").unwrap(); if let Some(file) = Assets::get(&path) { file.render(req, res); } else {