From fbfa946708985269bb83ef9e96fb2c84b42c7745 Mon Sep 17 00:00:00 2001 From: Paul Bauriegel Date: Sun, 16 Feb 2025 11:05:07 +0100 Subject: [PATCH 1/5] Create login_backend.py --- .../authentication/db/login_backend.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 argilla-server/src/argilla_server/security/authentication/db/login_backend.py diff --git a/argilla-server/src/argilla_server/security/authentication/db/login_backend.py b/argilla-server/src/argilla_server/security/authentication/db/login_backend.py new file mode 100644 index 0000000000..14e13dee1b --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/db/login_backend.py @@ -0,0 +1,38 @@ +from starlette.authentication import AuthenticationBackend +from argilla_server.jobs.queues import REDIS_CONNECTION + +MAX_ATTEMPTS = 5 +LOCKOUT_TIME = 300 # 5 minutes + +class LoginAuthenticationBackend(AuthenticationBackend): + """ + Authentication backend which locks the user out after a certain amount + of wrong attempts to login + """ + + def __init__(self): + self.redis = REDIS_CONNECTION + + async def check_lockout(self, credential_key: str) -> bool: + """Check if credential key (username or API key) is locked out""" + key = f"failed_auth:{credential_key}" + attempts = await self.redis.get(key) + attempts = int(attempts or 0) + if attempts >= MAX_ATTEMPTS: + return False + return True + + async def increase_lockout(self, credential_key: str) -> None: + """Increment failed attempts""" + key = f"failed_auth:{credential_key}" + await self.redis.incr(key) + + # Ensure expiration is set after first failure + ttl = await self.redis.ttl(key) + if ttl == -1: # Key exists but no expiration set + await self.redis.expire(key, LOCKOUT_TIME) + + async def clear_lockout(self, credential_key: str) -> None: + """ Reset failed attempts on successful authentication""" + key = f"failed_auth:{credential_key}" + await self.redis.delete(key) From a87c915b5d5ae11610ecf59929379cf6d66219cb Mon Sep 17 00:00:00 2001 From: Paul Bauriegel Date: Sun, 16 Feb 2025 11:09:43 +0100 Subject: [PATCH 2/5] Add lockout feature to bearer_token_backend --- .../authentication/db/bearer_token_backend.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py b/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py index 8e84debc5a..0b4a6b2b49 100644 --- a/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py @@ -16,14 +16,15 @@ from fastapi import Request from fastapi.security import HTTPBearer -from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser +from starlette.authentication import AuthCredentials, BaseUser from argilla_server.contexts import accounts from argilla_server.security.authentication.jwt import JWT from argilla_server.security.authentication.userinfo import UserInfo +from argilla_server.security.authentication.db.login_backend import LoginAuthenticationBackend -class BearerTokenAuthenticationBackend(AuthenticationBackend): +class BearerTokenAuthenticationBackend(LoginAuthenticationBackend): """Authenticate the user using the username and password Bearer header""" scheme = HTTPBearer(auto_error=False) @@ -36,12 +37,16 @@ async def authenticate(self, request: Request) -> typing.Optional[typing.Tuple[A token = credentials.credentials username = JWT.decode(token).get("username") + is_locked = self.check_lockout(username) + if is_locked: + return None db = request.state.db user = await accounts.get_user_by_username(db, username) if not user: + self.increase_lockout(user) return None - + self.clear_lockout(user) return AuthCredentials(), UserInfo( username=user.username, name=user.first_name, role=user.role, identity=str(user.id) ) From 52383acc242dde8412341c3cbc38921bd29e4599 Mon Sep 17 00:00:00 2001 From: Paul Bauriegel Date: Sun, 16 Feb 2025 11:15:06 +0100 Subject: [PATCH 3/5] Add lockout via client ip to api_key_backend --- .../security/authentication/db/api_key_backend.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/db/api_key_backend.py b/argilla-server/src/argilla_server/security/authentication/db/api_key_backend.py index 1b23902317..34b1bdf5fd 100644 --- a/argilla-server/src/argilla_server/security/authentication/db/api_key_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/db/api_key_backend.py @@ -16,14 +16,15 @@ from fastapi import Request from fastapi.security import APIKeyHeader -from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser +from starlette.authentication import AuthCredentials, BaseUser from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.contexts import accounts from argilla_server.security.authentication.userinfo import UserInfo +from argilla_server.security.authentication.db.login_backend import LoginAuthenticationBackend -class APIKeyAuthenticationBackend(AuthenticationBackend): +class APIKeyAuthenticationBackend(LoginAuthenticationBackend): """Authentication backend for API Key authentication""" scheme = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) @@ -31,13 +32,19 @@ class APIKeyAuthenticationBackend(AuthenticationBackend): async def authenticate(self, request: Request) -> Optional[Tuple[AuthCredentials, BaseUser]]: """Authenticate the user using the API Key header""" api_key: str = await self.scheme(request) + client_ip = request.client.host if not api_key: return None + is_locked = self.check_lockout(client_ip) + if is_locked: + return None db = request.state.db user = await accounts.get_user_by_api_key(db, api_key=api_key) if not user: + self.increase_lockout(client_ip) return None + self.clear_lockout(client_ip) return AuthCredentials(), UserInfo( username=user.username, name=user.first_name, role=user.role, identity=str(user.id) From f9c71618e74353e17b13b299141d766261b0d531 Mon Sep 17 00:00:00 2001 From: Paul Bauriegel Date: Sun, 16 Feb 2025 11:17:11 +0100 Subject: [PATCH 4/5] Add client ip bearer_token_backend --- .../security/authentication/db/bearer_token_backend.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py b/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py index 0b4a6b2b49..2eb2e9016f 100644 --- a/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/db/bearer_token_backend.py @@ -32,21 +32,25 @@ class BearerTokenAuthenticationBackend(LoginAuthenticationBackend): async def authenticate(self, request: Request) -> typing.Optional[typing.Tuple[AuthCredentials, BaseUser]]: """Authenticate the user using the username and password Bearer header""" credentials = await self.scheme(request) + client_ip = request.client.host if not credentials: return None token = credentials.credentials username = JWT.decode(token).get("username") is_locked = self.check_lockout(username) - if is_locked: + is_locked_ip = self.check_lockout(client_ip) + if is_locked or is_locked_ip: return None db = request.state.db user = await accounts.get_user_by_username(db, username) if not user: - self.increase_lockout(user) + self.increase_lockout(username) + self.increase_lockout(client_ip) return None - self.clear_lockout(user) + self.clear_lockout(username) + self.clear_lockout(client_ip) return AuthCredentials(), UserInfo( username=user.username, name=user.first_name, role=user.role, identity=str(user.id) ) From c11024dd23639861690585cbbd7abcd372908c75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 16 Feb 2025 10:18:56 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../authentication/db/login_backend.py | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/argilla-server/src/argilla_server/security/authentication/db/login_backend.py b/argilla-server/src/argilla_server/security/authentication/db/login_backend.py index 14e13dee1b..6a791550d2 100644 --- a/argilla-server/src/argilla_server/security/authentication/db/login_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/db/login_backend.py @@ -1,12 +1,27 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from starlette.authentication import AuthenticationBackend from argilla_server.jobs.queues import REDIS_CONNECTION MAX_ATTEMPTS = 5 LOCKOUT_TIME = 300 # 5 minutes + class LoginAuthenticationBackend(AuthenticationBackend): """ - Authentication backend which locks the user out after a certain amount + Authentication backend which locks the user out after a certain amount of wrong attempts to login """ @@ -14,25 +29,25 @@ def __init__(self): self.redis = REDIS_CONNECTION async def check_lockout(self, credential_key: str) -> bool: - """Check if credential key (username or API key) is locked out""" - key = f"failed_auth:{credential_key}" - attempts = await self.redis.get(key) - attempts = int(attempts or 0) - if attempts >= MAX_ATTEMPTS: - return False - return True - + """Check if credential key (username or API key) is locked out""" + key = f"failed_auth:{credential_key}" + attempts = await self.redis.get(key) + attempts = int(attempts or 0) + if attempts >= MAX_ATTEMPTS: + return False + return True + async def increase_lockout(self, credential_key: str) -> None: - """Increment failed attempts""" - key = f"failed_auth:{credential_key}" - await self.redis.incr(key) - - # Ensure expiration is set after first failure - ttl = await self.redis.ttl(key) - if ttl == -1: # Key exists but no expiration set - await self.redis.expire(key, LOCKOUT_TIME) + """Increment failed attempts""" + key = f"failed_auth:{credential_key}" + await self.redis.incr(key) + + # Ensure expiration is set after first failure + ttl = await self.redis.ttl(key) + if ttl == -1: # Key exists but no expiration set + await self.redis.expire(key, LOCKOUT_TIME) async def clear_lockout(self, credential_key: str) -> None: - """ Reset failed attempts on successful authentication""" - key = f"failed_auth:{credential_key}" - await self.redis.delete(key) + """Reset failed attempts on successful authentication""" + key = f"failed_auth:{credential_key}" + await self.redis.delete(key)