From fcfbdabc5decc9465d9083ec8d69357b7c5dd21a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 8 Nov 2024 04:50:17 +0000 Subject: [PATCH 1/2] Added Authorisation for dataprep api's using tokenvalidator middleware Signed-off-by: Ubuntu --- comps/cores/mega/keycloak.py | 76 +++++++++++++++++++ comps/cores/mega/utils.py | 60 ++++++++++++++- comps/dataprep/redis/README.md | 9 +++ comps/dataprep/redis/langchain/config.py | 3 + .../redis/langchain/prepare_doc_redis.py | 11 ++- .../dataprep/redis/langchain/requirements.txt | 1 + requirements.txt | 1 + 7 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 comps/cores/mega/keycloak.py diff --git a/comps/cores/mega/keycloak.py b/comps/cores/mega/keycloak.py new file mode 100644 index 0000000000..32397f591e --- /dev/null +++ b/comps/cores/mega/keycloak.py @@ -0,0 +1,76 @@ +import jwt +import requests +from jwt import ExpiredSignatureError, InvalidTokenError +from typing import Dict, Optional +import json +import os + +class Keycloak: + def __init__(self, realm_url: str = os.getenv("REALM_URL"), algorithm: str = "RS256"): + """ + Initializes the Keycloak JWT Interface with the realm URL and algorithm. + :param realm_url: Keycloak realm URL to fetch public keys for token verification. + :param algorithm: Algorithm used for the token, usually 'RS256' for Keycloak. + """ + self.realm_url = realm_url + self.algorithm = algorithm + self.public_keys = self.fetch_public_keys() + + def fetch_public_keys(self) -> Dict[str, str]: + """ + Fetches and returns Keycloak public keys for token verification. + :return: Dictionary mapping key IDs to their corresponding public keys. + """ + try: + response = requests.get(f"{self.realm_url}/protocol/openid-connect/certs") + response.raise_for_status() + certs = response.json() + return { + key["kid"]: jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) + for key in certs["keys"] + } + except requests.RequestException as e: + return {} + + def decode_token(self, token: str) -> Optional[Dict]: + """ + Decodes a Keycloak JWT token and verifies its signature and expiration. + :param token: JWT token as a string. + :return: Decoded payload as a dictionary if valid, None otherwise. + """ + try: + unverified_header = jwt.get_unverified_header(token) + key = self.public_keys.get(unverified_header.get("kid")) + if not key: + print("Invalid token header: key ID not found.") + return None + decoded = jwt.decode(token, key=key, algorithms=[self.algorithm]) + return decoded + except ExpiredSignatureError: + raise ExpiredSignatureError + except InvalidTokenError: + raise InvalidTokenError + + def verify_token(self, token: str) -> bool: + """ + Verifies if the token is valid and not expired. + :param token: JWT token as a string. + :return: True if valid, False otherwise. + """ + decoded = self.decode_token(token) + return decoded is not None + + def get_user_info(self, token: str) -> Optional[Dict]: + """ + Extracts user information from the JWT token payload. + :param token: JWT token as a string. + :return: Dictionary of user information if available, None otherwise. + """ + decoded = self.decode_token(token) + if decoded: + return { + "username": decoded.get("preferred_username"), + "email": decoded.get("email"), + "roles": decoded.get("realm_access", {}).get("roles", []) + } + return None diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index db23f023ad..7ab733a7af 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -7,6 +7,11 @@ import random from socket import AF_INET, SOCK_STREAM, socket from typing import List, Optional, Union +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from typing import Optional +from .keycloak import Keycloak +import jwt import requests @@ -204,7 +209,60 @@ def get_access_token(token_url: str, client_id: str, client_secret: str) -> str: logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}") return "" - +bearer_scheme = HTTPBearer(auto_error=False) +def token_validator(allowed_roles: Optional[List[str]] = None): + async def validate_token(request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme)): + """ + Validates the token, checks for allowed roles, and sets user details in request.state.user if valid. + Raises HTTPException with appropriate status code and message if validation fails. + """ + # If token is not provided, skip validation + JWT_AUTH = os.getenv("JWT_AUTH", False) + if not JWT_AUTH: + request.state.user = None + return + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme or Missing Token", + ) + if credentials.scheme != "Bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme", + ) + try: + token = credentials.credentials + identity_provider = Keycloak() + decoded_token = identity_provider.decode_token(token) + + if not decoded_token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED , detail="Invalid token.") + + # Extract roles from the token + user_roles = decoded_token.get("realm_access", {}).get("roles", []) + + # Check if user has any of the allowed roles + if allowed_roles and not any(role in user_roles for role in allowed_roles): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN , detail="User does not have required permissions.") + + # Set user details in request.state.user + request.state.user = { + "username": decoded_token.get("preferred_username"), + "email": decoded_token.get("email"), + "roles": user_roles + } + + except jwt.ExpiredSignatureError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired.") + + except jwt.InvalidTokenError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token signature.") + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token validation error: {str(e)}") + return validate_token class SafeContextManager: """This context manager ensures that the `__exit__` method of the sub context is called, even when there is an Exception in the diff --git a/comps/dataprep/redis/README.md b/comps/dataprep/redis/README.md index 384d9018f4..f983ef99b5 100644 --- a/comps/dataprep/redis/README.md +++ b/comps/dataprep/redis/README.md @@ -101,6 +101,15 @@ export REDIS_URL="redis://${your_ip}:6379" export INDEX_NAME=${your_index_name} export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token} ``` +if Authorization is needed with keycloak +```bash +realm_name=productivitysuite +export JWT_AUTH=True +export REALM_URL="http://${your_ip}/realms/$realm_name" +export ADMIN_ROLE="admin" +export USER_ROLE="user" +``` +If JWT_AUTH is enabled make sure to follow [keycloak setup guide](https://github.com/opea-project/GenAIExamples/blob/main/ProductivitySuite/docker_compose/intel/cpu/xeon/keycloak_setup_guide.md) ### 2.3 Build Docker Image diff --git a/comps/dataprep/redis/langchain/config.py b/comps/dataprep/redis/langchain/config.py index 2d722a84a6..5e604ff0e6 100644 --- a/comps/dataprep/redis/langchain/config.py +++ b/comps/dataprep/redis/langchain/config.py @@ -64,3 +64,6 @@ def format_redis_conn_from_env(): TIMEOUT_SECONDS = int(os.getenv("TIMEOUT_SECONDS", 600)) SEARCH_BATCH_SIZE = int(os.getenv("SEARCH_BATCH_SIZE", 10)) + +ADMIN_ROLE=os.getenv("ADMIN_ROLE_KEY", "admin") +USER_ROLE=os.getenv("USER_ROLE_KEY", "user") diff --git a/comps/dataprep/redis/langchain/prepare_doc_redis.py b/comps/dataprep/redis/langchain/prepare_doc_redis.py index 4aa6bef483..a77d3cca58 100644 --- a/comps/dataprep/redis/langchain/prepare_doc_redis.py +++ b/comps/dataprep/redis/langchain/prepare_doc_redis.py @@ -8,8 +8,8 @@ # from pyspark import SparkConf, SparkContext import redis -from config import EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE -from fastapi import Body, File, Form, HTTPException, UploadFile +from config import EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE, ADMIN_ROLE, USER_ROLE +from fastapi import Body, File, Form, HTTPException, UploadFile, Depends from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import Redis @@ -18,6 +18,8 @@ from redis.commands.search.field import TextField from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from comps.cores.mega.utils import token_validator + from comps import CustomLogger, DocPath, opea_microservices, register_microservice from comps.dataprep.utils import ( create_upload_folder, @@ -223,6 +225,7 @@ async def ingest_documents( chunk_overlap: int = Form(100), process_table: bool = Form(False), table_strategy: str = Form("fast"), + _: Optional[str] = Depends(token_validator([ADMIN_ROLE])) ): if logflag: logger.info(f"[ upload ] files:{files}") @@ -341,7 +344,7 @@ async def ingest_documents( @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/get_file", host="0.0.0.0", port=6007 ) -async def rag_get_file_structure(): +async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USER_ROLE,ADMIN_ROLE]))): if logflag: logger.info("[ get ] start to get file structure") @@ -375,7 +378,7 @@ async def rag_get_file_structure(): @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/delete_file", host="0.0.0.0", port=6007 ) -async def delete_single_file(file_path: str = Body(..., embed=True)): +async def delete_single_file(file_path: str = Body(..., embed=True), _: Optional[str] = Depends(token_validator([ADMIN_ROLE]))): """Delete file according to `file_path`. `file_path`: diff --git a/comps/dataprep/redis/langchain/requirements.txt b/comps/dataprep/redis/langchain/requirements.txt index 8c3b116fa8..0708ca1aef 100644 --- a/comps/dataprep/redis/langchain/requirements.txt +++ b/comps/dataprep/redis/langchain/requirements.txt @@ -23,6 +23,7 @@ pytesseract python-bidi python-docx python-pptx +pyjwt redis sentence_transformers shortuuid diff --git a/requirements.txt b/requirements.txt index 9c4b2d7706..28f9951f2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ docarray docx2txt fastapi httpx +pyjwt kubernetes langchain langchain-community From 0bef29bc4694c9500c03c7d29b2479375c85e0e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 04:55:35 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/cores/mega/keycloak.py | 38 +++++++------- comps/cores/mega/utils.py | 49 ++++++++++++------- comps/dataprep/redis/README.md | 7 ++- comps/dataprep/redis/langchain/config.py | 4 +- .../redis/langchain/prepare_doc_redis.py | 15 +++--- .../dataprep/redis/langchain/requirements.txt | 2 +- requirements.txt | 2 +- 7 files changed, 67 insertions(+), 50 deletions(-) diff --git a/comps/cores/mega/keycloak.py b/comps/cores/mega/keycloak.py index 32397f591e..2d01aad36f 100644 --- a/comps/cores/mega/keycloak.py +++ b/comps/cores/mega/keycloak.py @@ -1,14 +1,19 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +from typing import Dict, Optional + import jwt import requests from jwt import ExpiredSignatureError, InvalidTokenError -from typing import Dict, Optional -import json -import os + class Keycloak: def __init__(self, realm_url: str = os.getenv("REALM_URL"), algorithm: str = "RS256"): - """ - Initializes the Keycloak JWT Interface with the realm URL and algorithm. + """Initializes the Keycloak JWT Interface with the realm URL and algorithm. + :param realm_url: Keycloak realm URL to fetch public keys for token verification. :param algorithm: Algorithm used for the token, usually 'RS256' for Keycloak. """ @@ -17,24 +22,21 @@ def __init__(self, realm_url: str = os.getenv("REALM_URL"), algorithm: str = "RS self.public_keys = self.fetch_public_keys() def fetch_public_keys(self) -> Dict[str, str]: - """ - Fetches and returns Keycloak public keys for token verification. + """Fetches and returns Keycloak public keys for token verification. + :return: Dictionary mapping key IDs to their corresponding public keys. """ try: response = requests.get(f"{self.realm_url}/protocol/openid-connect/certs") response.raise_for_status() certs = response.json() - return { - key["kid"]: jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) - for key in certs["keys"] - } + return {key["kid"]: jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key)) for key in certs["keys"]} except requests.RequestException as e: return {} def decode_token(self, token: str) -> Optional[Dict]: - """ - Decodes a Keycloak JWT token and verifies its signature and expiration. + """Decodes a Keycloak JWT token and verifies its signature and expiration. + :param token: JWT token as a string. :return: Decoded payload as a dictionary if valid, None otherwise. """ @@ -52,8 +54,8 @@ def decode_token(self, token: str) -> Optional[Dict]: raise InvalidTokenError def verify_token(self, token: str) -> bool: - """ - Verifies if the token is valid and not expired. + """Verifies if the token is valid and not expired. + :param token: JWT token as a string. :return: True if valid, False otherwise. """ @@ -61,8 +63,8 @@ def verify_token(self, token: str) -> bool: return decoded is not None def get_user_info(self, token: str) -> Optional[Dict]: - """ - Extracts user information from the JWT token payload. + """Extracts user information from the JWT token payload. + :param token: JWT token as a string. :return: Dictionary of user information if available, None otherwise. """ @@ -71,6 +73,6 @@ def get_user_info(self, token: str) -> Optional[Dict]: return { "username": decoded.get("preferred_username"), "email": decoded.get("email"), - "roles": decoded.get("realm_access", {}).get("roles", []) + "roles": decoded.get("realm_access", {}).get("roles", []), } return None diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index 7ab733a7af..8a3159bddc 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -7,14 +7,13 @@ import random from socket import AF_INET, SOCK_STREAM, socket from typing import List, Optional, Union -from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from typing import Optional -from .keycloak import Keycloak -import jwt +import jwt import requests +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from .keycloak import Keycloak from .logger import CustomLogger @@ -209,16 +208,21 @@ def get_access_token(token_url: str, client_id: str, client_secret: str) -> str: logger.error(f"Failed to retrieve access token: {response.status_code}, {response.text}") return "" + bearer_scheme = HTTPBearer(auto_error=False) + + def token_validator(allowed_roles: Optional[List[str]] = None): - async def validate_token(request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme)): - """ - Validates the token, checks for allowed roles, and sets user details in request.state.user if valid. + async def validate_token( + request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme) + ): + """Validates the token, checks for allowed roles, and sets user details in request.state.user if valid. + Raises HTTPException with appropriate status code and message if validation fails. """ # If token is not provided, skip validation JWT_AUTH = os.getenv("JWT_AUTH", False) - if not JWT_AUTH: + if not JWT_AUTH: request.state.user = None return if credentials is None: @@ -235,34 +239,41 @@ async def validate_token(request: Request, credentials: Optional[HTTPAuthorizati token = credentials.credentials identity_provider = Keycloak() decoded_token = identity_provider.decode_token(token) - + if not decoded_token: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED , detail="Invalid token.") - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token.") + # Extract roles from the token user_roles = decoded_token.get("realm_access", {}).get("roles", []) # Check if user has any of the allowed roles if allowed_roles and not any(role in user_roles for role in allowed_roles): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN , detail="User does not have required permissions.") - + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="User does not have required permissions." + ) + # Set user details in request.state.user request.state.user = { "username": decoded_token.get("preferred_username"), "email": decoded_token.get("email"), - "roles": user_roles + "roles": user_roles, } - + except jwt.ExpiredSignatureError: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired.") - + except jwt.InvalidTokenError: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token signature.") except HTTPException as e: - raise e + raise e except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token validation error: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token validation error: {str(e)}" + ) + return validate_token + + class SafeContextManager: """This context manager ensures that the `__exit__` method of the sub context is called, even when there is an Exception in the diff --git a/comps/dataprep/redis/README.md b/comps/dataprep/redis/README.md index f983ef99b5..acf22428e0 100644 --- a/comps/dataprep/redis/README.md +++ b/comps/dataprep/redis/README.md @@ -101,7 +101,9 @@ export REDIS_URL="redis://${your_ip}:6379" export INDEX_NAME=${your_index_name} export HUGGINGFACEHUB_API_TOKEN=${your_hf_api_token} ``` -if Authorization is needed with keycloak + +if Authorization is needed with keycloak + ```bash realm_name=productivitysuite export JWT_AUTH=True @@ -109,7 +111,8 @@ export REALM_URL="http://${your_ip}/realms/$realm_name" export ADMIN_ROLE="admin" export USER_ROLE="user" ``` -If JWT_AUTH is enabled make sure to follow [keycloak setup guide](https://github.com/opea-project/GenAIExamples/blob/main/ProductivitySuite/docker_compose/intel/cpu/xeon/keycloak_setup_guide.md) + +If JWT_AUTH is enabled make sure to follow [keycloak setup guide](https://github.com/opea-project/GenAIExamples/blob/main/ProductivitySuite/docker_compose/intel/cpu/xeon/keycloak_setup_guide.md) ### 2.3 Build Docker Image diff --git a/comps/dataprep/redis/langchain/config.py b/comps/dataprep/redis/langchain/config.py index 5e604ff0e6..3077d21eba 100644 --- a/comps/dataprep/redis/langchain/config.py +++ b/comps/dataprep/redis/langchain/config.py @@ -65,5 +65,5 @@ def format_redis_conn_from_env(): SEARCH_BATCH_SIZE = int(os.getenv("SEARCH_BATCH_SIZE", 10)) -ADMIN_ROLE=os.getenv("ADMIN_ROLE_KEY", "admin") -USER_ROLE=os.getenv("USER_ROLE_KEY", "user") +ADMIN_ROLE = os.getenv("ADMIN_ROLE_KEY", "admin") +USER_ROLE = os.getenv("USER_ROLE_KEY", "user") diff --git a/comps/dataprep/redis/langchain/prepare_doc_redis.py b/comps/dataprep/redis/langchain/prepare_doc_redis.py index a77d3cca58..99fcbd3624 100644 --- a/comps/dataprep/redis/langchain/prepare_doc_redis.py +++ b/comps/dataprep/redis/langchain/prepare_doc_redis.py @@ -8,8 +8,8 @@ # from pyspark import SparkConf, SparkContext import redis -from config import EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE, ADMIN_ROLE, USER_ROLE -from fastapi import Body, File, Form, HTTPException, UploadFile, Depends +from config import ADMIN_ROLE, EMBED_MODEL, INDEX_NAME, KEY_INDEX_NAME, REDIS_URL, SEARCH_BATCH_SIZE, USER_ROLE +from fastapi import Body, Depends, File, Form, HTTPException, UploadFile from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import Redis @@ -18,9 +18,8 @@ from redis.commands.search.field import TextField from redis.commands.search.indexDefinition import IndexDefinition, IndexType -from comps.cores.mega.utils import token_validator - from comps import CustomLogger, DocPath, opea_microservices, register_microservice +from comps.cores.mega.utils import token_validator from comps.dataprep.utils import ( create_upload_folder, document_loader, @@ -225,7 +224,7 @@ async def ingest_documents( chunk_overlap: int = Form(100), process_table: bool = Form(False), table_strategy: str = Form("fast"), - _: Optional[str] = Depends(token_validator([ADMIN_ROLE])) + _: Optional[str] = Depends(token_validator([ADMIN_ROLE])), ): if logflag: logger.info(f"[ upload ] files:{files}") @@ -344,7 +343,7 @@ async def ingest_documents( @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/get_file", host="0.0.0.0", port=6007 ) -async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USER_ROLE,ADMIN_ROLE]))): +async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USER_ROLE, ADMIN_ROLE]))): if logflag: logger.info("[ get ] start to get file structure") @@ -378,7 +377,9 @@ async def rag_get_file_structure(_: Optional[str] = Depends(token_validator([USE @register_microservice( name="opea_service@prepare_doc_redis", endpoint="/v1/dataprep/delete_file", host="0.0.0.0", port=6007 ) -async def delete_single_file(file_path: str = Body(..., embed=True), _: Optional[str] = Depends(token_validator([ADMIN_ROLE]))): +async def delete_single_file( + file_path: str = Body(..., embed=True), _: Optional[str] = Depends(token_validator([ADMIN_ROLE])) +): """Delete file according to `file_path`. `file_path`: diff --git a/comps/dataprep/redis/langchain/requirements.txt b/comps/dataprep/redis/langchain/requirements.txt index 0708ca1aef..8bde8172d7 100644 --- a/comps/dataprep/redis/langchain/requirements.txt +++ b/comps/dataprep/redis/langchain/requirements.txt @@ -17,13 +17,13 @@ opentelemetry-sdk pandas Pillow prometheus-fastapi-instrumentator +pyjwt pymupdf pyspark pytesseract python-bidi python-docx python-pptx -pyjwt redis sentence_transformers shortuuid diff --git a/requirements.txt b/requirements.txt index 28f9951f2d..3a1c1b08f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ docarray docx2txt fastapi httpx -pyjwt kubernetes langchain langchain-community @@ -13,6 +12,7 @@ opentelemetry-exporter-otlp opentelemetry-sdk Pillow prometheus-fastapi-instrumentator +pyjwt pypdf python-multipart pyyaml