Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix internal communications #13

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ The plugin required the following environment variables but also supported `.env
| OAUTHLIB_INSECURE_TRANSPORT | Development only. Allow to use insecure endpoints for OIDC |
| LOG_LEVEL | Application log level |
| OIDC_USERS_DB_URI | Database connection string |
| MLFLOW_TRACKING_USERNAME | Credentials for internal communications via API |
| MLFLOW_TRACKING_PASSWORD | Credentials for internal communications via API |
| MLFLOW_TRACKING_URI | URI for internal communications via API |

# Configuration examples

Expand Down Expand Up @@ -56,16 +59,21 @@ OIDC_ADMIN_GROUP_NAME = "mlflow_admins_group_name"
> please note, that for getting group membership information, the application should have "GroupMember.Read.All" permission

# Development

Preconditions:

The following tools should be installed for local development:

* git
* nodejs
* python

```shell
git clone https://github.com/data-platform-hq/mlflow-oidc-auth
cd mlflow-oidc-auth
python3 -m venv venv
source venv/bin/activate
pip install --editable .
mlflow server --dev --app-name oidc-auth --host 0.0.0.0 --port 8080
./scripts/run-dev-server.sh
```


# License
Apache 2 Licensed. For more information please see [LICENSE](./LICENSE)

Expand Down
4 changes: 4 additions & 0 deletions mlflow_oidc_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import secrets
import requests
import secrets

from dotenv import load_dotenv
from mlflow_oidc_auth.app import app
Expand Down Expand Up @@ -34,6 +35,9 @@ class AppConfig:
OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:8080")
MLFLOW_TRACKING_USERNAME = os.environ.get("MLFLOW_TRACKING_USERNAME", secrets.token_urlsafe(32))
MLFLOW_TRACKING_PASSWORD = os.environ.get("MLFLOW_TRACKING_PASSWORD", secrets.token_urlsafe(72))

@staticmethod
def get_property(property_name):
Expand Down
37 changes: 25 additions & 12 deletions mlflow_oidc_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from mlflow.protos.service_pb2 import (
CreateExperiment,
SearchExperiments,
)
from mlflow_oidc_auth.permissions import Permission, get_permission
from mlflow.server.handlers import (
Expand All @@ -47,7 +48,7 @@

# Create the OAuth2 client
auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID"))
mlflow_client = MlflowClient()
mlflow_client = MlflowClient(tracking_uri=AppConfig.get_property("MLFLOW_TRACKING_URI"))
store = SqlAlchemyStore()
store.init_db((AppConfig.get_property("OIDC_USERS_DB_URI")))
_logger = logging.getLogger(__name__)
Expand All @@ -59,6 +60,7 @@ def _get_experiment_id(request_data: dict) -> str:
experiment_id = mlflow_client.get_experiment_by_name(request_data.get("experiment_name")).experiment_id
return experiment_id


def _get_request_param(param: str) -> str:
if request.method == "GET":
args = request.args
Expand All @@ -83,7 +85,14 @@ def _get_request_param(param: str) -> str:


def _is_unprotected_route(path: str) -> bool:
return path.startswith(("/static", "/favicon.ico", "/health", "/login", "/callback", "/oidc/static", "/oidc/ui"))
return path.startswith(
(
"/health",
"/login",
"/callback",
"/oidc/static",
)
)


def _get_permission_from_store_or_default(store_permission_func: Callable[[], str]) -> Permission:
Expand All @@ -105,6 +114,14 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]:
username = request.authorization.username
password = request.authorization.password
_logger.debug("Authenticating user %s", username)
# check for internal call, if credentials are correct, return True
if username == AppConfig.get_property("MLFLOW_TRACKING_USERNAME") and password == AppConfig.get_property(
"MLFLOW_TRACKING_PASSWORD"
):
_set_username(username)
_set_is_admin(True)
_logger.debug("User %s authenticated", username)
return True
if store.authenticate_user(username, password):
_set_username(username)
_logger.debug("User %s authenticated", username)
Expand Down Expand Up @@ -232,9 +249,9 @@ def get_experiment_permission():


# TODO
@catch_mlflow_exception
def search_experiment():
return render_template("home.html", username=_get_username())
# @catch_mlflow_exception
# def search_experiment():
# return render_template("home.html", username=_get_username())


def login():
Expand Down Expand Up @@ -338,9 +355,9 @@ def oidc_ui(filename=None):
return send_from_directory(ui_directory, filename)


# TODO
def search_model():
return render_template("home.html", username=_get_username())
# # TODO
# def search_model():
# return render_template("home.html", username=_get_username())


def create_user():
Expand Down Expand Up @@ -405,10 +422,6 @@ def get_user():
return jsonify({"user": user.to_json()})


def oidc_home():
return render_template("home.html", username=_get_username())


def permissions():
return redirect(url_for("list_users"))

Expand Down
2 changes: 1 addition & 1 deletion scripts/run-dev-server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ cleanup() {
python_preconfigure() {
if [ ! -d venv ]; then
python3 -m venv venv
source venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install build setuptools
source venv/bin/activate
python3 -m pip install -e .
fi
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<div class="controls-container pb-3">
<button mat-flat-button aria-label="login" color="primary" class="me-2">
<button mat-flat-button aria-label="login" color="primary" class="me-2" (click)="redirectToMLFlow()">
<span>Login to MLFlow</span>
</button>
<button mat-flat-button aria-label="login" color="primary" (click)="showAccessKeyModal()">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,8 @@ export class HomePageComponent implements OnInit {
this.dialog.open<AccessKeyModalComponent, AccessKeyDialogData>(AccessKeyModalComponent, { data })
});
}

redirectToMLFlow() {
window.location.href = '/';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<a mat-button [routerLink]="['/home']">Home</a>
<a mat-button [routerLink]="['/admin-panel']">Admin panel</a>
</nav>
<button mat-button [routerLink]="'/logout'" aria-label="logout" class="logout">
<button mat-button [routerLink]="'/logout'" aria-label="logout" class="logout" (click)="logout()">
<mat-icon>exit_to_app</mat-icon>
<span>Logout</span>
</button>
Expand Down
4 changes: 4 additions & 0 deletions web-ui/src/app/shared/components/header/header.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ export class HeaderComponent implements OnInit {

ngOnInit(): void {
}

logout() {
window.location.href = '/logout';
}
}
Loading