Skip to content

Commit

Permalink
[feat(udp)] allow select nat type in tproxy connector
Browse files Browse the repository at this point in the history
Signed-off-by: Bearice Ren <[email protected]>
  • Loading branch information
bearice committed Sep 20, 2022
1 parent 49b2c81 commit 679bb15
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "redproxy-rs"
version = "0.7.0"
version = "0.8.0"
authors = ["Bearice Ren <[email protected]>"]
edition = "2021"
default-run = "redproxy-rs"
Expand Down
6 changes: 6 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ listeners:
type: tproxy
bind: 0.0.0.0:8080
protocol: udp
# default: false, whether use full cone NAT or restrictive NAT,
# full cone NAT is slower and unable to filter with dst address,
# restrictive NAT is a little faster and able to filter with dst address
udp_full_cone: false
# default: 128, only applies for full cone NAT, controls how many sockets to be cached for sending udp replies.
udp_max_socket: 256

- name: udp-reverse
type: reverse
Expand Down
28 changes: 16 additions & 12 deletions src/common/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_trait::async_trait;
use std::io::Result as IoResult;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::{net::UdpSocket, sync::mpsc};

use super::frames::{Frame, FrameIO, FrameReader, FrameWriter};
use crate::context::TargetAddress;
Expand Down Expand Up @@ -75,33 +75,36 @@ pub fn udp_socket(
UdpSocket::from_std(socket)
}

pub type Receiver = mpsc::Receiver<Frame>;
pub type Sender = mpsc::Sender<Frame>;

pub fn setup_udp_session(
target: TargetAddress,
local: SocketAddr,
remote: SocketAddr,
first_frame: Option<Frame>,
extra_frame: Receiver,
transparent: bool,
) -> IoResult<FrameIO> {
let socket = udp_socket(local, Some(remote), transparent)?;
let socket = Arc::new(socket);
Ok((
UdpFrameReader::new(target, socket.clone(), first_frame),
UdpFrameReader::new(target, socket.clone(), extra_frame),
UdpFrameWriter::new(socket),
))
}

struct UdpFrameReader {
socket: Arc<UdpSocket>,
target: TargetAddress,
first_frame: Option<Frame>,
extra_frame: Receiver,
}

