Skip to content

Commit 2f4258b

Browse files
committed
Implement access token strategy db adapter
1 parent d6d1b0f commit 2f4258b

File tree

7 files changed

+187
-11
lines changed

7 files changed

+187
-11
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ format: isort
88

99
test:
1010
docker stop $(MONGODB_CONTAINER_NAME) || true
11-
docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mongo:4.2
11+
docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mongo:4.4
1212
pytest --cov=fastapi_users_db_mongodb/ --cov-report=term-missing --cov-fail-under=100
1313
docker stop $(MONGODB_CONTAINER_NAME)
1414

fastapi_users_db_mongodb/__init__.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -29,31 +29,30 @@ def __init__(
2929
):
3030
super().__init__(user_db_model)
3131
self.collection = collection
32-
self.collection.create_index("id", unique=True)
33-
self.collection.create_index("email", unique=True)
32+
self.initialized = False
3433

3534
if email_collation:
3635
self.email_collation = email_collation # pragma: no cover
3736
else:
3837
self.email_collation = Collation("en", strength=2)
3938

40-
self.collection.create_index(
41-
"email",
42-
name="case_insensitive_email_index",
43-
collation=self.email_collation,
44-
)
45-
4639
async def get(self, id: UUID4) -> Optional[UD]:
40+
await self._initialize()
41+
4742
user = await self.collection.find_one({"id": id})
4843
return self.user_db_model(**user) if user else None
4944

5045
async def get_by_email(self, email: str) -> Optional[UD]:
46+
await self._initialize()
47+
5148
user = await self.collection.find_one(
5249
{"email": email}, collation=self.email_collation
5350
)
5451
return self.user_db_model(**user) if user else None
5552

