Skip to content

Commit

Permalink
feat: percent-decode incoming path before routing
Browse files Browse the repository at this point in the history
  • Loading branch information
mladedav committed May 3, 2024
1 parent c18cb84 commit 0feb657
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 92 deletions.
21 changes: 21 additions & 0 deletions axum/src/extract/matched_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,4 +391,25 @@ mod tests {
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}

#[crate::test]
async fn matching_braces() {
let app = Router::new().route(
// Double braces are interpreted by matchit as single literal brace
"/{{foo}}",
get(|path: MatchedPath| async move { path.as_str().to_owned() }),
);

let client = TestClient::new(app);

let res = client.get("/{foo}").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "/{{foo}}");

let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);

let res = client.get("/{{foo}}").await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
}
}
18 changes: 18 additions & 0 deletions axum/src/extract/nested_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,22 @@ mod tests {
let res = client.get("/api/users").await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn nesting_with_braces() {
let api = Router::new().route(
"/users",
get(|nested_path: NestedPath| {
assert_eq!(nested_path.as_str(), "/{{api}}");
async {}
}),
);

let app = Router::new().nest("/{{api}}", api);

let client = TestClient::new(app);

let res = client.get("/{api}/users").await;
assert_eq!(res.status(), StatusCode::OK);
}
}
32 changes: 16 additions & 16 deletions axum/src/extract/path/de.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::{ErrorKind, PathDeserializationError};
use crate::util::PercentDecodedStr;
use serde::{
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
forward_to_deserialize_any, Deserializer,
Expand Down Expand Up @@ -33,7 +32,7 @@ macro_rules! parse_single_value {

let value = self.url_params[0].1.parse().map_err(|_| {
PathDeserializationError::new(ErrorKind::ParseError {
value: self.url_params[0].1.as_str().to_owned(),
value: self.url_params[0].1.as_ref().to_owned(),
expected_type: $ty,
})
})?;
Expand All @@ -43,12 +42,12 @@ macro_rules! parse_single_value {
}

pub(crate) struct PathDeserializer<'de> {
url_params: &'de [(Arc<str>, PercentDecodedStr)],
url_params: &'de [(Arc<str>, Arc<str>)],
}

impl<'de> PathDeserializer<'de> {
#[inline]
pub(crate) fn new(url_params: &'de [(Arc<str>, PercentDecodedStr)]) -> Self {
pub(crate) fn new(url_params: &'de [(Arc<str>, Arc<str>)]) -> Self {
PathDeserializer { url_params }
}
}
Expand Down Expand Up @@ -216,9 +215,9 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
}

struct MapDeserializer<'de> {
params: &'de [(Arc<str>, PercentDecodedStr)],
params: &'de [(Arc<str>, Arc<str>)],
key: Option<KeyOrIdx<'de>>,
value: Option<&'de PercentDecodedStr>,
value: Option<&'de Arc<str>>,
}

impl<'de> MapAccess<'de> for MapDeserializer<'de> {
Expand Down Expand Up @@ -300,19 +299,19 @@ macro_rules! parse_value {
let kind = match key {
KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey {
key: key.to_owned(),
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
},
KeyOrIdx::Idx { idx: index, key: _ } => ErrorKind::ParseErrorAtIndex {
index,
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
},
};
PathDeserializationError::new(kind)
} else {
PathDeserializationError::new(ErrorKind::ParseError {
value: self.value.as_str().to_owned(),
value: self.value.as_ref().to_owned(),
expected_type: $ty,
})
}
Expand All @@ -325,7 +324,7 @@ macro_rules! parse_value {
#[derive(Debug)]
struct ValueDeserializer<'de> {
key: Option<KeyOrIdx<'de>>,
value: &'de PercentDecodedStr,
value: &'de Arc<str>,
}

impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
Expand Down Expand Up @@ -414,7 +413,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
{
struct PairDeserializer<'de> {
key: Option<KeyOrIdx<'de>>,
value: Option<&'de PercentDecodedStr>,
value: Option<&'de Arc<str>>,
}

impl<'de> SeqAccess<'de> for PairDeserializer<'de> {
Expand Down Expand Up @@ -576,7 +575,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
}

struct SeqDeserializer<'de> {
params: &'de [(Arc<str>, PercentDecodedStr)],
params: &'de [(Arc<str>, Arc<str>)],
idx: usize,
}

Expand Down Expand Up @@ -629,15 +628,15 @@ mod tests {
a: i32,
}

fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, PercentDecodedStr)>
fn create_url_params<I, K, V>(values: I) -> Vec<(Arc<str>, Arc<str>)>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
values
.into_iter()
.map(|(k, v)| (Arc::from(k.as_ref()), PercentDecodedStr::new(v).unwrap()))
.map(|(k, v)| (Arc::from(k.as_ref()), Arc::from(v.as_ref())))
.collect()
}

