Skip to content

Commit

Permalink
Merge pull request #20 from shenwpo/master
Browse files Browse the repository at this point in the history
add init_app method
  • Loading branch information
jessecooper authored Mar 8, 2021
2 parents a6abbc9 + 80504b0 commit ed96f06
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 4 deletions.
16 changes: 12 additions & 4 deletions flask_authz/casbin_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@ class CasbinEnforcer:

e = None

def __init__(self, app, adapter, watcher=None):
def __init__(self, app=None, adapter=None, watcher=None):
"""
Args:
app (object): Flask App object to get Casbin Model
adapter (object): Casbin Adapter
"""
self.app = app
self.adapter = adapter
self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter)
if watcher:
self.e.set_watcher(watcher)
self.e = None
self.watcher = watcher
self._owner_loader = None
self.user_name_headers = None
if self.app is not None:
self.init_app(self.app)

def init_app(self, app):
self.app = app
self.e = casbin.Enforcer(app.config.get("CASBIN_MODEL"), self.adapter)
if self.watcher:
self.e.set_watcher(self.watcher)
self.user_name_headers = app.config.get("CASBIN_USER_NAME_HEADERS", None)

def set_watcher(self, watcher):
Expand Down
207 changes: 207 additions & 0 deletions tests/test_casbin_enforcer_init_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import pytest
from casbin.enforcer import Enforcer
from flask import request, jsonify
from casbin_sqlalchemy_adapter import Adapter
from casbin_sqlalchemy_adapter import Base
from casbin_sqlalchemy_adapter import CasbinRule
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from flask_authz import CasbinEnforcer


def enforcer_partial():
engine = create_engine("sqlite://")
adapter = Adapter(engine)

session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
s = session()
s.query(CasbinRule).delete()
s.add(CasbinRule(ptype="p", v0="alice", v1="/item", v2="GET"))
s.add(CasbinRule(ptype="p", v0="bob", v1="/item", v2="GET"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="POST"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="DELETE"))
s.add(CasbinRule(ptype="p", v0="data2_admin", v1="/item", v2="GET"))
s.add(CasbinRule(ptype="g", v0="alice", v1="data2_admin"))
s.add(CasbinRule(ptype="g", v0="users", v1="data2_admin"))
s.commit()
s.close()

return CasbinEnforcer(adapter=adapter)


@pytest.fixture
def enforcer(app_fixture):
e = enforcer_partial()
e.init_app(app_fixture)
yield e


@pytest.fixture
def watcher():
class SomeWatcher:
def should_reload(self):
return True

def update_callback(self):
pass

yield SomeWatcher


@pytest.mark.parametrize(
"header, user, method, status, user_name",
[
("X-User", "alice", "GET", 200, "X-User"),
("X-USER", "alice", "GET", 200, "x-user"),
("x-user", "alice", "GET", 200, "X-USER"),
("X-User", "alice", "GET", 200, "X-USER"),
("X-User", "alice", "GET", 200, "X-Not-A-Header"),
("X-User", "alice", "POST", 201, None),
("X-User", "alice", "DELETE", 202, None),
("X-User", "bob", "GET", 200, None),
("X-User", "bob", "POST", 401, None),
("X-User", "bob", "DELETE", 401, None),
("X-Idp-Groups", "admin", "GET", 401, "X-User"),
("X-Idp-Groups", "users", "GET", 200, None),
("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None),
("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None),
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"),
("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401, None),
],
)
def test_enforcer(app_fixture, enforcer, header, user, method, status, user_name):
# enable auditing with user name
if user_name:
enforcer.user_name_headers = {user_name}

@app_fixture.route("/")
@enforcer.enforcer
def index():
return jsonify({"message": "passed"}), 200

@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
@enforcer.enforcer
def item():
if request.method == "GET":
return jsonify({"message": "passed"}), 200
elif request.method == "POST":
return jsonify({"message": "passed"}), 201
elif request.method == "DELETE":
return jsonify({"message": "passed"}), 202

headers = {header: user}
c = app_fixture.test_client()
# c.post('/add', data=dict(title='2nd Item', text='The text'))
rv = c.get("/")
assert rv.status_code == 401
caller = getattr(c, method.lower())
rv = caller("/item", headers=headers)
assert rv.status_code == status


@pytest.mark.parametrize(
"header, user, method, status",
[
("X-User", "alice", "GET", 200),
("X-User", "alice", "POST", 201),
("X-User", "alice", "DELETE", 202),
("X-User", "bob", "GET", 200),
("X-User", "bob", "POST", 401),
("X-User", "bob", "DELETE", 401),
("X-Idp-Groups", "admin", "GET", 401),
("X-Idp-Groups", "users", "GET", 200),
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200),
("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401),
],
)
def test_enforcer_with_watcher(
app_fixture, enforcer, header, user, method, status, watcher
):
enforcer.set_watcher(watcher())

@app_fixture.route("/")
@enforcer.enforcer
def index():
return jsonify({"message": "passed"}), 200

@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
@enforcer.enforcer
def item():
if request.method == "GET":
return jsonify({"message": "passed"}), 200
elif request.method == "POST":
return jsonify({"message": "passed"}), 201
elif request.method == "DELETE":
return jsonify({"message": "passed"}), 202

headers = {header: user}
c = app_fixture.test_client()
# c.post('/add', data=dict(title='2nd Item', text='The text'))
rv = c.get("/")
assert rv.status_code == 401
caller = getattr(c, method.lower())
rv = caller("/item", headers=headers)
assert rv.status_code == status


def test_manager(app_fixture, enforcer):
@app_fixture.route("/manager", methods=["POST"])
@enforcer.manager
def manager(manager):
assert isinstance(manager, Enforcer)
return jsonify({"message": "passed"}), 201

c = app_fixture.test_client()
c.post("/manager")


def test_enforcer_set_watcher(enforcer, watcher):
assert enforcer.e.watcher is None
enforcer.set_watcher(watcher())
assert isinstance(enforcer.e.watcher, watcher)


@pytest.mark.parametrize(
"owner, method, status",
[
(["alice"], "GET", 200),
(["alice"], "POST", 201),
(["alice"], "DELETE", 202),
(["bob"], "GET", 200),
(["bob"], "POST", 401),
(["bob"], "DELETE", 401),
(["admin"], "GET", 401),
(["users"], "GET", 200),
(["alice", "bob"], "POST", 201),
(["noexist", "testnoexist"], "POST", 401),
],
)
def test_enforcer_with_owner_loader(app_fixture, enforcer, owner, method, status):
@app_fixture.route("/")
@enforcer.enforcer
def index():
return jsonify({"message": "passed"}), 200

@app_fixture.route("/item", methods=["GET", "POST", "DELETE"])
@enforcer.enforcer
def item():
if request.method == "GET":
return jsonify({"message": "passed"}), 200
elif request.method == "POST":
return jsonify({"message": "passed"}), 201
elif request.method == "DELETE":
return jsonify({"message": "passed"}), 202

@enforcer.owner_loader
def owner_loader():
return owner

c = app_fixture.test_client()
# c.post('/add', data=dict(title='2nd Item', text='The text'))
rv = c.get("/")
assert rv.status_code == 401
caller = getattr(c, method.lower())
rv = caller("/item")
assert rv.status_code == status

0 comments on commit ed96f06

Please sign in to comment.