diff --git a/api.py b/api.py index 692f5a1..16473a2 100644 --- a/api.py +++ b/api.py @@ -803,7 +803,19 @@ async def process_video(request: Request, video_url: str = Form(...)): @limiter.limit("2/30minute") async def run_risk_investment_agent(request:Request,stock_selection: str = Form("AAPL"), risk_tolerance : str = Form("Medium"), - trading_strategy_preference: str = Form("Day Trading")): + trading_strategy_preference: str = Form("Day Trading"), + token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, os.getenv("TOKEN_SECRET_KEY"), algorithms=[settings.ALGORITHM]) + email = payload.get("sub") + if email is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + user = users_collection.find_one({"email": email}) + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") + except JWTError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + try: input_data = {"stock_selection": stock_selection, "risk_tolerance": risk_tolerance,