-
Notifications
You must be signed in to change notification settings - Fork 363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
backend: Add refresh token logic to OAuth providers #252
Closed
Closed
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
7705f76
add login page components
misspia-cohere ddf4ee4
remove video from cell bg and add missing image
misspia-cohere 7ddb0ef
Add Register page and hooks for auth
malexw 572add0
Add the register page and connect all the frontend elements
malexw 91b9d31
Cleanup
malexw eb37ac2
Merge branch 'main' into alexw/coral-login
tianjing-li 1c556a5
Temporarily turn off auth for API /users route and fix some registrat…
malexw 9d04c3e
Redirect to /login if the token expires and clean up some console errors
malexw ab704c4
Initial version of Google SSO login plus OpenID components
malexw 488573d
Add error messages for failed logins
malexw dec4a35
Merge branch 'main' into alexw/coral-login
malexw 9701765
merge main
tianjing-li 41be885
Updates
tianjing-li 53bb041
add comment
tianjing-li ed07ae8
Pass code instead of state to Google OAuth callback
malexw 31de868
refactor OAuth logic
tianjing-li b3b7219
Add google oauth authorize logic
tianjing-li 206b6d6
Add backend changes to fetch endpoints dynamically
tianjing-li 6865612
Add oidc logic
tianjing-li c1bd1c8
Dynamically set SSO login buttons from auth_strategies
malexw 38484b4
Cleanup to tie the E2E SSO flows together
malexw 90785ff
Don't enable Basic Auth for everyone
malexw 2637536
A few fixes from code review
malexw d33d23d
Merge branch 'main' into alexw/coral-login
malexw a91f3d8
Merge branch 'alexw/coral-login' into alexw/google-oidc-sso
malexw 5b8e861
Remove old SSO code
malexw 2e1d40a
Fix styling in backend
malexw b44873c
Removing dead code and old comments
malexw a2ec56a
Merge branch 'main' into refresh-token
tianjing-li 9d7564c
add refresh token logic (wip)
tianjing-li c363248
Add refresh token guide
tianjing-li 4a12a8e
wip refresh token
tianjing-li b97edc0
add testing
tianjing-li e8c0a15
Handle auth token expiry and refresh in the frontend
malexw 96959e2
add merge conflicts
tianjing-li f2ba0c8
Resolve merge conflicts and fix builds
malexw 9b60500
Fixes for various auth issues
malexw 7ffe14f
Rewrite refresh to attempt JWT reissue on requests within 1 hour of e…
malexw 923ea06
Merge branch 'main' into refresh-token
malexw 20601a3
Reimplementation of JWT refresh; also set auth timings lower for testing
malexw c8ea211
Merge branch 'main' into refresh-token
malexw 67c5d0a
Merge branch 'main' into refresh-token
malexw 601dd86
Block JWTs once we've issued a replacement
malexw 6e130cd
Merge branch 'main' into refresh-token
malexw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from backend.middleware.logging import LoggingMiddleware | ||
|
||
__all__ = [ | ||
"LoggingMiddleware", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import logging | ||
import time | ||
|
||
from starlette.middleware.base import BaseHTTPMiddleware | ||
from starlette.requests import Request | ||
|
||
|
||
class LoggingMiddleware(BaseHTTPMiddleware): | ||
async def dispatch(self, request: Request, call_next): | ||
start_time = time.time() | ||
response = await call_next(request) | ||
|
||
logging.info(f"{request.method} {request.url.path}\n{request.headers}") | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,23 @@ | |
import uuid | ||
|
||
import jwt | ||
from fastapi import Depends | ||
from sqlalchemy.orm import Session | ||
|
||
from backend.database_models import Blacklist, get_session | ||
from backend.services.logger import get_logger | ||
|
||
logger = get_logger() | ||
|
||
|
||
class JWTService: | ||
ISSUER = "cohere-toolkit" | ||
EXPIRY_DAYS = 90 | ||
# AUTH_EXPIRY_MONTHS = 3 | ||
# JWT_EXPIRY_HOURS = 72 | ||
# REFRESH_AVAILABILITY_HOURS = 36 | ||
AUTH_EXPIRY_SECONDS = 60 * 2 | ||
JWT_EXPIRY_SECONDS = 30 | ||
REFRESH_AVAILABILITY_SECONDS = 15 | ||
ALGORITHM = "HS256" | ||
|
||
def __init__(self): | ||
|
@@ -25,29 +33,56 @@ def __init__(self): | |
|
||
self.secret_key = secret_key | ||
|
||
def create_and_encode_jwt(self, user: dict) -> str: | ||
def create_and_encode_jwt(self, user: dict, strategy_name: str, **kwargs) -> str: | ||
""" | ||
Creates a payload based on user info and creates a JWT token. | ||
|
||
Args: | ||
user (dict): User data. | ||
strategy_name (str): Name of the authentication strategy. | ||
kwargs: Additional payload data. These parameters will override normal payload data. | ||
|
||
Returns: | ||
str: JWT token. | ||
""" | ||
now = datetime.datetime.utcnow() | ||
now = datetime.datetime.now(datetime.timezone.utc) | ||
payload = { | ||
"iss": self.ISSUER, | ||
"iat": now, | ||
"exp": now + datetime.timedelta(days=self.EXPIRY_DAYS), | ||
# "exp": now + datetime.timedelta(hours=self.AUTH_EXPIRY_HOURS), | ||
"exp": now + datetime.timedelta(seconds=self.JWT_EXPIRY_SECONDS), | ||
"jti": str(uuid.uuid4()), | ||
"strategy": strategy_name, | ||
"context": user, | ||
**kwargs, | ||
} | ||
|
||
token = jwt.encode(payload, self.secret_key, self.ALGORITHM) | ||
|
||
return token | ||
|
||
def refresh_jwt(self, token: str) -> str: | ||
""" | ||
Refreshes a given JWT token. Callers should check if the token is expired before calling this method. | ||
|
||
Args: | ||
token (str): JWT token. | ||
|
||
Returns: | ||
str: New JWT token. | ||
""" | ||
decoded_payload = self.decode_jwt(token) | ||
|
||
if not decoded_payload: | ||
return None | ||
|
||
# Create new token with same payload | ||
return self.create_and_encode_jwt( | ||
decoded_payload["context"], | ||
decoded_payload["strategy"], | ||
iat=decoded_payload["iat"], | ||
) | ||
|
||
def decode_jwt(self, token: str) -> dict: | ||
""" | ||
Decodes a given JWT token. | ||
|
@@ -64,8 +99,77 @@ def decode_jwt(self, token: str) -> dict: | |
) | ||
return decoded_payload | ||
except jwt.ExpiredSignatureError: | ||
logger.warning("Token has expired.") | ||
return None | ||
logger.warning("JWT Token has expired.") | ||
decoded_payload = jwt.decode( | ||
token, | ||
self.secret_key, | ||
algorithms=[self.ALGORITHM], | ||
options={"verify_exp": False}, | ||
) | ||
return decoded_payload | ||
except jwt.InvalidTokenError: | ||
logger.warning("Invalid token.") | ||
logger.warning("JWT Token is invalid.") | ||
return None | ||
|
||
@staticmethod | ||
def check_validity(payload: dict, session: Session = Depends(get_session)) -> str: | ||
""" | ||
Checks a given JWT payload for its validity. | ||
|
||
Args: | ||
payload (dict): JWT payload. | ||
|
||
Returns: | ||
str: One of the following strings depending on the validity: | ||
- "valid": Token is valid. | ||
- "refreshable": Token is valid and within the refresh availability window. | ||
- "expired": Token is expired or blacklisted. | ||
- "invalid": Token is invalid. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thoughts on adding something here to prevent replays of refresh tokens? |
||
|
||
if payload is None or any( | ||
[ | ||
"context" not in payload, | ||
"jti" not in payload, | ||
"exp" not in payload, | ||
"strategy" not in payload, | ||
"iat" not in payload, | ||
] | ||
): | ||
return "invalid" | ||
GangGreenTemperTatum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
now = datetime.datetime.now(datetime.timezone.utc) | ||
|
||
# Check if token is blacklisted | ||
blacklist = ( | ||
session.query(Blacklist) | ||
.filter(Blacklist.token_id == payload["jti"]) | ||
.first() | ||
) | ||
|
||
if blacklist is not None: | ||
logger.warning("JWT payload is blacklisted.") | ||
return "expired" | ||
|
||
# Check if token is expired; either we're past the expiry time or iat is more than 3 months ago | ||
GangGreenTemperTatum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
expiry_datetime = datetime.datetime.fromtimestamp( | ||
payload["exp"], datetime.timezone.utc | ||
) | ||
issued_datetime = datetime.datetime.fromtimestamp( | ||
payload["iat"], datetime.timezone.utc | ||
) | ||
|
||
# if now > expiry_datetime or now > (issued_datetime + datetime.timedelta(months=JWTService.AUTH_EXPIRY_MONTHS)): | ||
if now > expiry_datetime or now > ( | ||
issued_datetime + datetime.timedelta(seconds=JWTService.AUTH_EXPIRY_SECONDS) | ||
): | ||
return "expired" | ||
|
||
# if now > (expiry_datetime - datetime.timedelta(hours=JWTService.REFRESH_AVAILABILITY_HOURS)): | ||
if now > ( | ||
expiry_datetime | ||
- datetime.timedelta(seconds=JWTService.REFRESH_AVAILABILITY_SECONDS) | ||
): | ||
return "refreshable" | ||
|
||
return "valid" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc, this may be too small of a value? if im right in thinking, i would suggest changing the actual authentication expiry of the current
access_token
to a value that is higher? IE,access_token
validity==1hr &&refresh_token
validity==6-24hrs etci assume since this is local deployment/backend, these can be set as per the environment requirements so it may be best to go with a boilerplate approach and add some definition around the struct for easy adoption
apologies if i am missing something here :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented out timings here are intended for production, while the shorter times in seconds are just for testing on this branch to make sure things are expired or refreshed properly