diff --git a/rust-toolchain b/rust-toolchain index f288d111..7f229af9 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.85.0 +1.92.0 diff --git a/secio/src/dh_compat/openssl_impl.rs b/secio/src/dh_compat/openssl_impl.rs index 4ecabf5d..07381ddd 100644 --- a/secio/src/dh_compat/openssl_impl.rs +++ b/secio/src/dh_compat/openssl_impl.rs @@ -19,14 +19,14 @@ struct Algorithm { static ECDH_P256: Algorithm = Algorithm { _private_len: 256 / 8, - pubkey_len: 1 + (2 * ((256 + 7) / 8)), + pubkey_len: 1 + (2 * 256_usize.div_ceil(8)), pairs_generate: p256_generate, from_pubkey: p256_from_pubkey, }; static ECDH_P384: Algorithm = Algorithm { _private_len: 384 / 8, - pubkey_len: 1 + (2 * ((384 + 7) / 8)), + pubkey_len: 1 + (2 * 384_usize.div_ceil(8)), pairs_generate: p384_generate, from_pubkey: p384_from_pubkey, }; diff --git a/tentacle/src/builder.rs b/tentacle/src/builder.rs index e6b58818..d9441c4a 100644 --- a/tentacle/src/builder.rs +++ b/tentacle/src/builder.rs @@ -147,6 +147,34 @@ where self } + /// Set trusted proxy addresses for HAProxy PROXY protocol and X-Forwarded-For header parsing. + /// + /// When a connection comes from one of these addresses, tentacle will extract the real client IP from: + /// - PROXY protocol v1/v2 headers (for TCP connections) + /// - X-Forwarded-For headers (for WebSocket connections) + /// + /// By default, loopback addresses (127.0.0.1 and ::1) are trusted. This method will **replace** + /// the default list with the provided addresses. + /// + /// # Example + /// + /// ```ignore + /// use std::net::IpAddr; + /// use tentacle::builder::ServiceBuilder; + /// + /// // Replace default loopback with custom proxy addresses + /// let builder = ServiceBuilder::new() + /// .trusted_proxies(vec![ + /// "192.168.1.100".parse().unwrap(), + /// "10.0.0.1".parse().unwrap(), + /// ]); + /// ``` + #[cfg(not(target_family = "wasm"))] + pub fn trusted_proxies(mut self, proxies: Vec) -> Self { + self.config.trusted_proxies = proxies; + self + } + /// Whether to allow tentative registration upnp, default is disable(false) /// /// upnp: https://en.wikipedia.org/wiki/Universal_Plug_and_Play diff --git a/tentacle/src/channel/bound.rs b/tentacle/src/channel/bound.rs index 8b1711ef..a4a4e181 100644 --- a/tentacle/src/channel/bound.rs +++ b/tentacle/src/channel/bound.rs @@ -696,6 +696,7 @@ impl Stream for Receiver { } impl Drop for Receiver { + #[allow(clippy::unnecessary_unwrap)] fn drop(&mut self) { // Drain the channel of all pending messages self.close(); diff --git a/tentacle/src/channel/unbound.rs b/tentacle/src/channel/unbound.rs index c091236c..547ccd3b 100644 --- a/tentacle/src/channel/unbound.rs +++ b/tentacle/src/channel/unbound.rs @@ -500,6 +500,7 @@ impl Stream for UnboundedReceiver { } impl Drop for UnboundedReceiver { + #[allow(clippy::unnecessary_unwrap)] fn drop(&mut self) { // Drain the channel of all pending messages self.close(); diff --git a/tentacle/src/service.rs b/tentacle/src/service.rs index 01203036..a707b173 100644 --- a/tentacle/src/service.rs +++ b/tentacle/src/service.rs @@ -173,7 +173,11 @@ where let transport = MultiTransport::new(config.timeout.timeout); #[allow(clippy::let_and_return)] #[cfg(not(target_family = "wasm"))] - let transport = MultiTransport::new(config.timeout, config.tcp_config.clone()); + let transport = MultiTransport::new( + config.timeout, + config.tcp_config.clone(), + config.trusted_proxies.clone(), + ); #[cfg(feature = "tls")] let transport = transport.tls_config(config.tls_config.clone()); transport @@ -256,7 +260,7 @@ where || extract_peer_id(&address) .map(|peer_id| { inner.dial_protocols.keys().any(|addr| { - if let Some(addr_peer_id) = extract_peer_id(&addr) { + if let Some(addr_peer_id) = extract_peer_id(addr) { addr_peer_id == peer_id } else { false @@ -1161,7 +1165,7 @@ where || extract_peer_id(&address) .map(|peer_id| { self.dial_protocols.keys().any(|addr| { - if let Some(addr_peer_id) = extract_peer_id(&addr) { + if let Some(addr_peer_id) = extract_peer_id(addr) { addr_peer_id == peer_id } else { false diff --git a/tentacle/src/service/config.rs b/tentacle/src/service/config.rs index 4be8cb1f..37a092fc 100644 --- a/tentacle/src/service/config.rs +++ b/tentacle/src/service/config.rs @@ -13,7 +13,11 @@ use std::os::{ fd::AsFd, unix::io::{AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}, }; -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + sync::Arc, + time::Duration, +}; #[cfg(feature = "tls")] use tokio_rustls::rustls::{ClientConfig, ServerConfig}; @@ -46,6 +50,11 @@ pub(crate) struct ServiceConfig { pub tcp_config: TcpConfig, #[cfg(feature = "tls")] pub tls_config: Option, + /// Trusted proxy addresses for HAProxy PROXY protocol and X-Forwarded-For header parsing. + /// When a connection comes from one of these addresses, the real client IP will be extracted + /// from PROXY protocol headers (for TCP) or X-Forwarded-For headers (for WebSocket). + /// By default, loopback addresses (127.0.0.1 and ::1) are included in this list. + pub trusted_proxies: Vec, } impl Default for ServiceConfig { @@ -61,6 +70,11 @@ impl Default for ServiceConfig { tcp_config: Default::default(), #[cfg(feature = "tls")] tls_config: None, + // Default: trust loopback addresses + trusted_proxies: vec![ + IpAddr::V4(Ipv4Addr::LOCALHOST), + IpAddr::V6(Ipv6Addr::LOCALHOST), + ], } } } diff --git a/tentacle/src/transports/mod.rs b/tentacle/src/transports/mod.rs index 26af2769..8fa771bd 100644 --- a/tentacle/src/transports/mod.rs +++ b/tentacle/src/transports/mod.rs @@ -18,6 +18,8 @@ mod memory; #[cfg(not(target_family = "wasm"))] mod onion; #[cfg(not(target_family = "wasm"))] +pub(crate) mod proxy_protocol; +#[cfg(not(target_family = "wasm"))] mod tcp; #[cfg(not(target_family = "wasm"))] pub(crate) mod tcp_base_listen; @@ -141,16 +143,23 @@ mod os { pub(crate) listens_upgrade_modes: Arc>>, #[cfg(feature = "tls")] pub(crate) tls_config: Option, + /// Trusted proxy addresses for HAProxy PROXY protocol and X-Forwarded-For header parsing. + pub(crate) trusted_proxies: Arc>, } impl MultiTransport { - pub fn new(timeout: ServiceTimeout, tcp_config: TcpConfig) -> Self { + pub fn new( + timeout: ServiceTimeout, + tcp_config: TcpConfig, + trusted_proxies: Vec, + ) -> Self { MultiTransport { timeout, tcp_config, listens_upgrade_modes: Arc::new(crate::lock::Mutex::new(Default::default())), #[cfg(feature = "tls")] tls_config: None, + trusted_proxies: Arc::new(trusted_proxies), } } diff --git a/tentacle/src/transports/proxy_protocol.rs b/tentacle/src/transports/proxy_protocol.rs new file mode 100644 index 00000000..01c4ddee --- /dev/null +++ b/tentacle/src/transports/proxy_protocol.rs @@ -0,0 +1,628 @@ +//! HAProxy PROXY Protocol v1 and v2 parser +//! +//! This module provides parsing capabilities for the HAProxy PROXY protocol, +//! which allows proxies to convey the original client IP address to backend servers. +//! +//! Reference: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use log::debug; +use tokio::io::AsyncReadExt; + +use crate::runtime::TcpStream; + +/// PROXY protocol v2 signature (12 bytes) +const PROXY_V2_SIGNATURE: [u8; 12] = [ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +]; + +/// Maximum length of PROXY protocol v1 header (107 chars + CRLF) +const PROXY_V1_MAX_LENGTH: usize = 108; + +/// PROXY protocol v2 header size (16 bytes) +const PROXY_V2_HEADER_SIZE: usize = 16; + +/// Maximum allowed address length for PROXY protocol v2 +/// IPv4: 12 bytes, IPv6: 36 bytes, Unix: 216 bytes +/// We allow some extra for TLV extensions, but cap at 512 to prevent DoS +const PROXY_V2_MAX_ADDR_LEN: usize = 512; + +/// Result of parsing PROXY protocol +#[derive(Debug)] +pub enum ProxyProtocolResult { + /// Successfully parsed, returns the real client address + Success(SocketAddr), + /// Not a PROXY protocol header (data should be processed as-is) + NotProxyProtocol, + /// Parse error + Error(String), +} + +/// Try to parse PROXY protocol from a TCP stream +/// +/// This function will: +/// 1. Peek at the stream to determine if PROXY protocol is present +/// 2. If PROXY protocol v1 or v2 is detected, read and parse it +/// 3. Return the real client address if successful +/// +/// The stream will have the PROXY protocol header consumed after this call. +pub async fn parse_proxy_protocol(stream: &mut TcpStream) -> ProxyProtocolResult { + // First, peek to detect protocol version + let mut peek_buf = [0u8; PROXY_V2_HEADER_SIZE]; + match stream.peek(&mut peek_buf).await { + Ok(n) if n >= 5 => { + // Check for v2 signature first (needs at least 13 bytes to confirm) + if n >= 13 && peek_buf[..12] == PROXY_V2_SIGNATURE { + // Verify version is 2 + if (peek_buf[12] & 0xF0) == 0x20 { + return parse_proxy_protocol_v2(stream).await; + } + } + + // Check for v1 header ("PROXY ") + if &peek_buf[..5] == b"PROXY" { + return parse_proxy_protocol_v1(stream).await; + } + + ProxyProtocolResult::NotProxyProtocol + } + Ok(_) => ProxyProtocolResult::NotProxyProtocol, + Err(e) => ProxyProtocolResult::Error(format!("Failed to peek stream: {}", e)), + } +} + +/// Parse PROXY protocol v1 header from a string line +/// +/// This is a pure function that parses a single line without I/O. +/// Used by both the async parser and unit tests. +fn parse_proxy_v1_line(line: &str) -> ProxyProtocolResult { + let line = line.trim_end_matches('\n').trim_end_matches('\r'); + let parts: Vec<&str> = line.split(' ').collect(); + + if parts.is_empty() || parts[0] != "PROXY" { + return ProxyProtocolResult::Error("Invalid PROXY v1 header".into()); + } + + if parts.len() < 2 { + return ProxyProtocolResult::Error("PROXY v1 header too short".into()); + } + + match parts[1] { + "UNKNOWN" => { + debug!("PROXY v1 UNKNOWN protocol, using socket address"); + ProxyProtocolResult::NotProxyProtocol + } + "TCP4" | "TCP6" => { + if parts.len() != 6 { + return ProxyProtocolResult::Error(format!( + "Invalid PROXY v1 header, expected 6 parts, got {}", + parts.len() + )); + } + + let src_ip: IpAddr = match parts[2].parse() { + Ok(ip) => ip, + Err(_) => { + return ProxyProtocolResult::Error(format!("Invalid source IP: {}", parts[2])); + } + }; + + let src_port: u16 = match parts[4].parse() { + Ok(port) => port, + Err(_) => { + return ProxyProtocolResult::Error(format!( + "Invalid source port: {}", + parts[4] + )); + } + }; + + let src_addr = SocketAddr::new(src_ip, src_port); + debug!("PROXY v1 parsed: src={}", src_addr); + ProxyProtocolResult::Success(src_addr) + } + proto => ProxyProtocolResult::Error(format!("Unsupported PROXY v1 protocol: {}", proto)), + } +} + +/// Parse PROXY protocol version 1 (text format) from a TCP stream +/// +/// Format: "PROXY \r\n" +/// Example: "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n" +/// +/// This implementation reads byte-by-byte to avoid buffering beyond the header, +/// ensuring no business data is lost. +async fn parse_proxy_protocol_v1(stream: &mut TcpStream) -> ProxyProtocolResult { + // Use a stack-allocated buffer for better performance + let mut buf = [0u8; PROXY_V1_MAX_LENGTH]; + let mut pos = 0; + + // Read byte-by-byte until we find \r\n or reach max length + loop { + if pos >= PROXY_V1_MAX_LENGTH { + return ProxyProtocolResult::Error("PROXY v1 header too long".into()); + } + + match stream.read_exact(&mut buf[pos..pos + 1]).await { + Ok(_) => { + pos += 1; + // Check for CRLF (\r\n) - the required line terminator per spec + if pos >= 2 && buf[pos - 2] == b'\r' && buf[pos - 1] == b'\n' { + break; + } + } + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return ProxyProtocolResult::Error( + "Connection closed while reading PROXY header".into(), + ); + } + Err(e) => { + return ProxyProtocolResult::Error(format!("Failed to read PROXY header: {}", e)); + } + } + } + + // Convert to string and parse + match std::str::from_utf8(&buf[..pos]) { + Ok(line) => parse_proxy_v1_line(line), + Err(_) => ProxyProtocolResult::Error("PROXY v1 header contains invalid UTF-8".into()), + } +} + +/// Parse PROXY protocol version 2 (binary format) from a TCP stream +/// +/// This implementation: +/// - Uses stack-allocated buffer for the fixed 16-byte header +/// - Validates address length to prevent DoS attacks (max 512 bytes) +/// - Handles LOCAL command correctly (address data is ignored per spec) +async fn parse_proxy_protocol_v2(stream: &mut TcpStream) -> ProxyProtocolResult { + // Read the 16-byte header using stack buffer + let mut header = [0u8; PROXY_V2_HEADER_SIZE]; + if let Err(e) = stream.read_exact(&mut header).await { + return ProxyProtocolResult::Error(format!("Failed to read PROXY v2 header: {}", e)); + } + + // Parse and validate address length from header + let addr_len = u16::from_be_bytes([header[14], header[15]]) as usize; + + // DoS protection: reject excessively large address lengths + // IPv4 needs 12 bytes, IPv6 needs 36 bytes, Unix needs 216 bytes + // We allow up to 512 bytes for TLV extensions + if addr_len > PROXY_V2_MAX_ADDR_LEN { + return ProxyProtocolResult::Error(format!( + "PROXY v2 address length {} exceeds maximum {}", + addr_len, PROXY_V2_MAX_ADDR_LEN + )); + } + + // Read address data if present + let addr_data = if addr_len > 0 { + let mut buf = vec![0u8; addr_len]; + if let Err(e) = stream.read_exact(&mut buf).await { + return ProxyProtocolResult::Error(format!("Failed to read PROXY v2 address: {}", e)); + } + buf + } else { + Vec::new() + }; + + parse_proxy_v2_bytes(&header, &addr_data) +} + +/// Parse PROXY protocol v2 from header and address data +/// +/// This is a pure function that parses bytes without I/O. +/// Used by both the async parser and unit tests. +fn parse_proxy_v2_bytes( + header: &[u8; PROXY_V2_HEADER_SIZE], + addr_data: &[u8], +) -> ProxyProtocolResult { + // Verify signature + if header[..12] != PROXY_V2_SIGNATURE { + return ProxyProtocolResult::Error("Invalid PROXY v2 signature".into()); + } + + let ver_cmd = header[12]; + let version = (ver_cmd & 0xF0) >> 4; + let command = ver_cmd & 0x0F; + + if version != 2 { + return ProxyProtocolResult::Error(format!("Unsupported PROXY version: {}", version)); + } + + let fam_proto = header[13]; + let family = (fam_proto & 0xF0) >> 4; + + match command { + 0x00 => { + // LOCAL: connection was established by proxy itself (health check) + // Address data is ignored for LOCAL command per spec + debug!("PROXY v2 LOCAL command, using socket address"); + ProxyProtocolResult::NotProxyProtocol + } + 0x01 => { + // PROXY: connection on behalf of another node + parse_proxy_v2_address(family, addr_data) + } + _ => ProxyProtocolResult::Error(format!("Unsupported PROXY v2 command: {}", command)), + } +} + +/// Parse address from PROXY v2 PROXY command +fn parse_proxy_v2_address(family: u8, addr_data: &[u8]) -> ProxyProtocolResult { + match family { + 0x00 => { + // AF_UNSPEC: unknown/unsupported + debug!("PROXY v2 AF_UNSPEC, using socket address"); + ProxyProtocolResult::NotProxyProtocol + } + 0x01 => { + // AF_INET (IPv4): 4 + 4 + 2 + 2 = 12 bytes + if addr_data.len() < 12 { + return ProxyProtocolResult::Error("PROXY v2 IPv4 address data too short".into()); + } + let src_ip = Ipv4Addr::new(addr_data[0], addr_data[1], addr_data[2], addr_data[3]); + let src_port = u16::from_be_bytes([addr_data[8], addr_data[9]]); + let src_addr = SocketAddr::new(IpAddr::V4(src_ip), src_port); + debug!("PROXY v2 parsed: src={}", src_addr); + ProxyProtocolResult::Success(src_addr) + } + 0x02 => { + // AF_INET6 (IPv6): 16 + 16 + 2 + 2 = 36 bytes + if addr_data.len() < 36 { + return ProxyProtocolResult::Error("PROXY v2 IPv6 address data too short".into()); + } + let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_data[0..16]).unwrap()); + let src_port = u16::from_be_bytes([addr_data[32], addr_data[33]]); + let src_addr = SocketAddr::new(IpAddr::V6(src_ip), src_port); + debug!("PROXY v2 parsed: src={}", src_addr); + ProxyProtocolResult::Success(src_addr) + } + 0x03 => { + // AF_UNIX: 108 + 108 = 216 bytes, no IP address available + debug!("PROXY v2 AF_UNIX, cannot extract IP address, using socket address"); + ProxyProtocolResult::NotProxyProtocol + } + _ => { + debug!("PROXY v2 unknown address family: {:#x}", family); + ProxyProtocolResult::NotProxyProtocol + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a PROXY protocol v2 header and address data for testing + /// Returns (header, addr_data) tuple + fn build_proxy_v2_parts( + command: u8, + family: u8, + protocol: u8, + addr_data: &[u8], + ) -> ([u8; PROXY_V2_HEADER_SIZE], Vec) { + let mut header = [0u8; PROXY_V2_HEADER_SIZE]; + // Signature + header[..12].copy_from_slice(&PROXY_V2_SIGNATURE); + // Version (2) and command + header[12] = 0x20 | (command & 0x0F); + // Family and protocol + header[13] = (family << 4) | (protocol & 0x0F); + // Address length + let addr_len = addr_data.len() as u16; + header[14..16].copy_from_slice(&addr_len.to_be_bytes()); + + (header, addr_data.to_vec()) + } + + // =================== + // PROXY v1 tests + // =================== + + #[test] + fn test_proxy_v1_tcp4() { + let line = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "192.168.0.1".parse::().unwrap()); + assert_eq!(addr.port(), 56324); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_tcp6() { + let line = "PROXY TCP6 2001:db8::1 2001:db8::2 56324 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!(addr.port(), 56324); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_unknown() { + let line = "PROXY UNKNOWN\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_unknown_with_addresses() { + // HAProxy spec allows UNKNOWN with optional addresses + let line = "PROXY UNKNOWN ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_invalid_header() { + let line = "NOT_PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Error(_) => {} + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_missing_fields() { + let line = "PROXY TCP4 192.168.0.1\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("expected 6 parts")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_invalid_ip() { + let line = "PROXY TCP4 not.an.ip 192.168.0.11 56324 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("Invalid source IP")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_invalid_port() { + let line = "PROXY TCP4 192.168.0.1 192.168.0.11 notaport 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("Invalid source port")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v1_unsupported_protocol() { + let line = "PROXY UDP4 192.168.0.1 192.168.0.11 56324 443\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("Unsupported PROXY v1 protocol")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + // =================== + // PROXY v2 tests + // =================== + + #[test] + fn test_proxy_v2_signature() { + assert_eq!(PROXY_V2_SIGNATURE.len(), 12); + assert_eq!(PROXY_V2_SIGNATURE[4], 0x00); // Contains null byte + } + + #[test] + fn test_proxy_v2_tcp4() { + // Build address data: src_ip (4) + dst_ip (4) + src_port (2) + dst_port (2) = 12 bytes + let mut addr_data = Vec::new(); + addr_data.extend_from_slice(&[192, 168, 1, 100]); // src IP + addr_data.extend_from_slice(&[192, 168, 1, 1]); // dst IP + addr_data.extend_from_slice(&12345u16.to_be_bytes()); // src port + addr_data.extend_from_slice(&443u16.to_be_bytes()); // dst port + + // command=PROXY(0x01), family=AF_INET(0x01), protocol=STREAM(0x01) + let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "192.168.1.100".parse::().unwrap()); + assert_eq!(addr.port(), 12345); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_tcp6() { + // Build address data: src_ip (16) + dst_ip (16) + src_port (2) + dst_port (2) = 36 bytes + let mut addr_data = Vec::new(); + // src IP: 2001:db8::1 + addr_data.extend_from_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); + // dst IP: 2001:db8::2 + addr_data.extend_from_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]); + addr_data.extend_from_slice(&54321u16.to_be_bytes()); // src port + addr_data.extend_from_slice(&8080u16.to_be_bytes()); // dst port + + // command=PROXY(0x01), family=AF_INET6(0x02), protocol=STREAM(0x01) + let (header, addr) = build_proxy_v2_parts(0x01, 0x02, 0x01, &addr_data); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "2001:db8::1".parse::().unwrap()); + assert_eq!(addr.port(), 54321); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_local_command() { + // LOCAL command (0x00) - health check from proxy itself + let (header, addr) = build_proxy_v2_parts(0x00, 0x00, 0x00, &[]); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_af_unspec() { + // AF_UNSPEC (0x00) - unknown address family + let (header, addr) = build_proxy_v2_parts(0x01, 0x00, 0x00, &[]); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_invalid_signature() { + let mut header = [0u8; PROXY_V2_HEADER_SIZE]; + // Wrong signature + header[..12].copy_from_slice(b"WRONG_SIGNAT"); + + match parse_proxy_v2_bytes(&header, &[]) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("Invalid PROXY v2 signature")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_ipv4_addr_too_short() { + // Only 8 bytes of address data, but IPv4 needs 12 + let addr_data = vec![0u8; 8]; + let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("IPv4 address data too short")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_ipv6_addr_too_short() { + // Only 20 bytes of address data, but IPv6 needs 36 + let addr_data = vec![0u8; 20]; + let (header, addr) = build_proxy_v2_parts(0x01, 0x02, 0x01, &addr_data); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("IPv6 address data too short")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_unsupported_command() { + // Command 0x02 is not defined + let (header, addr) = build_proxy_v2_parts(0x02, 0x01, 0x01, &[0u8; 12]); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Error(msg) => { + assert!(msg.contains("Unsupported PROXY v2 command")); + } + other => panic!("Expected Error, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_af_unix() { + // AF_UNIX (0x03) - Unix socket addresses don't have IP:port + // We correctly consume the data but return NotProxyProtocol since we can't extract IP + let (header, addr) = build_proxy_v2_parts(0x01, 0x03, 0x00, &[0u8; 216]); // AF_UNIX needs 216 bytes + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_unknown_family() { + // Unknown address family (0x04 and above) + let (header, addr) = build_proxy_v2_parts(0x01, 0x04, 0x01, &[0u8; 12]); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::NotProxyProtocol => {} + other => panic!("Expected NotProxyProtocol, got {:?}", other), + } + } + + // =================== + // Real-world examples + // =================== + + #[test] + fn test_proxy_v1_haproxy_example() { + // Example from HAProxy documentation + let line = "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n"; + match parse_proxy_v1_line(line) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "255.255.255.255".parse::().unwrap()); + assert_eq!(addr.port(), 65535); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_with_tlv_extensions() { + // v2 can have TLV extensions after the address + // Our parser should ignore them (they're included in addr_len) + let mut addr_data = Vec::new(); + addr_data.extend_from_slice(&[10, 0, 0, 1]); // src IP: 10.0.0.1 + addr_data.extend_from_slice(&[10, 0, 0, 2]); // dst IP: 10.0.0.2 + addr_data.extend_from_slice(&8080u16.to_be_bytes()); // src port + addr_data.extend_from_slice(&80u16.to_be_bytes()); // dst port + // Add some TLV data (type=0x20 PP2_TYPE_UNIQUE_ID, length=4, value) + addr_data.extend_from_slice(&[0x20, 0x00, 0x04, 0x01, 0x02, 0x03, 0x04]); + + let (header, addr) = build_proxy_v2_parts(0x01, 0x01, 0x01, &addr_data); + + match parse_proxy_v2_bytes(&header, &addr) { + ProxyProtocolResult::Success(addr) => { + assert_eq!(addr.ip(), "10.0.0.1".parse::().unwrap()); + assert_eq!(addr.port(), 8080); + } + other => panic!("Expected Success, got {:?}", other), + } + } + + #[test] + fn test_proxy_v2_max_addr_len() { + // Test that we reject excessively large address lengths + let mut header = [0u8; PROXY_V2_HEADER_SIZE]; + header[..12].copy_from_slice(&PROXY_V2_SIGNATURE); + header[12] = 0x21; // Version 2, PROXY command + header[13] = 0x11; // AF_INET, STREAM + // Set addr_len to exceed max (e.g., 65535) + header[14..16].copy_from_slice(&65535u16.to_be_bytes()); + + // This test verifies the constant is defined correctly + assert!(PROXY_V2_MAX_ADDR_LEN < 65535); + assert_eq!(PROXY_V2_MAX_ADDR_LEN, 512); + } +} diff --git a/tentacle/src/transports/tcp.rs b/tentacle/src/transports/tcp.rs index 082b17b4..e2f99247 100644 --- a/tentacle/src/transports/tcp.rs +++ b/tentacle/src/transports/tcp.rs @@ -43,6 +43,8 @@ pub struct TcpTransport { global: Arc>>, #[cfg(feature = "tls")] tls_config: TlsConfig, + /// Trusted proxy addresses for HAProxy PROXY protocol and X-Forwarded-For header parsing. + trusted_proxies: Arc>, } impl TcpTransport { @@ -54,6 +56,7 @@ impl TcpTransport { global: Arc::new(crate::lock::Mutex::new(Default::default())), #[cfg(feature = "tls")] tls_config: Default::default(), + trusted_proxies: Arc::new(Vec::new()), } } @@ -74,6 +77,7 @@ impl TcpTransport { global: multi_transport.listens_upgrade_modes, #[cfg(feature = "tls")] tls_config: multi_transport.tls_config.unwrap_or_default(), + trusted_proxies: multi_transport.trusted_proxies, } } } @@ -100,6 +104,7 @@ impl TransportListen for TcpTransport { self.tls_config, self.global, self.timeout, + self.trusted_proxies, ); Ok(TransportFuture::new(Box::pin(task))) } @@ -112,6 +117,7 @@ impl TransportListen for TcpTransport { self.tls_config, self.global, self.timeout, + self.trusted_proxies, ); Ok(TransportFuture::new(Box::pin(task))) } diff --git a/tentacle/src/transports/tcp_base_listen.rs b/tentacle/src/transports/tcp_base_listen.rs index fb1ed643..47b5bb27 100644 --- a/tentacle/src/transports/tcp_base_listen.rs +++ b/tentacle/src/transports/tcp_base_listen.rs @@ -2,7 +2,7 @@ use std::{ collections::{HashMap, hash_map::Entry}, future::Future, io, - net::SocketAddr, + net::{IpAddr, SocketAddr}, pin::Pin, sync::{ Arc, @@ -36,6 +36,7 @@ use crate::{ multiaddr::Multiaddr, runtime::{TcpListener, TcpStream}, service::config::TcpSocketConfig, + transports::proxy_protocol::{ProxyProtocolResult, parse_proxy_protocol}, transports::{MultiStream, Result, TcpListenMode, TransportErrorKind, tcp_listen}, utils::{multiaddr_to_socketaddr, socketaddr_to_multiaddr}, }; @@ -53,6 +54,7 @@ pub async fn bind( #[cfg(feature = "tls")] config: TlsConfig, global: Arc>>, timeout: Duration, + trusted_proxies: Arc>, ) -> Result<(Multiaddr, TcpBaseListenerEnum)> { let addr = address.await?; let upgrade_mode: UpgradeMode = listen_mode.into(); @@ -132,8 +134,14 @@ pub async fn bind( Ok(( listen_addr, TcpBaseListenerEnum::New({ - let tcp_listen = - TcpBaseListener::new(timeout, tcp, local_addr, upgrade_mode, global); + let tcp_listen = TcpBaseListener::new( + timeout, + tcp, + local_addr, + upgrade_mode, + global, + trusted_proxies, + ); #[cfg(feature = "tls")] let tcp_listen = tcp_listen.tls_config(tls_server_config); tcp_listen @@ -237,6 +245,8 @@ pub struct TcpBaseListener { global: Arc>>, #[cfg(feature = "tls")] tls_config: Arc, + /// Trusted proxy addresses for HAProxy PROXY protocol and X-Forwarded-For header parsing. + trusted_proxies: Arc>, } impl Drop for TcpBaseListener { @@ -252,6 +262,7 @@ impl TcpBaseListener { local_addr: SocketAddr, upgrade_mode: UpgradeMode, global: Arc>>, + trusted_proxies: Arc>, ) -> Self { let (tx, rx) = mpsc::channel(128); @@ -269,6 +280,7 @@ impl TcpBaseListener { .with_no_client_auth() .with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new())), ), + trusted_proxies, } } @@ -296,6 +308,7 @@ impl TcpBaseListener { let timeout = self.timeout; let sender = self.sender.clone(); let upgrade_mode = self.upgrade_mode.to_enum(); + let trusted_proxies = Arc::clone(&self.trusted_proxies); #[cfg(feature = "tls")] let acceptor = TlsAcceptor::from(Arc::clone(&self.tls_config)); crate::runtime::spawn(protocol_select( @@ -304,6 +317,7 @@ impl TcpBaseListener { upgrade_mode, sender, remote_address, + trusted_proxies, #[cfg(feature = "tls")] acceptor, )); @@ -343,11 +357,12 @@ impl Stream for TcpBaseListener { } async fn protocol_select( - stream: TcpStream, + mut stream: TcpStream, timeout: Duration, #[allow(unused_mut)] mut upgrade_mode: UpgradeModeEnum, mut sender: Sender<(Multiaddr, MultiStream)>, - remote_address: SocketAddr, + #[allow(unused_mut)] mut remote_address: SocketAddr, + trusted_proxies: Arc>, #[cfg(feature = "tls")] acceptor: TlsAcceptor, ) { let mut peek_buf = [0u8; 16]; @@ -375,9 +390,25 @@ async fn protocol_select( } } + // Track whether PROXY protocol has been parsed (to avoid double parsing when + // TcpAndTls continues to OnlyTcp or OnlyTls) + #[allow(unused_mut)] + let mut proxy_parsed = false; + + #[allow(clippy::never_loop)] loop { match upgrade_mode { UpgradeModeEnum::OnlyTcp => { + // Check if connection is from trusted proxy and try to parse PROXY protocol + if !proxy_parsed + && trusted_proxies.contains(&remote_address.ip()) + && try_parse_proxy_protocol(&mut stream, timeout, &mut remote_address) + .await + .is_err() + { + return; + } + if sender .send(( socketaddr_to_multiaddr(remote_address), @@ -392,8 +423,16 @@ async fn protocol_select( } #[cfg(feature = "ws")] UpgradeModeEnum::OnlyWs => { + // Check if connection is from trusted proxy and try to extract X-Forwarded-For + if trusted_proxies.contains(&remote_address.ip()) { + remote_address = + extract_forwarded_for_from_ws_handshake(&stream, remote_address).await; + } + match crate::runtime::timeout(timeout, accept_async(stream)).await { - Err(_) => debug!("accept websocket stream timeout"), + Err(_) => { + debug!("accept websocket stream timeout"); + } Ok(res) => match res { Ok(stream) => { let mut addr = socketaddr_to_multiaddr(remote_address); @@ -415,6 +454,16 @@ async fn protocol_select( } #[cfg(feature = "tls")] UpgradeModeEnum::OnlyTls => { + // Check if connection is from trusted proxy and try to parse PROXY protocol + if !proxy_parsed + && trusted_proxies.contains(&remote_address.ip()) + && try_parse_proxy_protocol(&mut stream, timeout, &mut remote_address) + .await + .is_err() + { + return; + } + match crate::runtime::timeout(timeout, acceptor.accept(stream)).await { Err(_) => debug!("accept tls server stream timeout"), Ok(res) => match res { @@ -438,6 +487,65 @@ async fn protocol_select( } #[cfg(feature = "tls")] UpgradeModeEnum::TcpAndTls => { + // Parse PROXY protocol first if from trusted proxy, then re-peek for protocol detection + let current_peek_buf = if trusted_proxies.contains(&remote_address.ip()) { + match crate::runtime::timeout(timeout, parse_proxy_protocol(&mut stream)).await + { + Ok(ProxyProtocolResult::Success(addr)) => { + debug!( + "PROXY protocol parsed successfully: {} -> {}", + remote_address, addr + ); + proxy_parsed = true; + remote_address = addr; + // After parsing PROXY protocol, we need to peek fresh data + let mut new_peek_buf = [0u8; 16]; + let peek_now = std::time::Instant::now(); + loop { + match stream.peek(&mut new_peek_buf).await { + Ok(n) if n == 16 => break, + Ok(_) => { + if peek_now.elapsed() > timeout { + debug!( + "Failed to peek 16 bytes after PROXY protocol parsing" + ); + return; + } + continue; + } + Err(e) => { + debug!("stream encountered err after PROXY parsing: {}", e); + return; + } + } + } + new_peek_buf + } + Ok(ProxyProtocolResult::NotProxyProtocol) => { + debug!("Not a PROXY protocol connection from {}", remote_address); + proxy_parsed = true; + peek_buf + } + Ok(ProxyProtocolResult::Error(e)) => { + log::warn!( + "PROXY protocol parse error from trusted proxy {}: {}", + remote_address, + e + ); + return; + } + Err(_) => { + log::warn!( + "PROXY protocol parse timeout from trusted proxy {}", + remote_address + ); + return; + } + } + } else { + peek_buf + }; + // The first sixteen bytes of secio's Propose message's mode is fixed // it's bytes like follow: // @@ -448,14 +556,18 @@ async fn protocol_select( // LengthDelimitedCodec header is big-end total len // molecule propose header is little-end total len // rand start offset is 24 = (5(feild count) + 1(total len))* 4 - let length_delimited_header = - u32::from_be_bytes(TryInto::<[u8; 4]>::try_into(&peek_buf[..4]).unwrap()); - let molecule_header = - u32::from_le_bytes(TryInto::<[u8; 4]>::try_into(&peek_buf[4..8]).unwrap()); - let rand_start = - u32::from_le_bytes(TryInto::<[u8; 4]>::try_into(&peek_buf[8..12]).unwrap()); - let rand_end = - u32::from_le_bytes(TryInto::<[u8; 4]>::try_into(&peek_buf[12..16]).unwrap()); + let length_delimited_header = u32::from_be_bytes( + TryInto::<[u8; 4]>::try_into(¤t_peek_buf[..4]).unwrap(), + ); + let molecule_header = u32::from_le_bytes( + TryInto::<[u8; 4]>::try_into(¤t_peek_buf[4..8]).unwrap(), + ); + let rand_start = u32::from_le_bytes( + TryInto::<[u8; 4]>::try_into(¤t_peek_buf[8..12]).unwrap(), + ); + let rand_end = u32::from_le_bytes( + TryInto::<[u8; 4]>::try_into(¤t_peek_buf[12..16]).unwrap(), + ); // The first twelve bytes of yamux's message's mode is fixed // it's bytes like follow: @@ -469,12 +581,14 @@ async fn protocol_select( // open window message stream id = 0x1(client standard implementation, but does not check, but can't be zero), ping message steam id = 0x0 // header_len is not a fixed value. // It may be the ping_id or the window length value expressed in windowupdate. - let yamux_version = peek_buf[0]; - let yamux_ty = peek_buf[1]; - let yamux_flags = - u16::from_be_bytes(TryInto::<[u8; 2]>::try_into(&peek_buf[2..4]).unwrap()); - let yamux_stream_id = - u32::from_be_bytes(TryInto::<[u8; 4]>::try_into(&peek_buf[4..8]).unwrap()); + let yamux_version = current_peek_buf[0]; + let yamux_ty = current_peek_buf[1]; + let yamux_flags = u16::from_be_bytes( + TryInto::<[u8; 2]>::try_into(¤t_peek_buf[2..4]).unwrap(), + ); + let yamux_stream_id = u32::from_be_bytes( + TryInto::<[u8; 4]>::try_into(¤t_peek_buf[4..8]).unwrap(), + ); if (length_delimited_header == molecule_header && rand_start == 24 @@ -527,3 +641,116 @@ async fn protocol_select( } } } + +/// Try to parse PROXY protocol from stream. +/// Returns Ok is successful. +async fn try_parse_proxy_protocol( + stream: &mut TcpStream, + timeout: Duration, + remote_address: &mut SocketAddr, +) -> std::result::Result<(), ()> { + match crate::runtime::timeout(timeout, parse_proxy_protocol(stream)).await { + Ok(ProxyProtocolResult::Success(addr)) => { + debug!( + "PROXY protocol parsed successfully: {} -> {}", + remote_address, addr + ); + *remote_address = addr; + Ok(()) + } + Ok(ProxyProtocolResult::NotProxyProtocol) => { + debug!("Not a PROXY protocol connection from {}", remote_address); + Ok(()) + } + Ok(ProxyProtocolResult::Error(e)) => { + log::warn!( + "PROXY protocol parse error from trusted proxy {}: {}", + remote_address, + e + ); + Err(()) + } + Err(_) => { + log::warn!( + "PROXY protocol parse timeout from trusted proxy {}", + remote_address + ); + Err(()) + } + } +} + +/// Extract X-Forwarded-For from WebSocket HTTP upgrade request using peek +/// This function peeks the HTTP headers without consuming them, so the WebSocket +/// handshake can proceed normally afterwards. +/// +/// # Security Warning +/// +/// This function takes the FIRST IP from the X-Forwarded-For header chain. +/// In a multi-proxy setup (Client -> Proxy1 -> Proxy2 -> Server), a malicious +/// client could forge the first IP. For maximum security with multiple proxies, +/// you should know how many trusted proxies are in front of your server and +/// count from the right side of the chain. +/// +/// Current behavior is suitable for single-proxy setups where only one trusted +/// proxy directly connects to the server. +#[cfg(feature = "ws")] +async fn extract_forwarded_for_from_ws_handshake( + stream: &TcpStream, + fallback_address: SocketAddr, +) -> SocketAddr { + use std::net::IpAddr; + + // Peek enough bytes to read HTTP headers (4KB should be enough for most cases) + let mut peek_buf = [0u8; 4096]; + let n = match stream.peek(&mut peek_buf).await { + Ok(n) => n, + Err(_) => return fallback_address, + }; + + // Parse HTTP request headers + let mut headers = [httparse::EMPTY_HEADER; 32]; + let mut req = httparse::Request::new(&mut headers); + + if req.parse(&peek_buf[..n]).is_err() { + return fallback_address; + } + + // Look for X-Forwarded-For and X-Forwarded-Port headers + let mut forwarded_ip: Option = None; + let mut forwarded_port: Option = None; + + for header in req.headers.iter() { + if header.name.eq_ignore_ascii_case("x-forwarded-for") { + if let Ok(value_str) = std::str::from_utf8(header.value) { + // X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2" + // We want the first one (the original client) + if let Some(first_ip) = value_str.split(',').next() { + let ip_str = first_ip.trim(); + if let Ok(ip) = ip_str.parse::() { + forwarded_ip = Some(ip); + } + } + } + } else if header.name.eq_ignore_ascii_case("x-forwarded-port") { + if let Ok(value_str) = std::str::from_utf8(header.value) { + // X-Forwarded-Port can also contain multiple ports, take the first one + if let Some(first_port) = value_str.split(',').next() { + if let Ok(port) = first_port.trim().parse::() { + forwarded_port = Some(port); + } + } + } + } + } + + match forwarded_ip { + Some(ip) => { + // Use X-Forwarded-Port if available, otherwise fallback to the original port + let port = forwarded_port.unwrap_or(fallback_address.port()); + debug!("X-Forwarded-For header found: {}:{}", ip, port); + SocketAddr::new(ip, port) + } + None => fallback_address, + } +} diff --git a/tentacle/tests/test_proxy_protocol.rs b/tentacle/tests/test_proxy_protocol.rs new file mode 100644 index 00000000..1f0f8a34 --- /dev/null +++ b/tentacle/tests/test_proxy_protocol.rs @@ -0,0 +1,1056 @@ +//! Integration tests for HAProxy PROXY protocol and X-Forwarded-For header support +//! +//! These tests verify that when connections come from loopback addresses, +//! the real client IP is correctly extracted from: +//! - PROXY protocol v1/v2 headers for TCP connections +//! - X-Forwarded-For headers for WebSocket connections + +use std::{ + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, + thread, + time::Duration, +}; + +use futures::channel; +use tentacle::{ + ProtocolId, async_trait, + builder::{MetaBuilder, ServiceBuilder}, + context::{ProtocolContext, ProtocolContextMutRef}, + multiaddr::Multiaddr, + secio::SecioKeyPair, + service::{ProtocolHandle, ProtocolMeta, Service, ServiceEvent}, + traits::{ServiceHandle, ServiceProtocol}, +}; +#[cfg(feature = "ws")] +use tokio::io::AsyncReadExt; +use tokio::{io::AsyncWriteExt, net::TcpStream}; + +/// Build PROXY protocol v1 header +fn build_proxy_v1_header(src_ip: &str, dst_ip: &str, src_port: u16, dst_port: u16) -> String { + let protocol = if src_ip.contains(':') { "TCP6" } else { "TCP4" }; + format!( + "PROXY {} {} {} {} {}\r\n", + protocol, src_ip, dst_ip, src_port, dst_port + ) +} + +/// Build PROXY protocol v2 header for IPv4 +fn build_proxy_v2_header_ipv4( + src_ip: [u8; 4], + dst_ip: [u8; 4], + src_port: u16, + dst_port: u16, +) -> Vec { + let mut header = Vec::new(); + + // Signature (12 bytes) + header.extend_from_slice(&[ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, + ]); + + // Version (2) and command (PROXY = 1) + header.push(0x21); + + // Family (AF_INET = 1) and protocol (STREAM = 1) + header.push(0x11); + + // Address length: 4 + 4 + 2 + 2 = 12 bytes + header.extend_from_slice(&12u16.to_be_bytes()); + + // Source IP + header.extend_from_slice(&src_ip); + // Destination IP + header.extend_from_slice(&dst_ip); + // Source port + header.extend_from_slice(&src_port.to_be_bytes()); + // Destination port + header.extend_from_slice(&dst_port.to_be_bytes()); + + header +} + +/// Build PROXY protocol v2 header for IPv6 +fn build_proxy_v2_header_ipv6( + src_ip: [u8; 16], + dst_ip: [u8; 16], + src_port: u16, + dst_port: u16, +) -> Vec { + let mut header = Vec::new(); + + // Signature (12 bytes) + header.extend_from_slice(&[ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, + ]); + + // Version (2) and command (PROXY = 1) + header.push(0x21); + + // Family (AF_INET6 = 2) and protocol (STREAM = 1) + header.push(0x21); + + // Address length: 16 + 16 + 2 + 2 = 36 bytes + header.extend_from_slice(&36u16.to_be_bytes()); + + // Source IP + header.extend_from_slice(&src_ip); + // Destination IP + header.extend_from_slice(&dst_ip); + // Source port + header.extend_from_slice(&src_port.to_be_bytes()); + // Destination port + header.extend_from_slice(&dst_port.to_be_bytes()); + + header +} + +/// Collected session addresses from the server +#[derive(Clone, Default)] +struct CollectedAddresses { + inner: Arc>>, +} + +impl CollectedAddresses { + fn push(&self, addr: Multiaddr) { + self.inner.lock().unwrap().push(addr); + } + + fn get_all(&self) -> Vec { + self.inner.lock().unwrap().clone() + } +} + +/// Service handle that collects session addresses +struct AddressCollectorHandle { + collected: CollectedAddresses, + sender: crossbeam_channel::Sender<()>, +} + +#[async_trait] +impl ServiceHandle for AddressCollectorHandle { + async fn handle_event( + &mut self, + _context: &mut tentacle::context::ServiceContext, + event: ServiceEvent, + ) { + if let ServiceEvent::SessionOpen { session_context } = event { + self.collected.push(session_context.address.clone()); + self.sender.try_send(()).unwrap(); + } + } +} + +/// Protocol handle for testing +struct TestProtocol; + +#[async_trait] +impl ServiceProtocol for TestProtocol { + async fn init(&mut self, _context: &mut ProtocolContext) {} + async fn connected(&mut self, _context: ProtocolContextMutRef<'_>, _version: &str) {} + async fn disconnected(&mut self, _context: ProtocolContextMutRef<'_>) {} +} + +fn create_meta(id: ProtocolId) -> ProtocolMeta { + MetaBuilder::new() + .id(id) + .service_handle(move || { + let handle = Box::new(TestProtocol); + ProtocolHandle::Callback(handle) + }) + .build() +} + +fn create_service( + collected: CollectedAddresses, + sender: crossbeam_channel::Sender<()>, +) -> Service { + let meta = create_meta(1.into()); + ServiceBuilder::default() + .insert_protocol(meta) + .forever(false) + .build(AddressCollectorHandle { collected, sender }) +} + +/// Extract IP from multiaddr (e.g., "/ip4/192.168.1.100/tcp/12345" -> "192.168.1.100") +fn extract_ip_from_multiaddr(addr: &Multiaddr) -> Option { + use tentacle::multiaddr::Protocol; + + for proto in addr.iter() { + match proto { + Protocol::Ip4(ip) => return Some(IpAddr::V4(ip)), + Protocol::Ip6(ip) => return Some(IpAddr::V6(ip)), + _ => continue, + } + } + None +} + +/// Test PROXY protocol v1 with IPv4 +#[test] +fn test_proxy_protocol_v1_ipv4() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send PROXY protocol v1 header + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send PROXY protocol v1 header with fake source IP + let proxy_header = build_proxy_v1_header("203.0.113.50", "192.168.1.1", 54321, 80); + stream.write_all(proxy_header.as_bytes()).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(200)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The first address should have the PROXY protocol source IP + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "203.0.113.50", + "Should use the IP from PROXY protocol header" + ); +} + +/// Test PROXY protocol v2 with IPv4 +#[test] +fn test_proxy_protocol_v2_ipv4() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send PROXY protocol v2 header + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send PROXY protocol v2 header with fake source IP 10.20.30.40 + let proxy_header = build_proxy_v2_header_ipv4( + [10, 20, 30, 40], // Source IP + [192, 168, 1, 1], // Destination IP + 12345, // Source port + 80, // Destination port + ); + stream.write_all(&proxy_header).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(200)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The first address should have the PROXY protocol source IP + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "10.20.30.40", + "Should use the IP from PROXY protocol v2 header" + ); +} + +/// Test PROXY protocol v1 with IPv6 +#[test] +fn test_proxy_protocol_v1_ipv6() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on IPv6 loopback + let listen_addr = service + .listen("/ip6/::1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip6(i) => ip = Some(IpAddr::V6(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send PROXY protocol v1 header with IPv6 + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send PROXY protocol v1 header with IPv6 source + let proxy_header = build_proxy_v1_header("2001:db8::1", "2001:db8::2", 54321, 80); + stream.write_all(proxy_header.as_bytes()).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(200)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The first address should have the PROXY protocol source IP + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "2001:db8::1", + "Should use the IPv6 from PROXY protocol header" + ); +} + +/// Test PROXY protocol v2 with IPv6 +#[test] +fn test_proxy_protocol_v2_ipv6() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on IPv6 loopback + let listen_addr = service + .listen("/ip6/::1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip6(i) => ip = Some(IpAddr::V6(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send PROXY protocol v2 header with IPv6 + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // 2001:db8:85a3::8a2e:370:7334 + let src_ip: [u8; 16] = [ + 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, + 0x73, 0x34, + ]; + // 2001:db8::1 + let dst_ip: [u8; 16] = [ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]; + + let proxy_header = build_proxy_v2_header_ipv6(src_ip, dst_ip, 12345, 80); + stream.write_all(&proxy_header).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(200)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The first address should have the PROXY protocol source IP + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "2001:db8:85a3::8a2e:370:7334", + "Should use the IPv6 from PROXY protocol v2 header" + ); +} + +/// Test that non-PROXY protocol connections still work (fallback to socket address) +#[test] +fn test_normal_connection_without_proxy_protocol() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect without PROXY protocol - send at least 16 bytes of non-PROXY data + // The server requires at least 16 bytes before processing the connection + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send 16+ bytes of non-PROXY data + // This data does NOT start with "PROXY" or the v2 signature + // Simulates a normal protocol message (e.g., yamux/secio handshake) + let non_proxy_data = [ + 0x00, 0x01, 0x00, 0x01, // 4 bytes + 0x00, 0x00, 0x00, 0x01, // 4 bytes + 0x00, 0x00, 0x00, 0x01, // 4 bytes + 0x00, 0x00, 0x00, 0x01, // 4 bytes + 0x00, 0x00, 0x00, 0x01, // 4 bytes extra for safety + ]; + stream.write_all(&non_proxy_data).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should be the local loopback since no PROXY protocol was used + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "127.0.0.1", + "Should use the socket address when no PROXY protocol is present" + ); +} + +/// Build a WebSocket upgrade request with X-Forwarded-For header +#[cfg(feature = "ws")] +fn build_ws_upgrade_request_with_forwarded_for(host: &str, forwarded_ip: &str) -> String { + // Use a fixed WebSocket key for testing (this is valid base64) + let ws_key = "dGhlIHNhbXBsZSBub25jZQ=="; + + format!( + "GET / HTTP/1.1\r\n\ + Host: {}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: {}\r\n\ + Sec-WebSocket-Version: 13\r\n\ + X-Forwarded-For: {}\r\n\ + \r\n", + host, ws_key, forwarded_ip + ) +} + +/// Build a WebSocket upgrade request with X-Forwarded-For and X-Forwarded-Port headers +#[cfg(feature = "ws")] +fn build_ws_upgrade_request_with_forwarded_for_and_port( + host: &str, + forwarded_ip: &str, + forwarded_port: u16, +) -> String { + let ws_key = "dGhlIHNhbXBsZSBub25jZQ=="; + + format!( + "GET / HTTP/1.1\r\n\ + Host: {}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: {}\r\n\ + Sec-WebSocket-Version: 13\r\n\ + X-Forwarded-For: {}\r\n\ + X-Forwarded-Port: {}\r\n\ + \r\n", + host, ws_key, forwarded_ip, forwarded_port + ) +} + +/// Test WebSocket connection with X-Forwarded-For header +#[cfg(feature = "ws")] +#[test] +fn test_ws_x_forwarded_for() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on WebSocket address + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send WebSocket upgrade request with X-Forwarded-For + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send WebSocket upgrade request with X-Forwarded-For header + let ws_request = build_ws_upgrade_request_with_forwarded_for( + &format!("127.0.0.1:{}", socket_addr.port()), + "198.51.100.178", + ); + stream.write_all(ws_request.as_bytes()).await.unwrap(); + + // Read the response (we need to complete the handshake) + let mut response = vec![0u8; 1024]; + stream.read_buf(&mut response).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should have the X-Forwarded-For IP + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "198.51.100.178", + "Should use the IP from X-Forwarded-For header" + ); +} + +/// Test WebSocket connection without X-Forwarded-For header (fallback) +#[cfg(feature = "ws")] +#[test] +fn test_ws_without_x_forwarded_for() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on WebSocket address + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect with a WebSocket client without X-Forwarded-For + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + use tokio_tungstenite::connect_async; + + let ws_url = format!("ws://127.0.0.1:{}/", socket_addr.port()); + connect_async(&ws_url).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should be the local loopback since no X-Forwarded-For was sent + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "127.0.0.1", + "Should use the socket address when no X-Forwarded-For is present" + ); +} + +/// Test WebSocket connection with multiple IPs in X-Forwarded-For (should use first) +#[cfg(feature = "ws")] +#[test] +fn test_ws_x_forwarded_for_multiple_ips() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on WebSocket address + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send WebSocket upgrade request with multiple IPs in X-Forwarded-For + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // X-Forwarded-For with multiple IPs: client, proxy1, proxy2 + // Should use the first one (the original client) + let ws_request = build_ws_upgrade_request_with_forwarded_for( + &format!("127.0.0.1:{}", socket_addr.port()), + "203.0.113.195, 70.41.3.18, 150.172.238.178", + ); + stream.write_all(ws_request.as_bytes()).await.unwrap(); + + // Read the response + let mut response = vec![0u8; 1024]; + stream.read_buf(&mut response).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should have the first IP from X-Forwarded-For chain + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "203.0.113.195", + "Should use the first IP from X-Forwarded-For header chain" + ); +} + +/// Test WebSocket connection with X-Forwarded-For header containing IPv6 +#[cfg(feature = "ws")] +#[test] +fn test_ws_x_forwarded_for_ipv6() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + // Listen on WebSocket address (IPv4 loopback for simplicity) + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send WebSocket upgrade request with IPv6 in X-Forwarded-For + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send WebSocket upgrade request with IPv6 X-Forwarded-For header + let ws_request = build_ws_upgrade_request_with_forwarded_for( + &format!("127.0.0.1:{}", socket_addr.port()), + "2001:db8:cafe::17", + ); + stream.write_all(ws_request.as_bytes()).await.unwrap(); + + // Read the response (we need to complete the handshake) + let mut response = vec![0u8; 1024]; + stream.read_buf(&mut response).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should have the IPv6 from X-Forwarded-For + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert_eq!( + ip.unwrap().to_string(), + "2001:db8:cafe::17", + "Should use the IPv6 from X-Forwarded-For header" + ); +} + +/// Extract port from multiaddr +#[cfg(feature = "ws")] +fn extract_port_from_multiaddr(addr: &Multiaddr) -> Option { + use tentacle::multiaddr::Protocol; + + for proto in addr.iter() { + if let Protocol::Tcp(port) = proto { + return Some(port); + } + } + None +} + +/// Test WebSocket connection with X-Forwarded-For and X-Forwarded-Port headers +#[cfg(feature = "ws")] +#[test] +fn test_ws_x_forwarded_for_with_port() { + let collected = CollectedAddresses::default(); + let (sender, receiver) = crossbeam_channel::bounded(1); + let (addr_sender, addr_receiver) = channel::oneshot::channel::(); + + let collected_clone = collected.clone(); + thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut service = create_service(collected_clone, sender); + rt.block_on(async move { + let listen_addr = service + .listen("/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()) + .await + .unwrap(); + addr_sender.send(listen_addr).unwrap(); + service.run().await + }); + }); + + // Wait for server to start and get listen address + let listen_addr = futures::executor::block_on(addr_receiver).unwrap(); + let socket_addr: SocketAddr = { + use tentacle::multiaddr::Protocol; + let mut ip = None; + let mut port = None; + for proto in listen_addr.iter() { + match proto { + Protocol::Ip4(i) => ip = Some(IpAddr::V4(i)), + Protocol::Tcp(p) => port = Some(p), + _ => {} + } + } + SocketAddr::new(ip.unwrap(), port.unwrap()) + }; + + // Connect and send WebSocket upgrade request with X-Forwarded-For and X-Forwarded-Port + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let mut stream = TcpStream::connect(socket_addr).await.unwrap(); + + // Send WebSocket upgrade request with both headers + let ws_request = build_ws_upgrade_request_with_forwarded_for_and_port( + &format!("127.0.0.1:{}", socket_addr.port()), + "198.51.100.50", + 54321, + ); + stream.write_all(ws_request.as_bytes()).await.unwrap(); + + // Read the response + let mut response = vec![0u8; 1024]; + stream.read_buf(&mut response).await.unwrap(); + + // Keep connection open briefly + tokio::time::sleep(Duration::from_millis(500)).await; + }); + + // Wait for session to be established + receiver.recv_timeout(Duration::from_secs(5)).unwrap(); + + // Give server a moment to process + thread::sleep(Duration::from_millis(100)); + + // Check collected addresses + let addresses = collected.get_all(); + assert!( + !addresses.is_empty(), + "Should have collected at least one address" + ); + + // The address should have both the IP and port from X-Forwarded headers + let first_addr = &addresses[0]; + let ip = extract_ip_from_multiaddr(first_addr); + let port = extract_port_from_multiaddr(first_addr); + assert!(ip.is_some(), "Should be able to extract IP from address"); + assert!( + port.is_some(), + "Should be able to extract port from address" + ); + assert_eq!( + ip.unwrap().to_string(), + "198.51.100.50", + "Should use the IP from X-Forwarded-For header" + ); + assert_eq!( + port.unwrap(), + 54321, + "Should use the port from X-Forwarded-Port header" + ); +}