Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Add refresh token logic to OAuth providers #252

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7705f76
add login page components
misspia-cohere May 28, 2024
ddf4ee4
remove video from cell bg and add missing image
misspia-cohere May 28, 2024
7ddb0ef
Add Register page and hooks for auth
malexw Jun 4, 2024
572add0
Add the register page and connect all the frontend elements
malexw Jun 6, 2024
91b9d31
Cleanup
malexw Jun 6, 2024
eb37ac2
Merge branch 'main' into alexw/coral-login
tianjing-li Jun 6, 2024
1c556a5
Temporarily turn off auth for API /users route and fix some registrat…
malexw Jun 6, 2024
9d04c3e
Redirect to /login if the token expires and clean up some console errors
malexw Jun 8, 2024
ab704c4
Initial version of Google SSO login plus OpenID components
malexw Jun 11, 2024
488573d
Add error messages for failed logins
malexw Jun 12, 2024
dec4a35
Merge branch 'main' into alexw/coral-login
malexw Jun 12, 2024
9701765
merge main
tianjing-li Jun 13, 2024
41be885
Updates
tianjing-li Jun 13, 2024
53bb041
add comment
tianjing-li Jun 13, 2024
ed07ae8
Pass code instead of state to Google OAuth callback
malexw Jun 14, 2024
31de868
refactor OAuth logic
tianjing-li Jun 14, 2024
b3b7219
Add google oauth authorize logic
tianjing-li Jun 14, 2024
206b6d6
Add backend changes to fetch endpoints dynamically
tianjing-li Jun 14, 2024
6865612
Add oidc logic
tianjing-li Jun 14, 2024
c1bd1c8
Dynamically set SSO login buttons from auth_strategies
malexw Jun 14, 2024
38484b4
Cleanup to tie the E2E SSO flows together
malexw Jun 15, 2024
90785ff
Don't enable Basic Auth for everyone
malexw Jun 17, 2024
2637536
A few fixes from code review
malexw Jun 17, 2024
d33d23d
Merge branch 'main' into alexw/coral-login
malexw Jun 17, 2024
a91f3d8
Merge branch 'alexw/coral-login' into alexw/google-oidc-sso
malexw Jun 18, 2024
5b8e861
Remove old SSO code
malexw Jun 18, 2024
2e1d40a
Fix styling in backend
malexw Jun 18, 2024
b44873c
Removing dead code and old comments
malexw Jun 18, 2024
a2ec56a
Merge branch 'main' into refresh-token
tianjing-li Jun 19, 2024
9d7564c
add refresh token logic (wip)
tianjing-li Jun 19, 2024
c363248
Add refresh token guide
tianjing-li Jun 19, 2024
4a12a8e
wip refresh token
tianjing-li Jun 20, 2024
b97edc0
add testing
tianjing-li Jun 20, 2024
e8c0a15
Handle auth token expiry and refresh in the frontend
malexw Jul 3, 2024
96959e2
add merge conflicts
tianjing-li Jul 4, 2024
f2ba0c8
Resolve merge conflicts and fix builds
malexw Jul 4, 2024
9b60500
Fixes for various auth issues
malexw Jul 8, 2024
7ffe14f
Rewrite refresh to attempt JWT reissue on requests within 1 hour of e…
malexw Jul 15, 2024
923ea06
Merge branch 'main' into refresh-token
malexw Jul 18, 2024
20601a3
Reimplementation of JWT refresh; also set auth timings lower for testing
malexw Jul 18, 2024
c8ea211
Merge branch 'main' into refresh-token
malexw Jul 18, 2024
67c5d0a
Merge branch 'main' into refresh-token
malexw Jul 21, 2024
601dd86
Block JWTs once we've issued a replacement
malexw Jul 24, 2024
6e130cd
Merge branch 'main' into refresh-token
malexw Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions docs/auth_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,21 @@ import secrets
print(secrets.token_hex(32))
```

## Configuring your OAuth app's Redirect URI
## Configuring your OAuth app

### Redirect URI

When configuring your OAuth apps, make sure to whitelist the Redirect URI to the frontend endpoint, it should look like
`<FRONTEND_HOST>/auth/<STRATEGY_NAME>`. For example, your Redirect URI will be `http://localhost:4000/auth/google` if you're running the GoogleOAuth class locally.
`<FRONTEND_HOST>/auth/<STRATEGY_NAME>`. For example, your Redirect URI will be `http://localhost:4000/auth/google` if you're running the GoogleOAuth class locally. The strategy name is defined in the `NAME` class attribute.

### Enabling Refresh Tokens

To enable refresh tokens, you must implement the `get_refresh_token_params()` method in your auth strategy class. This should return a dictionary containing key-value pairs that contain the query parameters the auth provider needs to return a refresh token. For example, if your auth provider requires a `?scope=offline` query parameter, you should add:

```python
def get_refresh_token_params(self):
return {"scope": "offline"}
```

