Skip to content

Commit d9fbd23

Browse files
author
share121
committed
refactor: 重构代码
BREAKING CHANGE: 暂时移除 udp 支持
1 parent dbd0674 commit d9fbd23

File tree

8 files changed

+319
-273
lines changed

8 files changed

+319
-273
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = ["share121 <[email protected]>"]
77
repository = "https://github.com/share121/port-mapping"
88
readme = "README.md"
99
exclude = ["/.github"]
10-
description = "简单的映射端口程序,有基础的负载均衡功能"
10+
description = "简单的映射端口程序"
1111
documentation = "https://docs.rs/port-mapping"
1212
homepage = "https://github.com/share121/port-mapping"
1313
keywords = ["port", "mapping", "tokio", "concurrency", "performance"]

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[![Latest version](https://img.shields.io/crates/v/port-mapping.svg)](https://crates.io/crates/port-mapping)
66
![License](https://img.shields.io/crates/l/port-mapping.svg)
77

8-
简单的映射端口程序,有基础的负载均衡功能
8+
简单的映射端口程序
99

1010
> **注意:** 只有 TCP 端口映射经过测试
1111
@@ -14,9 +14,15 @@
1414
修改 mapping.txt 文件,格式如下:
1515

1616
```
17-
:80 :8080 tcp # 把 0.0.0.0:80 端口映射到 localhost:8080 端口,协议为 tcp
18-
:80 :8081 tcp # 把 0.0.0.0:80 端口映射到 localhost:8081 端口,协议为 tcp,负载均衡
19-
127.0.0.1:443 100.88.11.5:8080 tcp # 把 127.0.0.1:443 端口映射到 100.88.11.5:8080 端口,协议为 tcp
17+
# t+u 表示同时使用 tcp 和 udp 协议
18+
# 把本地端口 40000-49999 映射到服务器 100.123.151.117 的端口 0000-9999 上
19+
t+u 40000-49999 100.123.151.117:0000-9999
20+
21+
# 使用 tcp 协议,把本地端口 5666 映射到 localhost 的端口 80 上
22+
tcp 5666 :80
23+
24+
# 使用 udp 协议,把本地端口 5666 映射到 localhost 的端口 80 上
25+
udp 5666 :80
2026
```
2127

2228
然后运行 port-mapping 即可

mapping.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
:5561 :3389 tcp
1+
# t+u 表示同时使用 tcp 和 udp 协议
2+
# 把本地端口 40000-49999 映射到服务器 100.123.151.117 的端口 0000-9999 上
3+
t+u 40000-49999 100.123.151.117:0000-9999
4+
tcp 5666 :80

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pub mod mapping_rule;
2+
pub mod tcp_proxy;
3+
pub mod udp_proxy;

src/main.rs

Lines changed: 12 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -1,285 +1,30 @@
1-
use std::{
2-
collections::HashMap,
3-
hash::{Hash, Hasher},
4-
net::SocketAddr,
5-
sync::{
6-
Arc,
7-
atomic::{AtomicUsize, Ordering},
8-
},
9-
time::{Duration, Instant},
1+
use port_mapping::{
2+
mapping_rule::{Protocol, read_mapping_file},
3+
tcp_proxy::TcpProxy,
104
};
11-
use tokio::{
12-
fs::File,
13-
io::{AsyncBufReadExt, BufReader},
14-
net::{TcpListener, TcpStream, UdpSocket},
15-
sync::Mutex,
16-
};
17-
18-
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
19-
enum Protocol {
20-
Tcp,
21-
Udp,
22-
}
23-
24-
#[derive(Debug)]
25-
struct MappingRuleEntry {
26-
listen: String,
27-
upstream: String,
28-
protocol: Protocol,
29-
}
30-
31-
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
32-
enum MappingRuleParseError {
33-
Empty,
34-
InvalidFormat(String),
35-
InvalidProtocol(String, String),
36-
}
37-
38-
impl MappingRuleEntry {
39-
fn parse(line: &str) -> Result<Self, MappingRuleParseError> {
40-
let parts = line.split('#').next().ok_or(MappingRuleParseError::Empty)?;
41-
let parts: Vec<&str> = parts.split_whitespace().collect();
42-
if parts.len() != 3 {
43-
return Err(MappingRuleParseError::InvalidFormat(line.to_string()));
44-
}
45-
let listen = if parts[0].starts_with(':') {
46-
format!("0.0.0.0{}", parts[0])
47-
} else {
48-
parts[0].to_string()
49-
};
50-
let upstream = if parts[1].starts_with(':') {
51-
format!("localhost{}", parts[1])
52-
} else {
53-
parts[1].to_string()
54-
};
55-
let protocol = match parts[2].to_lowercase().as_str() {
56-
"udp" => Protocol::Udp,
57-
"tcp" => Protocol::Tcp,
58-
_ => {
59-
return Err(MappingRuleParseError::InvalidProtocol(
60-
line.to_string(),
61-
parts[2].to_string(),
62-
));
63-
}
64-
};
65-
Ok(MappingRuleEntry {
66-
listen,
67-
upstream,
68-
protocol,
69-
})
70-
}
71-
}
5+
use std::sync::Arc;
6+
use tokio::{fs::File, io::BufReader};
727

73-
#[derive(Debug)]
74-
struct MappingRule {
75-
listen: String,
76-
upstreams: Vec<String>,
77-
protocol: Protocol,
78-
}
79-
80-
async fn read_mapping_file() -> Result<Vec<MappingRule>, std::io::Error> {
8+
#[tokio::main]
9+
async fn main() -> Result<(), std::io::Error> {
8110
let exe_path = std::env::current_exe()?;
8211
let dir = exe_path.parent().unwrap();
8312
let mapping_path = dir.join("mapping.txt");
8413
let file = File::open(&mapping_path).await?;
85-
let mut reader = BufReader::new(file);
86-
let mut rules: HashMap<(String, Protocol), Vec<String>> = HashMap::new();
87-
let mut line = String::new();
88-
while reader.read_line(&mut line).await? != 0 {
89-
match MappingRuleEntry::parse(&line) {
90-
Ok(entry) => {
91-
rules
92-
.entry((entry.listen, entry.protocol))
93-
.or_default()
94-
.push(entry.upstream);
95-
}
96-
Err(e) => match e {
97-
MappingRuleParseError::Empty => (),
98-
MappingRuleParseError::InvalidFormat(input) => {
99-
eprintln!("Invalid format: \"{}\"", input.trim())
100-
}
101-
MappingRuleParseError::InvalidProtocol(input, protocol) => {
102-
eprintln!("Invalid protocol: {protocol} in \"{}\"", input.trim())
103-
}
104-
},
105-
}
106-
line.clear();
107-
}
108-
Ok(rules
109-
.into_iter()
110-
.map(|((listen, protocol), upstreams)| MappingRule {
111-
listen,
112-
upstreams,
113-
protocol,
114-
})
115-
.collect())
116-
}
117-
118-
async fn handle_tcp_connection(
119-
upstream_addr: &str,
120-
mut downstream: TcpStream,
121-
) -> Result<(), std::io::Error> {
122-
let mut upstream = TcpStream::connect(upstream_addr).await?;
123-
tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await?;
124-
Ok(())
125-
}
126-
127-
async fn run_tcp_proxy(listen_addr: &str, upstreams: Vec<String>) -> Result<(), std::io::Error> {
128-
let listener = TcpListener::bind(listen_addr).await?;
129-
println!("TCP proxy listening on {} -> {:?}", listen_addr, upstreams);
130-
let current = Arc::new(AtomicUsize::new(0));
131-
let upstreams = Arc::new(upstreams);
132-
loop {
133-
let (downstream, _) = listener.accept().await?;
134-
let current = current.clone();
135-
let upstreams = upstreams.clone();
136-
tokio::spawn(async move {
137-
let idx = current.fetch_add(1, Ordering::Relaxed) % upstreams.len();
138-
let upstream_addr = upstreams[idx].clone();
139-
if let Err(e) = handle_tcp_connection(&upstream_addr, downstream).await {
140-
eprintln!("TCP proxy error: {}", e);
141-
}
142-
});
143-
}
144-
}
145-
146-
#[derive(Debug, Clone)]
147-
struct UdpProxyState {
148-
client_map: Arc<Mutex<HashMap<SocketAddr, (Arc<UdpSocket>, SocketAddr, Instant)>>>,
149-
upstreams: Vec<String>,
150-
}
151-
152-
impl UdpProxyState {
153-
fn new(upstreams: Vec<String>) -> Self {
154-
UdpProxyState {
155-
client_map: Arc::new(Mutex::new(HashMap::new())),
156-
upstreams,
157-
}
158-
}
159-
160-
async fn get_upstream_socket(
161-
&self,
162-
client_addr: SocketAddr,
163-
) -> Result<(Arc<UdpSocket>, SocketAddr), std::io::Error> {
164-
let mut map = self.client_map.lock().await;
165-
166-
// 清理过期连接
167-
let now = Instant::now();
168-
map.retain(|_, (_, _, last_used)| now.duration_since(*last_used) < Duration::from_secs(30));
169-
170-
// 查找或创建socket
171-
if let Some((sock, upstream_addr, _)) = map.get(&client_addr) {
172-
return Ok((sock.clone(), *upstream_addr));
173-
}
174-
175-
// 选择上游服务器
176-
let mut hasher = std::collections::hash_map::DefaultHasher::new();
177-
client_addr.ip().hash(&mut hasher);
178-
let idx = hasher.finish() as usize % self.upstreams.len();
179-
let upstream_addr = self.upstreams[idx].parse().map_err(|e| {
180-
std::io::Error::new(
181-
std::io::ErrorKind::InvalidInput,
182-
format!("Invalid upstream address: {}", e),
183-
)
184-
})?;
185-
186-
// 创建新socket
187-
let sock = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
188-
sock.connect(upstream_addr).await?;
189-
190-
// 存储并返回
191-
let entry = (sock.clone(), upstream_addr, Instant::now());
192-
map.insert(client_addr, entry);
193-
Ok((sock, upstream_addr))
194-
}
195-
}
196-
197-
async fn run_udp_proxy(listen_addr: &str, upstreams: Vec<String>) -> Result<(), std::io::Error> {
198-
let socket = Arc::new(UdpSocket::bind(listen_addr).await?);
199-
println!("UDP proxy listening on {} -> {:?}", listen_addr, upstreams);
200-
201-
let state = Arc::new(UdpProxyState::new(upstreams));
202-
let mut buf = [0u8; 65536];
203-
204-
loop {
205-
let (len, client_addr) = socket.recv_from(&mut buf).await?;
206-
let data = buf[..len].to_vec();
207-
let socket_clone = socket.clone();
208-
let state_clone = state.clone();
209-
210-
tokio::spawn(async move {
211-
// 获取或创建专用socket
212-
let (upstream_sock, upstream_addr) =
213-
match state_clone.get_upstream_socket(client_addr).await {
214-
Ok(v) => v,
215-
Err(e) => {
216-
eprintln!("Failed to get upstream socket: {}", e);
217-
return;
218-
}
219-
};
220-
221-
// 发送到上游
222-
if let Err(e) = upstream_sock.send(&data).await {
223-
eprintln!("Send to {} failed: {}", upstream_addr, e);
224-
return;
225-
}
226-
227-
// 接收响应(带超时和多包支持)
228-
let mut total_responses = 0;
229-
let start_time = Instant::now();
230-
while start_time.elapsed() < Duration::from_secs(5) {
231-
let mut resp_buf = [0u8; 65536];
232-
let timeout = tokio::time::sleep(Duration::from_millis(500));
233-
tokio::pin!(timeout);
234-
235-
tokio::select! {
236-
result = upstream_sock.recv(&mut resp_buf) => {
237-
match result {
238-
Ok(len) => {
239-
total_responses += 1;
240-
if let Err(e) = socket_clone.send_to(&resp_buf[..len], client_addr).await {
241-
eprintln!("Send to client {} failed: {}", client_addr, e);
242-
}
243-
}
244-
Err(e) => {
245-
eprintln!("Receive from {} failed: {}", upstream_addr, e);
246-
break;
247-
}
248-
}
249-
}
250-
_ = &mut timeout => {
251-
// 超时后检查是否需要继续等待
252-
if total_responses > 0 {
253-
// 至少收到一个响应,认为完成
254-
break;
255-
}
256-
}
257-
}
258-
}
259-
});
260-
}
261-
}
262-
263-
#[tokio::main]
264-
async fn main() -> Result<(), std::io::Error> {
265-
let rules = read_mapping_file().await?;
14+
let reader = BufReader::new(file);
15+
let rules = read_mapping_file(reader).await?;
26616
let mut handles = vec![];
26717
for rule in rules {
26818
match rule.protocol {
26919
Protocol::Tcp => {
27020
handles.push(tokio::spawn(async move {
271-
if let Err(e) = run_tcp_proxy(&rule.listen, rule.upstreams).await {
21+
let proxy = Arc::new(TcpProxy::new(rule.listen.clone(), rule.upstream.clone()));
22+
if let Err(e) = proxy.run().await {
27223
eprintln!("TCP proxy failed: {}", e);
27324
}
27425
}));
27526
}
276-
Protocol::Udp => {
277-
handles.push(tokio::spawn(async move {
278-
if let Err(e) = run_udp_proxy(&rule.listen, rule.upstreams).await {
279-
eprintln!("UDP proxy failed: {}", e);
280-
}
281-
}));
282-
}
27+
Protocol::Udp => {}
28328
}
28429
}
28530
for handle in handles {

0 commit comments

Comments
 (0)