From 0feb657818a514652ee40ccac9a06ef4122a4ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Ml=C3=A1dek?= Date: Thu, 2 May 2024 22:52:38 +0200 Subject: [PATCH] feat: percent-decode incoming path before routing --- axum/src/extract/matched_path.rs | 21 ++++++ axum/src/extract/nested_path.rs | 18 +++++ axum/src/extract/path/de.rs | 32 ++++----- axum/src/extract/path/mod.rs | 79 +++++++++++++++------ axum/src/routing/path_router.rs | 18 ++++- axum/src/routing/strip_prefix.rs | 62 ++++++++++++++++- axum/src/routing/tests/mod.rs | 114 +++++++++++++++++++++++++++++++ axum/src/routing/url_params.rs | 30 +++----- axum/src/util.rs | 29 -------- 9 files changed, 311 insertions(+), 92 deletions(-) diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index d51d36c2fe..b7e0791cb6 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -391,4 +391,25 @@ mod tests { let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } + + #[crate::test] + async fn matching_braces() { + let app = Router::new().route( + // Double braces are interpreted by matchit as single literal brace + "/{{foo}}", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ); + + let client = TestClient::new(app); + + let res = client.get("/{foo}").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "/{{foo}}"); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client.get("/{{foo}}").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } } diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs index 72712a4e9a..c326da28e7 100644 --- a/axum/src/extract/nested_path.rs +++ b/axum/src/extract/nested_path.rs @@ -262,4 +262,22 @@ mod tests { let res = client.get("/api/users").await; assert_eq!(res.status(), StatusCode::OK); } + + #[crate::test] + async fn nesting_with_braces() { + let api = Router::new().route( + "/users", + get(|nested_path: NestedPath| { + assert_eq!(nested_path.as_str(), "/{{api}}"); + async {} + }), + ); + + let app = Router::new().nest("/{{api}}", api); + + let client = TestClient::new(app); + + let res = client.get("/{api}/users").await; + assert_eq!(res.status(), StatusCode::OK); + } } diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 8ba8a431e9..0b0ab53356 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -1,5 +1,4 @@ use super::{ErrorKind, PathDeserializationError}; -use crate::util::PercentDecodedStr; use serde::{ de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, forward_to_deserialize_any, Deserializer, @@ -33,7 +32,7 @@ macro_rules! parse_single_value { let value = self.url_params[0].1.parse().map_err(|_| { PathDeserializationError::new(ErrorKind::ParseError { - value: self.url_params[0].1.as_str().to_owned(), + value: self.url_params[0].1.as_ref().to_owned(), expected_type: $ty, }) })?; @@ -43,12 +42,12 @@ macro_rules! parse_single_value { } pub(crate) struct PathDeserializer<'de> { - url_params: &'de [(Arc, PercentDecodedStr)], + url_params: &'de [(Arc, Arc)], } impl<'de> PathDeserializer<'de> { #[inline] - pub(crate) fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { + pub(crate) fn new(url_params: &'de [(Arc, Arc)]) -> Self { PathDeserializer { url_params } } } @@ -216,9 +215,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { } struct MapDeserializer<'de> { - params: &'de [(Arc, PercentDecodedStr)], + params: &'de [(Arc, Arc)], key: Option>, - value: Option<&'de PercentDecodedStr>, + value: Option<&'de Arc>, } impl<'de> MapAccess<'de> for MapDeserializer<'de> { @@ -300,19 +299,19 @@ macro_rules! parse_value { let kind = match key { KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { key: key.to_owned(), - value: self.value.as_str().to_owned(), + value: self.value.as_ref().to_owned(), expected_type: $ty, }, KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex { index, - value: self.value.as_str().to_owned(), + value: self.value.as_ref().to_owned(), expected_type: $ty, }, }; PathDeserializationError::new(kind) } else { PathDeserializationError::new(ErrorKind::ParseError { - value: self.value.as_str().to_owned(), + value: self.value.as_ref().to_owned(), expected_type: $ty, }) } @@ -325,7 +324,7 @@ macro_rules! parse_value { #[derive(Debug)] struct ValueDeserializer<'de> { key: Option>, - value: &'de PercentDecodedStr, + value: &'de Arc, } impl<'de> Deserializer<'de> for ValueDeserializer<'de> { @@ -414,7 +413,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { { struct PairDeserializer<'de> { key: Option>, - value: Option<&'de PercentDecodedStr>, + value: Option<&'de Arc>, } impl<'de> SeqAccess<'de> for PairDeserializer<'de> { @@ -576,7 +575,7 @@ impl<'de> VariantAccess<'de> for UnitVariant { } struct SeqDeserializer<'de> { - params: &'de [(Arc, PercentDecodedStr)], + params: &'de [(Arc, Arc)], idx: usize, } @@ -629,7 +628,7 @@ mod tests { a: i32, } - fn create_url_params(values: I) -> Vec<(Arc, PercentDecodedStr)> + fn create_url_params(values: I) -> Vec<(Arc, Arc)> where I: IntoIterator, K: AsRef, @@ -637,7 +636,7 @@ mod tests { { values .into_iter() - .map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap())) + .map(|(k, v)| (Arc::from(k.as_ref()), Arc::from(v.as_ref()))) .collect() } @@ -669,9 +668,10 @@ mod tests { check_single_value!(f32, "123", 123.0); check_single_value!(f64, "123", 123.0); check_single_value!(String, "abc", "abc"); - check_single_value!(String, "one%20two", "one two"); + check_single_value!(String, "one%20two", "one%20two"); + check_single_value!(String, "one two", "one two"); check_single_value!(&str, "abc", "abc"); - check_single_value!(&str, "one%20two", "one two"); + check_single_value!(&str, "one two", "one two"); check_single_value!(char, "a", 'a'); let url_params = create_url_params(vec![("a", "B")]); diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 07acf0884a..df3d20443e 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -6,7 +6,6 @@ mod de; use crate::{ extract::{rejection::*, FromRequestParts}, routing::url_params::UrlParams, - util::PercentDecodedStr, }; use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; @@ -156,15 +155,6 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let params = match parts.extensions.get::() { Some(UrlParams::Params(params)) => params, - Some(UrlParams::InvalidUtf8InPathParam { key }) => { - let err = PathDeserializationError { - kind: ErrorKind::InvalidUtf8InPathParam { - key: key.to_string(), - }, - }; - let err = FailedToDeserializePathParams(err); - return Err(err.into()); - } None => { return Err(MissingPathParams.into()); } @@ -444,7 +434,7 @@ impl std::error::Error for FailedToDeserializePathParams {} /// # let _: Router = app; /// ``` #[derive(Debug)] -pub struct RawPathParams(Vec<(Arc, PercentDecodedStr)>); +pub struct RawPathParams(Vec<(Arc, Arc)>); #[async_trait] impl FromRequestParts for RawPathParams @@ -456,12 +446,6 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let params = match parts.extensions.get::() { Some(UrlParams::Params(params)) => params, - Some(UrlParams::InvalidUtf8InPathParam { key }) => { - return Err(InvalidUtf8InPathParam { - key: Arc::clone(key), - } - .into()); - } None => { return Err(MissingPathParams.into()); } @@ -491,14 +475,14 @@ impl<'a> IntoIterator for &'a RawPathParams { /// /// Created with [`RawPathParams::iter`]. #[derive(Debug)] -pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc, PercentDecodedStr)>); +pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc, Arc)>); impl<'a> Iterator for RawPathParamsIter<'a> { type Item = (&'a str, &'a str); fn next(&mut self) -> Option { let (key, value) = self.0.next()?; - Some((&**key, value.as_str())) + Some((&**key, &**value)) } } @@ -890,4 +874,61 @@ mod tests { let body = res.text().await; assert_eq!(body, "a=foo b=bar c=baz"); } + + #[tokio::test] + async fn percent_encoding_path() { + let app = Router::new().route( + "/{capture}", + get(|Path(path): Path| async move { path }), + ); + + let client = TestClient::new(app); + + let res = client.get("/%61pi").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "api"); + + let res = client.get("/%2561pi").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "%61pi"); + } + + #[tokio::test] + async fn percent_encoding_slash_in_path() { + let app = Router::new().route( + "/{capture}", + get(|Path(path): Path| async move { path }) + .fallback(|| async { panic!("not matched") }), + ); + + let client = TestClient::new(app); + + // `%2f` decodes to `/` + // Slashes are treated specially in the router + let res = client.get("/%2flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "/lash"); + + let res = client.get("/%2Flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "/lash"); + + // TODO FIXME + // This is not the correct behavior but should be so exceedingly rare that we can live with this for now. + let res = client.get("/%252flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + // Should be + // assert_eq!(body, "%2flash"); + assert_eq!(body, "/lash"); + + let res = client.get("/%25252flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "%252flash"); + } } diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 8cb8f122dc..835c77f043 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -363,9 +363,21 @@ where } } - let path = req.uri().path().to_owned(); - - match self.node.at(&path) { + // Double encode any percent-encoded `/`s so that they're not + // interpreted by matchit. Additionally, percent-encode `%`s so that we + // can differentiate between `%2f` we have encoded to `%252f` and + // `%252f` the user might have sent us. + let path = req + .uri() + .path() + .replace("%2f", "%252f") + .replace("%2F", "%252F"); + let decode = percent_encoding::percent_decode_str(&path); + + match self.node.at(&decode + .decode_utf8() + .unwrap_or(Cow::Owned(req.uri().path().to_owned()))) + { Ok(match_) => { let id = *match_.value; diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index 7209607af7..53788f7399 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -1,5 +1,6 @@ use http::{Request, Uri}; use std::{ + borrow::Cow, sync::Arc, task::{Context, Poll}, }; @@ -60,13 +61,13 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { // path = /api/v0/users // ^^^^^^^ this much is matched and the length is 7. let mut matching_prefix_length = Some(0); - for item in zip_longest(segments(path_and_query.path()), segments(prefix)) { + for item in zip_longest(segments(path_and_query.path()), unescaped_segments(prefix)) { // count the `/` *matching_prefix_length.as_mut().unwrap() += 1; match item { Item::Both(path_segment, prefix_segment) => { - if is_capture(prefix_segment) || path_segment == prefix_segment { + if is_capture(&prefix_segment) || path_segment == prefix_segment { // the prefix segment is either a param, which matches anything, or // it actually matches the path segment *matching_prefix_length.as_mut().unwrap() += path_segment.len(); @@ -121,7 +122,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { Some(Uri::from_parts(parts).unwrap()) } -fn segments(s: &str) -> impl Iterator { +fn segments(s: &str) -> impl Iterator> { assert!( s.starts_with('/'), "path didn't start with '/'. axum should have caught this higher up." @@ -131,6 +132,19 @@ fn segments(s: &str) -> impl Iterator { // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"] // otherwise .skip(1) + .map(Cow::Borrowed) +} + +/// This unescapes anything handled specially by `matchit`. +/// Currently, that means only `{{` and `}}` to mean literal `{` and `}` respectively. +fn unescaped_segments(s: &str) -> impl Iterator> { + segments(s).map(|segment| { + if segment.contains("{{") || segment.contains("}}") { + Cow::Owned(segment.replace("{{", "{").replace("}}", "}")) + } else { + segment + } + }) } fn zip_longest(a: I, b: I2) -> impl Iterator> @@ -380,6 +394,48 @@ mod tests { expected = Some("/a"), ); + test!( + braces_1, + uri = "/{a}/a", + prefix = "/{{a}}/", + expected = Some("/a"), + ); + + test!( + braces_2, + uri = "/{a}/b", + prefix = "/{param}", + expected = Some("/b"), + ); + + test!( + braces_3, + uri = "/{a}/{b}", + prefix = "/{{a}}/{{b}}", + expected = Some("/"), + ); + + test!( + braces_4, + uri = "/{a}/{b}", + prefix = "/{{a}}/{b}", + expected = Some("/"), + ); + + test!( + braces_5, + uri = "/a/{b}", + prefix = "/a", + expected = Some("/{b}"), + ); + + test!( + braces_6, + uri = "/a/{b}", + prefix = "/{a}/{{b}}", + expected = Some("/"), + ); + #[quickcheck] fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool { let UriAndPrefix { uri, prefix } = uri_and_prefix; diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index e3a9d238a7..25c5e4fbf6 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1118,3 +1118,117 @@ async fn colon_in_route() { async fn asterisk_in_route() { _ = Router::<()>::new().route("/*foo", get(|| async move {})); } + +#[crate::test] +async fn colon_in_route_allowed() { + let app = Router::<()>::new() + .without_v07_checks() + .route("/:foo", get(|| async move {})); + + let client = TestClient::new(app); + + let res = client.get("/:foo").await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[crate::test] +async fn asterisk_in_route_allowed() { + let app = Router::<()>::new() + .without_v07_checks() + .route("/*foo", get(|| async move {})); + + let client = TestClient::new(app); + + let res = client.get("/*foo").await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[crate::test] +async fn percent_encoding() { + let app = Router::new().route("/api", get(|| async { "api" })); + + let client = TestClient::new(app); + + let res = client.get("/%61pi").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "api"); +} + +#[crate::test] +async fn percent_encoding_slash() { + let app = Router::new() + .route("/slash/%2flash", get(|| async { "lower" })) + .route("/slash/%2Flash", get(|| async { "upper" })) + .route("/slash//lash", get(|| async { "/" })) + .route("/api/user", get(|| async { "user" })) + .route( + "/{capture}", + get(|Path(capture): Path| { + assert_eq!(capture, "api/user"); + ready("capture") + }), + ); + + let client = TestClient::new(app); + + // %2f encodes `/` + let res = client.get("/api%2fuser").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "capture"); + + let res = client.get("/slash/%2flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "lower"); + + let res = client.get("/slash/%2Flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "upper"); + + // `%25` encodes `%` + // This must not be decoded twice + let res = client.get("/slash/%252flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "lower"); + + let res = client.get("/slash/%252Flash").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "upper"); +} + +#[crate::test] +async fn percent_encoding_percent() { + let app = Router::new() + .route("/%61pi", get(|| async { "percent" })) + .route("/api", get(|| async { "api" })); + + let client = TestClient::new(app); + + let res = client.get("/api").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "api"); + + let res = client.get("/%61pi").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "api"); + + // `%25` encodes `%` + // This must not be decoded twice, otherwise it will become `/api` + let res = client.get("/%2561pi").await; + assert_eq!(res.status(), StatusCode::OK); + let body = res.text().await; + assert_eq!(body, "percent"); +} diff --git a/axum/src/routing/url_params.rs b/axum/src/routing/url_params.rs index eb5a08a330..64ca1cd6d4 100644 --- a/axum/src/routing/url_params.rs +++ b/axum/src/routing/url_params.rs @@ -1,46 +1,32 @@ -use crate::util::PercentDecodedStr; use http::Extensions; use matchit::Params; use std::sync::Arc; #[derive(Clone)] pub(crate) enum UrlParams { - Params(Vec<(Arc, PercentDecodedStr)>), - InvalidUtf8InPathParam { key: Arc }, + Params(Vec<(Arc, Arc)>), } pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { let current_params = extensions.get_mut(); - if let Some(UrlParams::InvalidUtf8InPathParam { .. }) = current_params { - // nothing to do here since an error was stored earlier - return; - } - let params = params .iter() .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) .filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM)) .map(|(k, v)| { - if let Some(decoded) = PercentDecodedStr::new(v) { - Ok((Arc::from(k), decoded)) - } else { - Err(Arc::from(k)) - } + ( + Arc::from(k), + Arc::from(v.replace("%2f", "/").replace("%2F", "/")), + ) }) - .collect::, _>>(); + .collect::>(); match (current_params, params) { - (Some(UrlParams::InvalidUtf8InPathParam { .. }), _) => { - unreachable!("we check for this state earlier in this method") - } - (_, Err(invalid_key)) => { - extensions.insert(UrlParams::InvalidUtf8InPathParam { key: invalid_key }); - } - (Some(UrlParams::Params(current)), Ok(params)) => { + (Some(UrlParams::Params(current)), params) => { current.extend(params); } - (None, Ok(params)) => { + (None, params) => { extensions.insert(UrlParams::Params(params)); } } diff --git a/axum/src/util.rs b/axum/src/util.rs index bae803db88..aee7d2d3ad 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -1,36 +1,7 @@ use pin_project_lite::pin_project; -use std::{ops::Deref, sync::Arc}; pub(crate) use self::mutex::*; -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub(crate) struct PercentDecodedStr(Arc); - -impl PercentDecodedStr { - pub(crate) fn new(s: S) -> Option - where - S: AsRef, - { - percent_encoding::percent_decode(s.as_ref().as_bytes()) - .decode_utf8() - .ok() - .map(|decoded| Self(decoded.as_ref().into())) - } - - pub(crate) fn as_str(&self) -> &str { - &self.0 - } -} - -impl Deref for PercentDecodedStr { - type Target = str; - - #[inline] - fn deref(&self) -> &Self::Target { - self.as_str() - } -} - pin_project! { #[project = EitherProj] pub(crate) enum Either {