## Enabling Proof of Key Code Exchange (PKCE)

Expand Down
16 changes: 15 additions & 1 deletion src/backend/config/auth.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import os
import sys
from typing import Union

from dotenv import load_dotenv
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from backend.services.auth import BasicAuthentication, GoogleOAuth, OpenIDConnect
from backend.services.auth.strategies.base import (
BaseAuthenticationStrategy,
BaseOAuthStrategy,
)

load_dotenv()

SKIP_AUTH = os.getenv("SKIP_AUTH", None)
# Add Auth strategy classes here to enable them
# Ex: [BasicAuthentication]
ENABLED_AUTH_STRATEGIES = []
ENABLED_AUTH_STRATEGIES = [BasicAuthentication, GoogleOAuth]
if "pytest" in sys.modules or SKIP_AUTH == "true":
ENABLED_AUTH_STRATEGIES = []

Expand Down Expand Up @@ -49,6 +54,15 @@ def is_authentication_enabled() -> bool:
return False


def get_auth_strategy(
strategy_name: str,
) -> Union[BaseAuthenticationStrategy, BaseOAuthStrategy]:
if strategy_name not in ENABLED_AUTH_STRATEGY_MAPPING.keys():
return None

return ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]


async def get_auth_strategy_endpoints() -> None:
"""
Fetches the endpoints for each enabled strategy.
Expand Down
2 changes: 2 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from backend.routers.snapshot import router as snapshot_router
from backend.routers.tool import router as tool_router
from backend.routers.user import router as user_router
from backend.services.auth.request_validators import UPDATE_TOKEN_HEADER
from backend.services.logger import LoggingMiddleware, get_logger
from backend.services.metrics import MetricsMiddleware

Expand Down Expand Up @@ -75,6 +76,7 @@ def create_app():
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=[UPDATE_TOKEN_HEADER],
)
app.add_middleware(LoggingMiddleware)
app.add_middleware(MetricsMiddleware)
Expand Down
5 changes: 5 additions & 0 deletions src/backend/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from backend.middleware.logging import LoggingMiddleware

__all__ = [
"LoggingMiddleware",
]
14 changes: 14 additions & 0 deletions src/backend/middleware/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
import time

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request


class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)

logging.info(f"{request.method} {request.url.path}\n{request.headers}")
return response
16 changes: 10 additions & 6 deletions src/backend/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.responses import RedirectResponse
from starlette.requests import Request

from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING, get_auth_strategy
from backend.config.routers import RouterName
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import blacklist as blacklist_crud
Expand Down Expand Up @@ -49,6 +49,11 @@ def get_strategies() -> list[ListAuthStrategy]:
if hasattr(strategy_instance, "get_authorization_endpoint")
else None
),
"refresh_token_params": (
strategy_instance.get_refresh_token_params()
if hasattr(strategy_instance, "get_refresh_token_params")
else None
),
"pkce_enabled": (
strategy_instance.get_pkce_enabled()
if hasattr(strategy_instance, "get_pkce_enabled")
Expand Down Expand Up @@ -80,13 +85,12 @@ async def login(request: Request, login: Login, session: DBSessionDep):
strategy_name = login.strategy
payload = login.payload

if not is_enabled_authentication_strategy(strategy_name):
strategy = get_auth_strategy(strategy_name)
if not strategy:
raise HTTPException(
status_code=422, detail=f"Invalid Authentication strategy: {strategy_name}."
)

# Check that the payload required is given
strategy = ENABLED_AUTH_STRATEGY_MAPPING[strategy_name]
strategy_payload = strategy.get_required_payload()
if not set(strategy_payload).issubset(payload.keys()):
missing_keys = [key for key in strategy_payload if key not in payload.keys()]
Expand All @@ -102,7 +106,7 @@ async def login(request: Request, login: Login, session: DBSessionDep):
detail=f"Error performing {strategy_name} authentication with payload: {payload}.",
)

token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, strategy_name)

return {"token": token}

Expand Down Expand Up @@ -165,7 +169,7 @@ async def authorize(
# Get or create user, then set session user
user = get_or_create_user(session, userinfo)

token = JWTService().create_and_encode_jwt(user)
token = JWTService().create_and_encode_jwt(user, strategy_name)

return {"token": token}

Expand Down
118 changes: 111 additions & 7 deletions src/backend/services/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
import uuid

import jwt
from fastapi import Depends
from sqlalchemy.orm import Session

from backend.database_models import Blacklist, get_session
from backend.services.logger import get_logger

logger = get_logger()


class JWTService:
ISSUER = "cohere-toolkit"
EXPIRY_DAYS = 90
# AUTH_EXPIRY_MONTHS = 3
# JWT_EXPIRY_HOURS = 72
# REFRESH_AVAILABILITY_HOURS = 36
AUTH_EXPIRY_SECONDS = 60 * 2
JWT_EXPIRY_SECONDS = 30
REFRESH_AVAILABILITY_SECONDS = 15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, this may be too small of a value? if im right in thinking, i would suggest changing the actual authentication expiry of the current access_token to a value that is higher? IE, access_token validity==1hr && refresh_token validity==6-24hrs etc

i assume since this is local deployment/backend, these can be set as per the environment requirements so it may be best to go with a boilerplate approach and add some definition around the struct for easy adoption

apologies if i am missing something here :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented out timings here are intended for production, while the shorter times in seconds are just for testing on this branch to make sure things are expired or refreshed properly

ALGORITHM = "HS256"

def __init__(self):
Expand All @@ -25,29 +33,56 @@ def __init__(self):

self.secret_key = secret_key

def create_and_encode_jwt(self, user: dict) -> str:
def create_and_encode_jwt(self, user: dict, strategy_name: str, **kwargs) -> str:
"""
Creates a payload based on user info and creates a JWT token.

