Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 99 additions & 3 deletions sky/server/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import importlib
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

from fastapi import FastAPI

Expand All @@ -20,9 +20,71 @@
f'{skylet_constants.SKYPILOT_SERVER_ENV_VAR_PREFIX}PLUGINS_CONFIG')


@dataclasses.dataclass
class ExtensionContext:
app: FastAPI
"""Context provided to plugins during installation.

Attributes:
app: The FastAPI application instance.
rbac_rules: List of RBAC rules registered by the plugin.
Example:
[
('user', RBACRule(path='/plugins/api/xx/*', method='POST')),
('user', RBACRule(path='/plugins/api/xx/*', method='DELETE'))
]
"""

def __init__(self, app: FastAPI):
self.app = app
self.rbac_rules: List[Tuple[str, RBACRule]] = []

def register_rbac_rule(self,
path: str,
method: str,
description: Optional[str] = None,
role: str = 'user') -> None:
"""Register an RBAC rule for this plugin.

This method allows plugins to declare which endpoints should be
restricted to admin users during the install phase.

Args:
path: The path pattern to restrict (supports wildcards with
keyMatch2).
Example: '/plugins/api/credentials/*'
method: The HTTP method to restrict. Example: 'POST', 'DELETE'
description: Optional description of what this rule protects.
role: The role to add this rule to (default: 'user').
Rules added to 'user' role block regular users but allow
admins.

Example:
def install(self, ctx: ExtensionContext):
# Only admin can upload credentials
ctx.register_rbac_rule(
path='/plugins/api/credentials/*',
method='POST',
description='Only admin can upload credentials'
)
"""
rule = RBACRule(path=path, method=method, description=description)
self.rbac_rules.append((role, rule))
logger.debug(f'Registered RBAC rule for {role}: {method} {path}'
f'{f" - {description}" if description else ""}')


@dataclasses.dataclass
class RBACRule:
"""RBAC rule for a plugin endpoint.

Attributes:
path: The path pattern to match (supports wildcards with keyMatch2).
Example: '/plugins/api/credentials/*'
method: The HTTP method to restrict. Example: 'POST', 'DELETE'
description: Optional description of what this rule protects.
"""
path: str
method: str
description: Optional[str] = None


class BasePlugin(abc.ABC):
Expand Down Expand Up @@ -84,10 +146,14 @@ def _load_plugin_config() -> Optional[config_utils.Config]:


_PLUGINS: Dict[str, BasePlugin] = {}
_EXTENSION_CONTEXT: Optional[ExtensionContext] = None


def load_plugins(extension_context: ExtensionContext):
"""Load and initialize plugins from the config."""
global _EXTENSION_CONTEXT
_EXTENSION_CONTEXT = extension_context

config = _load_plugin_config()
if not config:
return
Expand Down Expand Up @@ -120,3 +186,33 @@ def load_plugins(extension_context: ExtensionContext):
def get_plugins() -> List[BasePlugin]:
"""Return shallow copies of the registered plugins."""
return list(_PLUGINS.values())


def get_plugin_rbac_rules() -> Dict[str, List[Dict[str, str]]]:
"""Collect RBAC rules from all loaded plugins.

Collects rules from the ExtensionContext.

Returns:
Dictionary mapping role names to lists of blocklist rules.
Example:
{
'user': [
{'path': '/plugins/api/credentials/*', 'method': 'POST'},
{'path': '/plugins/api/credentials/*', 'method': 'DELETE'}
]
}
"""
rules_by_role: Dict[str, List[Dict[str, str]]] = {}

# Collect rules registered via ExtensionContext
if _EXTENSION_CONTEXT:
for role, rule in _EXTENSION_CONTEXT.rbac_rules:
if role not in rules_by_role:
rules_by_role[role] = []
rules_by_role[role].append({
'path': rule.path,
'method': rule.method,
})

return rules_by_role
2 changes: 1 addition & 1 deletion sky/users/model.conf
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ g = _, _
e = some(where (p.eft == allow))

