Skip to content

Commit 30a7fa4

Browse files
committed
Fix user creation
1 parent ed0a71d commit 30a7fa4

File tree

2 files changed

+31
-36
lines changed

2 files changed

+31
-36
lines changed

gefapi/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767

6868
from gefapi.services import UserService # noqa:E402
69+
from gefapi.models import User # noqa:E402
6970

7071

7172
@app.route("/auth", methods=["POST"])
@@ -83,6 +84,12 @@ def create_token():
8384
return jsonify({"access_token": access_token, "user_id": user.id})
8485

8586

87+
@jwt.user_lookup_loader
88+
def user_lookup_callback(_jwt_header, jwt_data):
89+
identity = jwt_data["sub"]
90+
return User.query.filter_by(id=identity).one_or_none()
91+
92+
8693
@app.errorhandler(403)
8794
def forbidden(e):
8895
return error(status=403, detail="Forbidden")

gefapi/routes/api/v1/gef_api_router.py

+24-36
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import dateutil.parser
88
from flask import Response, json, jsonify, request, send_from_directory
9-
from flask_jwt_extended import get_jwt_identity, jwt_required
9+
from flask_jwt_extended import jwt_required, current_user
1010

1111
from gefapi.errors import (
1212
EmailError,
@@ -33,18 +33,6 @@
3333
logger = logging.getLogger()
3434

3535

36-
@jwt_required()
37-
def get_identity():
38-
user = None
39-
try:
40-
id = get_jwt_identity()
41-
user = UserService.get_user(id)
42-
except Exception as e:
43-
logger.error(str(e))
44-
logger.error("[JWT]: Error getting user for %s" % (id))
45-
return user
46-
47-
4836
# SCRIPT CREATION
4937
@endpoints.route("/script", strict_slashes=False, methods=["POST"])
5038
@jwt_required()
@@ -57,7 +45,7 @@ def create_script():
5745
sent_file = request.files.get("file")
5846
if sent_file.filename == "":
5947
sent_file.filename = "script"
60-
user = get_identity()
48+
user = current_user
6149
try:
6250
user = ScriptService.create_script(sent_file, user)
6351
except InvalidFile as e:
@@ -80,7 +68,7 @@ def get_scripts():
8068
include = request.args.get("include")
8169
include = include.split(",") if include else []
8270
try:
83-
scripts = ScriptService.get_scripts(get_identity())
71+
scripts = ScriptService.get_scripts(current_user)
8472
except Exception as e:
8573
logger.error("[ROUTER]: " + str(e))
8674
return error(status=500, detail="Generic Error")
@@ -95,7 +83,7 @@ def get_script(script):
9583
include = request.args.get("include")
9684
include = include.split(",") if include else []
9785
try:
98-
script = ScriptService.get_script(script, get_identity())
86+
script = ScriptService.get_script(script, current_user)
9987
except ScriptNotFound as e:
10088
logger.error("[ROUTER]: " + e.message)
10189
return error(status=404, detail=e.message)
@@ -111,7 +99,7 @@ def publish_script(script):
11199
"""Publish a script"""
112100
logger.info("[ROUTER]: Publishing script " + script)
113101
try:
114-
script = ScriptService.publish_script(script, get_identity())
102+
script = ScriptService.publish_script(script, current_user)
115103
except ScriptNotFound as e:
116104
logger.error("[ROUTER]: " + e.message)
117105
return error(status=404, detail=e.message)
@@ -127,7 +115,7 @@ def unpublish_script(script):
127115
"""Unpublish a script"""
128116
logger.info("[ROUTER]: Unpublishsing script " + script)
129117
try:
130-
script = ScriptService.unpublish_script(script, get_identity())
118+
script = ScriptService.unpublish_script(script, current_user)
131119
except ScriptNotFound as e:
132120
logger.error("[ROUTER]: " + e.message)
133121
return error(status=404, detail=e.message)
@@ -143,7 +131,7 @@ def download_script(script):
143131
"""Download a script"""
144132
logger.info("[ROUTER]: Download script " + script)
145133
try:
146-
script = ScriptService.get_script(script, get_identity())
134+
script = ScriptService.get_script(script, current_user)
147135

148136
temp_dir = tempfile.TemporaryDirectory().name
149137
script_file = script.slug + ".tar.gz"
@@ -191,7 +179,7 @@ def update_script(script):
191179
sent_file = request.files.get("file")
192180
if sent_file.filename == "":
193181
sent_file.filename = "script"
194-
user = get_identity()
182+
user = current_user
195183
# if user.role != 'ADMIN' and user.email != '[email protected]':
196184
# return error(status=403, detail='Forbidden')
197185
try:
@@ -216,7 +204,7 @@ def update_script(script):
216204
def delete_script(script):
217205
"""Delete a script"""
218206
logger.info("[ROUTER]: Deleting script: " + script)
219-
identity = get_identity()
207+
identity = current_user
220208
if identity.role != "ADMIN" and identity.email != "[email protected]":
221209
return error(status=403, detail="Forbidden")
222210
try:
@@ -236,7 +224,7 @@ def delete_script(script):
236224
def run_script(script):
237225
"""Run a script"""
238226
logger.info("[ROUTER]: Running script: " + script)
239-
user = get_identity()
227+
user = current_user
240228
try:
241229
params = request.args.to_dict() if request.args else {}
242230
if request.get_json(silent=True):
@@ -270,9 +258,7 @@ def get_executions():
270258
exclude = request.args.get("exclude")
271259
exclude = exclude.split(",") if exclude else []
272260
try:
273-
executions = ExecutionService.get_executions(
274-
get_identity(), user_id, updated_at
275-
)
261+
executions = ExecutionService.get_executions(current_user, user_id, updated_at)
276262
except Exception as e:
277263
logger.error("[ROUTER]: " + str(e))
278264
return error(status=500, detail="Generic Error")
@@ -291,7 +277,7 @@ def get_execution(execution):
291277
exclude = request.args.get("exclude")
292278
exclude = exclude.split(",") if exclude else []
293279
try:
294-
execution = ExecutionService.get_execution(execution, get_identity())
280+
execution = ExecutionService.get_execution(execution, current_user)
295281
except ExecutionNotFound as e:
296282
logger.error("[ROUTER]: " + e.message)
297283
return error(status=404, detail=e.message)
@@ -308,7 +294,7 @@ def update_execution(execution):
308294
"""Update an execution"""
309295
logger.info("[ROUTER]: Updating execution " + execution)
310296
body = request.get_json()
311-
user = get_identity()
297+
user = current_user
312298
if user.role != "ADMIN" and user.email != "[email protected]":
313299
return error(status=403, detail="Forbidden")
314300
try:
@@ -367,7 +353,7 @@ def create_execution_log(execution):
367353
"""Create log of an execution"""
368354
logger.info("[ROUTER]: Creating execution log for " + execution)
369355
body = request.get_json()
370-
user = get_identity()
356+
user = current_user
371357
if user.role != "ADMIN" and user.email != "[email protected]":
372358
return error(status=403, detail="Forbidden")
373359
try:
@@ -389,13 +375,15 @@ def create_user():
389375
logger.info("[ROUTER]: Creating user")
390376
body = request.get_json()
391377
if request.headers.get("Authorization", None) is not None:
378+
logger.debug("[ROUTER]: Authorization header found")
392379

393380
@jwt_required()
394381
def identity():
395382
pass
396383

397384
identity()
398-
identity = get_identity()
385+
logger.debug("[ROUTER]: Getting identity")
386+
identity = current_user
399387
if identity:
400388
user_role = body.get("role", "USER")
401389
if identity.role == "USER" and user_role == "ADMIN":
@@ -420,7 +408,7 @@ def get_users():
420408
logger.info("[ROUTER]: Getting all users")
421409
include = request.args.get("include")
422410
include = include.split(",") if include else []
423-
identity = get_identity()
411+
identity = current_user
424412
if identity.role != "ADMIN" and identity.email != "[email protected]":
425413
return error(status=403, detail="Forbidden")
426414
try:
@@ -438,7 +426,7 @@ def get_user(user):
438426
logger.info("[ROUTER]: Getting user" + user)
439427
include = request.args.get("include")
440428
include = include.split(",") if include else []
441-
identity = get_identity()
429+
identity = current_user
442430
if identity.role != "ADMIN" and identity.email != "[email protected]":
443431
return error(status=403, detail="Forbidden")
444432
try:
@@ -457,7 +445,7 @@ def get_user(user):
457445
def get_me():
458446
"""Get me"""
459447
logger.info("[ROUTER]: Getting my user")
460-
user = get_identity()
448+
user = current_user
461449
return jsonify(data=user.serialize()), 200
462450

463451

@@ -467,7 +455,7 @@ def update_profile():
467455
"""Update an user"""
468456
logger.info("[ROUTER]: Updating profile")
469457
body = request.get_json()
470-
identity = get_identity()
458+
identity = current_user
471459
try:
472460
password = body.get("password", None)
473461
repeat_password = body.get("repeatPassword", None)
@@ -501,7 +489,7 @@ def update_profile():
501489
def delete_profile():
502490
"""Delete Me"""
503491
logger.info("[ROUTER]: Delete me")
504-
identity = get_identity()
492+
identity = current_user
505493
try:
506494
user = UserService.delete_user(str(identity.id))
507495
except UserNotFound as e:
@@ -540,7 +528,7 @@ def update_user(user):
540528
"""Update an user"""
541529
logger.info("[ROUTER]: Updating user" + user)
542530
body = request.get_json()
543-
identity = get_identity()
531+
identity = current_user
544532
if identity.role != "ADMIN" and identity.email != "[email protected]":
545533
return error(status=403, detail="Forbidden")
546534
try:
@@ -559,7 +547,7 @@ def update_user(user):
559547
def delete_user(user):
560548
"""Delete an user"""
561549
logger.info("[ROUTER]: Deleting user" + user)
562-
identity = get_identity()
550+
identity = current_user
563551
if user == "[email protected]":
564552
return error(status=403, detail="Forbidden")
565553
if identity.role != "ADMIN" and identity.email != "[email protected]":

0 commit comments

Comments
 (0)