5653
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
54+
await self._initialize()
55+
5756
user = await self.collection.find_one(
5857
{
5958
"oauth_accounts.oauth_name": oauth,
@@ -63,12 +62,29 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD
6362
return self.user_db_model(**user) if user else None
6463

6564
async def create(self, user: UD) -> UD:
65+
await self._initialize()
66+
6667
await self.collection.insert_one(user.dict())
6768
return user
6869

6970
async def update(self, user: UD) -> UD:
71+
await self._initialize()
72+
7073
await self.collection.replace_one({"id": user.id}, user.dict())
7174
return user
7275

7376
async def delete(self, user: UD) -> None:
77+
await self._initialize()
78+
7479
await self.collection.delete_one({"id": user.id})
80+
81+
async def _initialize(self):
82+
if not self.initialized:
83+
await self.collection.create_index("id", unique=True)
84+
await self.collection.create_index("email", unique=True)
85+
await self.collection.create_index(
86+
"email",
87+
name="case_insensitive_email_index",
88+
collation=self.email_collation,
89+
)
90+
self.initialized = True
+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from datetime import datetime
2+
from typing import Any, Dict, Generic, Optional, Type
3+
4+
from fastapi_users.authentication.strategy.db import A, AccessTokenDatabase
5+
from motor.motor_asyncio import AsyncIOMotorCollection
6+
7+
8+
class MongoDBAccessTokenDatabase(AccessTokenDatabase, Generic[A]):
9+
"""
10+
Access token database adapter for MongoDB.
11+
12+
:param access_token_model: Pydantic model of a DB representation of an access token.
13+
:param collection: Collection instance from `motor`.
14+
"""
15+
16+
collection: AsyncIOMotorCollection
17+
18+
def __init__(self, access_token_model: Type[A], collection: AsyncIOMotorCollection):
19+
self.access_token_model = access_token_model
20+
self.collection = collection
21+
self.initialized = False
22+
23+
async def get_by_token(
24+
self, token: str, max_age: Optional[datetime] = None
25+
) -> Optional[A]:
26+
await self._initialize()
27+
28+
query: Dict[str, Any] = {"token": token}
29+
if max_age is not None:
30+
query["created_at"] = {"$gte": max_age}
31+
32+
access_token = await self.collection.find_one(query)
33+
return self.access_token_model(**access_token) if access_token else None
34+
35+
async def create(self, access_token: A) -> A:
36+
await self._initialize()
37+
38+
await self.collection.insert_one(access_token.dict())
39+
return access_token
40+
41+
async def update(self, access_token: A) -> A:
42+
await self._initialize()
43+
44+
await self.collection.replace_one(
45+
{"token": access_token.token}, access_token.dict()
46+
)
47+
return access_token
48+
49+
async def delete(self, access_token: A) -> None:
50+
await self._initialize()
51+
52+
await self.collection.delete_one({"token": access_token.token})
53+
54+
async def _initialize(self):
55+
if not self.initialized:
56+
await self.collection.create_index("token", unique=True)
57+
self.initialized = True

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
description-file = "README.md"
2323
requires-python = ">=3.7"
2424
requires = [
25-
"fastapi-users >= 6.1.2",
25+
"fastapi-users >= 9.1.0",
2626
"motor >=2.5.1,<3.0.0"
2727
]
2828

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
fastapi-users >= 6.1.2
1+
fastapi-users >= 9.1.0
22
motor >=2.5.1,<3.0.0

tests/test_access_token.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import uuid
2+
from datetime import datetime, timedelta, timezone
3+
from typing import AsyncGenerator
4+
5+
import pymongo.errors
6+
import pytest
7+
from fastapi_users.authentication.strategy.db.models import BaseAccessToken
8+
from motor.motor_asyncio import AsyncIOMotorClient
9+
from pydantic import UUID4
10+
11+
from fastapi_users_db_mongodb.access_token import MongoDBAccessTokenDatabase
12+
13+
14+
class AccessToken(BaseAccessToken):
15+
pass
16+
17+
18+
@pytest.fixture
19+
def user_id() -> UUID4:
20+
return uuid.uuid4()
21+
22+
23+
@pytest.fixture(scope="module")
24+
async def mongodb_client():
25+
client = AsyncIOMotorClient(
26+
"mongodb://localhost:27017",
27+
serverSelectionTimeoutMS=10000,
28+
uuidRepresentation="standard",
29+
)
30+
31+
try:
32+
await client.server_info()
33+
yield client
34+
client.close()
35+
except pymongo.errors.ServerSelectionTimeoutError:
36+
pytest.skip("MongoDB not available", allow_module_level=True)
37+
return
38+
39+
40+
@pytest.fixture
41+
@pytest.mark.asyncio
42+
async def mongodb_access_token_db(
43+
mongodb_client: AsyncIOMotorClient,
44+
) -> AsyncGenerator[MongoDBAccessTokenDatabase, None]:
45+
db = mongodb_client["test_database_access_token"]
46+
collection = db["access_tokens"]
47+
48+
yield MongoDBAccessTokenDatabase(AccessToken, collection)
49+
50+
await collection.delete_many({})
51+
52+
53+
@pytest.mark.asyncio
54+
@pytest.mark.db
55+
async def test_queries(
56+
mongodb_access_token_db: MongoDBAccessTokenDatabase[AccessToken],
57+
user_id: UUID4,
58+
):
59+
access_token = AccessToken(token="TOKEN", user_id=user_id)
60+
61+
# Create
62+
access_token_db = await mongodb_access_token_db.create(access_token)
63+
assert access_token_db.token == "TOKEN"
64+
assert access_token_db.user_id == user_id
65+
66+
# Update
67+
access_token_db.created_at = datetime.now(timezone.utc)
68+
await mongodb_access_token_db.update(access_token_db)
69+
70+
# Get by token
71+
access_token_by_token = await mongodb_access_token_db.get_by_token(
72+
access_token_db.token
73+
)
74+
assert access_token_by_token is not None
75+
76+
# Get by token expired
77+
access_token_by_token = await mongodb_access_token_db.get_by_token(
78+
access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
79+
)
80+
assert access_token_by_token is None
81+
82+
# Get by token not expired
83+
access_token_by_token = await mongodb_access_token_db.get_by_token(
84+
access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
85+
)
86+
assert access_token_by_token is not None
87+
88+
# Get by token unknown
89+
access_token_by_token = await mongodb_access_token_db.get_by_token(
90+
"NOT_EXISTING_TOKEN"
91+
)
92+
assert access_token_by_token is None
93+
94+
# Exception when inserting existing token
95+
with pytest.raises(pymongo.errors.DuplicateKeyError):
96+
await mongodb_access_token_db.create(access_token_db)
97+
98+
# Delete token
99+
await mongodb_access_token_db.delete(access_token_db)
100+
deleted_access_token = await mongodb_access_token_db.get_by_token(
101+
access_token_db.token
102+
)
103+
assert deleted_access_token is None
File renamed without changes.

0 commit comments

Comments
 (0)