Skip to content

Commit

Permalink
#31 #8 add login routine via JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwunderbar666 committed Sep 12, 2023
1 parent 84268ac commit 8cca4c9
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 22 deletions.
6 changes: 5 additions & 1 deletion flaskinventory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#### Load Extensions ####
# Login Extension
from flask_login import LoginManager, AnonymousUserMixin

# E-Mail Extension
from flask_mail import Mail
# Forms Extension
Expand All @@ -29,7 +30,6 @@
class AnonymousUser(AnonymousUserMixin):
_role = 0


login_manager = LoginManager()
login_manager.login_view = 'users.login'
login_manager.login_message_category = 'info'
Expand All @@ -40,6 +40,8 @@ class AnonymousUser(AnonymousUserMixin):

limiter = Limiter(key_func=get_remote_address)

# JWT Extension
from flaskinventory.users.authentication import jwt

def create_app(config_class=Config, config_json=None):
# assert versions
Expand Down Expand Up @@ -91,6 +93,8 @@ def create_app(config_class=Config, config_json=None):
app.register_blueprint(main)
app.register_blueprint(errors)

jwt.init_app(app)

dgraph.init_app(app)
login_manager.init_app(app)
mail.init_app(app)
Expand Down
3 changes: 2 additions & 1 deletion flaskinventory/api/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,6 @@

LoginToken = typing.TypedDict('LoginToken', {
"status": str,
"token": str
"access_token": str,
"refresh_token": str
})
143 changes: 126 additions & 17 deletions flaskinventory/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class API(Blueprint):

routes = {}

_jwt_required_kwargs = ['optional', 'fresh', 'refresh', 'locations',
'verify_type', 'skip_revocation_check']

