Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add cli option to limit number of connections per ip #22

Merged
merged 10 commits into from
May 30, 2024
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@ All notable changes to this project will be documented in this file.

## [Unreleased]

### Added

- Command line option `--connections-per-ip` that allows limiting the number of connections per ip address. Default is unlimited ([#22])

### Fixed

- Raise `ffmpeg` errors as early as possible, e.g. when the `ffmpeg` command is not found

[#22]: https://github.com/sbernauer/breakwater/pull/22

## [0.13.0] - 2024-05-15

## Added
### Added

- Also release binary for `aarch64-apple-darwin`

Expand Down
4 changes: 4 additions & 0 deletions breakwater/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,8 @@ pub struct CliArgs {
#[cfg(feature = "vnc")]
#[clap(short, long, default_value_t = 5900)]
pub vnc_port: u16,

/// Allow only a certain number of connections per ip address
#[clap(short, long)]
pub connections_per_ip: Option<u64>,
}
3 changes: 2 additions & 1 deletion breakwater/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async fn main() -> Result<(), Error> {
statistics_save_mode,
);

let server = Server::new(
let mut server = Server::new(
&args.listen_address,
Arc::clone(&fb),
statistics_tx.clone(),
Expand All @@ -124,6 +124,7 @@ async fn main() -> Result<(), Error> {
.context(InvalidNetworkBufferSizeSnafu {
network_buffer_size: args.network_buffer_size,
})?,
args.connections_per_ip,
)
.await
.context(StartPixelflutServerSnafu)?;
Expand Down
45 changes: 43 additions & 2 deletions breakwater/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::{cmp::min, net::IpAddr, sync::Arc, time::Duration};

use breakwater_core::framebuffer::FrameBuffer;
Expand Down Expand Up @@ -42,6 +44,8 @@ pub struct Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connections_per_ip: HashMap<IpAddr, u64>,
max_connections_per_ip: Option<u64>,
}

impl Server {
Expand All @@ -50,6 +54,7 @@ impl Server {
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
max_connections_per_ip: Option<u64>,
) -> Result<Self, Error> {
let listener = TcpListener::bind(listen_address)
.await
Expand All @@ -61,30 +66,60 @@ impl Server {
fb,
statistics_tx,
network_buffer_size,
connections_per_ip: HashMap::new(),
max_connections_per_ip,
})
}

pub async fn start(&self) -> Result<(), Error> {
pub async fn start(&mut self) -> Result<(), Error> {
let (connection_dropped_tx, mut connection_dropped_rx) =
mpsc::unbounded_channel::<IpAddr>();
let connection_dropped_tx = self.max_connections_per_ip.map(|_| connection_dropped_tx);
loop {
let (socket, socket_addr) = self
let (mut socket, socket_addr) = self
.listener
.accept()
.await
.context(AcceptNewClientConnectionSnafu)?;

// If connections are unlimited, will execute one try_recv per new connection
while let Ok(ip) = connection_dropped_rx.try_recv() {
if let Entry::Occupied(mut o) = self.connections_per_ip.entry(ip) {
let connections = o.get_mut();
*connections -= 1;
if *connections == 0 {
o.remove_entry();
}
}
}

// If you connect via IPv4 you often show up as embedded inside an IPv6 address
// Extracting the embedded information here, so we get the real (TM) address
let ip = socket_addr.ip().to_canonical();

if let Some(limit) = self.max_connections_per_ip {
let current_connections = self.connections_per_ip.entry(ip).or_default();
if *current_connections < limit {
*current_connections += 1;
} else {
// Errors if session is dropped prematurely
let _ = socket.shutdown();
continue;
}
};

let fb_for_thread = Arc::clone(&self.fb);
let statistics_tx_for_thread = self.statistics_tx.clone();
let network_buffer_size = self.network_buffer_size;
let connection_dropped_tx_clone = connection_dropped_tx.clone();
tokio::spawn(async move {
handle_connection(
socket,
ip,
fb_for_thread,
statistics_tx_for_thread,
network_buffer_size,
connection_dropped_tx_clone,
)
.await
});
Expand All @@ -98,6 +133,7 @@ pub async fn handle_connection(
fb: Arc<FrameBuffer>,
statistics_tx: mpsc::Sender<StatisticsEvent>,
network_buffer_size: usize,
connection_dropped_tx: Option<mpsc::UnboundedSender<IpAddr>>,
) -> Result<(), Error> {
debug!("Handling connection from {ip}");

Expand Down Expand Up @@ -195,5 +231,10 @@ pub async fn handle_connection(
.await
.context(WriteToStatisticsChannelSnafu)?;

if let Some(tx) = connection_dropped_tx {
// Will fail if the server thread ends before the client thread
let _ = tx.send(ip);
}

Ok(())
}
Loading