Skip to content

Commit

Permalink
Add cli option to limit number of connections per ip
Browse files Browse the repository at this point in the history
  • Loading branch information
fabi321 committed May 30, 2024
1 parent d6021a6 commit 7b31bb2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
4 changes: 4 additions & 0 deletions breakwater/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ pub struct CliArgs {
#[cfg(feature = "vnc")]
#[clap(short, long, default_value_t = 5900)]
pub vnc_port: u16,

/// Allow only a certain number of connections per ip address
#[clap(short, long)]
pub connections_per_ip: Option<u64>,
}
3 changes: 2 additions & 1 deletion breakwater/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async fn main() -> Result<(), Error> {
statistics_save_mode,
);

let server = Server::new(
let mut server = Server::new(
&args.listen_address,
Arc::clone(&fb),
statistics_tx.clone(),
Expand All @@ -124,6 +124,7 @@ async fn main() -> Result<(), Error> {
.context(InvalidNetworkBufferSizeSnafu {
network_buffer_size: args.network_buffer_size,
})?,
args.connections_per_ip,
)
.await
.context(StartPixelflutServerSnafu)?;
Expand Down
44 changes: 42 additions & 2 deletions breakwater/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{cmp::min, net::IpAddr, sync::Arc, time::Duration};
use std::collections::hash_map::Entry;
use std::collections::HashMap;

use breakwater_core::framebuffer::FrameBuffer;
use breakwater_parser::{original::OriginalParser, Parser, ParserError};
Expand Down Expand Up @@ -42,6 +44,8 @@ pub struct Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connections_per_ip: HashMap<IpAddr, u64>,
max_connections_per_ip: Option<u64>,
}

impl Server {
Expand All @@ -50,6 +54,7 @@ impl Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
max_connections_per_ip: Option<u64>,
) -> Result<Self, Error> {
let listener = TcpListener::bind(listen_address)
.await
Expand All @@ -61,30 +66,59 @@ impl Server {
fb,
statistics_tx,
network_buffer_size,
connections_per_ip: HashMap::new(),
max_connections_per_ip,
})
}

pub async fn start(&self) -> Result<(), Error> {
pub async fn start(&mut self) -> Result<(), Error> {
let (connection_dropped_tx, mut connection_dropped_rx) = mpsc::unbounded_channel::<IpAddr>();
let connection_dropped_tx = self.max_connections_per_ip.map(|_| connection_dropped_tx);
loop {
let (socket, socket_addr) = self
let (mut socket, socket_addr) = self
.listener
.accept()
.await
.context(AcceptNewClientConnectionSnafu)?;

// If connections are unlimited, will execute one try_recv per new connection
while let Ok(ip) = connection_dropped_rx.try_recv() {
if let Entry::Occupied(mut o) = self.connections_per_ip.entry(ip) {
let connections = o.get_mut();
*connections -= 1;
if *connections == 0 {
o.remove_entry();
}
}
}

// If you connect via IPv4 you often show up as embedded inside an IPv6 address
// Extracting the embedded information here, so we get the real (TM) address
let ip = socket_addr.ip().to_canonical();

if let Some(limit) = self.max_connections_per_ip {
let current_connections = self.connections_per_ip.entry(ip).or_default();
if *current_connections < limit {
*current_connections += 1;
} else {
// Errors if session is dropped prematurely
let _ = socket.shutdown();
continue
}
};

let fb_for_thread = Arc::clone(&self.fb);
let statistics_tx_for_thread = self.statistics_tx.clone();
let network_buffer_size = self.network_buffer_size;
let connection_dropped_tx_clone = connection_dropped_tx.clone();
tokio::spawn(async move {
handle_connection(
socket,
ip,
fb_for_thread,
statistics_tx_for_thread,
network_buffer_size,
connection_dropped_tx_clone,
)
.await
});
Expand All @@ -98,6 +132,7 @@ pub async fn handle_connection(
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connection_dropped_tx: Option<mpsc::UnboundedSender<IpAddr>>,
) -> Result<(), Error> {
debug!("Handling connection from {ip}");

Expand Down Expand Up @@ -195,5 +230,10 @@ pub async fn handle_connection(
.await
.context(WriteToStatisticsChannelSnafu)?;

if let Some(tx) = connection_dropped_tx {
// Will fail if the server thread ends before the client thread
let _ = tx.send(ip);
}

Ok(())
}

0 comments on commit 7b31bb2

Please sign in to comment.