@staticmethod
def abort(status_code, message=None):
response = jsonify({
Expand Down Expand Up @@ -304,7 +307,7 @@ def logic(*args, **kw):
return f(**params)
return logic

def route(self, rule: str, **options: t.Any) -> t.Callable[[F], F]:
def route(self, rule: str, authentication: bool = False, **options: t.Any) -> t.Callable[[F], F]:
""" Custom extension of Flask default routing / rule creation
This decorator extract function arguments and details and
stores it the blueprint class (the dict "routes")
Expand All @@ -313,6 +316,13 @@ def route(self, rule: str, **options: t.Any) -> t.Callable[[F], F]:
"""

methods = options.get('methods', ['GET'])
jwt_kwargs = {}
for k in self._jwt_required_kwargs:
try:
v = options.pop(k)
jwt_kwargs[k] = v
except:
continue

def decorator(f: F) -> F:
""" Custom extension """
Expand All @@ -327,10 +337,7 @@ def decorator(f: F) -> F:
for v in self.REGEX_RULE_PATH_PARAM.findall(rule):
path_parameters.append(v)

# print('*'* 80)
# print(f.__name__)
sig = inspect.signature(f)
# print(sig.return_annotation)
for arg, par in sig.parameters.items():

# if NoneType in par.annotation -> optional
Expand Down Expand Up @@ -359,6 +366,12 @@ def decorator(f: F) -> F:
self.routes[rule]['path'] = re.sub(r'<(?P<converter>[a-zA-Z_][a-zA-Z0-9_]*\:).*?>', '', rule).replace('<', '{').replace('>', '}')
self.routes[rule]['responses'] = sig.return_annotation

# Check if we need authentication
if authentication:
self.routes[rule]['security'] = [{'BearerAuth': []}]
_wrapper = jwtx.jwt_required(**jwt_kwargs)
f = _wrapper(f)

# Final step: apply @query_params to all routing funtions
# This way, every function has request arguments handled
# as keyword arguments
Expand Down Expand Up @@ -409,6 +422,11 @@ def schema() -> dict:
}
open_api['components'] = Schema.provide_types()
open_api['components']['parameters'] = Schema.provide_queryable_predicates()
open_api['components']['securitySchemes'] = {'BearerAuth': {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT'}
}
open_api['paths'] = {}

query_params_references = [{'$ref': '#/components/parameters/' + k} for k in open_api['components']['parameters'].keys()]
Expand Down Expand Up @@ -508,6 +526,10 @@ def schema() -> dict:
post_params = [details['parameters'][p] for p in details['request_body_params']]
content = api.annotation_to_request_params(post_params)
open_api['paths'][path][method.lower()]['requestBody']['content'] = content

if 'security' in details:
open_api['paths'][path][method.lower()]['security'] = details['security']

# for post_param in details['request_body_params']:
# post_param_type = PATH_TYPES[details['parameters'][post_param].annotation]
# p_param_val_pair = {post_param: {'type': post_param_type}}
Expand Down Expand Up @@ -1287,17 +1309,96 @@ def leave_comment(uid: str, message: str) -> SuccessfulAPIOperation:
""" User Related """

from flaskinventory.api.responses import LoginToken
import flask_jwt_extended as jwtx


@api.route('/user/login', methods=['POST'])
def login(email: str, password: str) -> LoginToken:
""" login to account, get a token back """
return api.abort(501)
"""
login to account, get a session cookie back
This route provides a JWT as a session cookie. It is
recommended method for the login routine, because
it allows Meteor to refresh tokens automatically.
"""

user = User.login(email, password)
if not user:
return api.abort(401, message="Wrong credentials. Make sure you have an account.")

response = jsonify({"message": "login successful"})
access_token = jwtx.create_access_token(identity=user)
jwtx.set_access_cookies(response, access_token)
return response

@api.route('/user/login/token', methods=['POST'])
def login_token(email: str, password: str) -> LoginToken:
""" login to account, get a JWT token back """

user = User.login(email, password)
if not user:
return api.abort(401, message="Wrong credentials. Make sure you have an account.")

access_token = jwtx.create_access_token(identity=user)
refresh_token = jwtx.create_refresh_token(identity=user)
return jsonify(access_token=access_token,
refresh_token=refresh_token,
status=200)

@api.route('/user/is_logged_in', authentication=True, optional=True)
def is_logged_in() -> SuccessfulAPIOperation:
if jwtx.get_jwt_identity():
return jsonify({'status': 200, 'message': 'Logged in',
'is_logged_in': True})
else:
return jsonify({'status': 200, 'message': 'Not logged in',
'is_logged_in': False})


@api.route('/user/login/refresh', methods=['POST'], authentication=True, refresh=True)
def refresh_token() -> LoginToken:
""" login to account, get a JWT token back """

identity = jwtx.get_jwt_identity()
access_token = jwtx.create_access_token(identity=identity)
return jsonify(access_token=access_token, status=200)

@api.after_app_request
def refresh_expiring_jwts(response):
""" Automatically handle refreshing of JWT (stored as session cookies) """
try:
exp_timestamp = jwtx.get_jwt()["exp"]
now = datetime.datetime.now(datetime.timezone.utc)
target_timestamp = datetime.datetime.timestamp(now + datetime.timedelta(minutes=30))
if target_timestamp > exp_timestamp:
access_token = jwtx.create_access_token(identity=jwtx.get_jwt_identity())
jwtx.set_access_cookies(response, access_token)
current_app.logger.debug(f'Refreshed Token for user <{jwtx.get_jwt_identity()}>')
return response
except (RuntimeError, KeyError):
# Case where there is not a valid JWT. Just return the original response
return response

@api.route('/user/logout')

@api.route('/user/logout', authentication=True, verify_type=False)
def logout() -> SuccessfulAPIOperation:
""" logout; invalidates token """
return api.abort(501)
"""
logout; invalidates session cookie and revokes current JWT
Ensure to call the API route twice to also invalidate the refresh token!
Implementation Details see: https://flask-jwt-extended.readthedocs.io/en/stable/blocklist_and_token_revoking.html
"""

response = jsonify({"message": "logout successful"})
token = jwtx.get_jwt()
dgraph.mutation({'uid': '_:jwt',
'dgraph.type': '_JWT',
'_jti': token["jti"],
'_token_type': token['type'],
'_revoked_timestamp': datetime.datetime.now().isoformat()})
jwtx.unset_jwt_cookies(response)
return response


@api.route('/user/register', methods=['POST'])
Expand Down Expand Up @@ -1332,18 +1433,26 @@ def change_password(old_pw: str, new_pw: str) -> SuccessfulAPIOperation:
return api.abort(501)


@api.route('/user/profile')
@api.route('/user/profile', authentication=True)
def profile() -> User:
""" view current user's profile """
return api.abort(501)
return jsonify(jwtx.current_user.json)

@api.route('/user/profile/update', methods=['POST'])
def update_profile(display_name: str,
affiliation: str,
orcid: str,
notifications: bool) -> SuccessfulAPIOperation:
@api.route('/user/profile/update', methods=['POST'], authentication=True)
def update_profile(display_name: str = None,
affiliation: str = None,
orcid: str = None,
notifications: bool = None) -> SuccessfulAPIOperation:
""" update current user's profile """
return api.abort(501)
try:
jwtx.current_user.update_profile({'display_name': display_name,
'affiliation': affiliation,
'orcid': orcid,
'notifications': notifications})
return jsonify({'status': 200,
'message': 'Profile updated'})
except Exception as e:
return api.abort(500, message=f'{e}')


@api.route('/user/profile/delete', methods=['POST'])
Expand Down
2 changes: 2 additions & 0 deletions flaskinventory/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Config:
TESTING = False
SECRET_KEY = os.environ.get(
'flaskinventory_SECRETKEY', secrets.token_hex(32))
JWT_SECRET_KEY = SECRET_KEY
JWT_TOKEN_LOCATION = ["headers", "cookies"]
DEBUG_MODE = os.environ.get('DEBUG_MODE', False)
MAIL_SERVER = os.environ.get('EMAIL_SERVER', 'localhost')
MAIL_PORT = os.environ.get('EMAIL_PORT', 25)
Expand Down
2 changes: 2 additions & 0 deletions flaskinventory/flaskdgraph/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def provide_types(cls) -> Iterable[dict]:
for t in Schema.__types_meta__:
# if Schema.__types_meta__[t]['private']:
# continue
if t.startswith('_'):
continue
schemas[t] = {'type': 'object',
'x-private': Schema.__types_meta__[t]['private'],
'properties': {}}
Expand Down
13 changes: 13 additions & 0 deletions flaskinventory/main/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,3 +1439,16 @@ class Rejected(Schema):
_former_types = ListString(edit=False)

entry_review_status = String(default="rejected")

class _JWT(Schema):

""" Block List of revoked tokens """

__permission_new__ = 99
__permission_edit__ = 99
__private__ = True

uid = UIDPredicate()
_jti = String(directives=['@index(hash)'])
_token_type = String(directives=['@index(hash)'])
_revoked_timestamp = DateTime()
34 changes: 34 additions & 0 deletions flaskinventory/users/authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import flask_jwt_extended as jwtx
from flask import current_app
from flaskinventory.users.dgraph import UserLogin
from flaskinventory.main.model import User
from flaskinventory import dgraph

jwt = jwtx.JWTManager()


# Register a callback function that takes whatever object is passed in as the
# identity when creating JWTs and converts it to a JSON serializable format.
@jwt.user_identity_loader
def user_identity_lookup(user):
try:
return user.id
except:
return user


# Register a callback function that loads a user from your database whenever
# a protected route is accessed. This should return any python object on a
# successful lookup, or None if the lookup failed for any reason (for example
# if the user has been deleted from the database).
@jwt.user_lookup_loader
def user_lookup_callback(_jwt_header, jwt_data):
identity = jwt_data["sub"]
user = User(uid=identity)
return user

@jwt.token_in_blocklist_loader
def check_if_token_is_revoked(jwt_header, jwt_payload: dict):
jti = jwt_payload["jti"]
token_in_redis = dgraph.get_uid(field="_jti", value=jti)
return token_in_redis is not None
11 changes: 9 additions & 2 deletions flaskinventory/users/dgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def generate_random_username() -> str:
class UserLogin(UserMixin):

id = None
json = {}

# Need more consistent ORM syntax
# Currently load users like:
Expand All @@ -39,14 +40,15 @@ def __repr__(self):
def get_user(self, **kwargs):
user_data = self.get_user_data(**kwargs)
if user_data:
self.json = user_data
for k, v in user_data.items():
if k == 'uid':
self.id = v
if '|' in k:
k = k.replace('|', '_')
setattr(self, k, v)
# Overwrite DGraph Predicates
# Maye find a more elegant solution later
# Maybe find a more elegant solution later
for attr in dir(self):
if isinstance(getattr(self, attr), _PrimitivePredicate):
setattr(self, attr, None)
Expand Down Expand Up @@ -145,9 +147,14 @@ def verify_email_token(cls, token: str) -> Union[bool, Any]:

def update_profile(self, form_data: dict) -> bool:
user_data = {}
for k, v in form_data.data.items():
try:
data = form_data.data
except:
data = form_data
for k, v in data.items():
if k in ['submit', 'csrf_token']:
continue
if v is None: continue
else:
user_data[k] = v
result = dgraph.update_entry(user_data, uid=self.id)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ pandas
openpyxl
pyarrow
tqdm
thefuzz
thefuzz
flask-jwt-extended

0 comments on commit 8cca4c9

Please sign in to comment.