Skip to content

Commit c1d3bed

Browse files
authored
users: add case-insensitive index to maintain backwards compatibility with fastapi-users (#1319)
follow up to #1290 Based on implementation in: https://github.com/fastapi-users/fastapi-users-db-mongodb/blob/main/fastapi_users_db_mongodb/__init__.py
1 parent 3c884f9 commit c1d3bed

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

backend/btrixcloud/users.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020

2121
from pymongo.errors import DuplicateKeyError
22+
from pymongo.collation import Collation
2223

2324
from .models import (
2425
UserCreate,
@@ -65,6 +66,8 @@ def __init__(self, mdb, email, invites):
6566
self.invites = invites
6667
self.org_ops = None
6768

69+
self.email_collation = Collation("en", strength=2)
70+
6871
self.registration_enabled = is_bool(os.environ.get("REGISTRATION_ENABLED"))
6972

7073
# pylint: disable=attribute-defined-outside-init
@@ -78,6 +81,13 @@ async def init_index(self):
7881
"""init lookup index"""
7982
await self.users.create_index("id", unique=True)
8083
await self.users.create_index("email", unique=True)
84+
85+
await self.users.create_index(
86+
"email",
87+
name="case_insensitive_email_index",
88+
collation=self.email_collation,
89+
)
90+
8191
# Expire failed logins object after one hour
8292
await self.failed_logins.create_index("attempted", expireAfterSeconds=3600)
8393

@@ -379,7 +389,9 @@ async def get_by_id(self, _id: uuid.UUID) -> Optional[User]:
379389

380390
async def get_by_email(self, email: str) -> Optional[User]:
381391
"""get user by email"""
382-
user = await self.users.find_one({"email": email})
392+
user = await self.users.find_one(
393+
{"email": email}, collation=self.email_collation
394+
)
383395
if not user:
384396
return None
385397

@@ -535,7 +547,9 @@ async def _update_password(self, user: User, new_password: str) -> None:
535547

536548
async def reset_failed_logins(self, email: str) -> None:
537549
"""Reset consecutive failed login attempts by deleting FailedLogin object"""
538-
await self.failed_logins.delete_one({"email": email})
550+
await self.failed_logins.delete_one(
551+
{"email": email}, collation=self.email_collation
552+
)
539553

540554
async def inc_failed_logins(self, email: str) -> None:
541555
"""Inc consecutive failed login attempts for user by 1
@@ -552,11 +566,14 @@ async def inc_failed_logins(self, email: str) -> None:
552566
"$inc": {"count": 1},
553567
},
554568
upsert=True,
569+
collation=self.email_collation,
555570
)
556571

557572
async def get_failed_logins_count(self, email: str) -> int:
558573
"""Get failed login attempts for user, falling back to 0"""
559-
failed_login = await self.failed_logins.find_one({"email": email})
574+
failed_login = await self.failed_logins.find_one(
575+
{"email": email}, collation=self.email_collation
576+
)
560577
if not failed_login:
561578
return 0
562579
return failed_login.get("count", 0)

backend/test/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
VIEWER_USERNAME = "[email protected]"
1616
VIEWER_PW = "viewerPASSW0RD!"
1717

18-
CRAWLER_USERNAME = "[email protected]"
18+
CRAWLER_USERNAME = "[email protected]"
19+
CRAWLER_USERNAME_LOWERCASE = "[email protected]"
1920
CRAWLER_PW = "crawlerPASSWORD!"
2021

2122
_admin_config_id = None

backend/test/test_users.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from .conftest import (
55
API_PREFIX,
66
CRAWLER_USERNAME,
7+
CRAWLER_USERNAME_LOWERCASE,
8+
CRAWLER_PW,
79
ADMIN_PW,
810
ADMIN_USERNAME,
911
FINISHED_STATES,
@@ -14,7 +16,6 @@
1416
VALID_USER_PW_RESET = "new!password"
1517
VALID_USER_PW_RESET_AGAIN = "new!password1"
1618

17-
1819
my_id = None
1920
valid_user_headers = None
2021

@@ -71,6 +72,20 @@ def test_me_id(admin_auth_headers, default_org_id):
7172
assert r.status_code == 404
7273

7374

75+
def test_login_case_insensitive_email():
76+
r = requests.post(
77+
f"{API_PREFIX}/auth/jwt/login",
78+
data={
79+
"username": CRAWLER_USERNAME_LOWERCASE,
80+
"password": CRAWLER_PW,
81+
"grant_type": "password",
82+
},
83+
)
84+
data = r.json()
85+
assert r.status_code == 200
86+
assert data["access_token"]
87+
88+
7489
def test_add_user_to_org_invalid_password(admin_auth_headers, default_org_id):
7590
r = requests.post(
7691
f"{API_PREFIX}/orgs/{default_org_id}/add-user",

0 commit comments

Comments
 (0)