diff --git a/Cargo.lock b/Cargo.lock index 9d96292..ad78b97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,7 +1153,7 @@ dependencies = [ [[package]] name = "redproxy-rs" -version = "0.7.0" +version = "0.8.0" dependencies = [ "async-trait", "axum", diff --git a/Cargo.toml b/Cargo.toml index 801df43..4cff0f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "redproxy-rs" -version = "0.7.0" +version = "0.8.0" authors = ["Bearice Ren "] edition = "2021" default-run = "redproxy-rs" diff --git a/config.yaml b/config.yaml index c6fbd3b..a768229 100644 --- a/config.yaml +++ b/config.yaml @@ -24,6 +24,12 @@ listeners: type: tproxy bind: 0.0.0.0:8080 protocol: udp + # default: false, whether use full cone NAT or restrictive NAT, + # full cone NAT is slower and unable to filter with dst address, + # restrictive NAT is a little faster and able to filter with dst address + udp_full_cone: false + # default: 128, only applies for full cone NAT, controls how many sockets to be cached for sending udp replies. + udp_max_socket: 256 - name: udp-reverse type: reverse diff --git a/src/common/udp.rs b/src/common/udp.rs index 5653875..f285b23 100644 --- a/src/common/udp.rs +++ b/src/common/udp.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use std::io::Result as IoResult; use std::net::SocketAddr; use std::sync::Arc; -use tokio::net::UdpSocket; +use tokio::{net::UdpSocket, sync::mpsc}; use super::frames::{Frame, FrameIO, FrameReader, FrameWriter}; use crate::context::TargetAddress; @@ -75,17 +75,20 @@ pub fn udp_socket( UdpSocket::from_std(socket) } +pub type Receiver = mpsc::Receiver; +pub type Sender = mpsc::Sender; + pub fn setup_udp_session( target: TargetAddress, local: SocketAddr, remote: SocketAddr, - first_frame: Option, + extra_frame: Receiver, transparent: bool, ) -> IoResult { let socket = udp_socket(local, Some(remote), transparent)?; let socket = Arc::new(socket); Ok(( - UdpFrameReader::new(target, socket.clone(), first_frame), + UdpFrameReader::new(target, socket.clone(), extra_frame), UdpFrameWriter::new(socket), )) } @@ -93,15 +96,15 @@ pub fn setup_udp_session( struct UdpFrameReader { socket: Arc, target: TargetAddress, - first_frame: Option, + extra_frame: Receiver, } impl UdpFrameReader { - fn new(target: TargetAddress, socket: Arc, first_frame: Option) -> Box { + fn new(target: TargetAddress, socket: Arc, extra_frame: Receiver) -> Box { Self { target, socket, - first_frame, + extra_frame, } .into() } @@ -110,13 +113,14 @@ impl UdpFrameReader { #[async_trait] impl FrameReader for UdpFrameReader { async fn read(&mut self) -> IoResult> { - if self.first_frame.is_some() { - return Ok(self.first_frame.take()); - } let mut buf = Frame::new(); - buf.recv_from(&self.socket).await?; - buf.addr = Some(self.target.clone()); - Ok(Some(buf)) + tokio::select! { + Some(f) = self.extra_frame.recv() => Ok(Some(f)), + _ = buf.recv_from(&self.socket) => { + buf.addr = Some(self.target.clone()); + Ok(Some(buf)) + } + } } } diff --git a/src/connectors/direct.rs b/src/connectors/direct.rs index e348037..e98dbf4 100644 --- a/src/connectors/direct.rs +++ b/src/connectors/direct.rs @@ -177,16 +177,16 @@ struct DirectFrames { #[async_trait] impl FrameReader for DirectFrames { async fn read(&mut self) -> IoResult> { - loop { - let mut frame = Frame::new(); - let (_, source) = frame.recv_from(&self.socket).await?; - log::trace!("read udp frame: {:?}", frame); - if self.target.ip().is_unspecified() || self.target == source { - return Ok(Some(frame)); - } else { - log::debug!("received unexpected udp frame from {:?}, dropping", source) - } - } + // loop { + let mut frame = Frame::new(); + let (_, _source) = frame.recv_from(&self.socket).await?; + log::trace!("read udp frame: {:?}", frame); + // if self.target.ip().is_unspecified() || self.target == source { + return Ok(Some(frame)); + // } else { + // log::debug!("received unexpected udp frame from {:?}, dropping", source) + // } + // } } } diff --git a/src/listeners/reverse.rs b/src/listeners/reverse.rs index c08b619..d4016d8 100644 --- a/src/listeners/reverse.rs +++ b/src/listeners/reverse.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use chashmap_async::CHashMap; use easy_error::{Error, ResultExt}; use log::{debug, error, info}; use serde::{Deserialize, Serialize}; @@ -6,23 +7,25 @@ use serde_yaml::Value; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::{TcpListener, UdpSocket}; -use tokio::sync::mpsc::Sender; +use tokio::sync::mpsc::{channel, Sender}; use super::Listener; use crate::common::frames::Frame; use crate::common::set_keepalive; -use crate::common::udp::{setup_udp_session, udp_socket}; -use crate::context::ContextRefOps; -use crate::context::{make_buffered_stream, ContextRef, Feature, TargetAddress}; +use crate::common::udp::{self, setup_udp_session, udp_socket}; +use crate::context::{make_buffered_stream, Context, ContextRef, Feature, TargetAddress}; +use crate::context::{ContextCallback, ContextRefOps}; use crate::GlobalState; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize)] pub struct ReverseProxyListener { name: String, bind: SocketAddr, target: TargetAddress, #[serde(default = "default_protocol")] protocol: Protocol, + #[serde(skip)] + sessions: Arc>, } #[derive(Serialize, Deserialize, Debug, Clone, Copy)] @@ -120,19 +123,47 @@ impl ReverseProxyListener { let source = crate::common::try_map_v4_addr(source); debug!("{}: recv from {:?} length: {}", self.name, source, size); - let ctx = state - .contexts - .create_context(self.name.to_owned(), source) - .await; - let frames = setup_udp_session(self.target.clone(), self.bind, source, Some(buf), false) - .context("setup session")?; - ctx.write() - .await - .set_target(self.target.clone()) - .set_feature(Feature::UdpForward) - .set_idle_timeout(state.timeouts.udp) - .set_client_frames(frames); - ctx.enqueue(queue).await?; + if let Some(tx) = self.sessions.get(&source).await { + tx.send(buf).await.context("send")?; + } else { + let (tx, rx) = channel(100); + let io = setup_udp_session(self.target.clone(), self.bind, source, rx, false) + .context("setup session")?; + self.sessions.insert(source, tx).await; + let ctx = state + .contexts + .create_context(self.name.to_owned(), source) + .await; + ctx.write() + .await + .set_target(self.target.clone()) + .set_feature(Feature::UdpForward) + .set_idle_timeout(state.timeouts.udp) + .set_callback(ReverseCallback::new(source, self.sessions.clone())) + .set_client_frames(io); + ctx.enqueue(queue).await?; + } Ok(()) } } + +struct ReverseCallback { + client: SocketAddr, + sessions: Arc>, +} + +impl ReverseCallback { + fn new(client: SocketAddr, sessions: Arc>) -> Self { + Self { client, sessions } + } +} + +#[async_trait] +impl ContextCallback for ReverseCallback { + async fn on_error(&self, _ctx: &mut Context, _error: Error) { + self.sessions.remove(&self.client).await; + } + async fn on_finish(&self, _ctx: &mut Context) { + self.sessions.remove(&self.client).await; + } +} diff --git a/src/listeners/tproxy.rs b/src/listeners/tproxy.rs index 5ae9eb8..c34f444 100644 --- a/src/listeners/tproxy.rs +++ b/src/listeners/tproxy.rs @@ -39,7 +39,7 @@ use crate::{ common::{ frames::{Frame, FrameReader, FrameWriter}, set_keepalive, try_map_v4_addr, - udp::udp_socket, + udp::{setup_udp_session, udp_socket}, }, context::{ make_buffered_stream, Context, ContextCallback, ContextRef, ContextRefOps, Feature, @@ -58,6 +58,8 @@ pub struct TProxyListener { protocol: Protocol, #[serde(default = "default_max_udp_socket")] max_udp_socket: usize, + #[serde(default)] + udp_full_cone: bool, #[serde(skip)] inner: Option>, } @@ -70,7 +72,7 @@ pub enum Protocol { } struct Internals { - sessions: CHashMap, + sessions: CHashMap<(SocketAddr, SocketAddr), Session>, sockets: Mutex>, } @@ -185,6 +187,7 @@ impl TProxyListener { ctx.enqueue(queue).await?; Ok(()) } + async fn udp_accept( self: &Arc, listener: &TproxyUdpSocket, @@ -201,37 +204,54 @@ impl TProxyListener { trace!("{}: recv from {:?} length: {}", self.name, src, size); let inner = self.inner.as_ref().unwrap(); - if !inner.sessions.contains_key(&src).await { - let (tx, rx) = channel(100); - let r = TproxyReader::new(rx); - let w = TproxyWriter::new(src, inner.clone()); - inner.sessions.insert(src, Session::new(src, tx)).await; - let target = if dst.is_ipv4() { - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)) - } else { - SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)) - } - .into(); + let key = if self.udp_full_cone { + (src, src) + } else { + (src, dst) + }; + if let Some(mut session) = inner.sessions.get_mut(&key).await { + session + .add_frame(buf) + .await + .context("setup session failed")?; + } else { let ctx = state .contexts .create_context(self.name.to_owned(), src) .await; + let (tx, rx) = channel(100); + inner.sessions.insert(key, Session::new(src, tx)).await; + if self.udp_full_cone { + let r = TproxyReader::new(rx); + let w = TproxyWriter::new(src, inner.clone()); + let target = if dst.is_ipv4() { + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)) + } else { + SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)) + } + .into(); + ctx.write() + .await + .set_target(target) + .set_feature(Feature::UdpBind) + .set_extra("udp-bind-source", src) + .set_client_frames((r, w)); + } else { + let frames = + setup_udp_session(dst.into(), dst, src, rx, true).context("setup session")?; + ctx.write() + .await + .set_target(dst.into()) + .set_feature(Feature::UdpForward) + .set_client_frames(frames); + } ctx.write() .await - .set_target(target) - .set_feature(Feature::UdpBind) - .set_idle_timeout(state.timeouts.udp) - .set_callback(TproxyCallback::new(src, inner.clone())) - .set_extra("udp-bind-source", src) - .set_client_frames((r, w)); + .set_callback(TproxyCallback::new(key, inner.clone())) + .set_idle_timeout(state.timeouts.udp); ctx.enqueue(queue).await?; } - if let Some(mut session) = inner.sessions.get_mut(&src).await { - session - .add_frame(buf) - .await - .context("setup session failed")?; - } + Ok(()) } } @@ -254,7 +274,7 @@ pub fn set_nonblocking(fd: i32) -> std::io::Result<()> { Ok(()) } -pub struct TproxyUdpSocket { +struct TproxyUdpSocket { inner: AsyncFd, } @@ -414,28 +434,29 @@ impl FrameWriter for TproxyWriter { Ok(frame.len()) } async fn shutdown(&mut self) -> IoResult<()> { - self.inner.sessions.remove(&self.client).await; + let key = (self.client, self.client); + self.inner.sessions.remove(&key).await; Ok(()) } } struct TproxyCallback { - client: SocketAddr, + key: (SocketAddr, SocketAddr), inner: Arc, } impl TproxyCallback { - fn new(client: SocketAddr, inner: Arc) -> Self { - Self { client, inner } + fn new(key: (SocketAddr, SocketAddr), inner: Arc) -> Self { + Self { key, inner } } } #[async_trait] impl ContextCallback for TproxyCallback { async fn on_error(&self, _ctx: &mut Context, _error: Error) { - self.inner.sessions.remove(&self.client).await; + self.inner.sessions.remove(&self.key).await; } async fn on_finish(&self, _ctx: &mut Context) { - self.inner.sessions.remove(&self.client).await; + self.inner.sessions.remove(&self.key).await; } }