diff --git a/comps/cores/mega/keycloak.py b/comps/cores/mega/keycloak.py new file mode 100644 index 0000000000..2d01aad36f --- /dev/null +++ b/comps/cores/mega/keycloak.py @@ -0,0 +1,78 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from typing import Dict, Optional + +import jwt +import requests +from jwt import ExpiredSignatureError, InvalidTokenError + + +class Keycloak: + def __init__(self, realm_url: str = os.getenv("REALM_URL"), algorithm: str = "RS256"): + """Initializes the Keycloak JWT Interface with the realm URL and algorithm. + + :param realm_url: Keycloak realm URL to fetch public keys for token verification. + :param algorithm: Algorithm used for the token, usually 'RS256' for Keycloak. + """ + self.realm_url = realm_url + self.algorithm = algorithm + self.public_keys = self.fetch_public_keys() + + def fetch_public_keys(self) -> Dict[str, str]: + """Fetches and returns Keycloak public keys for token verification. + + :return: Dictionary mapping key IDs to their corresponding public keys. + """ + try: + response = requests.get(f"{self.realm_url}/protocol/openid-connect/certs") + response.raise_for_status() + certs = response.json() + return {key["kid"]: jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) for key in certs["keys"]} + except requests.RequestException as e: + return {} + + def decode_token(self, token: str) -> Optional[Dict]: + """Decodes a Keycloak JWT token and verifies its signature and expiration. + + :param token: JWT token as a string. + :return: Decoded payload as a dictionary if valid, None otherwise. + """ + try: + unverified_header = jwt.get_unverified_header(token) + key = self.public_keys.get(unverified_header.get("kid")) + if not key: + print("Invalid token header: key ID not found.") + return None + decoded = jwt.decode(token, key=key, algorithms=[self.algorithm]) + return decoded + except ExpiredSignatureError: + raise ExpiredSignatureError + except InvalidTokenError: + raise InvalidTokenError + + def verify_token(self, token: str) -> bool: + """Verifies if the token is valid and not expired. + + :param token: JWT token as a string. + :return: True if valid, False otherwise. + """ + decoded = self.decode_token(token) + return decoded is not None + + def get_user_info(self, token: str) -> Optional[Dict]: + """Extracts user information from the JWT token payload. + + :param token: JWT token as a string. + :return: Dictionary of user information if available, None otherwise. + """ + decoded = self.decode_token(token) + if decoded: + return { + "username": decoded.get("preferred_username"), + "email": decoded.get("email"), + "roles": decoded.get("realm_access", {}).get("roles", []), + } + return None diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index e5b2df4f5f..60462d89db 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -9,8 +9,12 @@ from socket import AF_INET, SOCK_STREAM, socket from typing import List, Optional, Union +import jwt import requests +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from .keycloak import Keycloak from .logger import CustomLogger @@ -244,6 +248,71 @@ def get_access_token(token_url: str, client_id: str, client_secret: str) -> str: return "" +bearer_scheme = HTTPBearer(auto_error=False) + + +def token_validator(allowed_roles: Optional[List[str]] = None): + async def validate_token( + request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme) + ): + """Validates the token, checks for allowed roles, and sets user details in request.state.user if valid. + + Raises HTTPException with appropriate status code and message if validation fails. + """ + # If token is not provided, skip validation + JWT_AUTH = os.getenv("JWT_AUTH", False) + if not JWT_AUTH: + request.state.user = None + return + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme or Missing Token", + ) + if credentials.scheme != "Bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme", + ) + try: + token = credentials.credentials + identity_provider = Keycloak() + decoded_token = identity_provider.decode_token(token) + + if not decoded_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token.") + + # Extract roles from the token + user_roles = decoded_token.get("realm_access", {}).get("roles", []) + + # Check if user has any of the allowed roles + if allowed_roles and not any(role in user_roles for role in allowed_roles): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="User does not have required permissions." + ) + + # Set user details in request.state.user + request.state.user = { + "username": decoded_token.get("preferred_username"), + "email": decoded_token.get("email"), + "roles": user_roles, + } + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired.") + + except jwt.InvalidTokenError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token signature.") + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token validation error: {str(e)}" + ) + + return validate_token + + class SafeContextManager: """This context manager ensures that the `__exit__` method of the sub context is called, even when there is an Exception in the diff --git a/comps/dataprep/redis/README.md b/comps/dataprep/redis/README.md index 384d9018f4..acf22428e0 100644 --- a/comps/dataprep/redis/README.md +++ b/comps/dataprep/redis/README.md @@ -102,6 +102,18 @@ export INDEX_NAME=${your_index_name} export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token} ``` +if Authorization is needed with keycloak + +```bash +realm_name=productivitysuite +export JWT_AUTH=True +export REALM_URL="http://${your_ip}/realms/$realm_name" +export ADMIN_ROLE="admin" +export USER_ROLE="user" +``` + +If JWT_AUTH is enabled make sure to follow [keycloak setup guide](https://github.com/opea-project/GenAIExamples/blob/main/ProductivitySuite/docker_compose/intel/cpu/xeon/keycloak_setup_guide.md) + ### 2.3 Build Docker Image - Build docker image with langchain diff --git a/comps/dataprep/redis/langchain/config.py b/comps/dataprep/redis/langchain/config.py index 2d722a84a6..3077d21eba 100644 --- a/comps/dataprep/redis/langchain/config.py +++ b/comps/dataprep/redis/langchain/config.py @@ -64,3 +64,6 @@ def format_redis_conn_from_env(): TIMEOUT_SECONDS = int(os.getenv("TIMEOUT_SECONDS", 600)) SEARCH_BATCH_SIZE = int(os.getenv("SEARCH_BATCH_SIZE", 10)) + +ADMIN_ROLE = os.getenv("ADMIN_ROLE_KEY", "admin") +USER_ROLE = os.getenv("USER_ROLE_KEY", "user") diff --git a/comps/dataprep/redis/langchain/prepare_doc_redis.py b/comps/dataprep/redis/langchain/prepare_doc_redis.py index 6902117dcd..9bb0a638f2 100644 --- a/comps/dataprep/redis/langchain/prepare_doc_redis.py +++ b/comps/dataprep/redis/langchain/prepare_doc_redis.py @@ -8,8 +8,8 @@ # from pyspark import SparkConf, SparkContext import redis -from config import EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE -from fastapi import Body, File, Form, HTTPException, UploadFile +from config import ADMIN_ROLE, EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE, USER_ROLE +from fastapi import Body, Depends, File, Form, HTTPException, UploadFile from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import Redis @@ -19,6 +19,7 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType from comps import CustomLogger, DocPath, opea_microservices, register_microservice +from comps.cores.mega.utils import token_validator from comps.dataprep.utils import ( create_upload_folder, document_loader, @@ -223,6 +224,7 @@ async def ingest_documents( chunk_overlap: int = Form(100), process_table: bool = Form(False), table_strategy: str = Form("fast"), + _: Optional[str] = Depends(token_validator([ADMIN_ROLE])), ): if logflag: logger.info(f"[ upload ] files:{files}") @@ -341,7 +343,7 @@ async def ingest_documents( @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/get_file", host="0.0.0.0", port=6007 ) -async def rag_get_file_structure(): +async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USER_ROLE, ADMIN_ROLE]))): if logflag: logger.info("[ get ] start to get file structure") @@ -375,7 +377,9 @@ async def rag_get_file_structure(): @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/delete_file", host="0.0.0.0", port=6007 ) -async def delete_single_file(file_path: str = Body(..., embed=True)): +async def delete_single_file( + file_path: str = Body(..., embed=True), _: Optional[str] = Depends(token_validator([ADMIN_ROLE])) +): """Delete file according to `file_path`. `file_path`: diff --git a/comps/dataprep/redis/langchain/requirements.txt b/comps/dataprep/redis/langchain/requirements.txt index 8c3b116fa8..8bde8172d7 100644 --- a/comps/dataprep/redis/langchain/requirements.txt +++ b/comps/dataprep/redis/langchain/requirements.txt @@ -17,6 +17,7 @@ opentelemetry-sdk pandas Pillow prometheus-fastapi-instrumentator +pyjwt pymupdf pyspark pytesseract diff --git a/requirements.txt b/requirements.txt index 9c4b2d7706..3a1c1b08f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ opentelemetry-exporter-otlp opentelemetry-sdk Pillow prometheus-fastapi-instrumentator +pyjwt pypdf python-multipart pyyaml