Expand Down Expand Up @@ -669,9 +668,10 @@ mod tests {
check_single_value!(f32, "123", 123.0);
check_single_value!(f64, "123", 123.0);
check_single_value!(String, "abc", "abc");
check_single_value!(String, "one%20two", "one two");
check_single_value!(String, "one%20two", "one%20two");
check_single_value!(String, "one two", "one two");
check_single_value!(&str, "abc", "abc");
check_single_value!(&str, "one%20two", "one two");
check_single_value!(&str, "one two", "one two");
check_single_value!(char, "a", 'a');

let url_params = create_url_params(vec![("a", "B")]);
Expand Down
79 changes: 60 additions & 19 deletions axum/src/extract/path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ mod de;
use crate::{
extract::{rejection::*, FromRequestParts},
routing::url_params::UrlParams,
util::PercentDecodedStr,
};
use async_trait::async_trait;
use axum_core::response::{IntoResponse, Response};
Expand Down Expand Up @@ -156,15 +155,6 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.to_string(),
},
};
let err = FailedToDeserializePathParams(err);
return Err(err.into());
}
None => {
return Err(MissingPathParams.into());
}
Expand Down Expand Up @@ -444,7 +434,7 @@ impl std::error::Error for FailedToDeserializePathParams {}
/// # let _: Router = app;
/// ```
#[derive(Debug)]
pub struct RawPathParams(Vec<(Arc<str>, PercentDecodedStr)>);
pub struct RawPathParams(Vec<(Arc<str>, Arc<str>)>);

#[async_trait]
impl<S> FromRequestParts<S> for RawPathParams
Expand All @@ -456,12 +446,6 @@ where
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params = match parts.extensions.get::<UrlParams>() {
Some(UrlParams::Params(params)) => params,
Some(UrlParams::InvalidUtf8InPathParam { key }) => {
return Err(InvalidUtf8InPathParam {
key: Arc::clone(key),
}
.into());
}
None => {
return Err(MissingPathParams.into());
}
Expand Down Expand Up @@ -491,14 +475,14 @@ impl<'a> IntoIterator for &'a RawPathParams {
///
/// Created with [`RawPathParams::iter`].
#[derive(Debug)]
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, PercentDecodedStr)>);
pub struct RawPathParamsIter<'a>(std::slice::Iter<'a, (Arc<str>, Arc<str>)>);

impl<'a> Iterator for RawPathParamsIter<'a> {
type Item = (&'a str, &'a str);

fn next(&mut self) -> Option<Self::Item> {
let (key, value) = self.0.next()?;
Some((&**key, value.as_str()))
Some((&**key, &**value))
}
}

Expand Down Expand Up @@ -890,4 +874,61 @@ mod tests {
let body = res.text().await;
assert_eq!(body, "a=foo b=bar c=baz");
}

#[tokio::test]
async fn percent_encoding_path() {
let app = Router::new().route(
"/{capture}",
get(|Path(path): Path<String>| async move { path }),
);

let client = TestClient::new(app);

let res = client.get("/%61pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "api");

let res = client.get("/%2561pi").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "%61pi");
}

#[tokio::test]
async fn percent_encoding_slash_in_path() {
let app = Router::new().route(
"/{capture}",
get(|Path(path): Path<String>| async move { path })
.fallback(|| async { panic!("not matched") }),
);

let client = TestClient::new(app);

// `%2f` decodes to `/`
// Slashes are treated specially in the router
let res = client.get("/%2flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "/lash");

let res = client.get("/%2Flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "/lash");

// TODO FIXME
// This is not the correct behavior but should be so exceedingly rare that we can live with this for now.
let res = client.get("/%252flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
// Should be
// assert_eq!(body, "%2flash");
assert_eq!(body, "/lash");

let res = client.get("/%25252flash").await;
assert_eq!(res.status(), StatusCode::OK);
let body = res.text().await;
assert_eq!(body, "%252flash");
}
}
18 changes: 15 additions & 3 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,21 @@ where
}
}

let path = req.uri().path().to_owned();

match self.node.at(&path) {
// Double encode any percent-encoded `/`s so that they're not
// interpreted by matchit. Additionally, percent-encode `%`s so that we
// can differentiate between `%2f` we have encoded to `%252f` and
// `%252f` the user might have sent us.
let path = req
.uri()
.path()
.replace("%2f", "%252f")
.replace("%2F", "%252F");
let decode = percent_encoding::percent_decode_str(&path);

match self.node.at(&decode
.decode_utf8()
.unwrap_or(Cow::Owned(req.uri().path().to_owned())))
{
Ok(match_) => {
let id = *match_.value;

Expand Down
Loading

0 comments on commit 0feb657

Please sign in to comment.