diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..4b965c5 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +exclude = + .git, + __pycache__, + build, + dist, + versioneer.py, + tiled/_version.py, + tiled/serialization/_zipfile_py39.py, + docs/source/conf.py, + share, + web-frontend, +max-line-length = 115 \ No newline at end of file diff --git a/.github/workflows/build-deploy.yml b/.github/workflows/build-deploy.yml index 3a3655d..eb47be3 100644 --- a/.github/workflows/build-deploy.yml +++ b/.github/workflows/build-deploy.yml @@ -67,4 +67,3 @@ jobs: push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} - \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7595db1..ca98e97 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,4 @@ api_keys.yml users.yml splash_flows_globus/ -splash_auth/ \ No newline at end of file +splash_auth/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..413679f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-ast + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + +- repo: https://github.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + +- repo: https://github.com/timothycrosley/isort + rev: 5.12.0 + hooks: + - id: isort + +- repo: https://github.com/psf/black + rev: 22.8.0 + hooks: + - id: black \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3800a70 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# Project Name + +## Description + +[Insert project description here] + +## Installation + +[Insert installation instructions here] + +## Usage + +[Insert usage instructions here] + +## Contributing + +[Insert contributing guidelines here] + +## License + +[Insert license information here] diff --git a/makefile b/makefile index 843ad18..409f72a 100644 --- a/makefile +++ b/makefile @@ -19,8 +19,7 @@ build: @podman build -t ${TAG} . --platform=linux/amd64 @echo "tagging to: " ${TAG} ${REGISTRY_TAG} @podman tag ${TAG} ${REGISTRY_TAG} - + push: @echo "Pushing " ${REGISTRY_TAG} @podman push ${REGISTRY_TAG} - diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8c95dbd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "splash_auth" +description = "Authenticating proxy with oidc and api_key support" +readme = { file = "README.md", content-type = "text/markdown" } + +requires-python = ">=3.11" + + + +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11" +] + +dependencies = [ + "fastapi>=0.68.0,<0.69.0", + "uvicorn>=0.15.0,<0.16.0", + "python-jose", + "httpx==0.23.3", + "pyyaml" + +] + +dynamic = ["version"] + + +[tool.hatch] +version.source = "vcs" +build.hooks.vcs.version-file = "server/_version.py" + +[tool.hatch.metadata] +allow-direct-references = true \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6334c2d..5735dd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ fastapi>=0.68.0,<0.69.0 -pydantic>=1.8.0,<2.0.0 +# pydantic>=1.8.0,<2.0.0 uvicorn>=0.15.0,<0.16.0 python-jose httpx==0.23.3 -pyyaml \ No newline at end of file +pyyaml diff --git a/server/_version.py b/server/_version.py new file mode 100644 index 0000000..6bd539b --- /dev/null +++ b/server/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.1.11.dev1+gce6e10f.d20231011' +__version_tuple__ = version_tuple = (0, 1, 11, 'dev1', 'gce6e10f.d20231011') diff --git a/server/config.py b/server/config.py index 6902a71..2ee7f11 100644 --- a/server/config.py +++ b/server/config.py @@ -1,37 +1,38 @@ -import os import logging +import os logger = logging.getLogger("splash_auth.config") - JWT_SECRET = os.environ["JWT_SECRET"] TOKEN_TIME = int(os.environ["TOKEN_EXP_TIME"]) -OAUTH_AUTH_ENDPOINT=os.environ["OAUTH_AUTH_ENDPOINT"] +OAUTH_AUTH_ENDPOINT = os.environ["OAUTH_AUTH_ENDPOINT"] OAUTH_CLIENT_ID = os.environ["OAUTH_CLIENT_ID"] OAUTH_CLIENT_SECRET = os.environ["OAUTH_CLIENT_SECRET"] OAUTH_REDIRECT_URI = os.environ["OAUTH_REDIRECT_URI"] -OAUTH_TOKEN_URI = os.environ["OAUTH_TOKEN_URI"] # can be found at https://accounts.google.com/.well-known/openid-configuration +OAUTH_TOKEN_URI = os.environ[ + "OAUTH_TOKEN_URI" +] # can be found at https://accounts.google.com/.well-known/openid-configuration OUATH_SUCCESS_REDIRECT_URI = os.environ["OUATH_SUCCESS_REDIRECT_URI"] OUATH_FAIL_REDIRECT_URI = os.environ["OUATH_FAIL_REDIRECT_URI"] OUATH_JWKS_URI = os.environ["OUATH_JWKS_URI"] HTTP_CLIENT_MAX_CONNECTIONS = os.getenv("HTTP_CLIENT_MAX_CONNECTIONS", 100) -HTTP_CLIENT_TIMEOUT_ALL= os.getenv("HTTP_CLIENT_TIMEOUT_ALL", 5.0) +HTTP_CLIENT_TIMEOUT_ALL = os.getenv("HTTP_CLIENT_TIMEOUT_ALL", 5.0) HTTP_CLIENT_TIMEOUT_CONNECT = os.getenv("HTTP_CLIENT_TIMEOUT_CONNECT", 3.0) HTTP_CLIENT_TIMEOUT_POOL = os.getenv("HTTP_CLIENT_TIMEOUT_POOL", 10) - -google_claims = {'iss': 'https://accounts.google.com', - 'azp': OAUTH_CLIENT_ID, - 'aud': OAUTH_CLIENT_ID +google_claims = { + "iss": "https://accounts.google.com", + "azp": OAUTH_CLIENT_ID, + "aud": OAUTH_CLIENT_ID, } logger.info(f"JWT_SECRET {JWT_SECRET}") logger.info(f"TOKEN_TIME {TOKEN_TIME}") logger.info(f"OAUTH_AUTH_ENDPOINT {OAUTH_AUTH_ENDPOINT}") logger.info(f"OAUTH_CLIENT_ID {OAUTH_CLIENT_ID}") -logger.info(f"OAUTH_CLIENT_SECRET is a secret") +logger.info("OAUTH_CLIENT_SECRET is a secret") logger.info(f"OAUTH_REDIRECT_URI {OAUTH_REDIRECT_URI}") logger.info(f"OAUTH_TOKEN_URI {OAUTH_TOKEN_URI}") logger.info(f"OUATH_SUCCESS_REDIRECT_URI {OUATH_SUCCESS_REDIRECT_URI}") @@ -41,7 +42,9 @@ logger.info(f"HTTP_CLIENT_TIMEOUT_ALL {HTTP_CLIENT_TIMEOUT_ALL}") logger.info(f"HTTP_CLIENT_TIMEOUT_CONNECT {HTTP_CLIENT_TIMEOUT_CONNECT}") logger.info(f"HTTP_CLIENT_TIMEOUT_POOL {HTTP_CLIENT_TIMEOUT_POOL}") -class Config(): + + +class Config: jwt_secret = JWT_SECRET token_time = TOKEN_TIME oauth_endpoint = OAUTH_AUTH_ENDPOINT @@ -58,4 +61,5 @@ class Config(): http_client_timeout_connect = HTTP_CLIENT_TIMEOUT_CONNECT http_client_timeout_pool = HTTP_CLIENT_TIMEOUT_POOL -config = Config() \ No newline at end of file + +config = Config() diff --git a/server/main.py b/server/main.py index 787d0ee..9da4593 100644 --- a/server/main.py +++ b/server/main.py @@ -1,25 +1,17 @@ import logging -from enum import Enum import os +from enum import Enum from typing import List, Optional, Union -from fastapi import ( - Cookie, - Depends, - FastAPI, - HTTPException, - Request, - Response -) +import httpx +from fastapi import Cookie, Depends, FastAPI, HTTPException, Request, Response from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.security import HTTPBearer from jose import jwt -import httpx -from starlette.responses import StreamingResponse from starlette.background import BackgroundTask +from starlette.responses import StreamingResponse from starlette.status import HTTP_403_FORBIDDEN, HTTP_502_BAD_GATEWAY - from .config import config from .oidc import oidc_router from .user_db import users_db @@ -61,18 +53,24 @@ """ + def new_httpx_client(): limits = httpx.Limits(max_connections=config.http_client_max_connections) timeout = httpx.Timeout( config.http_client_timeout_all, connect=config.http_client_timeout_connect, - pool=config.http_client_timeout_pool) - + pool=config.http_client_timeout_pool, + ) + return httpx.AsyncClient( - base_url="http://prefect_server:4200", limits=limits, timeout=timeout) + base_url="http://prefect_server:4200", limits=limits, timeout=timeout + ) + + client = new_httpx_client() -@app.on_event('shutdown') + +@app.on_event("shutdown") async def shutdown_event(): global client await client.aclose() @@ -84,37 +82,43 @@ class Scopes(str, Enum): @app.get("/login", response_class=HTMLResponse) -async def endpoint_login(redirect : Union[str, None] = None): +async def endpoint_login(redirect: Union[str, None] = None): """ This endpoint prints a login form. Currently, this directs the user to google for login. The mechanics of OIDC login from google are handled in oidc.py - """ - return AUTH_SITE.format(config.oauth_endpoint, config.oauth_redirect_uri, config.oauth_client_id ) + """ + return AUTH_SITE.format( + config.oauth_endpoint, config.oauth_redirect_uri, config.oauth_client_id + ) # return RedirectResponse("http://noether.lbl.gov:7443/data_workspace/login") -@app.api_route("/{path:path}", methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS", "HEAD"]) -async def endpoint_reverse_proxy(request: Request, - response: Response, - als_token: Union[str, None] = Cookie(default=None), - # api_key: APIKey = Depends(get_api_key_from_request), - bearer: HTTPBearer = Depends(http_bearer)) -> StreamingResponse: +@app.api_route( + "/{path:path}", methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS", "HEAD"] +) +async def endpoint_reverse_proxy( + request: Request, + response: Response, + als_token: Union[str, None] = Cookie(default=None), + # api_key: APIKey = Depends(get_api_key_from_request), + bearer: HTTPBearer = Depends(http_bearer), +) -> StreamingResponse: """ - This endpoint server as a reverse proxy for prefect messages. It authenticates every message using one of + This endpoint server as a reverse proxy for prefect messages. It authenticates every message using one of two methods. 1. Authorization: Bearer - This method is allows clients to send a provided key. It is the primary way that prefect agents + This method is allows clients to send a provided key. It is the primary way that prefect agents can authenticate. 2. Cookie - This method is used for logging into the prefect UI. If a cookie is not set in the message, the + This method is used for logging into the prefect UI. If a cookie is not set in the message, the user is redirected to the /login endpoint, which allows them to login. - + """ logger.info(f"{request.method} - {request.url}") - ### check for api key in bearer - if bearer: + # check for api key in bearer + if bearer: if bearer.credentials in users_db.api_keys: return await _reverse_proxy(request) else: @@ -122,14 +126,14 @@ async def endpoint_reverse_proxy(request: Request, raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials" ) - - ### check for cookie + + # check for cookie if not als_token: return RedirectResponse("/login") - - ### check if cookie's value is valid + + # check if cookie's value is valid try: - decoded_value = jwt.decode(als_token, config.jwt_secret, algorithms=["HS256"]) + jwt.decode(als_token, config.jwt_secret, algorithms=["HS256"]) except jwt.ExpiredSignatureError: # Signature has expired logger.debug("Signature expired in cookie") @@ -138,34 +142,35 @@ async def endpoint_reverse_proxy(request: Request, response.status_code = 200 try: return await _reverse_proxy(request) - except Exception as e: + except Exception: # a problem exists with the client not accepting new connections - # this is ugly, but we try and keep the service running by killing + # this is ugly, but we try and keep the service running by killing # the client and starting fresh logger.error("Exception from http client", exc_info=1) global client await client.aclose() client = new_httpx_client() raise HTTPException( - status_code=HTTP_502_BAD_GATEWAY, detail=f"Excpetion talking to service" + status_code=HTTP_502_BAD_GATEWAY, detail="Excpetion talking to service" ) - async def close(resp: StreamingResponse): await resp.aclose() -async def _reverse_proxy(request: Request, scopes: Optional[List[str]] = None) -> StreamingResponse: + +async def _reverse_proxy( + request: Request, scopes: Optional[List[str]] = None +) -> StreamingResponse: # # cheap and quick scope feature # if scopes and request.method.lower() in sceope global client - url = httpx.URL(path=request.url.path, - query=request.url.query.encode("utf-8")) - rp_req = client.build_request(request.method, url, - headers=request.headers.raw, - content=await request.body()) - + url = httpx.URL(path=request.url.path, query=request.url.query.encode("utf-8")) + rp_req = client.build_request( + request.method, url, headers=request.headers.raw, content=await request.body() + ) + rp_resp = await client.send(rp_req, stream=True) return StreamingResponse( rp_resp.aiter_raw(), @@ -173,4 +178,3 @@ async def _reverse_proxy(request: Request, scopes: Optional[List[str]] = None) - headers=rp_resp.headers, background=BackgroundTask(close, rp_resp), ) - diff --git a/server/oidc.py b/server/oidc.py index e6d2880..6932cbc 100644 --- a/server/oidc.py +++ b/server/oidc.py @@ -1,36 +1,19 @@ -from fastapi import APIRouter -from dataclasses import dataclass -from datetime import datetime, timezone, timedelta -import logging -import os +from datetime import datetime, timedelta, timezone from typing import Dict - -from fastapi import ( - Cookie, - Depends, - FastAPI, - Request, - Response, - Security -) -from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey - -from pydantic import BaseModel - -from fastapi.responses import HTMLResponse, RedirectResponse -from jose import jwt, jwk import httpx +from fastapi import APIRouter, Request +from fastapi.responses import RedirectResponse +from jose import jwk, jwt -from .config import google_claims, config +from .config import config, google_claims from .user_db import users_db - oidc_router = APIRouter(prefix="/oidc") # Singleton keyset, can be refreshed in get_keys() -oauth_validation_keyset = None +oauth_validation_keyset = None class OAuthKeysUnavailableException(Exception): @@ -42,10 +25,11 @@ class KeyNotFoundError(Exception): def contstruct_key(kid: str, keys: Dict): - for key in keys['keys']: - if key['kid'] == kid: + for key in keys["keys"]: + if key["kid"] == kid: return jwk.construct(key) + async def find_key(token): """finds a key from the configured keys based on the kid claim of the token Args: @@ -56,7 +40,7 @@ async def find_key(token): Key: found key object """ unverified = jwt.get_unverified_header(token) - kid = unverified.get('kid') + kid = unverified.get("kid") print(f"~!!!!!!! kid {kid}") if not kid: raise KeyNotFoundError("kid not found in jwt") @@ -70,8 +54,9 @@ async def find_key(token): raise KeyError(f"Key not found in fetched keys {keys}") return key + def validate_jwt(token, key, access_token): - jwt.decode(token, key, audience=google_claims['aud'], access_token=access_token) + jwt.decode(token, key, audience=google_claims["aud"], access_token=access_token) async def exchange_code(token_url, auth_code, client_id, client_secret, redirect_uri): @@ -88,44 +73,41 @@ async def exchange_code(token_url, auth_code, client_id, client_secret, redirect "client_id": client_id, "redirect_uri": redirect_uri, "code": auth_code, - "client_secret": client_secret - } + "client_secret": client_secret, + }, ) return response.json() -async def get_keys(stale = False): + +async def get_keys(stale=False): """ - Fetch oauth_validateion_keyset from OUATH_JWKS_URI - + Fetch oauth_validateion_keyset from OUATH_JWKS_URI + """ if oauth_validation_keyset and not stale: return oauth_validation_keyset - + async with httpx.AsyncClient() as client: response = await client.get(config.oauth_jwks_uri) if response.is_success: return response.json() - raise OAuthKeysUnavailableException(f"Cannot get keyset from OAuth server {response}") + raise OAuthKeysUnavailableException( + f"Cannot get keyset from OAuth server {response}" + ) + async def get_user_info(user_info_url, access_token): - """Unused but useful method for getting additional user information - """ + """Unused but useful method for getting additional user information""" response = httpx.get( - url=user_info_url, - headers={ - "Authorization": "Bearer " + access_token - } + url=user_info_url, headers={"Authorization": "Bearer " + access_token} ) return response.json() + async def get_user_info(user_info_url, access_token): - """Unused but useful method for getting additional user information - """ + """Unused but useful method for getting additional user information""" response = httpx.get( - url=user_info_url, - headers={ - "Authorization": "Bearer " + access_token - } + url=user_info_url, headers={"Authorization": "Bearer " + access_token} ) return response.json() @@ -133,19 +115,25 @@ async def get_user_info(user_info_url, access_token): @oidc_router.get("/auth/code") async def endpoint_validate_ouath_code(request: Request): """ - Do OAuth2 token exchange with the configured service (Google, ORCID). + Do OAuth2 token exchange with the configured service (Google, ORCID). + + Does a back-channel communicaiton with the service, and returns a + JWT that we produce. - Does a back-channel communicaiton with the service, and returns a - JWT that we produce. - """ print(f"request.query_params {request.query_params}") - code = request.query_params['code'] - response_body = await exchange_code(config.oauth_token_url, code, config.oauth_client_id, config.oauth_client_secret, config.oauth_redirect_uri) + code = request.query_params["code"] + response_body = await exchange_code( + config.oauth_token_url, + code, + config.oauth_client_id, + config.oauth_client_secret, + config.oauth_redirect_uri, + ) print(response_body) - id_token = response_body['id_token'] - access_token = response_body['access_token'] + id_token = response_body["id_token"] + access_token = response_body["access_token"] key = await find_key(id_token) validate_jwt(id_token, key, access_token) # # below is a second method that we used to get more profile info...but then we found that in the html @@ -154,15 +142,19 @@ async def endpoint_validate_ouath_code(request: Request): # if token_user_info['email'] not in users_db.users: # return RedirectResponse(config.oauth_fail_redirect_uri) id_claims = jwt.get_unverified_claims(id_token) - if id_claims['email'] not in users_db.users: + if id_claims["email"] not in users_db.users: return RedirectResponse(config.oauth_fail_redirect_uri) - encoded_jwt = jwt.encode({ - "email" : id_claims['email'], - "exp": datetime.now(tz=timezone.utc) + timedelta(seconds=config.token_time) - }, - config.jwt_secret, - algorithm="HS256") + encoded_jwt = jwt.encode( + { + "email": id_claims["email"], + "exp": datetime.now(tz=timezone.utc) + timedelta(seconds=config.token_time), + }, + config.jwt_secret, + algorithm="HS256", + ) response = RedirectResponse(config.oauth_success_redirect_uri) - response.set_cookie(key="als_user", value=id_claims['email'], httponly=True) - response.set_cookie(key="als_token", value=encoded_jwt, httponly=True, max_age=86400) + response.set_cookie(key="als_user", value=id_claims["email"], httponly=True) + response.set_cookie( + key="als_token", value=encoded_jwt, httponly=True, max_age=86400 + ) return response diff --git a/server/user_db.py b/server/user_db.py index fe8589d..c1f668e 100644 --- a/server/user_db.py +++ b/server/user_db.py @@ -1,17 +1,17 @@ - from dataclasses import dataclass from typing import Dict, List import yaml + @dataclass -class Users(): +class Users: users: List[str] api_keys: List[str] def populate_users(): - users= Users([], []) + users = Users([], []) with open("/app/users.yml", "r") as users_file: users.users = yaml.safe_load(users_file)["users"] @@ -19,4 +19,5 @@ def populate_users(): users.api_keys = yaml.safe_load(api_keys_file)["valid_keys"] return users -users_db = populate_users() \ No newline at end of file + +users_db = populate_users()