diff --git a/backend/src/main.rs b/backend/src/main.rs index fcfa862..bcfd1ba 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -83,6 +83,11 @@ async fn main() -> io::Result<()> { let session_registry = Arc::new(SessionRegistry::new()); let address_book = Arc::new(WsAddressBook::new()); + // Initialize Auth Services for Realtime + let jwt_config = crate::auth::jwt_service::JwtConfig::default(); + let jwt_service = Arc::new(crate::auth::jwt_service::JwtService::new(jwt_config, redis_conn.clone())); + let auth_guard = Arc::new(crate::realtime::auth::RealtimeAuth::new(db_pool.clone())); + // Start Redis Pub/Sub subscriber (broadcasts to local WebSocket actors) let broadcaster = WsBroadcaster::new( config.redis.url.clone(), @@ -103,6 +108,8 @@ async fn main() -> io::Result<()> { .app_data(web::Data::new(event_bus.clone())) .app_data(web::Data::new(session_registry.clone())) .app_data(web::Data::new(address_book.clone())) + .app_data(web::Data::new(jwt_service.clone())) + .app_data(web::Data::new(auth_guard.clone())) .app_data(web::Data::new(matchmaker_service.clone())) .app_data(web::Data::new(elo_engine.clone())) .wrap(cors_middleware()) diff --git a/backend/src/realtime/auth.rs b/backend/src/realtime/auth.rs new file mode 100644 index 0000000..a54f0bc --- /dev/null +++ b/backend/src/realtime/auth.rs @@ -0,0 +1,120 @@ +use crate::auth::jwt_service::Claims; +use crate::db::DbPool; +use uuid::Uuid; +use thiserror::Error; +use tracing::{warn, debug}; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Unauthorized: {0}")] + Unauthorized(String), + #[error("Invalid channel format: {0}")] + InvalidChannel(String), + #[error("Database error: {0}")] + Database(String), +} + +/// Centralized authorization guard for real-time channels. +pub struct RealtimeAuth { + db_pool: DbPool, +} + +impl RealtimeAuth { + pub fn new(db_pool: DbPool) -> Self { + Self { db_pool } + } + + /// Authorize a subscription to a channel. + pub async fn authorize_subscription( + &self, + claims: &Claims, + channel: &str, + ) -> Result<(), AuthError> { + let user_id = Uuid::parse_str(&claims.sub) + .map_err(|_| AuthError::Unauthorized("Invalid user ID in claims".to_string()))?; + + if channel.starts_with("user:") { + self.authorize_user_channel(user_id, channel) + } else if channel.starts_with("match:") { + self.authorize_match_channel(user_id, channel).await + } else { + Err(AuthError::InvalidChannel(format!("Unknown channel prefix: {}", channel))) + } + } + + fn authorize_user_channel(&self, user_id: Uuid, channel: &str) -> Result<(), AuthError> { + let target_id_str = channel.strip_prefix("user:").unwrap(); + let target_id = Uuid::parse_str(target_id_str) + .map_err(|_| AuthError::InvalidChannel("Invalid user ID in channel name".to_string()))?; + + if user_id == target_id { + Ok(()) + } else { + warn!(user_id = %user_id, target_id = %target_id, "Unauthorized attempt to subscribe to foreign user channel"); + Err(AuthError::Unauthorized("Cannot subscribe to another user's private channel".to_string())) + } + } + + async fn authorize_match_channel(&self, user_id: Uuid, channel: &str) -> Result<(), AuthError> { + let match_id_str = channel.strip_prefix("match:").unwrap(); + let match_id = Uuid::parse_str(match_id_str) + .map_err(|_| AuthError::InvalidChannel("Invalid match ID in channel name".to_string()))?; + + // Check if user is a participant in the match + let is_participant = sqlx::query!( + r#" + SELECT EXISTS ( + SELECT 1 FROM matches + WHERE id = $1 AND (player1_id = $2 OR player2_id = $2) + ) as "exists!" + "#, + match_id, + user_id + ) + .fetch_one(&self.db_pool) + .await + .map_err(|e| AuthError::Database(e.to_string()))? + .exists; + + if is_participant { + debug!(user_id = %user_id, match_id = %match_id, "Authorized subscription to match channel"); + Ok(()) + } else { + // Also check match_authority table just in case it's a newer match type + let is_participant_auth = sqlx::query!( + r#" + SELECT EXISTS ( + SELECT 1 FROM match_authority + WHERE id = $1 AND (player_a = $2 OR player_b = $2) + ) as "exists!" + "#, + match_id, + user_id.to_string() + ) + .fetch_one(&self.db_pool) + .await + .map_err(|e| AuthError::Database(e.to_string()))? + .exists; + + if is_participant_auth { + debug!(user_id = %user_id, match_id = %match_id, "Authorized subscription to match channel via match_authority"); + Ok(()) + } else { + warn!(user_id = %user_id, match_id = %match_id, "Unauthorized attempt to subscribe to match channel"); + Err(AuthError::Unauthorized("Not a participant in this match".to_string())) + } + } + } + + /// Authorize publishing to a channel (if clients are allowed to publish). + pub async fn authorize_publish( + &self, + claims: &Claims, + channel: &str, + ) -> Result<(), AuthError> { + // Currently, we don't allow clients to publish to any channel. + // If we did, we'd add similar logic here. + warn!(user_id = %claims.sub, channel = %channel, "Rejecting client-side publish attempt (not allowed)"); + Err(AuthError::Unauthorized("Publishing to channels is restricted to internal services".to_string())) + } +} diff --git a/backend/src/realtime/events.rs b/backend/src/realtime/events.rs index 025f3ce..05843e4 100644 --- a/backend/src/realtime/events.rs +++ b/backend/src/realtime/events.rs @@ -58,6 +58,16 @@ pub struct WsEnvelope { pub enum ClientMessage { Ping, Pong, + Subscribe { + channel: String, + }, + Unsubscribe { + channel: String, + }, + Publish { + channel: String, + event: RealtimeEvent, + }, } /// Actix message for delivering a realtime event to an actor. diff --git a/backend/src/realtime/mod.rs b/backend/src/realtime/mod.rs index 2762ed0..90747be 100644 --- a/backend/src/realtime/mod.rs +++ b/backend/src/realtime/mod.rs @@ -3,6 +3,7 @@ pub mod event_bus; pub mod ws_broadcaster; pub mod user_ws; pub mod session_registry; +pub mod auth; pub use events::*; pub use event_bus::EventBus; diff --git a/backend/src/realtime/session_registry.rs b/backend/src/realtime/session_registry.rs index f2ebc9f..d674364 100644 --- a/backend/src/realtime/session_registry.rs +++ b/backend/src/realtime/session_registry.rs @@ -2,52 +2,107 @@ use std::collections::{HashMap, HashSet}; use std::sync::RwLock; use uuid::Uuid; -/// Thread-safe registry mapping user IDs to their active WebSocket session IDs. +/// Thread-safe registry mapping user IDs to their active WebSocket session IDs, +/// and tracking channel subscriptions for event routing. pub struct SessionRegistry { - inner: RwLock>>, + user_to_sessions: RwLock>>, + channel_to_sessions: RwLock>>, + session_to_channels: RwLock>>, } impl SessionRegistry { pub fn new() -> Self { Self { - inner: RwLock::new(HashMap::new()), + user_to_sessions: RwLock::new(HashMap::new()), + channel_to_sessions: RwLock::new(HashMap::new()), + session_to_channels: RwLock::new(HashMap::new()), } } - /// Register a new session for a user. Returns true if newly added. + /// Register a new session for a user. pub fn register(&self, user_id: Uuid, session_id: Uuid) -> bool { - let mut map = self.inner.write().unwrap(); + let mut map = self.user_to_sessions.write().unwrap(); map.entry(user_id).or_default().insert(session_id) } - /// Remove a session for a user. Cleans up user entry if no sessions remain. + /// Remove a session for a user and clean up all its subscriptions. pub fn unregister(&self, user_id: Uuid, session_id: Uuid) { - let mut map = self.inner.write().unwrap(); - if let Some(sessions) = map.get_mut(&user_id) { + // Remove from user_to_sessions + { + let mut map = self.user_to_sessions.write().unwrap(); + if let Some(sessions) = map.get_mut(&user_id) { + sessions.remove(&session_id); + if sessions.is_empty() { + map.remove(&user_id); + } + } + } + + // Clean up subscriptions + let mut session_map = self.session_to_channels.write().unwrap(); + if let Some(channels) = session_map.remove(&session_id) { + let mut channel_map = self.channel_to_sessions.write().unwrap(); + for channel in channels { + if let Some(sessions) = channel_map.get_mut(&channel) { + sessions.remove(&session_id); + if sessions.is_empty() { + channel_map.remove(&channel); + } + } + } + } + } + + /// Subscribe a session to a channel. + pub fn subscribe(&self, session_id: Uuid, channel: String) { + let mut channel_map = self.channel_to_sessions.write().unwrap(); + channel_map.entry(channel.clone()).or_default().insert(session_id); + + let mut session_map = self.session_to_channels.write().unwrap(); + session_map.entry(session_id).or_default().insert(channel); + } + + /// Unsubscribe a session from a channel. + pub fn unsubscribe(&self, session_id: Uuid, channel: &str) { + let mut channel_map = self.channel_to_sessions.write().unwrap(); + if let Some(sessions) = channel_map.get_mut(channel) { sessions.remove(&session_id); if sessions.is_empty() { - map.remove(&user_id); + channel_map.remove(channel); } } + + let mut session_map = self.session_to_channels.write().unwrap(); + if let Some(channels) = session_map.get_mut(&session_id) { + channels.remove(channel); + } } /// Get all session IDs for a user. pub fn get_sessions(&self, user_id: &Uuid) -> Vec { - let map = self.inner.read().unwrap(); + let map = self.user_to_sessions.read().unwrap(); map.get(user_id) .map(|s| s.iter().copied().collect()) .unwrap_or_default() } + /// Get all session IDs subscribed to a channel. + pub fn get_subscribers(&self, channel: &str) -> Vec { + let map = self.channel_to_sessions.read().unwrap(); + map.get(channel) + .map(|s| s.iter().copied().collect()) + .unwrap_or_default() + } + /// Check if a user has any active sessions. pub fn has_user(&self, user_id: &Uuid) -> bool { - let map = self.inner.read().unwrap(); + let map = self.user_to_sessions.read().unwrap(); map.contains_key(user_id) } /// Number of distinct connected users. pub fn connected_user_count(&self) -> usize { - let map = self.inner.read().unwrap(); + let map = self.user_to_sessions.read().unwrap(); map.len() } } diff --git a/backend/src/realtime/user_ws.rs b/backend/src/realtime/user_ws.rs index 7eec98f..915e269 100644 --- a/backend/src/realtime/user_ws.rs +++ b/backend/src/realtime/user_ws.rs @@ -1,6 +1,8 @@ -use crate::realtime::events::{ClientMessage, DeliverEvent, WsEnvelope}; +use crate::auth::jwt_service::{Claims, JwtService}; +use crate::realtime::auth::RealtimeAuth; +use crate::realtime::events::{channels, ClientMessage, DeliverEvent, WsEnvelope}; use crate::realtime::session_registry::SessionRegistry; -use actix::{Actor, ActorContext, AsyncContext, Handler, StreamHandler}; +use actix::{Actor, ActorContext, AsyncContext, Handler, StreamHandler, ActorFutureExt}; use actix_web::{web, Error, HttpRequest, HttpResponse}; use actix_web_actors::ws; use std::sync::Arc; @@ -15,17 +17,26 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); pub struct UserWebSocket { session_id: Uuid, user_id: Uuid, + claims: Claims, hb: Instant, registry: Arc, + auth: Arc, } impl UserWebSocket { - pub fn new(user_id: Uuid, registry: Arc) -> Self { + pub fn new( + user_id: Uuid, + claims: Claims, + registry: Arc, + auth: Arc, + ) -> Self { Self { session_id: Uuid::new_v4(), user_id, + claims, hb: Instant::now(), registry, + auth, } } @@ -45,6 +56,14 @@ impl UserWebSocket { ctx.ping(b""); }); } + + fn send_error(&self, ctx: &mut ::Context, message: &str) { + let error_msg = serde_json::json!({ + "type": "error", + "message": message + }); + ctx.text(error_msg.to_string()); + } } impl Actor for UserWebSocket { @@ -57,6 +76,11 @@ impl Actor for UserWebSocket { "WebSocket session started" ); self.registry.register(self.user_id, self.session_id); + + // Automatically subscribe to own user channel + let user_channel = channels::user_channel(self.user_id); + self.registry.subscribe(self.session_id, user_channel); + self.start_heartbeat(ctx); } @@ -109,6 +133,60 @@ impl StreamHandler> for UserWebSocket { Ok(ClientMessage::Pong) => { self.hb = Instant::now(); } + Ok(ClientMessage::Subscribe { channel }) => { + let auth = self.auth.clone(); + let claims = self.claims.clone(); + let session_id = self.session_id; + let registry = self.registry.clone(); + + let fut = async move { + auth.authorize_subscription(&claims, &channel).await + }; + + ctx.wait(actix::fut::wrap_future(fut).then( + move |res, _act, ctx| { + match res { + Ok(_) => { + registry.subscribe(session_id, channel.clone()); + info!(session_id = %session_id, channel = %channel, "Subscribed to channel"); + let success = serde_json::json!({ + "type": "subscribed", + "channel": channel + }); + ctx.text(success.to_string()); + } + Err(e) => { + warn!(session_id = %session_id, channel = %channel, error = %e, "Subscription denied"); + let error_msg = serde_json::json!({ + "type": "subscription_error", + "channel": channel, + "reason": e.to_string() + }); + ctx.text(error_msg.to_string()); + } + } + actix::fut::ready(()) + }, + )); + } + Ok(ClientMessage::Unsubscribe { channel }) => { + self.registry.unsubscribe(self.session_id, &channel); + info!(session_id = %self.session_id, channel = %channel, "Unsubscribed from channel"); + } + Ok(ClientMessage::Publish { channel, .. }) => { + // All publish attempts are currently rejected in our guard + let auth = self.auth.clone(); + let claims = self.claims.clone(); + let fut = async move { + auth.authorize_publish(&claims, &channel).await + }; + ctx.wait(actix::fut::wrap_future(fut).then(|res, act: &mut Self, ctx| { + if let Err(e) = res { + act.send_error(ctx, &e.to_string()); + } + actix::fut::ready(()) + })); + } Err(_) => { debug!( user_id = %self.user_id, @@ -166,43 +244,43 @@ impl Handler for UserWebSocket { } /// HTTP upgrade endpoint for WebSocket connections. -/// -/// Extracts user identity from query parameters. In a future task, this will -/// be replaced with proper JWT validation via JwtService. pub async fn ws_handler( req: HttpRequest, stream: web::Payload, registry: web::Data>, + jwt_service: web::Data>, + auth_guard: web::Data>, ) -> Result { let query_string = req.query_string(); - // Extract token from query string (reserved for future JWT validation) - let _token = query_string.split('&').find_map(|pair| { + // Extract token from query string + let token = query_string.split('&').find_map(|pair| { let mut parts = pair.splitn(2, '='); match (parts.next(), parts.next()) { (Some("token"), Some(value)) => Some(value.to_string()), _ => None, } - }); - - // TODO: In Task 6, we'll add JwtService validation here. - // For now, extract user_id from a simple "user_id" query param for testing. - let user_id_str = query_string - .split('&') - .find_map(|pair| { - let mut parts = pair.splitn(2, '='); - match (parts.next(), parts.next()) { - (Some("user_id"), Some(value)) => Some(value.to_string()), - _ => None, - } - }) - .ok_or_else(|| actix_web::error::ErrorUnauthorized("Missing user_id parameter"))?; + }).ok_or_else(|| actix_web::error::ErrorUnauthorized("Missing token parameter"))?; + + // Validate token + let claims = jwt_service.validate_token(&token).await.map_err(|e| { + warn!(error = %e, "WebSocket connection rejected: invalid token"); + actix_web::error::ErrorUnauthorized(format!("Invalid token: {}", e)) + })?; + + let user_id = Uuid::parse_str(&claims.sub).map_err(|_| { + actix_web::error::ErrorUnauthorized("Invalid user ID in token") + })?; - let user_id = Uuid::parse_str(&user_id_str) - .map_err(|_| actix_web::error::ErrorUnauthorized("Invalid user_id"))?; + info!(user_id = %user_id, "WebSocket upgrade request approved via JWT"); + + let ws_actor = UserWebSocket::new( + user_id, + claims, + registry.get_ref().clone(), + auth_guard.get_ref().clone() + ); - info!(user_id = %user_id, "WebSocket upgrade request"); - let ws_actor = UserWebSocket::new(user_id, registry.get_ref().clone()); ws::start(ws_actor, &req, stream) } diff --git a/backend/src/realtime/ws_broadcaster.rs b/backend/src/realtime/ws_broadcaster.rs index e76e380..1fda483 100644 --- a/backend/src/realtime/ws_broadcaster.rs +++ b/backend/src/realtime/ws_broadcaster.rs @@ -174,13 +174,20 @@ impl WsBroadcaster { } fn route_match_event( - _channel: &str, - _event: &RealtimeEvent, - _registry: &Arc, - _address_book: &Arc, + channel: &str, + event: &RealtimeEvent, + registry: &Arc, + address_book: &Arc, ) { - // Match events are primarily routed via user channels. - // This handler is for future spectator/admin features. - debug!("Match channel event received (spectator routing not yet implemented)"); + // Channel format: "match:" + let subscribers = registry.get_subscribers(channel); + + for session_id in subscribers { + if let Some(addr) = address_book.get(&session_id) { + addr.do_send(DeliverEvent(event.clone())); + } + } + + debug!(channel = %channel, subscriber_count = %registry.get_subscribers(channel).len(), "Routed event to match subscribers"); } } diff --git a/backend/tests/realtime_auth_integration_test.rs b/backend/tests/realtime_auth_integration_test.rs new file mode 100644 index 0000000..157a01d --- /dev/null +++ b/backend/tests/realtime_auth_integration_test.rs @@ -0,0 +1,65 @@ +use actix_web::{test, App, web}; +use actix_web_actors::ws; +use arenax_backend::realtime::user_ws::{ws_handler, UserWebSocket}; +use arenax_backend::realtime::session_registry::SessionRegistry; +use arenax_backend::realtime::auth::RealtimeAuth; +use arenax_backend::auth::jwt_service::{JwtService, JwtConfig}; +use arenax_backend::db::DbPool; +use std::sync::Arc; +use uuid::Uuid; + +#[tokio::test] +async fn test_ws_connection_requires_token() { + let registry = Arc::new(SessionRegistry::new()); + let jwt_config = JwtConfig::default(); + let redis_client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let redis_conn = redis::aio::ConnectionManager::new(redis_client).await.unwrap(); + let jwt_service = Arc::new(JwtService::new(jwt_config, redis_conn)); + let db_pool = DbPool::default(); // Mock for test + let auth_guard = Arc::new(RealtimeAuth::new(db_pool)); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(registry.clone())) + .app_data(web::Data::new(jwt_service.clone())) + .app_data(web::Data::new(auth_guard.clone())) + .route("/ws", web::get().to(ws_handler)) + ).await; + + // Test without token + let req = test::TestRequest::with_uri("/ws").to_request(); + let resp = test::call_service(&app, req).await; + assert_eq!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn test_ws_connection_with_valid_token() { + let registry = Arc::new(SessionRegistry::new()); + let jwt_config = JwtConfig::default(); + let redis_client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let redis_conn = redis::aio::ConnectionManager::new(redis_client).await.unwrap(); + let jwt_service = Arc::new(JwtService::new(jwt_config, redis_conn)); + let db_pool = DbPool::default(); + let auth_guard = Arc::new(RealtimeAuth::new(db_pool)); + + let user_id = Uuid::new_v4(); + let token = jwt_service.generate_access_token(user_id, vec!["user".to_string()], None).await.unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(registry.clone())) + .app_data(web::Data::new(jwt_service.clone())) + .app_data(web::Data::new(auth_guard.clone())) + .route("/ws", web::get().to(ws_handler)) + ).await; + + // Test with valid token + let uri = format="/ws?token={}", token; + let req = test::TestRequest::with_uri(&uri).to_request(); + // ws::start would be called here, but in a test environment we'd need more setup for actual WS + // For now, let's just assert it passes the upgrade check (which returns 101 Switching Protocols) + let resp = test::call_service(&app, req).await; + // Note: actix-web test::call_service for WS might return 101 or 400 depending on headers + // But it definitely shouldn't be 401 Unauthorized + assert_ne!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED); +} diff --git a/backend/tests/realtime_auth_test.rs b/backend/tests/realtime_auth_test.rs new file mode 100644 index 0000000..20e82c8 --- /dev/null +++ b/backend/tests/realtime_auth_test.rs @@ -0,0 +1,49 @@ +use crate::realtime::auth::RealtimeAuth; +use crate::auth::jwt_service::{Claims, TokenType}; +use uuid::Uuid; +use chrono::{Duration, Utc}; + +#[tokio::test] +async fn test_authorize_user_channel_success() { + let db_pool = crate::db::DbPool::new_test().await; // Assuming this exists or I'll mock it + let auth = RealtimeAuth::new(db_pool); + + let user_id = Uuid::new_v4(); + let claims = Claims { + sub: user_id.to_string(), + exp: (Utc::now() + Duration::minutes(15)).timestamp(), + iat: Utc::now().timestamp(), + jti: Uuid::new_v4().to_string(), + token_type: TokenType::Access, + device_id: None, + session_id: Uuid::new_v4().to_string(), + roles: vec!["user".to_string()], + }; + + let channel = format!("user:{}", user_id); + let result = auth.authorize_subscription(&claims, &channel).await; + assert!(result.is_ok()); +} + +#[tokio::test] +async fn test_authorize_user_channel_denied() { + let db_pool = crate::db::DbPool::new_test().await; + let auth = RealtimeAuth::new(db_pool); + + let user_id = Uuid::new_v4(); + let other_user_id = Uuid::new_v4(); + let claims = Claims { + sub: user_id.to_string(), + exp: (Utc::now() + Duration::minutes(15)).timestamp(), + iat: Utc::now().timestamp(), + jti: Uuid::new_v4().to_string(), + token_type: TokenType::Access, + device_id: None, + session_id: Uuid::new_v4().to_string(), + roles: vec!["user".to_string()], + }; + + let channel = format!("user:{}", other_user_id); + let result = auth.authorize_subscription(&claims, &channel).await; + assert!(result.is_err()); +}