|
1 | 1 | import asyncio
|
2 | 2 | import json
|
3 | 3 | import uuid
|
| 4 | +from jose import jwt, JWTError |
4 | 5 | from typing import List
|
5 | 6 | from redis.asyncio import from_url as create_redis
|
6 |
| -from fastapi import WebSocket, WebSocketDisconnect |
| 7 | +from fastapi import WebSocket, WebSocketDisconnect, Depends |
| 8 | +from sqlalchemy.orm import Session |
| 9 | +from app.database import get_db |
7 | 10 |
|
8 | 11 | from app.config import get_config
|
| 12 | +from app.models.user import User |
9 | 13 |
|
10 | 14 | redis_pub = None
|
11 | 15 | redis_sub = None
|
@@ -82,17 +86,35 @@ async def publish_message(event: str, data: dict):
|
82 | 86 |
|
83 | 87 | def register_ws_routes(app):
|
84 | 88 | @app.websocket("/ws")
|
85 |
| - async def websocket_endpoint(websocket: WebSocket): |
| 89 | + async def websocket_endpoint(websocket: WebSocket, db: Session = Depends(get_db)): |
| 90 | + token = websocket.query_params.get('token') |
| 91 | + if not token: |
| 92 | + await websocket.close(code=1008, reason="Missing token") |
| 93 | + return |
| 94 | + |
| 95 | + try: |
| 96 | + from app.api.auth import ALGORITHM, SECRET_KEY |
| 97 | + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| 98 | + user_email = payload.get("sub") |
| 99 | + if not user_email: |
| 100 | + raise ValueError("Invalid token") |
| 101 | + except JWTError: |
| 102 | + await websocket.close(code=1008, reason="Invalid token") |
| 103 | + return |
| 104 | + |
| 105 | + user = db.query(User).filter(User.email == user_email).first() |
| 106 | + if user is None: |
| 107 | + await websocket.close(code=1008, reason="User not found") |
| 108 | + return |
| 109 | + |
86 | 110 | await manager.connect(websocket)
|
87 | 111 | try:
|
88 | 112 | while True:
|
89 | 113 | data = await websocket.receive_text()
|
90 |
| - # Handle incoming messages |
91 |
| - await manager.send_personal_message(f"You said: {data}", websocket) |
| 114 | + # Optionally handle incoming messages |
| 115 | + await manager.send_personal_message("You said: " + data, websocket) |
92 | 116 | except WebSocketDisconnect:
|
93 | 117 | await manager.disconnect(websocket)
|
94 |
| - except Exception as e: |
95 |
| - print(f"Error: {e}") |
96 |
| - await manager.disconnect(websocket) |
| 118 | + |
97 | 119 |
|
98 | 120 | return manager
|
0 commit comments