Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ async fn main() -> io::Result<()> {
let session_registry = Arc::new(SessionRegistry::new());
let address_book = Arc::new(WsAddressBook::new());

// Initialize Auth Services for Realtime
let jwt_config = crate::auth::jwt_service::JwtConfig::default();
let jwt_service = Arc::new(crate::auth::jwt_service::JwtService::new(jwt_config, redis_conn.clone()));
let auth_guard = Arc::new(crate::realtime::auth::RealtimeAuth::new(db_pool.clone()));

// Start Redis Pub/Sub subscriber (broadcasts to local WebSocket actors)
let broadcaster = WsBroadcaster::new(
config.redis.url.clone(),
Expand All @@ -103,6 +108,8 @@ async fn main() -> io::Result<()> {
.app_data(web::Data::new(event_bus.clone()))
.app_data(web::Data::new(session_registry.clone()))
.app_data(web::Data::new(address_book.clone()))
.app_data(web::Data::new(jwt_service.clone()))
.app_data(web::Data::new(auth_guard.clone()))
.app_data(web::Data::new(matchmaker_service.clone()))
.app_data(web::Data::new(elo_engine.clone()))
.wrap(cors_middleware())
Expand Down
120 changes: 120 additions & 0 deletions backend/src/realtime/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use crate::auth::jwt_service::Claims;
use crate::db::DbPool;
use uuid::Uuid;
use thiserror::Error;
use tracing::{warn, debug};

#[derive(Debug, Error)]
pub enum AuthError {
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Invalid channel format: {0}")]
InvalidChannel(String),
#[error("Database error: {0}")]
Database(String),
}

/// Centralized authorization guard for real-time channels.
pub struct RealtimeAuth {
db_pool: DbPool,
}

impl RealtimeAuth {
pub fn new(db_pool: DbPool) -> Self {
Self { db_pool }
}

/// Authorize a subscription to a channel.
pub async fn authorize_subscription(
&self,
claims: &Claims,
channel: &str,
) -> Result<(), AuthError> {
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| AuthError::Unauthorized("Invalid user ID in claims".to_string()))?;

if channel.starts_with("user:") {
self.authorize_user_channel(user_id, channel)
} else if channel.starts_with("match:") {
self.authorize_match_channel(user_id, channel).await
} else {
Err(AuthError::InvalidChannel(format!("Unknown channel prefix: {}", channel)))
}
}

fn authorize_user_channel(&self, user_id: Uuid, channel: &str) -> Result<(), AuthError> {
let target_id_str = channel.strip_prefix("user:").unwrap();
let target_id = Uuid::parse_str(target_id_str)
.map_err(|_| AuthError::InvalidChannel("Invalid user ID in channel name".to_string()))?;

if user_id == target_id {
Ok(())
} else {
warn!(user_id = %user_id, target_id = %target_id, "Unauthorized attempt to subscribe to foreign user channel");
Err(AuthError::Unauthorized("Cannot subscribe to another user's private channel".to_string()))
}
}

async fn authorize_match_channel(&self, user_id: Uuid, channel: &str) -> Result<(), AuthError> {
let match_id_str = channel.strip_prefix("match:").unwrap();
let match_id = Uuid::parse_str(match_id_str)
.map_err(|_| AuthError::InvalidChannel("Invalid match ID in channel name".to_string()))?;

// Check if user is a participant in the match
let is_participant = sqlx::query!(
r#"
SELECT EXISTS (
SELECT 1 FROM matches
WHERE id = $1 AND (player1_id = $2 OR player2_id = $2)
) as "exists!"
"#,
match_id,
user_id
)
.fetch_one(&self.db_pool)
.await
.map_err(|e| AuthError::Database(e.to_string()))?
.exists;

if is_participant {
debug!(user_id = %user_id, match_id = %match_id, "Authorized subscription to match channel");
Ok(())
} else {
// Also check match_authority table just in case it's a newer match type
let is_participant_auth = sqlx::query!(
r#"
SELECT EXISTS (
SELECT 1 FROM match_authority
WHERE id = $1 AND (player_a = $2 OR player_b = $2)
) as "exists!"
"#,
match_id,
user_id.to_string()
)
.fetch_one(&self.db_pool)
.await
.map_err(|e| AuthError::Database(e.to_string()))?
.exists;

if is_participant_auth {
debug!(user_id = %user_id, match_id = %match_id, "Authorized subscription to match channel via match_authority");
Ok(())
} else {
warn!(user_id = %user_id, match_id = %match_id, "Unauthorized attempt to subscribe to match channel");
Err(AuthError::Unauthorized("Not a participant in this match".to_string()))
}
}
}

/// Authorize publishing to a channel (if clients are allowed to publish).
pub async fn authorize_publish(
&self,
claims: &Claims,
channel: &str,
) -> Result<(), AuthError> {
// Currently, we don't allow clients to publish to any channel.
// If we did, we'd add similar logic here.
warn!(user_id = %claims.sub, channel = %channel, "Rejecting client-side publish attempt (not allowed)");
Err(AuthError::Unauthorized("Publishing to channels is restricted to internal services".to_string()))
}
}
10 changes: 10 additions & 0 deletions backend/src/realtime/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ pub struct WsEnvelope {
pub enum ClientMessage {
Ping,
Pong,
Subscribe {
channel: String,
},
Unsubscribe {
channel: String,
},
Publish {
channel: String,
event: RealtimeEvent,
},
}

/// Actix message for delivering a realtime event to an actor.
Expand Down
1 change: 1 addition & 0 deletions backend/src/realtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod event_bus;
pub mod ws_broadcaster;
pub mod user_ws;
pub mod session_registry;
pub mod auth;

pub use events::*;
pub use event_bus::EventBus;
Expand Down
79 changes: 67 additions & 12 deletions backend/src/realtime/session_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,107 @@ use std::collections::{HashMap, HashSet};
use std::sync::RwLock;
use uuid::Uuid;

/// Thread-safe registry mapping user IDs to their active WebSocket session IDs.
/// Thread-safe registry mapping user IDs to their active WebSocket session IDs,
/// and tracking channel subscriptions for event routing.
pub struct SessionRegistry {
inner: RwLock<HashMap<Uuid, HashSet<Uuid>>>,
user_to_sessions: RwLock<HashMap<Uuid, HashSet<Uuid>>>,
channel_to_sessions: RwLock<HashMap<String, HashSet<Uuid>>>,
session_to_channels: RwLock<HashMap<Uuid, HashSet<String>>>,
}

impl SessionRegistry {
pub fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
user_to_sessions: RwLock::new(HashMap::new()),
channel_to_sessions: RwLock::new(HashMap::new()),
session_to_channels: RwLock::new(HashMap::new()),
}
}

/// Register a new session for a user. Returns true if newly added.
/// Register a new session for a user.
pub fn register(&self, user_id: Uuid, session_id: Uuid) -> bool {
let mut map = self.inner.write().unwrap();
let mut map = self.user_to_sessions.write().unwrap();
map.entry(user_id).or_default().insert(session_id)
}

/// Remove a session for a user. Cleans up user entry if no sessions remain.
/// Remove a session for a user and clean up all its subscriptions.
pub fn unregister(&self, user_id: Uuid, session_id: Uuid) {
let mut map = self.inner.write().unwrap();
if let Some(sessions) = map.get_mut(&user_id) {
// Remove from user_to_sessions
{
let mut map = self.user_to_sessions.write().unwrap();
if let Some(sessions) = map.get_mut(&user_id) {
sessions.remove(&session_id);
if sessions.is_empty() {
map.remove(&user_id);
}
}
}

// Clean up subscriptions
let mut session_map = self.session_to_channels.write().unwrap();
if let Some(channels) = session_map.remove(&session_id) {
let mut channel_map = self.channel_to_sessions.write().unwrap();
for channel in channels {
if let Some(sessions) = channel_map.get_mut(&channel) {
sessions.remove(&session_id);
if sessions.is_empty() {
channel_map.remove(&channel);
}
}
}
}
}

/// Subscribe a session to a channel.
pub fn subscribe(&self, session_id: Uuid, channel: String) {
let mut channel_map = self.channel_to_sessions.write().unwrap();
channel_map.entry(channel.clone()).or_default().insert(session_id);

let mut session_map = self.session_to_channels.write().unwrap();
session_map.entry(session_id).or_default().insert(channel);
}

/// Unsubscribe a session from a channel.
pub fn unsubscribe(&self, session_id: Uuid, channel: &str) {
let mut channel_map = self.channel_to_sessions.write().unwrap();
if let Some(sessions) = channel_map.get_mut(channel) {
sessions.remove(&session_id);
if sessions.is_empty() {
map.remove(&user_id);
channel_map.remove(channel);
}
}

let mut session_map = self.session_to_channels.write().unwrap();
if let Some(channels) = session_map.get_mut(&session_id) {
channels.remove(channel);
}
}

/// Get all session IDs for a user.
pub fn get_sessions(&self, user_id: &Uuid) -> Vec<Uuid> {
let map = self.inner.read().unwrap();
let map = self.user_to_sessions.read().unwrap();
map.get(user_id)
.map(|s| s.iter().copied().collect())
.unwrap_or_default()
}

/// Get all session IDs subscribed to a channel.
pub fn get_subscribers(&self, channel: &str) -> Vec<Uuid> {
let map = self.channel_to_sessions.read().unwrap();
map.get(channel)
.map(|s| s.iter().copied().collect())
.unwrap_or_default()
}

/// Check if a user has any active sessions.
pub fn has_user(&self, user_id: &Uuid) -> bool {
let map = self.inner.read().unwrap();
let map = self.user_to_sessions.read().unwrap();
map.contains_key(user_id)
}

/// Number of distinct connected users.
pub fn connected_user_count(&self) -> usize {
let map = self.inner.read().unwrap();
let map = self.user_to_sessions.read().unwrap();
map.len()
}
}
Loading
Loading