From f137d24781e9ca277a435f78ce3547335dd87288 Mon Sep 17 00:00:00 2001 From: abdurrehman11 Date: Fri, 3 May 2024 03:57:01 +0500 Subject: [PATCH 1/2] fixed trunk issues in conftest --- app/conftest.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/app/conftest.py b/app/conftest.py index 2fcab8d..35961ae 100644 --- a/app/conftest.py +++ b/app/conftest.py @@ -77,7 +77,10 @@ def user_mocked_info(db_mocked_app_client): "password": "pass_mock", } response = db_mocked_app_client.post(url="user/register", json=user_info) - assert 201 == response.status_code + if response.status_code != 201: + raise AssertionError( + f"Expected status code 201, but got {response.status_code}" + ) # Login response = db_mocked_app_client.post( @@ -88,7 +91,10 @@ def user_mocked_info(db_mocked_app_client): }, ) - assert 200 == response.status_code + if response.status_code != 200: + raise AssertionError( + f"Expected status code 200, but got {response.status_code}" + ) user_info["access_token"] = response.json()["access_token"] user_info["refresh_token"] = response.json()["refresh_token"] From 8ce4fa1693e99695a276131dedcce398ed3b746b Mon Sep 17 00:00:00 2001 From: abdurrehman11 Date: Sat, 4 May 2024 21:14:38 +0500 Subject: [PATCH 2/2] fixed trunk issues --- app/core/admin/auth.py | 2 + app/core/admin/models.py | 3 +- app/core/db/migrations/env.py | 1 - app/core/db/session.py | 26 ++++++--- app/core/models/base.py | 5 +- app/core/test/admin/test_auth.py | 21 +++++-- app/core/test/auth/test_functions.py | 42 ++++++++++---- app/core/test/db/test_engine.py | 3 +- app/core/test/db/test_session.py | 12 ++-- app/core/test/models/test_base.py | 40 +++++++------ app/core/test/test_main.py | 21 ++++--- app/services/user/routes.py | 17 +++--- app/services/user/test/test_models.py | 3 +- app/services/user/test/test_routes.py | 82 ++++++++++++++++++++++----- generate_client/sdk_client_script.py | 2 +- 15 files changed, 195 insertions(+), 85 deletions(-) diff --git a/app/core/admin/auth.py b/app/core/admin/auth.py index 06bb170..cfa40b7 100644 --- a/app/core/admin/auth.py +++ b/app/core/admin/auth.py @@ -39,3 +39,5 @@ async def authenticate(self, request: Request) -> RedirectResponse | None: ) request.session.update({"username": username}) + + return None diff --git a/app/core/admin/models.py b/app/core/admin/models.py index 3a9f724..1c5e9e4 100644 --- a/app/core/admin/models.py +++ b/app/core/admin/models.py @@ -4,6 +4,7 @@ from sqladmin._queries import Query from starlette.requests import Request +from app.core.models.base import ModelCore from app.core.models.record import Record source = "ADMIN" @@ -52,7 +53,7 @@ async def update_model(self, request: Request, pk: str, data: dict) -> Any: async def delete_model(self, request: Request, pk: Any) -> None: """Overwrite default delete method for safe_delete/reverse_delete methods""" - model: object = await self.get_object_for_edit(pk) + model: ModelCore = await self.get_object_for_edit(pk) if model.deleted: deleted = False action_description = "REVERSE_DELETE" diff --git a/app/core/db/migrations/env.py b/app/core/db/migrations/env.py index 3c43dae..1c54bf4 100644 --- a/app/core/db/migrations/env.py +++ b/app/core/db/migrations/env.py @@ -18,7 +18,6 @@ # add your model's MetaData object here # for 'autogenerate' support -from app.core.db.migrations.models import * meta = MetaData() for table in SQLModel.metadata.tables.values(): diff --git a/app/core/db/session.py b/app/core/db/session.py index 9fb5f3a..80b8f9c 100644 --- a/app/core/db/session.py +++ b/app/core/db/session.py @@ -1,5 +1,5 @@ from contextvars import ContextVar -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL @@ -13,6 +13,8 @@ from starlette.requests import Request from starlette.types import ASGIApp +from app.core.models.base import ModelCore + class sessionmaker(sessionmaker_): def __init__(self, *args, **kwargs): @@ -31,8 +33,8 @@ def __init__( app: ASGIApp, db_url: Optional[Union[str, URL]] = None, custom_engine: Optional[Engine] = None, - engine_args: Dict = None, - session_args: Dict = None, + engine_args: Optional[Dict[Any, Any]] = None, + session_args: Optional[Dict[Any, Any]] = None, commit_on_exit: bool = False, ): super().__init__(app) @@ -106,7 +108,9 @@ def session(self) -> Session: raise MissingSessionError return session - def get_one(self, model: any, key: any, value: any) -> object: + def get_one( + self, model: ModelCore, key: Any, value: Any + ) -> Optional[ModelCore]: try: statement = select(model).where( key == value, model.deleted == False @@ -116,7 +120,9 @@ def get_one(self, model: any, key: any, value: any) -> object: except NoResultFound: return None - def get_all(self, model: any, offset: int, limit: int, order_by: any): + def get_all( + self, model: ModelCore, offset: int, limit: int, order_by: Any + ): """Get all items excluding delete ones""" statement = ( select(model) @@ -128,20 +134,20 @@ def get_all(self, model: any, offset: int, limit: int, order_by: any): result = self.session.exec(statement).all() return result - def update(self, item: any) -> object: + def update(self, item: Any) -> object: """Create or modify""" self.session.add(item) self.session.commit() self.session.refresh(item) return item - def delete(self, item: any) -> bool: + def delete(self, item: Any) -> bool: """TODO: Add return False when exceptions are observed""" self.session.delete(item) self.session.commit() return True - def count(self, model: any) -> int: + def count(self, model: Any) -> int: result = len( self.session.exec( select(model).where(model.deleted == False) @@ -152,7 +158,9 @@ def count(self, model: any) -> int: class DBSession(metaclass=DBSessionMeta): def __init__( - self, session_args: Dict = None, commit_on_exit: bool = False + self, + session_args: Optional[Dict[Any, Any]] = None, + commit_on_exit: bool = False, ): self.token = None self.session_args = session_args or {} diff --git a/app/core/models/base.py b/app/core/models/base.py index 54bf1e8..0a7c28d 100644 --- a/app/core/models/base.py +++ b/app/core/models/base.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any, Optional from sqlmodel import Field, SQLModel @@ -64,7 +65,7 @@ def delete( return True @classmethod - def get_one(cls, value: any, key: any = None) -> any: + def get_one(cls, value: Any, key: Any = None) -> Optional["ModelCore"]: """ Get one item based in key/value - value: the value of the field @@ -75,7 +76,7 @@ def get_one(cls, value: any, key: any = None) -> any: return db.get_one(cls, key=key, value=value) @classmethod - def get_all(cls, offset: int = 0, limit: int = 100, order_by: any = None): + def get_all(cls, offset: int = 0, limit: int = 100, order_by: Any = None): """ Return a list of items Params: diff --git a/app/core/test/admin/test_auth.py b/app/core/test/admin/test_auth.py index a719adb..f9630a9 100644 --- a/app/core/test/admin/test_auth.py +++ b/app/core/test/admin/test_auth.py @@ -39,7 +39,8 @@ async def test_admin_auth_login_successful(admin_auth): result = await admin_auth.login(mock_request) - assert result is True + if result is not True: + raise AssertionError("Expected result to be True, but it was not.") @pytest.mark.asyncio @@ -53,7 +54,8 @@ async def test_admin_auth_login_failed(admin_auth): result = await admin_auth.login(mock_request) - assert result is False + if result is not True: + raise AssertionError("Expected result to be True, but it was not.") @pytest.mark.asyncio @@ -62,7 +64,8 @@ async def test_admin_auth_logout(admin_auth): result = await admin_auth.logout(mock_request) - assert result is True + if result is not True: + raise AssertionError("Expected result to be True, but it was not.") @pytest.mark.asyncio @@ -71,7 +74,8 @@ async def test_admin_auth_authenticate_valid_token(admin_auth, valid_token): result = await admin_auth.authenticate(mock_request) - assert result is None + if result is not None: + raise AssertionError("Expected result to be None, but it was not.") @pytest.mark.asyncio @@ -80,5 +84,10 @@ async def test_admin_auth_authenticate_invalid_token(admin_auth): result = await admin_auth.authenticate(mock_request) - assert isinstance(result, RedirectResponse) - assert result.status_code == 302 + if not isinstance(result, RedirectResponse): + raise AssertionError( + "Expected result to be an instance of RedirectResponse." + ) + + if result.status_code != 302: + raise AssertionError("Expected status code to be 302, but it was not.") diff --git a/app/core/test/auth/test_functions.py b/app/core/test/auth/test_functions.py index c8b04cf..1cff215 100644 --- a/app/core/test/auth/test_functions.py +++ b/app/core/test/auth/test_functions.py @@ -19,24 +19,28 @@ def test_verify_password_ok(self): password_plain = "test" password_hashed = hash_password(password_plain) - assert verify_password(password_plain, password_hashed) + if not verify_password(password_plain, password_hashed): + raise AssertionError("Password verification failed.") def test_verify_password_ko(self): password_plain = "test" password_hashed = hash_password(password_plain) - assert not verify_password("incorrect_password", password_hashed) + if verify_password("incorrect_password", password_hashed): + raise AssertionError("Password verification should fail.") class TestToken(TestCase): def test_create_token(self): access_token = create_access_token(username="test") - assert str == type(access_token) + if type(access_token) != str: + raise AssertionError("Access token should be of type str.") def test_get_current_user_ok(self): access_token = create_access_token(username="test") username = get_current_user(access_token) - assert "test" == username + if username != "test": + raise AssertionError("Username should be 'test'.") def test_get_current_user_no_username(self): access_token = create_jwt_token( @@ -59,46 +63,60 @@ def test_get_current_user_expired_signature(self): def test_get_current_admin_ok(self): access_token = create_access_token(username="test") username = get_current_admin(access_token) - assert "test" == username + if username != "test": + raise AssertionError("Username should be 'test'.") def test_get_current_admin_no_username(self): access_token = create_jwt_token( data={}, expiration_delta=timedelta(minutes=30) ) username = get_current_admin(access_token) - assert username is None + if username is not None: + raise AssertionError("Username should be None.") def test_get_current_admin_invalid_signature(self): username = get_current_admin("incorrect_token") - assert username is None + if username is not None: + raise AssertionError("Username should be None.") def test_get_current_admin_expired_signature(self): access_token = create_jwt_token( data={}, expiration_delta=timedelta(minutes=-30) ) username = get_current_admin(access_token) - assert username is None + if username is not None: + raise AssertionError("Username should be None.") def test_verify_refresh_token_ok(self): payload = {"scopes": "refresh_token"} token = create_jwt_token( data=payload, expiration_delta=timedelta(minutes=30) ) - assert payload == verify_refresh_token(token) + if payload != verify_refresh_token(token): + raise AssertionError( + "Payload should match the result of verify_refresh_token." + ) def test_verify_refresh_token_no_scopes(self): payload = {} token = create_jwt_token( data=payload, expiration_delta=timedelta(minutes=30) ) - assert verify_refresh_token(token) is None + if payload != verify_refresh_token(token): + raise AssertionError( + "Payload should match the result of verify_refresh_token." + ) def test_verify_refresh_token_invalid_signature(self): - assert verify_refresh_token("invalid_token") is None + if verify_refresh_token("invalid_token") is not None: + raise AssertionError("Invalid token should return None.") def test_verify_refresh_token_expired_signature(self): payload = {"scopes": "refresh_token"} token = create_jwt_token( data=payload, expiration_delta=timedelta(minutes=-30) ) - assert verify_refresh_token(token) is None + if verify_refresh_token(token) is not None: + raise AssertionError( + "Refresh token verification should return None." + ) diff --git a/app/core/test/db/test_engine.py b/app/core/test/db/test_engine.py index 83dbbdf..9d1e6f0 100644 --- a/app/core/test/db/test_engine.py +++ b/app/core/test/db/test_engine.py @@ -4,4 +4,5 @@ def test_get_engine(): engine = get_engine() - assert settings.database_url == str(engine.url) + if settings.database_url != str(engine.url): + raise AssertionError("Database URL does not match engine URL.") diff --git a/app/core/test/db/test_session.py b/app/core/test/db/test_session.py index 9461a0c..2ae0831 100644 --- a/app/core/test/db/test_session.py +++ b/app/core/test/db/test_session.py @@ -67,10 +67,10 @@ def test_db_session_constructor(self): self.assertIsInstance(db_session, DBSessionMeta) # Exit rollback - with self.assertRaises(Exception): + with self.assertRaises(RuntimeError): with DBSession() as db_session: self.assertIsNotNone(db_session) - raise Exception("Simulated error") + raise RuntimeError("Simulated error") # Exit commit with DBSession(commit_on_exit=True) as db_session: @@ -91,15 +91,17 @@ def _db_mocked(self, db_mocked): def test_session_ok(self): db_test = DBSession session = db_test.session - assert session is not None + if session is None: + raise AssertionError("Session should not be None.") def test_session_missing(self): with patch( "app.core.db.session._session", MagicMock() - ) as mock_session, self.assertRaises(MissingSessionError): + ) as mock_session: mock_session.get.return_value = None db_test = DBSession - db_test.session + session = db_test.session + self.assertRaises(MissingSessionError, session) @patch("app.core.db.session._Session", None) def test_session_no_initialised(self): diff --git a/app/core/test/models/test_base.py b/app/core/test/models/test_base.py index 74e06d7..6f95cde 100644 --- a/app/core/test/models/test_base.py +++ b/app/core/test/models/test_base.py @@ -6,7 +6,7 @@ from app.core.models.base import ModelCore -class TestModel(ModelCore, table=True): +class TestModel(ModelCore): name: str @@ -21,49 +21,52 @@ def test_save(self): model_saved.delete() - def test_delete_soft(self): + def test_delete_soft(self) -> None: model_to_delete: TestModel = TestModel(name="test_delete").save() model_to_delete.delete() - assert ( - TestModel.get_one(value="test_delete", key=TestModel.name) is None + self.assertIsNone( + TestModel.get_one(value="test_delete", key=TestModel.name) ) - def test_delete_hard(self): + def test_delete_hard(self) -> None: model_to_delete: TestModel = TestModel(name="test_delete").save() model_to_delete.delete(hard=True) - assert ( - TestModel.get_one(value="test_delete", key=TestModel.name) is None + self.assertIsNone( + TestModel.get_one(value="test_delete", key=TestModel.name) ) - def test_delete_error(self): + def test_delete_error(self) -> None: model_to_delete: TestModel = TestModel(name="test_delete").save() with patch( "app.core.models.base.ModelCore.save", MagicMock() ) as save_method: save_method.return_value = False - assert not model_to_delete.delete() + self.assertFalse(model_to_delete.delete()) model_to_delete.delete() - def test_get_one(self): + def test_get_one(self) -> None: model_to_get: TestModel = TestModel(name="model_to_get").save() - assert ( - TestModel.get_one(value="model_to_get", key=TestModel.name) - == model_to_get + self.assertEqual( + TestModel.get_one(value="model_to_get", key=TestModel.name), + model_to_get, ) model_to_get.delete() - def test_get_one_by_default_id(self): + def test_get_one_by_default_id(self) -> None: model_to_get: TestModel = TestModel(name="model_to_get").save() - assert TestModel.get_one(value=model_to_get.id) == model_to_get + self.assertEqual( + TestModel.get_one(value=model_to_get.id), + model_to_get, + ) model_to_get.delete() @@ -72,7 +75,10 @@ def test_get_all(self): for name in names: TestModel(name=name).save() - assert len(names) == len(model_list := TestModel.get_all()) + self.assertEqual( + len(names), + len(model_list := TestModel.get_all()), + ) for model in model_list: model.delete() @@ -82,7 +88,7 @@ def test_count(self): for name in names: TestModel(name=name).save() - assert len(names) == TestModel.count() + self.assertEqual(len(names), TestModel.count()) for model in TestModel.get_all(): model.delete() diff --git a/app/core/test/test_main.py b/app/core/test/test_main.py index d2808bf..ea9c513 100644 --- a/app/core/test/test_main.py +++ b/app/core/test/test_main.py @@ -12,7 +12,7 @@ def _db_mocked_app_client(self, db_mocked_app_client): def test_endpoint_main(self): response = self.app.get("/") - assert 200 == response.status_code + self.assertEqual(200, response.status_code) def test_check_health_ok(self): with patch("app.main.get_engine", MagicMock()), patch( @@ -21,7 +21,7 @@ def test_check_health_ok(self): "app.main.redis", MagicMock() ): response = self.app.get("/check_health") - assert 200 == response.status_code + self.assertEqual(200, response.status_code) def test_check_health_ko(self): with patch( @@ -44,10 +44,13 @@ def test_check_health_ko(self): response = self.app.get("/check_health") - assert 200 == response.status_code - assert { - "celery": "down", - "db": "down", - "rabbitmq": "down", - "redis": "down", - } == response.json() + self.assertEqual(200, response.status_code) + self.assertEqual( + { + "celery": "down", + "db": "down", + "rabbitmq": "down", + "redis": "down", + }, + response.json(), + ) diff --git a/app/services/user/routes.py b/app/services/user/routes.py index da94a74..08699db 100644 --- a/app/services/user/routes.py +++ b/app/services/user/routes.py @@ -75,19 +75,22 @@ def refresh_token(refresh_token: RefreshToken): """ Refresh access token from a refresh token """ - refresh_token = refresh_token.refresh_token + refresh_token_val = refresh_token.refresh_token # Validate the refresh token and get the user - if not (payload := verify_refresh_token(refresh_token)): + if not (payload := verify_refresh_token(refresh_token_val)): raise HTTPException(status_code=400, detail="Not authorized") username = payload.get("sub") # Check if the refresh token is revoked - if RevokedToken.is_revoked(refresh_token): + if RevokedToken.is_revoked(refresh_token_val): raise HTTPException(status_code=400, detail="Revoked refresh token") - # Return new access token - access_token = create_access_token(username) + if username is None: + raise HTTPException(status_code=400, detail="Invalid payload") + else: + # Return new access token + access_token = create_access_token(username) return TokenRefreshed(access_token=access_token, token_type="bearer") @@ -102,9 +105,9 @@ def logout(refresh_token: RefreshToken): # TODO: include a cron task to remove expired revoked tokens return: Always return 200 OK (for security reasons) """ - refresh_token = refresh_token.refresh_token + refresh_token_val = refresh_token.refresh_token # Validate token - if payload := verify_refresh_token(refresh_token): + if payload := verify_refresh_token(refresh_token_val): # Extract payload username = payload["sub"] expiration_date = payload["exp"] diff --git a/app/services/user/test/test_models.py b/app/services/user/test/test_models.py index 8970c46..d967b50 100644 --- a/app/services/user/test/test_models.py +++ b/app/services/user/test/test_models.py @@ -4,4 +4,5 @@ def test_authenticate_user_no_exists(db_mocked_app_client): user = User.authenticate_user(username="test", password="test") - assert user is None + if user is not None: + raise AssertionError("User is not None") diff --git a/app/services/user/test/test_routes.py b/app/services/user/test/test_routes.py index 5da288a..5f33a21 100644 --- a/app/services/user/test/test_routes.py +++ b/app/services/user/test/test_routes.py @@ -24,7 +24,12 @@ def test_register(self): }, ) - assert response.status_code == status.HTTP_201_CREATED + if response.status_code != status.HTTP_201_CREATED: + raise AssertionError( + "Expected status code 201, but got {}".format( + response.status_code + ) + ) # Register the same user response = self.app.post( @@ -35,7 +40,12 @@ def test_register(self): }, ) - assert response.status_code == status.HTTP_400_BAD_REQUEST + if response.status_code != status.HTTP_400_BAD_REQUEST: + raise AssertionError( + "Expected status code 400, but got {}".format( + response.status_code + ) + ) def test_login(self): # Login ok @@ -47,7 +57,12 @@ def test_login(self): }, ) - assert response.status_code == status.HTTP_200_OK + if response.status_code != status.HTTP_200_OK: + raise AssertionError( + "Expected status code 200, but got {}".format( + response.status_code + ) + ) # Login ko response = self.app.post( @@ -58,7 +73,12 @@ def test_login(self): }, ) - assert response.status_code == status.HTTP_400_BAD_REQUEST + if response.status_code != status.HTTP_400_BAD_REQUEST: + raise AssertionError( + "Expected status code 400, but got {}".format( + response.status_code + ) + ) def test_refresh_token(self): # Login @@ -83,8 +103,14 @@ def test_refresh_token(self): }, ) - assert "access_token" in response.json() - assert response.status_code == status.HTTP_200_OK + if "access_token" not in response.json(): + raise AssertionError("Expected 'access_token' in response JSON") + if response.status_code != status.HTTP_200_OK: + raise AssertionError( + "Expected status code 200, but got {}".format( + response.status_code + ) + ) # Refresh token ko response = self.app.post( @@ -94,7 +120,12 @@ def test_refresh_token(self): }, ) - assert response.status_code == status.HTTP_400_BAD_REQUEST + if response.status_code != status.HTTP_400_BAD_REQUEST: + raise AssertionError( + "Expected status code 400, but got {}".format( + response.status_code + ) + ) def test_logout_and_protected(self): # Access to protected @@ -108,8 +139,12 @@ def test_logout_and_protected(self): }, ) - assert response.status_code == status.HTTP_200_OK - + if response.status_code != status.HTTP_200_OK: + raise AssertionError( + "Expected status code 200, but got {}".format( + response.status_code + ) + ) # Login again response = self.app.post( url="user/login", @@ -119,7 +154,13 @@ def test_logout_and_protected(self): }, ) - assert response.status_code == status.HTTP_200_OK + if response.status_code != status.HTTP_200_OK: + raise AssertionError( + "Expected status code 200, but got {}".format( + response.status_code + ) + ) + refresh_token = response.json()["refresh_token"] # Logout and access to protected @@ -127,7 +168,12 @@ def test_logout_and_protected(self): url="user/logout", json={"refresh_token": refresh_token} ) - assert response.status_code == status.HTTP_200_OK + if response.status_code != status.HTTP_200_OK: + raise AssertionError( + "Expected status code 200, but got {}".format( + response.status_code + ) + ) # Refresh revoked token response = self.app.post( @@ -137,11 +183,21 @@ def test_logout_and_protected(self): }, ) - assert response.status_code == status.HTTP_400_BAD_REQUEST + if response.status_code != status.HTTP_400_BAD_REQUEST: + raise AssertionError( + "Expected status code 400, but got {}".format( + response.status_code + ) + ) # Access to protected without valid token response = self.app.get( url="user/protected", ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED + if response.status_code != status.HTTP_401_UNAUTHORIZED: + raise AssertionError( + "Expected status code 401, but got {}".format( + response.status_code + ) + ) diff --git a/generate_client/sdk_client_script.py b/generate_client/sdk_client_script.py index bf4ac09..a97519f 100644 --- a/generate_client/sdk_client_script.py +++ b/generate_client/sdk_client_script.py @@ -8,7 +8,7 @@ destination_file = script_dir / "openapi.json" print("HOLA") print(destination_file) -response = requests.get(url) +response = requests.get(url, timeout=5) if response.status_code == 200: with open(destination_file, "wb") as file: