diff --git a/backend/apps/northbound_app.py b/backend/apps/northbound_app.py index a39877ded..e6aaf4eb6 100644 --- a/backend/apps/northbound_app.py +++ b/backend/apps/northbound_app.py @@ -1,12 +1,12 @@ import logging from http import HTTPStatus -from typing import Optional, Dict +from typing import Optional, Dict, Any import uuid -from fastapi import APIRouter, Body, Header, Request, HTTPException +from fastapi import APIRouter, Body, Header, Request, HTTPException, Query from fastapi.responses import JSONResponse -from consts.exceptions import UnauthorizedError, LimitExceededError, SignatureValidationError +from consts.exceptions import LimitExceededError, UnauthorizedError from services.northbound_service import ( NorthboundContext, get_conversation_history, @@ -14,86 +14,89 @@ start_streaming_chat, stop_chat, get_agent_info_list, - update_conversation_title + update_conversation_title, ) -from utils.auth_utils import get_current_user_id, validate_aksk_authentication +from utils.auth_utils import validate_bearer_token, get_user_and_tenant_by_access_key router = APIRouter(prefix="/nb/v1", tags=["northbound"]) -def _get_header(headers: Dict[str, str], name: str) -> Optional[str]: - for k, v in headers.items(): - if k.lower() == name.lower(): - return v - return None +async def _get_northbound_context(request: Request) -> NorthboundContext: + """ + Build northbound context from request. + Authentication: Bearer Token (API Key) in Authorization header + - Authorization: Bearer -async def _parse_northbound_context(request: Request) -> NorthboundContext: - """ - Build northbound context from headers. + The user_id and tenant_id are derived from the access_key by querying + user_token_info_t and user_tenant_t tables. - - X-Access-Key: Access key for AK/SK authentication - - X-Timestamp: Timestamp for signature validation - - X-Signature: HMAC-SHA256 signature signed with secret key - - Authorization: Bearer , jwt contains sub (user_id) - - X-Request-Id: optional, generated if not provided + Optional headers: + - X-Request-Id: Request ID, generated if not provided """ - # 1. Verify AK/SK signature + # 1. Validate Bearer Token and extract access_key try: - # Get request body for signature verification - request_body = "" - if request.method in ["POST", "PUT", "PATCH"]: - try: - body_bytes = await request.body() - request_body = body_bytes.decode('utf-8') if body_bytes else "" - except Exception as e: - logging.warning( - f"Cannot read request body for signature verification: {e}") - request_body = "" - - validate_aksk_authentication(request.headers, request_body) - except (UnauthorizedError, LimitExceededError, SignatureValidationError) as e: - raise e + auth_header = request.headers.get("Authorization") + is_valid, token_info = validate_bearer_token(auth_header) + + if not is_valid or not token_info: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Invalid or missing bearer token" + ) + + # Extract access_key from the token + access_key = auth_header.replace("Bearer ", "") if auth_header.startswith("Bearer ") else auth_header + + # Get user_id and tenant_id from access_key + user_tenant_info = get_user_and_tenant_by_access_key(access_key) + resolved_user_id = user_tenant_info.get("user_id") + resolved_tenant_id = user_tenant_info.get("tenant_id") + token_id = user_tenant_info.get("token_id") + + except HTTPException: + raise + except LimitExceededError as e: + logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="Too Many Requests: rate limit exceeded") + except UnauthorizedError as e: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail=str(e) + ) except Exception as e: - logging.error(f"Failed to parse northbound context: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Internal Server Error: cannot parse northbound context") - - # 2. Parse JWT token - auth_header = _get_header(request.headers, "Authorization") - if not auth_header: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: No authorization header found") + logging.error(f"Failed to validate bearer token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: invalid API key" + ) - # Use auth_utils to parse JWT token - try: - user_id, tenant_id = get_current_user_id(auth_header) + if not resolved_user_id: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Missing user information for this access key" + ) - if not user_id: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: missing user_id in JWT token") - if not tenant_id: - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: unregistered user_id in JWT token") + if not resolved_tenant_id: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Missing tenant information for this access key" + ) - except HTTPException as e: - # Preserve explicit HTTP errors raised during JWT parsing - raise e - except Exception as e: - logging.error(f"Failed to parse JWT token: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Internal Server Error: cannot parse JWT token") + request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) - request_id = _get_header( - request.headers, "X-Request-Id") or str(uuid.uuid4()) + # Get authorization header if present, otherwise use a placeholder + auth_header_value = request.headers.get("Authorization", "Bearer placeholder") return NorthboundContext( request_id=request_id, - tenant_id=tenant_id, - user_id=str(user_id), - authorization=auth_header, + tenant_id=resolved_tenant_id, + user_id=resolved_user_id, + authorization=auth_header_value, + token_id=token_id, ) @@ -105,34 +108,27 @@ async def health_check(): @router.post("/chat/run") async def run_chat( request: Request, - conversation_id: str = Body(..., embed=True), + conversation_id: Optional[int] = Body(None, embed=True), agent_name: str = Body(..., embed=True), query: str = Body(..., embed=True), + meta_data: Optional[Dict[str, Any]] = Body(None, embed=True), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await start_streaming_chat( ctx=ctx, - external_conversation_id=conversation_id, + conversation_id=conversation_id, agent_name=agent_name, query=query, + meta_data=meta_data, idempotency_key=idempotency_key, ) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: - # Propagate HTTP errors from context parsing without altering status/detail raise e except Exception as e: logging.error(f"Failed to run chat: {str(e)}", exc_info=e) @@ -141,22 +137,25 @@ async def run_chat( @router.get("/chat/stop/{conversation_id}") -async def stop_chat_stream(request: Request, conversation_id: str): +async def stop_chat_stream( + request: Request, + conversation_id: int, + meta_data: Optional[str] = Query(None, description="Optional metadata as JSON string"), +): + import json + parsed_meta_data = None + if meta_data: + try: + parsed_meta_data = json.loads(meta_data) + except json.JSONDecodeError: + pass try: - ctx: NorthboundContext = await _parse_northbound_context(request) - return await stop_chat(ctx=ctx, external_conversation_id=conversation_id) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") + ctx: NorthboundContext = await _get_northbound_context(request) + return await stop_chat(ctx=ctx, conversation_id=conversation_id, meta_data=parsed_meta_data) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -166,22 +165,17 @@ async def stop_chat_stream(request: Request, conversation_id: str): @router.get("/conversations/{conversation_id}") -async def get_history(request: Request, conversation_id: str): +async def get_history( + request: Request, + conversation_id: int, +): try: - ctx: NorthboundContext = await _parse_northbound_context(request) - return await get_conversation_history(ctx=ctx, external_conversation_id=conversation_id) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") + ctx: NorthboundContext = await _get_northbound_context(request) + return await get_conversation_history(ctx=ctx, conversation_id=conversation_id) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -193,20 +187,12 @@ async def get_history(request: Request, conversation_id: str): @router.get("/agents") async def list_agents(request: Request): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await get_agent_info_list(ctx=ctx) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -218,20 +204,12 @@ async def list_agents(request: Request): @router.get("/conversations") async def list_convs(request: Request): try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) return await list_conversations(ctx=ctx) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: @@ -243,34 +221,35 @@ async def list_convs(request: Request): @router.put("/conversations/{conversation_id}/title") async def update_convs_title( request: Request, - conversation_id: str, - title: str, + conversation_id: int, + title: str = Query(..., description="New title"), + meta_data: Optional[str] = Query(None, description="Optional metadata as JSON string"), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): + import json + parsed_meta_data = None + if meta_data: + try: + parsed_meta_data = json.loads(meta_data) + except json.JSONDecodeError: + pass try: - ctx: NorthboundContext = await _parse_northbound_context(request) + ctx: NorthboundContext = await _get_northbound_context(request) result = await update_conversation_title( ctx=ctx, - external_conversation_id=conversation_id, + conversation_id=conversation_id, title=title, + meta_data=parsed_meta_data, idempotency_key=idempotency_key, ) headers_out = { "Idempotency-Key": result.get("idempotency_key", ""), "X-Request-Id": ctx.request_id} return JSONResponse(content=result, headers=headers_out) - except UnauthorizedError as e: - logging.error(f"Unauthorized: AK/SK authentication failed: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: AK/SK authentication failed") except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") - except SignatureValidationError as e: - logging.error(f"Unauthorized: invalid signature: {str(e)}", exc_info=e) - raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, - detail="Unauthorized: invalid signature") except HTTPException as e: raise e except Exception as e: diff --git a/backend/apps/user_management_app.py b/backend/apps/user_management_app.py index c38b4e73c..956832f52 100644 --- a/backend/apps/user_management_app.py +++ b/backend/apps/user_management_app.py @@ -1,9 +1,10 @@ import logging from dotenv import load_dotenv -from fastapi import APIRouter, Request, HTTPException +from fastapi import APIRouter, Header, Query, Request, HTTPException from fastapi.responses import JSONResponse from http import HTTPStatus +from typing import Optional from supabase_auth.errors import AuthApiError, AuthWeakPasswordError @@ -11,7 +12,7 @@ from consts.exceptions import NoInviteCodeException, IncorrectInviteCodeException, UserRegistrationException from services.user_management_service import get_authorized_client, validate_token, \ check_auth_service_health, signup_user_with_invitation, signin_user, refresh_user_token, \ - get_session_by_authorization, get_user_info + get_session_by_authorization, get_user_info, create_token, list_tokens_by_user, delete_token from services.user_service import delete_user_and_cleanup from consts.exceptions import UnauthorizedError from utils.auth_utils import get_current_user_id @@ -273,3 +274,107 @@ async def revoke_user_account(request: Request): logging.error(f"User revoke failed: {str(e)}") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="User revoke failed") + +@router.post("/tokens") +async def create_token_endpoint( + authorization: Optional[str] = Header(None) +): + """Create a new token for the authenticated user. + + The user_id is extracted from the Authorization header (JWT token). + Returns the complete token including the secret key. + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + user_id, _ = get_current_user_id(authorization) + if not user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + result = create_token(str(user_id)) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": result} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to create token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.get("/tokens") +async def list_tokens_endpoint( + user_id: str = Query(..., description="User ID to query tokens for"), + authorization: Optional[str] = Header(None) +): + """List all tokens for the specified user. + + Returns token information with masked access keys (middle part replaced with *). + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + request_user_id, _ = get_current_user_id(authorization) + if not request_user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + # Only allow users to list their own tokens + if str(request_user_id) != user_id: + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, + detail="Forbidden: cannot list tokens for other users") + + tokens = list_tokens_by_user(user_id) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": tokens} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to list tokens: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.delete("/tokens/{token_id}") +async def delete_token_endpoint( + token_id: int, + authorization: Optional[str] = Header(None) +): + """Soft delete a token. + + Only the owner of the token can delete it. + """ + try: + if not authorization: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: No authorization header found") + + user_id, _ = get_current_user_id(authorization) + if not user_id: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, + detail="Unauthorized: missing user_id in JWT token") + + success = delete_token(token_id, str(user_id)) + if not success: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, + detail="Token not found or not owned by user") + + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": {"token_id": token_id}} + ) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to delete token: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") diff --git a/backend/database/conversation_db.py b/backend/database/conversation_db.py index 0267d77c2..18c0ee9fc 100644 --- a/backend/database/conversation_db.py +++ b/backend/database/conversation_db.py @@ -4,7 +4,7 @@ from sqlalchemy import asc, desc, func, insert, select, update -from .client import as_dict, get_db_session +from .client import as_dict, db_client, get_db_session from .db_models import ( ConversationMessage, ConversationMessageUnit, @@ -328,11 +328,12 @@ def rename_conversation(conversation_id: int, new_title: str, user_id: Optional[ # Ensure conversation_id is of integer type conversation_id = int(conversation_id) - # Prepare update data + # Prepare update data with UTF-8 encoding for title update_data = { "conversation_title": new_title, "update_time": func.current_timestamp() } + update_data = db_client.clean_string_values(update_data) if user_id: update_data = add_update_tracking(update_data, user_id) diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 36f475f53..80dcc87eb 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -1,4 +1,5 @@ from sqlalchemy import BigInteger, Boolean, Column, Integer, JSON, Numeric, PrimaryKeyConstraint, Sequence, String, Text, TIMESTAMP +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql import func @@ -483,3 +484,31 @@ class AgentVersion(TableBase): source_version_no = Column(Integer, doc="Source version number. If this version is a rollback, record the source version") source_type = Column(String(30), doc="Source type: NORMAL (normal publish) / ROLLBACK (rollback and republish)") status = Column(String(30), default="RELEASED", doc="Version status: RELEASED / DISABLED / ARCHIVED") + + +class UserTokenInfo(TableBase): + """ + User token (AK/SK) information table + """ + __tablename__ = "user_token_info_t" + __table_args__ = {"schema": SCHEMA} + + token_id = Column(Integer, Sequence("user_token_info_t_token_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Token ID, unique primary key") + access_key = Column(String(100), nullable=False, doc="Access Key (AK)") + user_id = Column(String(100), nullable=False, doc="User ID who owns this token") + + +class UserTokenUsageLog(TableBase): + """ + User token usage log table + """ + __tablename__ = "user_token_usage_log_t" + __table_args__ = {"schema": SCHEMA} + + token_usage_id = Column(Integer, Sequence("user_token_usage_log_t_token_usage_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Token usage log ID, unique primary key") + token_id = Column(Integer, nullable=False, doc="Foreign key to user_token_info_t.token_id") + call_function_name = Column(String(100), doc="API function name being called") + related_id = Column(Integer, doc="Related resource ID (e.g., conversation_id)") + meta_data = Column(JSONB, doc="Additional metadata for this usage log entry, stored as JSON") diff --git a/backend/database/token_db.py b/backend/database/token_db.py new file mode 100644 index 000000000..70d53a42e --- /dev/null +++ b/backend/database/token_db.py @@ -0,0 +1,189 @@ +""" +Database operations for user API token (API Key) management. +""" +import secrets +from typing import Any, Dict, List, Optional + +from database.client import get_db_session +from database.db_models import UserTokenInfo, UserTokenUsageLog + + +def generate_access_key() -> str: + """Generate a random access key with format nexent-xxxxx...""" + random_part = secrets.token_hex(12) # 24 hex characters for more entropy + return f"nexent-{random_part}" + + +def create_token(access_key: str, user_id: str) -> Dict[str, Any]: + """Create a new token record in the database. + + Args: + access_key: The access key (API Key). + user_id: The user ID who owns this token. + + Returns: + Dictionary containing the created token information. + """ + with get_db_session() as session: + token = UserTokenInfo( + access_key=access_key, + user_id=user_id, + created_by=user_id, + updated_by=user_id, + delete_flag='N' + ) + session.add(token) + session.flush() + + return { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id + } + + +def list_tokens_by_user(user_id: str) -> List[Dict[str, Any]]: + """List all active tokens for the specified user. + + Args: + user_id: The user ID to query tokens for. + + Returns: + List of token information with masked access keys. + """ + with get_db_session() as session: + tokens = session.query(UserTokenInfo).filter( + UserTokenInfo.user_id == user_id, + UserTokenInfo.delete_flag == 'N' + ).order_by(UserTokenInfo.create_time.desc()).all() + + return [ + { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id, + "create_time": token.create_time.isoformat() if token.create_time else None + } + for token in tokens + ] + + +def get_token_by_id(token_id: int) -> UserTokenInfo: + """Get a token by its ID. + + Args: + token_id: The token ID to query. + + Returns: + UserTokenInfo object if found and active, None otherwise. + """ + with get_db_session() as session: + return session.query(UserTokenInfo).filter( + UserTokenInfo.token_id == token_id, + UserTokenInfo.delete_flag == 'N' + ).first() + + +def get_token_by_access_key(access_key: str) -> Optional[Dict[str, Any]]: + """Get a token by its access key. + + Args: + access_key: The access key to query. + + Returns: + Token information dict if found and active, None otherwise. + """ + with get_db_session() as session: + token = session.query(UserTokenInfo).filter( + UserTokenInfo.access_key == access_key, + UserTokenInfo.delete_flag == 'N' + ).first() + + if token: + return { + "token_id": token.token_id, + "access_key": token.access_key, + "user_id": token.user_id, + "delete_flag": token.delete_flag + } + return None + + +def delete_token(token_id: int, user_id: str) -> bool: + """Soft delete a token by setting delete_flag to 'Y'. + + Args: + token_id: The token ID to delete. + user_id: The user ID who owns this token (for authorization). + + Returns: + True if the token was deleted, False if not found or not owned by user. + """ + with get_db_session() as session: + token = session.query(UserTokenInfo).filter( + UserTokenInfo.token_id == token_id, + UserTokenInfo.user_id == user_id, + UserTokenInfo.delete_flag == 'N' + ).first() + + if not token: + return False + + token.delete_flag = 'Y' + token.updated_by = user_id + return True + + +def log_token_usage( + token_id: int, + call_function_name: str, + related_id: Optional[int], + created_by: str, + metadata: Optional[Dict[str, Any]] = None +) -> int: + """Log token usage to the database. + + Args: + token_id: The token ID used. + call_function_name: The API function name being called. + related_id: Related resource ID (e.g., conversation_id). + created_by: User ID who initiated the call. + metadata: Optional additional metadata for this usage log entry. + + Returns: + The created token_usage_id. + """ + with get_db_session() as session: + usage_log = UserTokenUsageLog( + token_id=token_id, + call_function_name=call_function_name, + related_id=related_id, + created_by=created_by, + meta_data=metadata + ) + session.add(usage_log) + session.flush() + return usage_log.token_usage_id + + +def get_latest_usage_metadata(token_id: int, related_id: int, call_function_name: str) -> Optional[Dict[str, Any]]: + """Get the latest metadata for a given token, related_id and function name. + + Args: + token_id: The token ID used. + related_id: Related resource ID (e.g., conversation_id). + call_function_name: The API function name. + + Returns: + The metadata dict if found, None otherwise. + """ + with get_db_session() as session: + usage_log = session.query(UserTokenUsageLog).filter( + UserTokenUsageLog.token_id == token_id, + UserTokenUsageLog.related_id == related_id, + UserTokenUsageLog.call_function_name == call_function_name + ).order_by(UserTokenUsageLog.create_time.desc()).first() + + if usage_log and usage_log.meta_data: + return usage_log.meta_data + return None diff --git a/backend/services/northbound_service.py b/backend/services/northbound_service.py index 6f9164269..a6eaed77d 100644 --- a/backend/services/northbound_service.py +++ b/backend/services/northbound_service.py @@ -13,11 +13,7 @@ ) from consts.model import AgentRequest from database.conversation_db import get_conversation_messages -from database.partner_db import ( - add_mapping_id, - get_external_id_by_internal, - get_internal_id_by_external -) +from database.token_db import log_token_usage, get_latest_usage_metadata from services.agent_service import ( run_agent_stream, stop_agent_tasks, @@ -40,6 +36,7 @@ class NorthboundContext: tenant_id: str user_id: str authorization: str + token_id: int = 0 # ----------------------------- @@ -114,26 +111,6 @@ def _build_idempotency_key(*parts: Any) -> str: return ":".join(processed) -# ----------------------------- -# ID mapping helpers -# ----------------------------- -async def to_external_conversation_id(internal_id: int) -> str: - if not internal_id: - raise Exception("invalid internal conversation id") - external_id = get_external_id_by_internal(internal_id=internal_id, mapping_type="CONVERSATION") - if not external_id: - logger.error(f"cannot find external id for conversation_id: {internal_id}") - raise Exception("cannot find external id") - return external_id - - -async def to_internal_conversation_id(external_id: str) -> int: - if not external_id: - raise Exception("invalid external conversation id") - internal_id = get_internal_id_by_external(external_id=external_id, mapping_type="CONVERSATION") - return internal_id - - # ----------------------------- # Agent resolver # ----------------------------- @@ -146,30 +123,30 @@ async def get_agent_info_by_name(agent_name: str, tenant_id: str) -> int: async def start_streaming_chat( ctx: NorthboundContext, - external_conversation_id: str, + conversation_id: Optional[int], agent_name: str, query: str, + meta_data: Optional[Dict[str, Any]] = None, idempotency_key: Optional[str] = None ) -> StreamingResponse: try: # Simple rate limit await check_and_consume_rate_limit(ctx.tenant_id) - internal_conversation_id = await to_internal_conversation_id(external_conversation_id) - # Add mapping to postgres database - if internal_conversation_id is None: - logging.info(f"Conversation {external_conversation_id} not found, creating a new conversation") - # Create a new conversation and get its internal ID + # If conversation_id is not provided, create a new conversation + if conversation_id is None: + logging.info("No conversation_id provided, creating a new conversation") new_conversation = create_new_conversation(title="New Conversation", user_id=ctx.user_id) - internal_conversation_id = new_conversation["conversation_id"] - # Add the new mapping to the database - add_mapping_id(internal_id=internal_conversation_id, external_id=external_conversation_id, tenant_id=ctx.tenant_id, user_id=ctx.user_id) + conversation_id = new_conversation["conversation_id"] + logging.info(f"Created new conversation with id: {conversation_id}") + + internal_conversation_id = conversation_id # Get history according to internal_conversation_id - history_resp = await get_conversation_history(ctx, external_conversation_id) + history_resp = await get_conversation_history_internal(ctx, internal_conversation_id) agent_id = await get_agent_id_by_name(agent_name=agent_name, tenant_id=ctx.tenant_id) # Idempotency: only prevent concurrent duplicate starts - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, external_conversation_id, agent_id, query) + composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), agent_id, query) await idempotency_start(composed_key) agent_request = AgentRequest( conversation_id=internal_conversation_id, @@ -192,7 +169,7 @@ async def start_streaming_chat( except UnauthorizedError as _: raise UnauthorizedError("Cannot authenticate.") except Exception as e: - raise Exception(f"Failed to start streaming chat for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to start streaming chat for conversation_id {conversation_id}: {str(e)}") try: response = await run_agent_stream( @@ -207,34 +184,82 @@ async def start_streaming_chat( if composed_key: asyncio.create_task(_release_idempotency_after_delay(composed_key)) - # Attach request id header + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="run_chat", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + + # Attach request id header and conversation_id (internal id) response.headers["X-Request-Id"] = ctx.request_id - response.headers["conversation_id"] = external_conversation_id + response.headers["conversation_id"] = str(conversation_id) return response -async def stop_chat(ctx: NorthboundContext, external_conversation_id: str) -> Dict[str, Any]: +async def stop_chat(ctx: NorthboundContext, conversation_id: int, meta_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: try: - internal_id = await to_internal_conversation_id(external_conversation_id) - - stop_result = stop_agent_tasks(internal_id, ctx.user_id) - return {"message": stop_result.get("message", "success"), "data": external_conversation_id, "requestId": ctx.request_id} + stop_result = stop_agent_tasks(conversation_id, ctx.user_id) + + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="stop_chat_stream", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + + return {"message": stop_result.get("message", "success"), "data": conversation_id, "requestId": ctx.request_id} except Exception as e: - raise Exception(f"Failed to stop chat for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to stop chat for conversation_id {conversation_id}: {str(e)}") async def list_conversations(ctx: NorthboundContext) -> Dict[str, Any]: conversations = get_conversation_list_service(ctx.user_id) # get_conversation_list_service is sync - for item in conversations: - item["conversation_id"] = await to_external_conversation_id(int(item["conversation_id"])) - return {"message": "success", "data": conversations, "requestId": ctx.request_id} + # Add meta_data from token usage log if available + if ctx.token_id > 0: + for item in conversations: + # Ensure we do not leak empty meta_data keys + if "meta_data" in item and not item.get("meta_data"): + item.pop("meta_data", None) + + conversation_id = item.get("conversation_id") + if conversation_id: + try: + meta_data = get_latest_usage_metadata( + token_id=ctx.token_id, + related_id=int(conversation_id), + call_function_name="run_chat" + ) + # Only return meta_data when there is a usage log record and meta_data is non-empty + if meta_data: + item["meta_data"] = meta_data + else: + item.pop("meta_data", None) + except Exception as e: + logger.warning(f"Failed to get meta_data for conversation {conversation_id}: {str(e)}") + item.pop("meta_data", None) + + # Now return internal conversation_id directly + return {"message": "success", "data": conversations, "requestId": ctx.request_id} -async def get_conversation_history(ctx: NorthboundContext, external_conversation_id: str) -> Dict[str, Any]: - internal_id = await to_internal_conversation_id(external_conversation_id) - history = get_conversation_messages(internal_id) +async def get_conversation_history_internal(ctx: NorthboundContext, conversation_id: int) -> Dict[str, Any]: + """Internal helper to get conversation history without logging.""" + history = get_conversation_messages(conversation_id) # Remove unnecessary fields result = [] for message in history: @@ -244,44 +269,63 @@ async def get_conversation_history(ctx: NorthboundContext, external_conversation }) response = { - "conversation_id": external_conversation_id, + "conversation_id": conversation_id, "history": result } - # Ensure external id in response return {"message": "success", "data": response, "requestId": ctx.request_id} +async def get_conversation_history(ctx: NorthboundContext, conversation_id: int) -> Dict[str, Any]: + try: + return await get_conversation_history_internal(ctx, conversation_id) + except Exception as e: + raise Exception(f"Failed to get conversation history for conversation_id {conversation_id}: {str(e)}") + + async def get_agent_info_list(ctx: NorthboundContext) -> Dict[str, Any]: try: - agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id) + agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id, user_id=ctx.user_id) # Remove internal information that partner don't need for agent_info in agent_info_list: agent_info.pop("agent_id", None) + return {"message": "success", "data": agent_info_list, "requestId": ctx.request_id} except Exception as e: raise Exception(f"Failed to get agent info list for tenant {ctx.tenant_id}: {str(e)}") -async def update_conversation_title(ctx: NorthboundContext, external_conversation_id: str, title: str, idempotency_key: Optional[str] = None) -> Dict[str, Any]: +async def update_conversation_title(ctx: NorthboundContext, conversation_id: int, title: str, meta_data: Optional[Dict[str, Any]] = None, idempotency_key: Optional[str] = None) -> Dict[str, Any]: composed_key: Optional[str] = None try: - internal_id = await to_internal_conversation_id(external_conversation_id) - # Idempotency: avoid concurrent duplicate title update for same conversation - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, external_conversation_id, title) + composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), title) await idempotency_start(composed_key) - update_conversation_title_service(internal_id, title, ctx.user_id) + update_conversation_title_service(conversation_id, title, ctx.user_id) + + # Log token usage + if ctx.token_id > 0: + try: + log_token_usage( + token_id=ctx.token_id, + call_function_name="update_conversation_title", + related_id=conversation_id, + created_by=ctx.user_id, + metadata=meta_data + ) + except Exception as e: + logger.warning(f"Failed to log token usage: {str(e)}") + return { "message": "success", - "data": external_conversation_id, + "data": conversation_id, "requestId": ctx.request_id, "idempotency_key": composed_key, } except LimitExceededError as _: raise LimitExceededError("Duplicate request is still running, please wait.") except Exception as e: - raise Exception(f"Failed to update conversation title for external conversation id {external_conversation_id}: {str(e)}") + raise Exception(f"Failed to update conversation title for conversation_id {conversation_id}: {str(e)}") finally: if composed_key: asyncio.create_task(_release_idempotency_after_delay(composed_key)) diff --git a/backend/services/user_management_service.py b/backend/services/user_management_service.py index 792887ec5..3499d3170 100644 --- a/backend/services/user_management_service.py +++ b/backend/services/user_management_service.py @@ -1,6 +1,13 @@ import logging from typing import Optional, Any, Tuple, Dict, List +from database.token_db import ( + create_token as create_token_record, + generate_access_key, + list_tokens_by_user as list_tokens_by_user_record, + delete_token as delete_token_record, +) + import aiohttp from fastapi import Header from supabase import Client @@ -472,3 +479,45 @@ def format_role_permissions(permissions: List[Dict[str, Any]]) -> Dict[str, List "permissions": formatted_permissions, "accessibleRoutes": accessible_routes } + + +# ----------------------------- +# Token Management +# ----------------------------- + +def create_token(user_id: str) -> Dict[str, Any]: + """Create a new API token for the specified user. + + Args: + user_id: The user ID who owns this token. + + Returns: + Dictionary containing the API token information including token_id. + """ + access_key = generate_access_key() + return create_token_record(access_key, user_id) + + +def list_tokens_by_user(user_id: str) -> List[Dict[str, Any]]: + """List all tokens for the specified user. + + Args: + user_id: The user ID to query token pairs for. + + Returns: + List of token information with masked access keys. + """ + return list_tokens_by_user_record(user_id) + + +def delete_token(token_id: int, user_id: str) -> bool: + """Soft delete a token. + + Args: + token_id: The token ID to delete. + user_id: The user ID who owns this token (for authorization). + + Returns: + True if the token was deleted, False if not found or not owned by user. + """ + return delete_token_record(token_id, user_id) diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index a27a48b38..7b40576e2 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -1,9 +1,9 @@ import logging -import hashlib -import hmac import time +import hmac +import hashlib from datetime import datetime, timedelta -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import jwt from fastapi import Request @@ -20,189 +20,195 @@ DEBUG_JWT_EXPIRE_SECONDS, LANGUAGE, ) -from consts.exceptions import LimitExceededError, SignatureValidationError, UnauthorizedError +from consts.exceptions import LimitExceededError, UnauthorizedError from database.user_tenant_db import get_user_tenant_by_user_id +from database.token_db import get_token_by_access_key # Module logger logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# AK/SK authentication helpers (merged from aksk_auth_utils.py) +# Shared test constants # --------------------------------------------------------------------------- -# Mock AK/SK configuration (replace with DB/config lookup in production) -MOCK_ACCESS_KEY = "mock_access_key_12345" -MOCK_SECRET_KEY = "mock_secret_key_67890abcdef" -MOCK_JWT_SECRET_KEY = "mock_jwt_secret_key_67890abcdef" +# Fixed test secret used by generate_test_jwt and unit tests. +MOCK_JWT_SECRET_KEY = "nexent-mock-jwt-secret" -# Timestamp validity window in seconds (prevent replay attacks) -TIMESTAMP_VALIDITY_WINDOW = 300 +# --------------------------------------------------------------------------- +# AK/SK (Access Key / Secret Key) authentication helpers +# --------------------------------------------------------------------------- +# Validity window in seconds for X-Timestamp header. +TIMESTAMP_VALIDITY_WINDOW = 5 * 60 -def get_aksk_config(tenant_id: str) -> Tuple[str, str]: - """ - Get AK/SK configuration according to tenant_id - Returns: - Tuple[str, str]: (access_key, secret_key) +def calculate_hmac_signature(secret_key: str, access_key: str, timestamp: str, body: str) -> str: """ + Calculate HMAC-SHA256 signature for AK/SK authentication. - # TODO: get ak/sk according to tenant_id from DB - return MOCK_ACCESS_KEY, MOCK_SECRET_KEY + Returns a lowercase hex digest of length 64. + """ + message = f"{access_key}\n{timestamp}\n{body}".encode("utf-8") + return hmac.new(secret_key.encode("utf-8"), message, hashlib.sha256).hexdigest() def validate_timestamp(timestamp: str) -> bool: - """ - Validate timestamp is within validity window + """Validate that timestamp is within allowed window.""" + try: + ts = int(timestamp) + except (TypeError, ValueError): + return False - Args: - timestamp: timestamp string + now = int(time.time()) + return abs(now - ts) <= TIMESTAMP_VALIDITY_WINDOW - Returns: - bool: whether timestamp is valid + +def extract_aksk_headers(headers: Dict[str, str]) -> Tuple[str, str, str]: + """Extract AK/SK headers or raise UnauthorizedError when missing.""" + access_key = headers.get("X-Access-Key") if headers else None + timestamp = headers.get("X-Timestamp") if headers else None + signature = headers.get("X-Signature") if headers else None + + if not access_key or not timestamp or not signature: + raise UnauthorizedError("Missing AK/SK authentication headers") + + return access_key, timestamp, signature + + +def get_aksk_config(tenant_id: str) -> Tuple[str, str]: """ - try: - timestamp_int = int(timestamp) - current_time = int(time.time()) + Get (access_key, secret_key) configuration for a tenant. - if abs(current_time - timestamp_int) > TIMESTAMP_VALIDITY_WINDOW: - logger.warning( - f"Timestamp validation failed: current={current_time}, provided={timestamp_int}" - ) - return False + This is intentionally a thin indirection so tests can monkeypatch it. + """ + raise UnauthorizedError("AK/SK authentication is not configured") - return True - except (ValueError, TypeError) as e: - logger.error(f"Invalid timestamp format: {timestamp}, error: {e}") + +def verify_aksk_signature(access_key: str, timestamp: str, signature: str, body: str, tenant_id: str = None) -> bool: + """Verify AK/SK signature; returns False instead of raising on mismatch.""" + tenant = tenant_id or DEFAULT_TENANT_ID + try: + expected_access_key, secret_key = get_aksk_config(tenant) + except Exception: return False + if access_key != expected_access_key: + return False -def calculate_hmac_signature(secret_key: str, access_key: str, timestamp: str, request_body: str = "") -> str: - """ - Calculate HMAC-SHA256 signature + expected_sig = calculate_hmac_signature(secret_key, access_key, timestamp, body) + return hmac.compare_digest(expected_sig, signature) - Args: - secret_key: secret key - access_key: access key - timestamp: timestamp - request_body: request body (optional) - Returns: - str: HMAC-SHA256 signature (hex string) +def validate_aksk_authentication(headers: Dict[str, str], body: str, tenant_id: str = None) -> bool: """ - string_to_sign = f"{access_key}{timestamp}{request_body}" - signature = hmac.new( - secret_key.encode("utf-8"), - string_to_sign.encode("utf-8"), - hashlib.sha256, - ).hexdigest() - return signature - - -def verify_aksk_signature( - access_key: str, timestamp: str, signature: str, request_body: str = "" -) -> bool: - """ - Validate AK/SK signature - - Args: - access_key: access key - timestamp: timestamp - signature: provided signature - request_body: request body (optional) + Validate AK/SK authentication. - Returns: - bool: whether signature is valid + Returns True when valid, otherwise raises domain exceptions. """ - try: - if not validate_timestamp(timestamp): - raise SignatureValidationError("Timestamp is invalid or expired") + from consts.exceptions import SignatureValidationError # imported lazily for test-time stubbing - # TODO: get ak/sk according to tenant_id from DB - mock_access_key, mock_secret_key = get_aksk_config( - tenant_id="tenant_id") + try: + access_key, ts, sig = extract_aksk_headers(headers) - if access_key != mock_access_key: - logger.warning(f"Invalid access key: {access_key}") - return False + if not validate_timestamp(ts): + raise UnauthorizedError("Invalid or expired timestamp") - expected_signature = calculate_hmac_signature( - mock_secret_key, access_key, timestamp, request_body - ) + # Call with positional args so tests can monkeypatch with simple lambdas. + if tenant_id is None: + ok = verify_aksk_signature(access_key, ts, sig, body) + else: + ok = verify_aksk_signature(access_key, ts, sig, body, tenant_id) - if not hmac.compare_digest(signature, expected_signature): - logger.warning( - f"Signature mismatch: expected={expected_signature}, provided={signature}" - ) - return False + if not ok: + raise SignatureValidationError("Invalid signature") return True - except Exception as e: - logger.error(f"Error during signature verification: {e}") - return False + except (UnauthorizedError, SignatureValidationError): + raise + except Exception as exc: + logger.exception("Unexpected error during AK/SK authentication") + raise UnauthorizedError("Authentication failed") from exc +# --------------------------------------------------------------------------- +# Bearer Token (API Key) authentication +# --------------------------------------------------------------------------- -def extract_aksk_headers(headers: dict) -> Tuple[str, str, str]: + +def validate_bearer_token(authorization: Optional[str]) -> Tuple[bool, Optional[dict]]: """ - Extract AK/SK related information from request headers + Validate Bearer token (API Key) from Authorization header. Args: - headers: request headers dictionary + authorization: Authorization header value (e.g., "Bearer nexent-xxxxx") Returns: - Tuple[str, str, str]: (access_key, timestamp, signature) - - Raises: - UnauthorizedError: when required headers are missing + Tuple of (is_valid, token_info_dict) + - is_valid: True if token exists and is active + - token_info: Token information dict if valid, None otherwise """ + if not authorization: + logger.warning("No authorization header provided") + return False, None - def get_header(headers: dict, name: str) -> Optional[str]: - for k, v in headers.items(): - if k.lower() == name.lower(): - return v - return None - - access_key = get_header(headers, "X-Access-Key") - timestamp = get_header(headers, "X-Timestamp") - signature = get_header(headers, "X-Signature") + # Extract token from "Bearer " format + token = authorization.replace("Bearer ", "") if authorization.startswith("Bearer ") else authorization - if not access_key: - raise UnauthorizedError("Missing X-Access-Key header") - if not timestamp: - raise UnauthorizedError("Missing X-Timestamp header") - if not signature: - raise UnauthorizedError("Missing X-Signature header") + if not token: + logger.warning("Empty bearer token") + return False, None - return access_key, timestamp, signature + # Look up token in database + try: + token_info = get_token_by_access_key(token) + if token_info and token_info.get("delete_flag") != "Y": + logger.debug(f"Token validated successfully for user {token_info.get('user_id')}") + return True, token_info + else: + logger.warning(f"Invalid or inactive token: {token[:20]}...") + return False, None + except Exception as e: + logger.error(f"Error validating bearer token: {str(e)}") + return False, None -def validate_aksk_authentication(headers: dict, request_body: str = "") -> bool: +def get_user_and_tenant_by_access_key(access_key: str) -> Dict[str, str]: """ - Validate AK/SK authentication + Get user_id and tenant_id from access_key by querying user_token_info_t and user_tenant_t. Args: - headers: request headers dictionary - request_body: request body (optional) + access_key: The access key (API Key) from the Authorization header. Returns: - bool: whether authentication is successful + Dict containing user_id and tenant_id. Raises: - UnauthorizedError: when authentication fails - SignatureValidationError: when signature verification fails + UnauthorizedError: If the access key is not found or invalid. """ - try: - access_key, timestamp, signature = extract_aksk_headers(headers) - - if not verify_aksk_signature(access_key, timestamp, signature, request_body): - raise SignatureValidationError("Invalid signature") - - return True - except (UnauthorizedError, SignatureValidationError, LimitExceededError) as e: - raise e - except Exception as e: - logger.error(f"Unexpected error during AK/SK authentication: {e}") - raise UnauthorizedError("Authentication failed") + if not access_key: + raise UnauthorizedError("Invalid access key") + + # Query token from user_token_info_t + token_info = get_token_by_access_key(access_key) + if not token_info or token_info.get("delete_flag") == "Y": + raise UnauthorizedError("Invalid or inactive access key") + + user_id = token_info.get("user_id") + if not user_id: + raise UnauthorizedError("No user associated with this access key") + + # Query tenant from user_tenant_t + user_tenant_record = get_user_tenant_by_user_id(user_id) + if user_tenant_record and user_tenant_record.get("tenant_id"): + tenant_id = user_tenant_record["tenant_id"] + else: + tenant_id = DEFAULT_TENANT_ID + logger.warning(f"No tenant relationship found for user {user_id}, using default tenant") + + return { + "user_id": user_id, + "tenant_id": tenant_id, + "token_id": token_info.get("token_id") + } def get_supabase_client(): diff --git a/docker/sql/v1.8.1_0306_add_user_token_info.sql b/docker/sql/v1.8.1_0306_add_user_token_info.sql new file mode 100644 index 000000000..040530334 --- /dev/null +++ b/docker/sql/v1.8.1_0306_add_user_token_info.sql @@ -0,0 +1,118 @@ +-- Migration: Add user_token_info_t and user_token_usage_log_t tables +-- Date: 2026-03-06 +-- Description: Create user token (AK/SK) management tables with audit fields + +-- Set search path to nexent schema +SET search_path TO nexent; + +-- Create the user_token_info_t table in the nexent schema +CREATE TABLE IF NOT EXISTS nexent.user_token_info_t ( + token_id SERIAL4 PRIMARY KEY NOT NULL, + access_key VARCHAR(100) NOT NULL, + user_id VARCHAR(100) NOT NULL, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE "user_token_info_t" OWNER TO "root"; + +-- Add comment to the table +COMMENT ON TABLE nexent.user_token_info_t IS 'User token (AK/SK) information table'; + +-- Add comments to the columns +COMMENT ON COLUMN nexent.user_token_info_t.token_id IS 'Token ID, unique primary key'; +COMMENT ON COLUMN nexent.user_token_info_t.access_key IS 'Access Key (AK)'; +COMMENT ON COLUMN nexent.user_token_info_t.user_id IS 'User ID who owns this token'; +COMMENT ON COLUMN nexent.user_token_info_t.create_time IS 'Creation time, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.update_time IS 'Update time, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.created_by IS 'Creator ID, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.updated_by IS 'Last updater ID, audit field'; +COMMENT ON COLUMN nexent.user_token_info_t.delete_flag IS 'Soft delete flag, Y means deleted'; + +-- Create unique index on access_key to ensure uniqueness +CREATE UNIQUE INDEX IF NOT EXISTS idx_user_token_info_access_key ON nexent.user_token_info_t(access_key) WHERE delete_flag = 'N'; + +-- Create index on user_id for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_info_user_id ON nexent.user_token_info_t(user_id) WHERE delete_flag = 'N'; + +-- Create a function to update the update_time column +CREATE OR REPLACE FUNCTION update_user_token_info_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Add comment to the function +COMMENT ON FUNCTION update_user_token_info_update_time() IS 'Function to update the update_time column when a record in user_token_info_t is updated'; + +-- Create a trigger to call the function before each update +DROP TRIGGER IF EXISTS update_user_token_info_update_time_trigger ON nexent.user_token_info_t; +CREATE TRIGGER update_user_token_info_update_time_trigger +BEFORE UPDATE ON nexent.user_token_info_t +FOR EACH ROW +EXECUTE FUNCTION update_user_token_info_update_time(); + +-- Add comment to the trigger +COMMENT ON TRIGGER update_user_token_info_update_time_trigger ON nexent.user_token_info_t IS 'Trigger to call update_user_token_info_update_time function before each update on user_token_info_t table'; + + +-- Create the user_token_usage_log_t table in the nexent schema +CREATE TABLE IF NOT EXISTS nexent.user_token_usage_log_t ( + token_usage_id SERIAL4 PRIMARY KEY NOT NULL, + token_id INT4 NOT NULL, + call_function_name VARCHAR(100), + related_id INT4, + meta_data JSONB, + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +ALTER TABLE "user_token_usage_log_t" OWNER TO "root"; + +-- Add comment to the table +COMMENT ON TABLE nexent.user_token_usage_log_t IS 'User token usage log table'; + +-- Add comments to the columns +COMMENT ON COLUMN nexent.user_token_usage_log_t.token_usage_id IS 'Token usage log ID, unique primary key'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.token_id IS 'Foreign key to user_token_info_t.token_id'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.call_function_name IS 'API function name being called'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.related_id IS 'Related resource ID (e.g., conversation_id)'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.meta_data IS 'Additional metadata for this usage log entry, stored as JSON'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.create_time IS 'Creation time, audit field'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.update_time IS 'Update time, audit field'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.created_by IS 'Creator ID, audit field'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.updated_by IS 'Last updater ID, audit field'; +COMMENT ON COLUMN nexent.user_token_usage_log_t.delete_flag IS 'Soft delete flag, Y means deleted'; + +-- Create index on token_id for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_usage_log_token_id ON nexent.user_token_usage_log_t(token_id); + +-- Create index on call_function_name for query performance +CREATE INDEX IF NOT EXISTS idx_user_token_usage_log_function_name ON nexent.user_token_usage_log_t(call_function_name); + +-- Add foreign key constraint +ALTER TABLE nexent.user_token_usage_log_t +ADD CONSTRAINT fk_user_token_usage_log_token_id +FOREIGN KEY (token_id) +REFERENCES nexent.user_token_info_t(token_id) +ON DELETE CASCADE; + + +-- Migration: Remove partner_mapping_id_t table for northbound conversation ID mapping +-- Date: 2026-03-10 +-- Description: Remove the external-internal conversation ID mapping table as northbound APIs now use internal conversation IDs directly +-- Note: This table is no longer needed after refactoring northbound authentication logic + +-- Drop the partner_mapping_id_t table if it exists +DROP TABLE IF EXISTS nexent.partner_mapping_id_t CASCADE; + +-- Drop the associated sequence if it exists +DROP SEQUENCE IF EXISTS nexent.partner_mapping_id_t_id_seq; diff --git a/frontend/app/[locale]/users/components/UserProfileComp.tsx b/frontend/app/[locale]/users/components/UserProfileComp.tsx index 6d45b4db0..2a66bd89e 100644 --- a/frontend/app/[locale]/users/components/UserProfileComp.tsx +++ b/frontend/app/[locale]/users/components/UserProfileComp.tsx @@ -1,6 +1,6 @@ "use client"; -import React, { useState } from "react"; +import React, { useState, useEffect } from "react"; import { Button, Typography, @@ -25,6 +25,9 @@ import { Edit, Key, ChevronRight, + KeySquare, + KeyRound, + Copy, } from "lucide-react"; import { USER_ROLES } from "@/const/modelConfig"; import { useAuthorizationContext } from "@/components/providers/AuthorizationProvider"; @@ -32,6 +35,12 @@ import { useAuthenticationContext } from "@/components/providers/AuthenticationP import { useGroupList } from "@/hooks/group/useGroupList"; import { useMemo } from "react"; import { DeleteAccountModal } from "@/components/auth/DeleteAccountModal"; +import log from "@/lib/logger"; +import { + getUserTokens, + deleteUserToken, + createUserToken, +} from "@/services/tokenService"; /** * UserProfileComp - User profile and account settings component @@ -77,6 +86,12 @@ export default function UserProfileComp() { const [isPasswordModalOpen, setIsPasswordModalOpen] = useState(false); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); + // AK/SK state + const [akInfo, setAkInfo] = useState(null); + const [existingTokenIds, setExistingTokenIds] = useState([]); + const [isLoadingAkSk, setIsLoadingAkSk] = useState(false); + const [isGeneratingAkSk, setIsGeneratingAkSk] = useState(false); + // Form instances const [editForm] = Form.useForm(); const [passwordForm] = Form.useForm(); @@ -121,6 +136,58 @@ export default function UserProfileComp() { } }; + // Fetch AK/SK info on mount + useEffect(() => { + const fetchAkSkInfo = async () => { + if (!user?.id) return; + setIsLoadingAkSk(true); + try { + const tokens = await getUserTokens(user.id); + if (tokens.length > 0) { + setAkInfo(tokens[0].access_key); + setExistingTokenIds(tokens.map((t) => t.token_id)); + } + } catch (error) { + log.error("Failed to fetch AK/SK info:", error); + } finally { + setIsLoadingAkSk(false); + } + }; + + fetchAkSkInfo(); + }, [user?.id]); + + // Handle generate AK/SK: delete existing tokens first, then create a new one + const handleGenerateAkSk = async () => { + setIsGeneratingAkSk(true); + try { + for (const tokenId of existingTokenIds) { + await deleteUserToken(tokenId); + } + + const newToken = await createUserToken(); + setAkInfo(newToken.access_key); + setExistingTokenIds([newToken.token_id]); + antdMessage.success(t("profile.generateAkSkSuccess") || "Access key generated successfully"); + } catch (error) { + antdMessage.error(t("profile.generateAkSkFailed") || "Failed to generate access key"); + } finally { + setIsGeneratingAkSk(false); + } + }; + + // Handle copy AK to clipboard + const handleCopyAk = async () => { + if (akInfo) { + try { + await navigator.clipboard.writeText(akInfo); + antdMessage.success(t("profile.copyAkSuccess") || "Access key copied to clipboard"); + } catch (error) { + antdMessage.error(t("profile.copyAkFailed") || "Failed to copy access key"); + } + } + }; + // Open edit modal // const openEditModal = () => { // editForm.setFieldsValue({ @@ -272,7 +339,7 @@ export default function UserProfileComp() { >
- +
@@ -286,6 +353,86 @@ export default function UserProfileComp() {
+ {/* Generate Access Token Option */} +
{ + if (akInfo) { + Modal.confirm({ + title: t("profile.generateAkSkConfirmTitle") || "Generate New Access Key", + content: t("profile.generateAkSkConfirmContent") || "You already have an access key. Generating a new one will overwrite the existing key. Continue?", + okText: t("common.confirm") || "Confirm", + cancelText: t("common.cancel") || "Cancel", + onOk: handleGenerateAkSk, + okButtonProps: { loading: isGeneratingAkSk }, + }); + } else { + handleGenerateAkSk(); + } + }} + > +
+
+ +
+
+
+ {t("profile.generateAkSk") || "Generate Access Token"} +
+ {akInfo ? ( +
+ + {akInfo} + +
+ ) : ( +
+ {t("profile.generateAkSkDesc") || "Create or regenerate your API access key"} +
+ )} +
+
+ +
+