From d0cb6da9f3b22559404b94cfe905c3b3fb4b06d0 Mon Sep 17 00:00:00 2001 From: gngpp Date: Fri, 3 May 2024 20:20:09 +0800 Subject: [PATCH] feat(socks5): Support binding IP-CDIR when connecting (#29) * refactor(serve): Refactor the validation module * feat(socks5): Support binding IP-CDIR when connecting * Update --- src/daemon.rs | 8 +-- src/proxy/connect.rs | 141 ++++++++++++++++++++++++++++++++++++++++ src/proxy/http/mod.rs | 109 +++---------------------------- src/proxy/mod.rs | 19 +++--- src/proxy/socks5/mod.rs | 39 +++++------ 5 files changed, 181 insertions(+), 135 deletions(-) create mode 100644 src/proxy/connect.rs diff --git a/src/daemon.rs b/src/daemon.rs index 5a7f068..e5f3c7b 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -7,15 +7,15 @@ use std::{ }; #[cfg(target_family = "unix")] -pub(crate) const PID_PATH: &str = "/var/run/vproxy.pid"; +const PID_PATH: &str = "/var/run/vproxy.pid"; #[cfg(target_family = "unix")] -pub(crate) const DEFAULT_STDOUT_PATH: &str = "/var/run/vproxy.out"; +const DEFAULT_STDOUT_PATH: &str = "/var/run/vproxy.out"; #[cfg(target_family = "unix")] -pub(crate) const DEFAULT_STDERR_PATH: &str = "/var/run/vproxy.err"; +const DEFAULT_STDERR_PATH: &str = "/var/run/vproxy.err"; /// Get the pid of the daemon #[cfg(target_family = "unix")] -pub(crate) fn get_pid() -> Option { +fn get_pid() -> Option { if let Ok(data) = std::fs::read(PID_PATH) { let binding = String::from_utf8(data).expect("pid file is not utf8"); return Some(binding.trim().to_string()); diff --git a/src/proxy/connect.rs b/src/proxy/connect.rs new file mode 100644 index 0000000..7c9f64a --- /dev/null +++ b/src/proxy/connect.rs @@ -0,0 +1,141 @@ +use cidr::Ipv6Cidr; +use hyper_util::client::legacy::connect::HttpConnector; +use rand::Rng; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use tokio::net::{lookup_host, TcpSocket, TcpStream}; + +#[derive(Clone)] +pub struct Connector { + cidr: Option, + fallback: Option, +} + +impl Connector { + pub(super) fn new(cidr: Option, fallback: Option) -> Self { + Connector { cidr, fallback } + } + + pub fn new_http_connector(&self) -> HttpConnector { + let mut connector = HttpConnector::new(); + + match (self.cidr, self.fallback) { + (Some(v6), Some(IpAddr::V4(v4))) => { + let v6 = get_rand_ipv6(v6.first_address().into(), v6.network_length()); + connector.set_local_addresses(v4, v6); + } + (Some(v6), None) => { + let v6 = get_rand_ipv6(v6.first_address().into(), v6.network_length()); + connector.set_local_address(Some(v6.into())); + } + // ipv4 or ipv6 + (None, Some(ip)) => connector.set_local_address(Some(ip)), + _ => {} + } + + connector + } + + /// Attempts to establish a connection to a given SocketAddr. + /// If an IPv6 subnet and a fallback IP are provided, it will attempt to + /// connect using them. If no IPv6 subnet is provided but a fallback IP + /// is, it will attempt to connect using the fallback IP. If neither are + /// provided, it will attempt to connect directly to the given SocketAddr. + pub async fn try_connect(&self, addr: SocketAddr) -> std::io::Result { + match (self.cidr, self.fallback) { + (Some(ipv6_cidr), ip_addr) => { + try_connect_with_ipv6_and_fallback(addr, ipv6_cidr, ip_addr).await + } + (None, Some(ip)) => try_connect_with_fallback(addr, ip).await, + _ => TcpStream::connect(addr).await, + } + } + + /// Attempts to establish a connection to a given domain and port. + /// It first resolves the domain, then tries to connect to each resolved + /// address, until it successfully connects to an address or has tried + /// all addresses. If all connection attempts fail, it will return the + /// error from the last attempt. If no connection attempts were made, it + /// will return a new `Error` object. + pub async fn try_connect_for_domain( + &self, + domain: String, + port: u16, + ) -> std::io::Result { + let mut last_err = None; + + for target_addr in lookup_host((domain, port)).await? { + match self.try_connect(target_addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = Some(e), + }; + } + + match last_err { + Some(e) => Err(e), + None => Err(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + "Failed to connect to any resolved address", + )), + } + } +} + +/// Try to connect with ipv6 and fallback to ipv4/ipv6 +async fn try_connect_with_ipv6_and_fallback( + target_addr: SocketAddr, + cidr: Ipv6Cidr, + fallback: Option, +) -> std::io::Result { + let socket = TcpSocket::new_v6()?; + let bind_addr = SocketAddr::new( + get_rand_ipv6(cidr.first_address().into(), cidr.network_length()).into(), + 0, + ); + socket.bind(bind_addr)?; + + // Try to connect with ipv6 + match socket.connect(target_addr).await { + Ok(first) => Ok(first), + Err(err) => { + tracing::debug!("try connect with ipv6 failed: {}", err); + if let Some(ip) = fallback { + // Try to connect with fallback ip (ipv4 or ipv6) + let socket = create_socket_for_ip(ip)?; + let bind_addr = SocketAddr::new(ip, 0); + socket.bind(bind_addr)?; + socket.connect(target_addr).await + } else { + // Try to connect with system default ip + TcpStream::connect(target_addr).await + } + } + } +} + +/// Try to connect with fallback to ipv4/ipv6 +async fn try_connect_with_fallback( + target_addr: SocketAddr, + ip: IpAddr, +) -> std::io::Result { + let socket = create_socket_for_ip(ip)?; + let bind_addr = SocketAddr::new(ip, 0); + socket.bind(bind_addr)?; + socket.connect(target_addr).await +} + +/// Create a socket for ip +fn create_socket_for_ip(ip: IpAddr) -> std::io::Result { + match ip { + IpAddr::V4(_) => TcpSocket::new_v4(), + IpAddr::V6(_) => TcpSocket::new_v6(), + } +} + +/// Get a random ipv6 address +fn get_rand_ipv6(mut ipv6: u128, prefix_len: u8) -> Ipv6Addr { + let rand: u128 = rand::thread_rng().gen(); + let net_part = (ipv6 >> (128 - prefix_len)) << (128 - prefix_len); + let host_part = (rand << prefix_len) >> prefix_len; + ipv6 = net_part | host_part; + ipv6.into() +} diff --git a/src/proxy/http/mod.rs b/src/proxy/http/mod.rs index afb3b8f..cf10680 100644 --- a/src/proxy/http/mod.rs +++ b/src/proxy/http/mod.rs @@ -2,25 +2,23 @@ mod auth; pub mod error; use self::{auth::Authenticator, error::ProxyError}; -use super::ProxyContext; +use super::{connect::Connector, ProxyContext}; use bytes::Bytes; -use cidr::Ipv6Cidr; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; use hyper::{ server::conn::http1, service::service_fn, upgrade::Upgraded, Method, Request, Response, }; use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, + client::legacy::Client, rt::{TokioExecutor, TokioIo}, }; -use rand::Rng; use std::{ - net::{IpAddr, Ipv6Addr, SocketAddr, ToSocketAddrs}, + net::{SocketAddr, ToSocketAddrs}, sync::Arc, }; -use tokio::net::{TcpSocket, TcpStream}; +use tokio::net::TcpStream; -pub async fn run(ctx: ProxyContext) -> crate::Result<()> { +pub async fn proxy(ctx: ProxyContext) -> crate::Result<()> { tracing::info!("Http server listening on {}", ctx.bind); let socket = if ctx.bind.is_ipv4() { @@ -63,10 +61,8 @@ pub async fn run(ctx: ProxyContext) -> crate::Result<()> { struct HttpProxy { /// Authentication type auth: Authenticator, - /// Ipv6 subnet, e.g. 2001:db8::/32 - ipv6_subnet: Option, - /// Fallback address - fallback: Option, + /// Connecetor + connector: Connector, } impl From for HttpProxy { @@ -77,8 +73,7 @@ impl From for HttpProxy { _ => Authenticator::None, }, - ipv6_subnet: ctx.ipv6_subnet, - fallback: ctx.fallback, + connector: ctx.connector, } } } @@ -129,26 +124,10 @@ impl HttpProxy { Ok(resp) } } else { - let mut connector = HttpConnector::new(); - - match (self.ipv6_subnet, self.fallback) { - (Some(v6), Some(IpAddr::V4(v4))) => { - let v6 = get_rand_ipv6(v6.first_address().into(), v6.network_length()); - connector.set_local_addresses(v4, v6); - } - (Some(v6), None) => { - let v6 = get_rand_ipv6(v6.first_address().into(), v6.network_length()); - connector.set_local_address(Some(v6.into())); - } - // ipv4 or ipv6 - (None, Some(ip)) => connector.set_local_address(Some(ip)), - _ => {} - } - let resp = Client::builder(TokioExecutor::new()) .http1_title_case_headers(true) .http1_preserve_header_case(true) - .build(connector) + .build(self.connector.new_http_connector()) .request(req) .await?; @@ -160,7 +139,7 @@ impl HttpProxy { // and the upgraded connection async fn tunnel(&self, upgraded: Upgraded, addr_str: String) -> std::io::Result<()> { for addr in addr_str.to_socket_addrs()? { - match self.try_connect(addr).await { + match self.connector.try_connect(addr).await { Ok(mut server) => { tracing::info!("tunnel: {} via {}", addr_str, server.local_addr()?); return tunnel_proxy(upgraded, &mut server).await; @@ -176,65 +155,6 @@ impl HttpProxy { Ok(()) } - - /// Get a socket and a bind address - async fn try_connect(&self, addr: SocketAddr) -> std::io::Result { - match (self.ipv6_subnet, self.fallback) { - (Some(ipv6_cidr), ip_addr) => { - try_connect_with_ipv6_and_fallback(addr, ipv6_cidr, ip_addr).await - } - (None, Some(ip)) => try_connect_with_fallback(addr, ip).await, - _ => TcpStream::connect(addr).await, - } - } -} - -/// Try to connect with ipv6 and fallback to ipv4/ipv6 -async fn try_connect_with_ipv6_and_fallback( - addr: SocketAddr, - v6: Ipv6Cidr, - ip: Option, -) -> std::io::Result { - let socket = TcpSocket::new_v6()?; - let bind_addr = SocketAddr::new( - get_rand_ipv6(v6.first_address().into(), v6.network_length()).into(), - 0, - ); - socket.bind(bind_addr)?; - - // Try to connect with ipv6 - match socket.connect(addr).await { - Ok(first) => Ok(first), - Err(err) => { - tracing::debug!("try connect with ipv6 failed: {}", err); - if let Some(ip) = ip { - // Try to connect with fallback ip (ipv4 or ipv6) - let socket = create_socket_for_ip(ip)?; - let bind_addr = SocketAddr::new(ip, 0); - socket.bind(bind_addr)?; - socket.connect(addr).await - } else { - // Try to connect with system default ip - TcpStream::connect(addr).await - } - } - } -} - -/// Try to connect with fallback to ipv4/ipv6 -async fn try_connect_with_fallback(addr: SocketAddr, ip: IpAddr) -> std::io::Result { - let socket = create_socket_for_ip(ip)?; - let bind_addr = SocketAddr::new(ip, 0); - socket.bind(bind_addr)?; - socket.connect(addr).await -} - -/// Create a socket for ip -fn create_socket_for_ip(ip: IpAddr) -> std::io::Result { - match ip { - IpAddr::V4(_) => TcpSocket::new_v4(), - IpAddr::V6(_) => TcpSocket::new_v6(), - } } /// Proxy data between upgraded connection and server @@ -249,15 +169,6 @@ async fn tunnel_proxy(upgraded: Upgraded, server: &mut TcpStream) -> std::io::Re Ok(()) } -/// Get a random ipv6 address -fn get_rand_ipv6(mut ipv6: u128, prefix_len: u8) -> Ipv6Addr { - let rand: u128 = rand::thread_rng().gen(); - let net_part = (ipv6 >> (128 - prefix_len)) << (128 - prefix_len); - let host_part = (rand << prefix_len) >> prefix_len; - ipv6 = net_part | host_part; - ipv6.into() -} - fn host_addr(uri: &http::Uri) -> Option { uri.authority().map(|auth| auth.to_string()) } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index d2bc0fb..cc115b3 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,10 +1,11 @@ mod auth; +mod connect; mod http; mod socks5; use crate::{AuthMode, BootArgs, Proxy}; pub use socks5::Error; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; struct ProxyContext { @@ -14,10 +15,8 @@ struct ProxyContext { pub concurrent: usize, /// Authentication type pub auth: AuthMode, - /// Ipv6 subnet, e.g. 2001:db8::/32 - pub ipv6_subnet: Option, - /// Fallback address - pub fallback: Option, + /// Connector + pub connector: connect::Connector, } #[tokio::main(flavor = "multi_thread")] @@ -53,22 +52,20 @@ pub async fn run(args: BootArgs) -> crate::Result<()> { match args.proxy { Proxy::Http { auth } => { - http::run(ProxyContext { + http::proxy(ProxyContext { bind: args.bind, concurrent: args.concurrent, auth, - ipv6_subnet: args.ipv6_subnet, - fallback: args.fallback, + connector: connect::Connector::new(args.ipv6_subnet, args.fallback), }) .await } Proxy::Socks5 { auth } => { - socks5::run(ProxyContext { + socks5::proxy(ProxyContext { bind: args.bind, concurrent: args.concurrent, auth, - ipv6_subnet: args.ipv6_subnet, - fallback: args.fallback, + connector: connect::Connector::new(args.ipv6_subnet, args.fallback), }) .await } diff --git a/src/proxy/socks5/mod.rs b/src/proxy/socks5/mod.rs index 8a3bda8..378735a 100644 --- a/src/proxy/socks5/mod.rs +++ b/src/proxy/socks5/mod.rs @@ -10,30 +10,27 @@ use self::{ ClientConnection, IncomingConnection, Server, UdpAssociate, }, }; -use super::ProxyContext; +use super::{connect, ProxyContext}; use as_any::AsAny; pub use error::Error; use std::{ net::{SocketAddr, ToSocketAddrs}, sync::Arc, }; -use tokio::{ - net::{TcpStream, UdpSocket}, - sync::Mutex, -}; +use tokio::{net::UdpSocket, sync::Mutex}; -pub async fn run(ctx: ProxyContext) -> crate::Result<()> { +pub async fn proxy(ctx: ProxyContext) -> crate::Result<()> { tracing::info!("Socks5 server listening on {}", ctx.bind); - match (ctx.auth.username, ctx.auth.password) { + match (&ctx.auth.username, &ctx.auth.password) { (Some(username), Some(password)) => { - let auth = Arc::new(auth::Password::new(&username, &password)); - event_loop(auth, ctx.bind, ctx.concurrent as u32).await?; + let auth = Arc::new(auth::Password::new(username, password)); + event_loop(auth, ctx).await?; } _ => { let auth = Arc::new(auth::NoAuth); - event_loop(auth, ctx.bind, ctx.concurrent as u32).await?; + event_loop(auth, ctx).await?; } } @@ -45,19 +42,16 @@ const MAX_UDP_RELAY_PACKET_SIZE: usize = 1500; /// The library's `Result` type alias. pub type Result = std::result::Result; -async fn event_loop( - auth: auth::AuthAdaptor, - listen_addr: SocketAddr, - concurrent: u32, -) -> Result<()> +async fn event_loop(auth: auth::AuthAdaptor, ctx: ProxyContext) -> Result<()> where S: Send + Sync + 'static, { - let server = Server::bind_with_concurrency(listen_addr, auth, concurrent).await?; - + let server = Server::bind_with_concurrency(ctx.bind, auth, ctx.concurrent as u32).await?; + let connector = Arc::new(ctx.connector); while let Ok((conn, _)) = server.accept().await { + let connector = connector.clone(); tokio::spawn(async move { - if let Err(err) = handle(conn).await { + if let Err(err) = handle(conn, connector).await { tracing::error!("{err}"); } }); @@ -65,7 +59,7 @@ where Ok(()) } -async fn handle(conn: IncomingConnection) -> Result<()> +async fn handle(conn: IncomingConnection, connector: Arc) -> Result<()> where S: Send + Sync + 'static, { @@ -91,8 +85,10 @@ where } ClientConnection::Connect(connect, addr) => { let target = match addr { - Address::DomainAddress(domain, port) => TcpStream::connect((domain, port)).await, - Address::SocketAddress(addr) => TcpStream::connect(addr).await, + Address::DomainAddress(domain, port) => { + connector.try_connect_for_domain(domain, port).await + } + Address::SocketAddress(addr) => connector.try_connect(addr).await, }; if let Ok(mut target) = target { @@ -188,3 +184,4 @@ async fn handle_s5_upd_associate(associate: UdpAssociate) } } } +