diff --git a/crates/core/src/handler.rs b/crates/core/src/handler.rs index ae5129ae7..cb6db08df 100644 --- a/crates/core/src/handler.rs +++ b/crates/core/src/handler.rs @@ -117,6 +117,8 @@ //! } //! } //! ``` +use std::sync::Arc; + use crate::http::StatusCode; use crate::{async_trait, Depot, FlowCtrl, Request, Response}; @@ -160,6 +162,11 @@ pub struct WhenHoop { pub inner: H, pub filter: F, } +impl WhenHoop { + pub fn new(inner: H, filter: F) -> Self { + Self { inner, filter } + } +} #[async_trait] impl Handler for WhenHoop where @@ -189,6 +196,61 @@ where } } +/// Handler that wrap [`Handler`] to let it use middlwares. +#[non_exhaustive] +pub struct HoopedHandler { + inner: Arc, + hoops: Vec>, +} +impl HoopedHandler { + /// Create new `HoopedHandler`. + pub fn new(inner: H) -> Self { + Self { + inner: Arc::new(inner), + hoops: vec![], + } + } + + /// Get current catcher's middlewares reference. + #[inline] + pub fn hoops(&self) -> &Vec> { + &self.hoops + } + /// Get current catcher's middlewares mutable reference. + #[inline] + pub fn hoops_mut(&mut self) -> &mut Vec> { + &mut self.hoops + } + + /// Add a handler as middleware, it will run the handler when error catched. + #[inline] + pub fn hoop(mut self, hoop: H) -> Self { + self.hoops.push(Arc::new(hoop)); + self + } + + /// Add a handler as middleware, it will run the handler when error catched. + /// + /// This middleware only effective when the filter return true. + #[inline] + pub fn hoop_when(mut self, hoop: H, filter: F) -> Self + where + H: Handler, + F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static, + { + self.hoops.push(Arc::new(WhenHoop::new(hoop, filter))); + self + } +} +#[async_trait] +impl Handler for HoopedHandler { + async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) { + let inner: Arc = self.inner.clone(); + ctrl.handlers.extend(self.hoops.iter().chain([&inner]).cloned()); + ctrl.call_next(req, depot, res).await; + } +} + /// `none_skipper` will skipper nothing. /// /// It can be used as default `Skipper` in middleware. diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 7ebc4312b..c3cea5c8f 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -166,6 +166,7 @@ where } } +#[doc(hidden)] #[macro_export] macro_rules! for_each_tuple { ($callback:ident) => { diff --git a/crates/serve-static/src/dir.rs b/crates/serve-static/src/dir.rs index cb093d106..cd65ac5ed 100644 --- a/crates/serve-static/src/dir.rs +++ b/crates/serve-static/src/dir.rs @@ -9,10 +9,11 @@ use std::str::FromStr; use std::time::SystemTime; use salvo_core::fs::NamedFile; +use salvo_core::handler::{Handler, HoopedHandler}; use salvo_core::http::header::ACCEPT_ENCODING; use salvo_core::http::{self, HeaderValue, Request, Response, StatusCode, StatusError}; use salvo_core::writing::Text; -use salvo_core::{async_trait, Depot, FlowCtrl, Handler, IntoVecString}; +use salvo_core::{async_trait, Depot, FlowCtrl, IntoVecString}; use serde::{Deserialize, Serialize}; use serde_json::json; use time::{macros::format_description, OffsetDateTime}; @@ -157,7 +158,7 @@ impl StaticDir { compressed_variations.insert(CompressionAlgo::Gzip, vec!["gz".to_owned()]); compressed_variations.insert(CompressionAlgo::Deflate, vec!["deflate".to_owned()]); - StaticDir { + Self { roots: roots.collect(), chunk_size: None, include_dot_files: false, @@ -240,6 +241,30 @@ impl StaticDir { } false } + + /// Wrap to `HoopedHandler`. + #[inline] + pub fn hooped(self) -> HoopedHandler { + HoopedHandler::new(self) + } + + /// Add a handler as middleware, it will run the handler when error catched. + #[inline] + pub fn hoop(self, hoop: H) -> HoopedHandler { + HoopedHandler::new(self).hoop(hoop) + } + + /// Add a handler as middleware, it will run the handler when error catched. + /// + /// This middleware only effective when the filter return true. + #[inline] + pub fn hoop_when(self, hoop: H, filter: F) -> HoopedHandler + where + H: Handler, + F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static, + { + HoopedHandler::new(self).hoop_when(hoop, filter) + } } #[derive(Serialize, Deserialize, Debug)] struct CurrentInfo { diff --git a/crates/serve-static/src/embed.rs b/crates/serve-static/src/embed.rs index c7895d679..db88403ce 100644 --- a/crates/serve-static/src/embed.rs +++ b/crates/serve-static/src/embed.rs @@ -4,7 +4,8 @@ use std::marker::PhantomData; use rust_embed::{EmbeddedFile, Metadata, RustEmbed}; use salvo_core::http::header::{CONTENT_TYPE, ETAG, IF_NONE_MATCH}; use salvo_core::http::{HeaderValue, Mime, Request, Response, StatusCode}; -use salvo_core::{async_trait, Depot, FlowCtrl, Handler, IntoVecString}; +use salvo_core::{async_trait, Depot, FlowCtrl, IntoVecString}; +use salvo_core::handler::{HoopedHandler, Handler}; use super::{decode_url_path_safely, format_url_path_safely, join_path, redirect_to_dir_url}; @@ -107,6 +108,30 @@ where self.fallback = Some(fallback.into()); self } + + /// Wrap to `HoopedHandler`. + #[inline] + pub fn hooped(self) -> HoopedHandler { + HoopedHandler::new(self) + } + + /// Add a handler as middleware, it will run the handler when error catched. + #[inline] + pub fn hoop(self, hoop: H) -> HoopedHandler { + HoopedHandler::new(self).hoop(hoop) + } + + /// Add a handler as middleware, it will run the handler when error catched. + /// + /// This middleware only effective when the filter return true. + #[inline] + pub fn hoop_when(self, hoop: H, filter: F) -> HoopedHandler + where + H: Handler, + F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static, + { + HoopedHandler::new(self).hoop_when(hoop, filter) + } } #[async_trait] impl Handler for StaticEmbed