diff --git a/russh/examples/echoserver.rs b/russh/examples/echoserver.rs index 2e76faa2..91a6180c 100644 --- a/russh/examples/echoserver.rs +++ b/russh/examples/echoserver.rs @@ -76,6 +76,7 @@ impl server::Server for Server { impl server::Handler for Server { type Error = russh::Error; + type Data = (); async fn channel_open_session( &mut self, diff --git a/russh/examples/ratatui_app.rs b/russh/examples/ratatui_app.rs index f0ad38ed..49b29acc 100644 --- a/russh/examples/ratatui_app.rs +++ b/russh/examples/ratatui_app.rs @@ -143,6 +143,7 @@ impl Server for AppServer { impl Handler for AppServer { type Error = anyhow::Error; + type Data = (); async fn channel_open_session( &mut self, diff --git a/russh/examples/ratatui_shared_app.rs b/russh/examples/ratatui_shared_app.rs index d27daa69..a97793ce 100644 --- a/russh/examples/ratatui_shared_app.rs +++ b/russh/examples/ratatui_shared_app.rs @@ -145,6 +145,7 @@ impl Server for AppServer { impl Handler for AppServer { type Error = anyhow::Error; + type Data = (); async fn channel_open_session( &mut self, diff --git a/russh/examples/sftp_server.rs b/russh/examples/sftp_server.rs index f5b09768..6d9435f3 100644 --- a/russh/examples/sftp_server.rs +++ b/russh/examples/sftp_server.rs @@ -42,6 +42,7 @@ impl SshSession { impl russh::server::Handler for SshSession { type Error = anyhow::Error; + type Data = (); async fn auth_password(&mut self, user: &str, password: &str) -> Result { info!("credentials: {user}, {password}"); diff --git a/russh/examples/test.rs b/russh/examples/test.rs index be139da0..2c59d9e3 100644 --- a/russh/examples/test.rs +++ b/russh/examples/test.rs @@ -48,6 +48,7 @@ impl server::Server for Server { impl server::Handler for Server { type Error = anyhow::Error; + type Data = (); async fn channel_open_session( &mut self, diff --git a/russh/src/client/test.rs b/russh/src/client/test.rs index 566f898c..487c145c 100644 --- a/russh/src/client/test.rs +++ b/russh/src/client/test.rs @@ -33,6 +33,7 @@ mod tests { impl ServerHandler for TestServer { type Error = Error; + type Data = (); async fn channel_open_session( &mut self, diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index b6a1a2d9..a2141d7d 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -209,6 +209,7 @@ impl Auth { #[cfg_attr(feature = "async-trait", async_trait::async_trait)] pub trait Handler: Sized { type Error: From + Send; + type Data: Send; /// Check authentication using the "none" method. Russh makes /// sure rejection happens in time `config.auth_rejection_time`, @@ -792,6 +793,58 @@ pub trait Handler: Sized { Ok(Some(best_group.clone())) } } + + /// Called when the handler needs to be updated. + /// ['trigger'] should be used with ['process'] + /// + /// # Cancel safety + /// + /// The safety of this method depends entirely on how you implement it; + /// it provides no inherent security guarantees. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::mpsc::Receiver; + /// use russh::server::{Handler, Session}; + /// + /// struct App{ + /// foo: String, + /// recv: Receiver, + /// trigger: Receiver, + /// } + /// + /// impl Handler for App { + /// type Error = russh::Error; + /// type Data = String; + /// async fn trigger(&mut self) -> Result { + /// match self.trigger.recv().await { + /// Some(d) => Ok(d), + /// None => std::future::pending().await, + /// } + /// } + /// + /// async fn process(&mut self, s: Self::Data, session: &mut Session) -> Result<(), Self::Error> { + /// let s = self.recv.recv().await.unwrap(); + /// self.foo = s; + /// Ok(()) + /// } + /// } + /// ``` + /// + fn trigger(&mut self) -> impl Future> + Send { + std::future::pending() + } + + /// Called after [`trigger`], See [`trigger`] for more. + #[allow(unused_variables)] + fn process( + &mut self, + data: Self::Data, + session: &mut Session, + ) -> impl Future> + Send { + async { Ok(()) } + } } pub struct RunningServerHandle { diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index e20d18ff..48912e0f 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -542,6 +542,13 @@ impl Session { } reading.set(start_reading(stream_read, buffer, opening_cipher)); } + t = handler.trigger() => { + debug!("handler trigger is invoked"); + match t { + Ok(d) => handler.process(d,&mut self).await?, + Err(e) => return Err(e) + } + } () = &mut keepalive_timer => { self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { diff --git a/russh/src/tests.rs b/russh/src/tests.rs index 6241f4c4..a7e402ef 100644 --- a/russh/src/tests.rs +++ b/russh/src/tests.rs @@ -91,6 +91,7 @@ mod compress { impl server::Handler for Server { type Error = super::Error; + type Data = (); async fn channel_open_session( &mut self, @@ -256,6 +257,7 @@ mod channels { impl server::Handler for ServerHandle { type Error = crate::Error; + type Data = (); async fn auth_publickey( &mut self, @@ -327,6 +329,7 @@ mod channels { impl server::Handler for ServerHandle { type Error = crate::Error; + type Data = (); async fn auth_publickey( &mut self, @@ -411,6 +414,7 @@ mod channels { impl server::Handler for ServerHandle { type Error = crate::Error; + type Data = (); async fn auth_publickey( &mut self, @@ -496,6 +500,7 @@ mod channels { impl server::Handler for ServerHandle { type Error = crate::Error; + type Data = (); async fn auth_publickey( &mut self, @@ -615,5 +620,6 @@ mod server_kex_junk { impl server::Handler for Server { type Error = super::Error; + type Data = (); } } diff --git a/russh/tests/test_backpressure.rs b/russh/tests/test_backpressure.rs index c9b9946c..4c4cf2f2 100644 --- a/russh/tests/test_backpressure.rs +++ b/russh/tests/test_backpressure.rs @@ -118,6 +118,7 @@ impl russh::server::Server for Server { impl russh::server::Handler for Server { type Error = anyhow::Error; + type Data = (); async fn auth_publickey( &mut self, diff --git a/russh/tests/test_data_stream.rs b/russh/tests/test_data_stream.rs index feb65c22..a067b4bd 100644 --- a/russh/tests/test_data_stream.rs +++ b/russh/tests/test_data_stream.rs @@ -187,6 +187,7 @@ impl russh::server::Server for Server { impl russh::server::Handler for Server { type Error = anyhow::Error; + type Data = (); async fn auth_publickey( &mut self, diff --git a/russh/tests/test_kex_shared_secret.rs b/russh/tests/test_kex_shared_secret.rs index 37d06658..bf41323d 100644 --- a/russh/tests/test_kex_shared_secret.rs +++ b/russh/tests/test_kex_shared_secret.rs @@ -292,6 +292,7 @@ struct TestServer {} impl server::Handler for TestServer { type Error = russh::Error; + type Data = (); async fn auth_publickey( &mut self, diff --git a/russh/tests/test_mlkem_kex.rs b/russh/tests/test_mlkem_kex.rs index c2f9b0cd..65dc7f98 100644 --- a/russh/tests/test_mlkem_kex.rs +++ b/russh/tests/test_mlkem_kex.rs @@ -334,6 +334,7 @@ struct TestServer {} impl server::Handler for TestServer { type Error = russh::Error; + type Data = (); async fn auth_publickey( &mut self, diff --git a/russh/tests/test_rekey_strict_kex.rs b/russh/tests/test_rekey_strict_kex.rs index a30802f3..173dcdf4 100644 --- a/russh/tests/test_rekey_strict_kex.rs +++ b/russh/tests/test_rekey_strict_kex.rs @@ -117,6 +117,7 @@ struct TestServer {} // Insecure server that accepts any public key and echos back data it receives; ONLY FOR TESTS impl server::Handler for TestServer { type Error = russh::Error; + type Data = (); async fn auth_publickey( &mut self,