Skip to content

Commit

Permalink
feat: permissions management (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
kharkevich authored Dec 17, 2024
1 parent f840eac commit 2144744
Show file tree
Hide file tree
Showing 23 changed files with 5,883 additions and 5,753 deletions.
2 changes: 0 additions & 2 deletions mlflow_oidc_auth/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def to_json(self):
"username": self.username,
"is_admin": self.is_admin,
"display_name": self.display_name,
"experiment_permissions": [p.to_json() for p in self.experiment_permissions],
"registered_model_permissions": [p.to_json() for p in self.registered_model_permissions],
"groups": [g.to_json() for g in self.groups],
}

Expand Down
42 changes: 39 additions & 3 deletions mlflow_oidc_auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Callable, NamedTuple

from flask import request, session
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE
from mlflow.exceptions import ErrorCode, MlflowException
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
from mlflow.server import app
from mlflow.server.handlers import _get_tracking_store

from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.app import app
from mlflow_oidc_auth.auth import validate_token
from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.permissions import Permission, get_permission
from mlflow_oidc_auth.store import store


def get_request_param(param: str) -> str:
Expand Down Expand Up @@ -69,3 +74,34 @@ def get_experiment_id() -> str:
"Either 'experiment_id' or 'experiment_name' must be provided in the request data.",
INVALID_PARAMETER_VALUE,
)


class PermissionResult(NamedTuple):
permission: Permission
type: str


def get_permission_from_store_or_default(
store_permission_user_func: Callable[[], str], store_permission_group_func: Callable[[], str]
) -> PermissionResult:
"""
Attempts to get permission from store,
and returns default permission if no record is found.
user permission takes precedence over group permission
"""
try:
perm = store_permission_user_func()
app.logger.debug("User permission found")
perm_type = "user"
except MlflowException as e:
if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST):
try:
perm = store_permission_group_func()
app.logger.debug("Group permission found")
perm_type = "group"
except MlflowException as e:
if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST):
perm = config.DEFAULT_MLFLOW_PERMISSION
app.logger.debug("Default permission used")
perm_type = "fallback"
return PermissionResult(get_permission(perm), perm_type)
33 changes: 0 additions & 33 deletions mlflow_oidc_auth/validators/_permissions.py

This file was deleted.

10 changes: 4 additions & 6 deletions mlflow_oidc_auth/validators/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from mlflow_oidc_auth.app import config
from mlflow_oidc_auth.permissions import Permission, get_permission
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_experiment_id, get_request_param, get_username

from ._permissions import get_permission_from_store_or_default
from mlflow_oidc_auth.utils import get_experiment_id, get_permission_from_store_or_default, get_request_param, get_username


def _get_permission_from_experiment_id() -> Permission:
Expand All @@ -19,7 +17,7 @@ def _get_permission_from_experiment_id() -> Permission:
return get_permission_from_store_or_default(
lambda: store.get_experiment_permission(experiment_id, username).permission,
lambda: store.get_user_groups_experiment_permission(experiment_id, username).permission,
)
).permission


def _get_permission_from_experiment_name() -> Permission:
Expand All @@ -34,7 +32,7 @@ def _get_permission_from_experiment_name() -> Permission:
return get_permission_from_store_or_default(
lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission,
lambda: store.get_user_groups_experiment_permission(store_exp.experiment_id, username).permission,
)
).permission


_EXPERIMENT_ID_PATTERN = re.compile(r"^(\d+)/")
Expand All @@ -54,7 +52,7 @@ def _get_permission_from_experiment_id_artifact_proxy() -> Permission:
return get_permission_from_store_or_default(
lambda: store.get_experiment_permission(experiment_id, username).permission,
lambda: store.get_user_groups_experiment_permission(experiment_id, username).permission,
)
).permission
return get_permission(config.DEFAULT_MLFLOW_PERMISSION)


Expand Down
6 changes: 2 additions & 4 deletions mlflow_oidc_auth/validators/registered_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from mlflow_oidc_auth.permissions import Permission
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_request_param, get_username

from ._permissions import get_permission_from_store_or_default
from mlflow_oidc_auth.utils import get_permission_from_store_or_default, get_request_param, get_username


def _get_permission_from_registered_model_name() -> Permission:
Expand All @@ -11,7 +9,7 @@ def _get_permission_from_registered_model_name() -> Permission:
return get_permission_from_store_or_default(
lambda: store.get_registered_model_permission(model_name, username).permission,
lambda: store.get_user_groups_registered_model_permission(model_name, username).permission,
)
).permission


def validate_can_read_registered_model():
Expand Down
12 changes: 5 additions & 7 deletions mlflow_oidc_auth/validators/run.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_request_param, get_username
from mlflow.server.handlers import _get_tracking_store