Args:
user (dict): User data.
strategy_name (str): Name of the authentication strategy.
kwargs: Additional payload data. These parameters will override normal payload data.

Returns:
str: JWT token.
"""
now = datetime.datetime.utcnow()
now = datetime.datetime.now(datetime.timezone.utc)
payload = {
"iss": self.ISSUER,
"iat": now,
"exp": now + datetime.timedelta(days=self.EXPIRY_DAYS),
# "exp": now + datetime.timedelta(hours=self.AUTH_EXPIRY_HOURS),
"exp": now + datetime.timedelta(seconds=self.JWT_EXPIRY_SECONDS),
"jti": str(uuid.uuid4()),
"strategy": strategy_name,
"context": user,
**kwargs,
}

token = jwt.encode(payload, self.secret_key, self.ALGORITHM)

return token

def refresh_jwt(self, token: str) -> str:
"""
Refreshes a given JWT token. Callers should check if the token is expired before calling this method.

Args:
token (str): JWT token.

Returns:
str: New JWT token.
"""
decoded_payload = self.decode_jwt(token)

if not decoded_payload:
return None

# Create new token with same payload
return self.create_and_encode_jwt(
decoded_payload["context"],
decoded_payload["strategy"],
iat=decoded_payload["iat"],
)

def decode_jwt(self, token: str) -> dict:
"""
Decodes a given JWT token.
Expand All @@ -64,8 +99,77 @@ def decode_jwt(self, token: str) -> dict:
)
return decoded_payload
except jwt.ExpiredSignatureError:
logger.warning("Token has expired.")
return None
logger.warning("JWT Token has expired.")
decoded_payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.ALGORITHM],
options={"verify_exp": False},
)
return decoded_payload
except jwt.InvalidTokenError:
logger.warning("Invalid token.")
logger.warning("JWT Token is invalid.")
return None

@staticmethod
def check_validity(payload: dict, session: Session = Depends(get_session)) -> str:
"""
Checks a given JWT payload for its validity.

Args:
payload (dict): JWT payload.

Returns:
str: One of the following strings depending on the validity:
- "valid": Token is valid.
- "refreshable": Token is valid and within the refresh availability window.
- "expired": Token is expired or blacklisted.
- "invalid": Token is invalid.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on adding something here to prevent replays of refresh tokens?


if payload is None or any(
[
"context" not in payload,
"jti" not in payload,
"exp" not in payload,
"strategy" not in payload,
"iat" not in payload,
]
):
return "invalid"
GangGreenTemperTatum marked this conversation as resolved.
Show resolved Hide resolved

now = datetime.datetime.now(datetime.timezone.utc)

# Check if token is blacklisted
blacklist = (
session.query(Blacklist)
.filter(Blacklist.token_id == payload["jti"])
.first()
)

if blacklist is not None:
logger.warning("JWT payload is blacklisted.")
return "expired"

# Check if token is expired; either we're past the expiry time or iat is more than 3 months ago
GangGreenTemperTatum marked this conversation as resolved.
Show resolved Hide resolved
expiry_datetime = datetime.datetime.fromtimestamp(
payload["exp"], datetime.timezone.utc
)
issued_datetime = datetime.datetime.fromtimestamp(
payload["iat"], datetime.timezone.utc
)

# if now > expiry_datetime or now > (issued_datetime + datetime.timedelta(months=JWTService.AUTH_EXPIRY_MONTHS)):
if now > expiry_datetime or now > (
issued_datetime + datetime.timedelta(seconds=JWTService.AUTH_EXPIRY_SECONDS)
):
return "expired"

# if now > (expiry_datetime - datetime.timedelta(hours=JWTService.REFRESH_AVAILABILITY_HOURS)):
if now > (
expiry_datetime
- datetime.timedelta(seconds=JWTService.REFRESH_AVAILABILITY_SECONDS)
):
return "refreshable"

return "valid"
Loading
Loading