[matchers]
m = (g(r.sub, p.sub)|| p.sub == '*') && r.obj == p.obj && r.act == p.act
m = (g(r.sub, p.sub)|| p.sub == '*') && keyMatch2(r.obj, p.obj) && r.act == p.act
25 changes: 24 additions & 1 deletion sky/users/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ def _ensure_enforcer(self) -> casbin.Enforcer:
'Enforcer should be initialized after _lazy_initialize()')
return self.enforcer

def _get_plugin_rbac_rules(self):
"""Get RBAC rules from loaded plugins.
Returns:
Dictionary of plugin RBAC rules, or empty dict if plugins module
is not available or no rules are defined.
"""
try:
# pylint: disable=import-outside-toplevel
from sky.server import plugins as server_plugins
return server_plugins.get_plugin_rbac_rules()
except ImportError:
# Plugin module not available (e.g., not running as server)
logger.debug(
'Plugin module not available, skipping plugin RBAC rules')
return {}
except Exception as e: # pylint: disable=broad-except
logger.warning(f'Failed to get plugin RBAC rules: {e}')
return {}

def _maybe_initialize_basic_auth_user(self) -> None:
"""Initialize basic auth user if it is enabled."""
basic_auth = os.environ.get(constants.SKYPILOT_INITIAL_BASIC_AUTH)
Expand Down Expand Up @@ -101,9 +121,12 @@ def _maybe_initialize_policies(self) -> None:
enforcer = self._ensure_enforcer()
existing_policies = enforcer.get_policy()

# Get plugin RBAC rules dynamically
plugin_rules = self._get_plugin_rbac_rules()

# If we already have policies for the expected roles, skip
# initialization
role_permissions = rbac.get_role_permissions()
role_permissions = rbac.get_role_permissions(plugin_rules=plugin_rules)
expected_policies = []
for role, permissions in role_permissions.items():
if permissions['permissions'] and 'blocklist' in permissions[
Expand Down
34 changes: 31 additions & 3 deletions sky/users/rbac.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""RBAC (Role-Based Access Control) functionality for SkyPilot API Server."""

import enum
from typing import Dict, List
from typing import Dict, List, Optional

from sky import sky_logging
from sky import skypilot_config
Expand Down Expand Up @@ -55,8 +55,13 @@ def get_default_role() -> str:


def get_role_permissions(
plugin_rules: Optional[Dict[str, List[Dict[str, str]]]] = None
) -> Dict[str, Dict[str, Dict[str, List[Dict[str, str]]]]]:
"""Get all role permissions from config.
"""Get all role permissions from config and plugins.

Args:
plugin_rules: Optional dictionary of plugin RBAC rules to merge.
Format: {'user': [{'path': '...', 'method': '...'}]}

Returns:
Dictionary containing all roles and their permissions configuration.
Expand Down Expand Up @@ -91,9 +96,32 @@ def get_role_permissions(
if 'user' not in config_permissions:
config_permissions['user'] = {
'permissions': {
'blocklist': _DEFAULT_USER_BLOCKLIST
'blocklist': _DEFAULT_USER_BLOCKLIST.copy()
}
}

# Merge plugin rules into the appropriate roles
if plugin_rules:
for role, rules in plugin_rules.items():
if role not in supported_roles:
logger.warning(f'Plugin specified invalid role: {role}')
continue
if role not in config_permissions:
config_permissions[role] = {'permissions': {'blocklist': []}}
if 'permissions' not in config_permissions[role]:
config_permissions[role]['permissions'] = {'blocklist': []}
if 'blocklist' not in config_permissions[role]['permissions']:
config_permissions[role]['permissions']['blocklist'] = []

# Merge plugin rules, avoiding duplicates
existing_rules = config_permissions[role]['permissions'][
'blocklist']
for rule in rules:
if rule not in existing_rules:
existing_rules.append(rule)
logger.debug(f'Added plugin RBAC rule for {role}: '
f'{rule["method"]} {rule["path"]}')

return config_permissions


Expand Down