from mlflow_oidc_auth.permissions import Permission
from mlflow.server.handlers import (
_get_tracking_store,
)
from ._permissions import get_permission_from_store_or_default
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_permission_from_store_or_default, get_request_param, get_username


def _get_permission_from_run_id() -> Permission:
Expand All @@ -17,7 +15,7 @@ def _get_permission_from_run_id() -> Permission:
return get_permission_from_store_or_default(
lambda: store.get_experiment_permission(experiment_id, username).permission,
lambda: store.get_user_groups_experiment_permission(experiment_id, username).permission,
)
).permission


def validate_can_read_run():
Expand Down
33 changes: 30 additions & 3 deletions mlflow_oidc_auth/views/experiment.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from flask import jsonify, make_response
from mlflow.server.handlers import _get_tracking_store, catch_mlflow_exception

from mlflow_oidc_auth.responses.client_error import make_forbidden_response
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_experiment_id, get_request_param
from mlflow_oidc_auth.utils import (
get_experiment_id,
get_is_admin,
get_permission_from_store_or_default,
get_request_param,
get_username,
)


@catch_mlflow_exception
Expand Down Expand Up @@ -44,7 +51,19 @@ def delete_experiment_permission():

@catch_mlflow_exception
def get_experiments():
list_experiments = _get_tracking_store().search_experiments()
current_user = store.get_user(get_username())
is_admin = get_is_admin()
if is_admin:
list_experiments = _get_tracking_store().search_experiments()
else:
list_experiments = []
for experiment in _get_tracking_store().search_experiments():
permission = get_permission_from_store_or_default(
lambda: store.get_experiment_permission(experiment.experiment_id, current_user.username).permission,
lambda: store.get_user_groups_experiment_permission(experiment.experiment_id, current_user.username).permission,
).permission
if permission.can_manage:
list_experiments.append(experiment)
experiments = [
{
"name": experiment.name,
Expand All @@ -59,7 +78,15 @@ def get_experiments():
@catch_mlflow_exception
def get_experiment_users(experiment_id: str):
experiment_id = str(experiment_id)
# Get the list of all users
current_user = store.get_user(get_username())
is_admin = get_is_admin()
if not is_admin:
permission = get_permission_from_store_or_default(
lambda: store.get_experiment_permission(experiment_id, current_user.username).permission,
lambda: store.get_user_groups_experiment_permission(experiment_id, current_user.username).permission,
).permission
if not permission.can_manage:
return make_forbidden_response()
list_users = store.list_users()
# Filter users who are associated with the given experiment
users = []
Expand Down
28 changes: 24 additions & 4 deletions mlflow_oidc_auth/views/registered_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from flask import jsonify, make_response
from mlflow.server.handlers import _get_model_registry_store, catch_mlflow_exception

from mlflow_oidc_auth.responses.client_error import make_forbidden_response
from mlflow_oidc_auth.store import store
from mlflow_oidc_auth.utils import get_request_param
from mlflow_oidc_auth.utils import get_is_admin, get_permission_from_store_or_default, get_request_param, get_username


@catch_mlflow_exception
Expand Down Expand Up @@ -41,8 +42,19 @@ def delete_registered_model_permission():

@catch_mlflow_exception
def get_registered_models():
# TODO: Implement pagination
registered_models = _get_model_registry_store().search_registered_models(max_results=1000)
current_user = store.get_user(get_username())
is_admin = get_is_admin()
if is_admin:
registered_models = _get_model_registry_store().search_registered_models(max_results=1000)
else:
registered_models = []
for model in _get_model_registry_store().search_registered_models(max_results=1000):
permission = get_permission_from_store_or_default(
lambda: store.get_registered_model_permission(model.name, current_user.username).permission,
lambda: store.get_user_groups_registered_model_permission(model.name, current_user.username).permission,
).permission
if permission.can_manage:
registered_models.append(model)
models = [
{
"name": model.name,
Expand All @@ -57,7 +69,15 @@ def get_registered_models():

@catch_mlflow_exception
def get_registered_model_users(model_name):
# Get the list of all users
current_user = store.get_user(get_username())
is_admin = get_is_admin()
if not is_admin:
permission = get_permission_from_store_or_default(
lambda: store.get_registered_model_permission(model_name, current_user.username).permission,
lambda: store.get_user_groups_registered_model_permission(model_name, current_user.username).permission,
).permission
if not permission.can_manage:
return make_forbidden_response()
list_users = store.list_users()
# Filter users who are associated with the given model
users = []
Expand Down
Loading

0 comments on commit 2144744

Please sign in to comment.