diff --git a/flask_authz/casbin_enforcer.py b/flask_authz/casbin_enforcer.py index c807574..0386ca3 100644 --- a/flask_authz/casbin_enforcer.py +++ b/flask_authz/casbin_enforcer.py @@ -16,7 +16,7 @@ class CasbinEnforcer: e = None - def __init__(self, app, adapter, watcher=None): + def __init__(self, app=None, adapter=None, watcher=None): """ Args: app (object): Flask App object to get Casbin Model @@ -24,10 +24,18 @@ def __init__(self, app, adapter, watcher=None): """ self.app = app self.adapter = adapter - self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter) - if watcher: - self.e.set_watcher(watcher) + self.e = None + self.watcher = watcher self._owner_loader = None + self.user_name_headers = None + if self.app is not None: + self.init_app(self.app) + + def init_app(self, app): + self.app = app + self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter) + if self.watcher: + self.e.set_watcher(self.watcher) self.user_name_headers = app.config.get("CASBIN_USER_NAME_HEADERS", None) def set_watcher(self, watcher): diff --git a/tests/test_casbin_enforcer_init_app.py b/tests/test_casbin_enforcer_init_app.py new file mode 100644 index 0000000..321f2c9 --- /dev/null +++ b/tests/test_casbin_enforcer_init_app.py @@ -0,0 +1,207 @@ +import pytest +from casbin.enforcer import Enforcer +from flask import request, jsonify +from casbin_sqlalchemy_adapter import Adapter +from casbin_sqlalchemy_adapter import Base +from casbin_sqlalchemy_adapter import CasbinRule +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from flask_authz import CasbinEnforcer + + +def enforcer_partial(): + engine = create_engine("sqlite://") + adapter = Adapter(engine) + + session = sessionmaker(bind=engine) + Base.metadata.create_all(engine) + s = session() + s.query(CasbinRule).delete() + s.add(CasbinRule(ptype="p", v0="alice", v1="/item", v2="GET")) + s.add(CasbinRule(ptype="p", v0="bob", v1="/item", v2="GET")) + s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="POST")) + s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="DELETE")) + s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="GET")) + s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin")) + s.add(CasbinRule(ptype="g", v0="users", v1="data2_admin")) + s.commit() + s.close() + + return CasbinEnforcer(adapter=adapter) + + +@pytest.fixture +def enforcer(app_fixture): + e = enforcer_partial() + e.init_app(app_fixture) + yield e + + +@pytest.fixture +def watcher(): + class SomeWatcher: + def should_reload(self): + return True + + def update_callback(self): + pass + + yield SomeWatcher + + +@pytest.mark.parametrize( + "header, user, method, status, user_name", + [ + ("X-User", "alice", "GET", 200, "X-User"), + ("X-USER", "alice", "GET", 200, "x-user"), + ("x-user", "alice", "GET", 200, "X-USER"), + ("X-User", "alice", "GET", 200, "X-USER"), + ("X-User", "alice", "GET", 200, "X-Not-A-Header"), + ("X-User", "alice", "POST", 201, None), + ("X-User", "alice", "DELETE", 202, None), + ("X-User", "bob", "GET", 200, None), + ("X-User", "bob", "POST", 401, None), + ("X-User", "bob", "DELETE", 401, None), + ("X-Idp-Groups", "admin", "GET", 401, "X-User"), + ("X-Idp-Groups", "users", "GET", 200, None), + ("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None), + ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None), + ("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None), + ("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"), + ("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401, None), + ], +) +def test_enforcer(app_fixture, enforcer, header, user, method, status, user_name): + # enable auditing with user name + if user_name: + enforcer.user_name_headers = {user_name} + + @app_fixture.route("/") + @enforcer.enforcer + def index(): + return jsonify({"message": "passed"}), 200 + + @app_fixture.route("/item", methods=["GET", "POST", "DELETE"]) + @enforcer.enforcer + def item(): + if request.method == "GET": + return jsonify({"message": "passed"}), 200 + elif request.method == "POST": + return jsonify({"message": "passed"}), 201 + elif request.method == "DELETE": + return jsonify({"message": "passed"}), 202 + + headers = {header: user} + c = app_fixture.test_client() + # c.post('/add', data=dict(title='2nd Item', text='The text')) + rv = c.get("/") + assert rv.status_code == 401 + caller = getattr(c, method.lower()) + rv = caller("/item", headers=headers) + assert rv.status_code == status + + +@pytest.mark.parametrize( + "header, user, method, status", + [ + ("X-User", "alice", "GET", 200), + ("X-User", "alice", "POST", 201), + ("X-User", "alice", "DELETE", 202), + ("X-User", "bob", "GET", 200), + ("X-User", "bob", "POST", 401), + ("X-User", "bob", "DELETE", 401), + ("X-Idp-Groups", "admin", "GET", 401), + ("X-Idp-Groups", "users", "GET", 200), + ("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200), + ("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401), + ], +) +def test_enforcer_with_watcher( + app_fixture, enforcer, header, user, method, status, watcher +): + enforcer.set_watcher(watcher()) + + @app_fixture.route("/") + @enforcer.enforcer + def index(): + return jsonify({"message": "passed"}), 200 + + @app_fixture.route("/item", methods=["GET", "POST", "DELETE"]) + @enforcer.enforcer + def item(): + if request.method == "GET": + return jsonify({"message": "passed"}), 200 + elif request.method == "POST": + return jsonify({"message": "passed"}), 201 + elif request.method == "DELETE": + return jsonify({"message": "passed"}), 202 + + headers = {header: user} + c = app_fixture.test_client() + # c.post('/add', data=dict(title='2nd Item', text='The text')) + rv = c.get("/") + assert rv.status_code == 401 + caller = getattr(c, method.lower()) + rv = caller("/item", headers=headers) + assert rv.status_code == status + + +def test_manager(app_fixture, enforcer): + @app_fixture.route("/manager", methods=["POST"]) + @enforcer.manager + def manager(manager): + assert isinstance(manager, Enforcer) + return jsonify({"message": "passed"}), 201 + + c = app_fixture.test_client() + c.post("/manager") + + +def test_enforcer_set_watcher(enforcer, watcher): + assert enforcer.e.watcher is None + enforcer.set_watcher(watcher()) + assert isinstance(enforcer.e.watcher, watcher) + + +@pytest.mark.parametrize( + "owner, method, status", + [ + (["alice"], "GET", 200), + (["alice"], "POST", 201), + (["alice"], "DELETE", 202), + (["bob"], "GET", 200), + (["bob"], "POST", 401), + (["bob"], "DELETE", 401), + (["admin"], "GET", 401), + (["users"], "GET", 200), + (["alice", "bob"], "POST", 201), + (["noexist", "testnoexist"], "POST", 401), + ], +) +def test_enforcer_with_owner_loader(app_fixture, enforcer, owner, method, status): + @app_fixture.route("/") + @enforcer.enforcer + def index(): + return jsonify({"message": "passed"}), 200 + + @app_fixture.route("/item", methods=["GET", "POST", "DELETE"]) + @enforcer.enforcer + def item(): + if request.method == "GET": + return jsonify({"message": "passed"}), 200 + elif request.method == "POST": + return jsonify({"message": "passed"}), 201 + elif request.method == "DELETE": + return jsonify({"message": "passed"}), 202 + + @enforcer.owner_loader + def owner_loader(): + return owner + + c = app_fixture.test_client() + # c.post('/add', data=dict(title='2nd Item', text='The text')) + rv = c.get("/") + assert rv.status_code == 401 + caller = getattr(c, method.lower()) + rv = caller("/item") + assert rv.status_code == status