impl UdpFrameReader {
fn new(target: TargetAddress, socket: Arc<UdpSocket>, first_frame: Option<Frame>) -> Box<Self> {
fn new(target: TargetAddress, socket: Arc<UdpSocket>, extra_frame: Receiver) -> Box<Self> {
Self {
target,
socket,
first_frame,
extra_frame,
}
.into()
}
Expand All @@ -110,13 +113,14 @@ impl UdpFrameReader {
#[async_trait]
impl FrameReader for UdpFrameReader {
async fn read(&mut self) -> IoResult<Option<Frame>> {
if self.first_frame.is_some() {
return Ok(self.first_frame.take());
}
let mut buf = Frame::new();
buf.recv_from(&self.socket).await?;
buf.addr = Some(self.target.clone());
Ok(Some(buf))
tokio::select! {
Some(f) = self.extra_frame.recv() => Ok(Some(f)),
_ = buf.recv_from(&self.socket) => {
buf.addr = Some(self.target.clone());
Ok(Some(buf))
}
}
}
}

Expand Down
20 changes: 10 additions & 10 deletions src/connectors/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ struct DirectFrames {
#[async_trait]
impl FrameReader for DirectFrames {
async fn read(&mut self) -> IoResult<Option<Frame>> {
loop {
let mut frame = Frame::new();
let (_, source) = frame.recv_from(&self.socket).await?;
log::trace!("read udp frame: {:?}", frame);
if self.target.ip().is_unspecified() || self.target == source {
return Ok(Some(frame));
} else {
log::debug!("received unexpected udp frame from {:?}, dropping", source)
}
}
// loop {
let mut frame = Frame::new();
let (_, _source) = frame.recv_from(&self.socket).await?;
log::trace!("read udp frame: {:?}", frame);
// if self.target.ip().is_unspecified() || self.target == source {
return Ok(Some(frame));
// } else {
// log::debug!("received unexpected udp frame from {:?}, dropping", source)
// }
// }
}
}

Expand Down
67 changes: 49 additions & 18 deletions src/listeners/reverse.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
use async_trait::async_trait;
use chashmap_async::CHashMap;
use easy_error::{Error, ResultExt};
use log::{debug, error, info};
use serde::{Deserialize, Serialize};
use serde_yaml::Value;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{channel, Sender};

use super::Listener;
use crate::common::frames::Frame;
use crate::common::set_keepalive;
use crate::common::udp::{setup_udp_session, udp_socket};
use crate::context::ContextRefOps;
use crate::context::{make_buffered_stream, ContextRef, Feature, TargetAddress};
use crate::common::udp::{self, setup_udp_session, udp_socket};
use crate::context::{make_buffered_stream, Context, ContextRef, Feature, TargetAddress};
use crate::context::{ContextCallback, ContextRefOps};
use crate::GlobalState;

#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize)]
pub struct ReverseProxyListener {
name: String,
bind: SocketAddr,
target: TargetAddress,
#[serde(default = "default_protocol")]
protocol: Protocol,
#[serde(skip)]
sessions: Arc<CHashMap<SocketAddr, udp::Sender>>,
}

#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
Expand Down Expand Up @@ -120,19 +123,47 @@ impl ReverseProxyListener {
let source = crate::common::try_map_v4_addr(source);
debug!("{}: recv from {:?} length: {}", self.name, source, size);

let ctx = state
.contexts
.create_context(self.name.to_owned(), source)
.await;
let frames = setup_udp_session(self.target.clone(), self.bind, source, Some(buf), false)
.context("setup session")?;
ctx.write()
.await
.set_target(self.target.clone())
.set_feature(Feature::UdpForward)
.set_idle_timeout(state.timeouts.udp)
.set_client_frames(frames);
ctx.enqueue(queue).await?;
if let Some(tx) = self.sessions.get(&source).await {
tx.send(buf).await.context("send")?;
} else {
let (tx, rx) = channel(100);
let io = setup_udp_session(self.target.clone(), self.bind, source, rx, false)
.context("setup session")?;
self.sessions.insert(source, tx).await;
let ctx = state
.contexts
.create_context(self.name.to_owned(), source)
.await;
ctx.write()
.await
.set_target(self.target.clone())
.set_feature(Feature::UdpForward)
.set_idle_timeout(state.timeouts.udp)
.set_callback(ReverseCallback::new(source, self.sessions.clone()))
.set_client_frames(io);
ctx.enqueue(queue).await?;
}
Ok(())
}
}

struct ReverseCallback {
client: SocketAddr,
sessions: Arc<CHashMap<SocketAddr, udp::Sender>>,
}

impl ReverseCallback {
fn new(client: SocketAddr, sessions: Arc<CHashMap<SocketAddr, udp::Sender>>) -> Self {
Self { client, sessions }
}
}

#[async_trait]
impl ContextCallback for ReverseCallback {
async fn on_error(&self, _ctx: &mut Context, _error: Error) {
self.sessions.remove(&self.client).await;
}
async fn on_finish(&self, _ctx: &mut Context) {
self.sessions.remove(&self.client).await;
}
}
85 changes: 53 additions & 32 deletions src/listeners/tproxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
common::{
frames::{Frame, FrameReader, FrameWriter},
set_keepalive, try_map_v4_addr,
udp::udp_socket,
udp::{setup_udp_session, udp_socket},
},
context::{
make_buffered_stream, Context, ContextCallback, ContextRef, ContextRefOps, Feature,
Expand All @@ -58,6 +58,8 @@ pub struct TProxyListener {
protocol: Protocol,
#[serde(default = "default_max_udp_socket")]
max_udp_socket: usize,
#[serde(default)]
udp_full_cone: bool,
#[serde(skip)]
inner: Option<Arc<Internals>>,
}
Expand All @@ -70,7 +72,7 @@ pub enum Protocol {
}

struct Internals {
sessions: CHashMap<SocketAddr, Session>,
sessions: CHashMap<(SocketAddr, SocketAddr), Session>,
sockets: Mutex<LruCache<SocketAddr, UdpSocket>>,
}

Expand Down Expand Up @@ -185,6 +187,7 @@ impl TProxyListener {
ctx.enqueue(queue).await?;
Ok(())
}

async fn udp_accept(
self: &Arc<Self>,
listener: &TproxyUdpSocket,
Expand All @@ -201,37 +204,54 @@ impl TProxyListener {

trace!("{}: recv from {:?} length: {}", self.name, src, size);
let inner = self.inner.as_ref().unwrap();
if !inner.sessions.contains_key(&src).await {
let (tx, rx) = channel(100);
let r = TproxyReader::new(rx);
let w = TproxyWriter::new(src, inner.clone());
inner.sessions.insert(src, Session::new(src, tx)).await;
let target = if dst.is_ipv4() {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
}
.into();
let key = if self.udp_full_cone {
(src, src)
} else {
(src, dst)
};
if let Some(mut session) = inner.sessions.get_mut(&key).await {
session
.add_frame(buf)
.await
.context("setup session failed")?;
} else {
let ctx = state
.contexts
.create_context(self.name.to_owned(), src)
.await;
let (tx, rx) = channel(100);
inner.sessions.insert(key, Session::new(src, tx)).await;
if self.udp_full_cone {
let r = TproxyReader::new(rx);
let w = TproxyWriter::new(src, inner.clone());
let target = if dst.is_ipv4() {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
}
.into();
ctx.write()
.await
.set_target(target)
.set_feature(Feature::UdpBind)
.set_extra("udp-bind-source", src)
.set_client_frames((r, w));
} else {
let frames =
setup_udp_session(dst.into(), dst, src, rx, true).context("setup session")?;
ctx.write()
.await
.set_target(dst.into())
.set_feature(Feature::UdpForward)
.set_client_frames(frames);
}
ctx.write()
.await
.set_target(target)
.set_feature(Feature::UdpBind)
.set_idle_timeout(state.timeouts.udp)
.set_callback(TproxyCallback::new(src, inner.clone()))
.set_extra("udp-bind-source", src)
.set_client_frames((r, w));
.set_callback(TproxyCallback::new(key, inner.clone()))
.set_idle_timeout(state.timeouts.udp);
ctx.enqueue(queue).await?;
}
if let Some(mut session) = inner.sessions.get_mut(&src).await {
session
.add_frame(buf)
.await
.context("setup session failed")?;
}

Ok(())
}
}
Expand All @@ -254,7 +274,7 @@ pub fn set_nonblocking(fd: i32) -> std::io::Result<()> {
Ok(())
}

pub struct TproxyUdpSocket {
struct TproxyUdpSocket {
inner: AsyncFd<RawFd>,
}

Expand Down Expand Up @@ -414,28 +434,29 @@ impl FrameWriter for TproxyWriter {
Ok(frame.len())
}
async fn shutdown(&mut self) -> IoResult<()> {
self.inner.sessions.remove(&self.client).await;
let key = (self.client, self.client);
self.inner.sessions.remove(&key).await;
Ok(())
}
}

struct TproxyCallback {
client: SocketAddr,
key: (SocketAddr, SocketAddr),
inner: Arc<Internals>,
}

impl TproxyCallback {
fn new(client: SocketAddr, inner: Arc<Internals>) -> Self {
Self { client, inner }
fn new(key: (SocketAddr, SocketAddr), inner: Arc<Internals>) -> Self {
Self { key, inner }
}
}

#[async_trait]
impl ContextCallback for TproxyCallback {
async fn on_error(&self, _ctx: &mut Context, _error: Error) {
self.inner.sessions.remove(&self.client).await;
self.inner.sessions.remove(&self.key).await;
}
async fn on_finish(&self, _ctx: &mut Context) {
self.inner.sessions.remove(&self.client).await;
self.inner.sessions.remove(&self.key).await;
}
}

0 comments on commit 679bb15

Please sign in to comment.