diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index b33ff8c..b385778 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -58,4 +58,4 @@ def __init__(self): app.logger.error(f"Cache module for {self.CACHE_TYPE} could not be imported.") -config = AppConfig() +config = AppConfig() \ No newline at end of file diff --git a/mlflow_oidc_auth/views/authentication.py b/mlflow_oidc_auth/views/authentication.py index bd96eb3..9269789 100644 --- a/mlflow_oidc_auth/views/authentication.py +++ b/mlflow_oidc_auth/views/authentication.py @@ -3,7 +3,7 @@ from flask import redirect, session, url_for, render_template import mlflow_oidc_auth.utils as utils -from mlflow_oidc_auth.auth import get_oauth_instance +from mlflow_oidc_auth.auth import get_oauth_instance, validate_token from mlflow_oidc_auth.app import app from mlflow_oidc_auth.config import config from mlflow_oidc_auth.user import create_user, populate_groups, update_user @@ -48,7 +48,13 @@ def callback(): user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(token["access_token"]) else: - user_groups = token["userinfo"][config.OIDC_GROUPS_ATTRIBUTE] + group_attr = config.OIDC_GROUPS_ATTRIBUTE + user_info = token["userinfo"] + decoded_access_token = validate_token(token["access_token"]) + if group_attr in decoded_access_token: + user_groups = decoded_access_token[group_attr] + if group_attr in user_info: + user_groups = user_info[group_attr] app.logger.debug(f"User groups: {user_groups}")