From e3bfec5c8362f6cecffc49f05f053496f2794b90 Mon Sep 17 00:00:00 2001 From: Trey <73353716+TreyWW@users.noreply.github.com> Date: Sun, 3 Nov 2024 23:01:50 +0000 Subject: [PATCH] Fixed public API authentication for teams. Now works with new Python Client Signed-off-by: Trey <73353716+TreyWW@users.noreply.github.com> --- backend/core/api/public/authentication.py | 50 ++++++++++++++++--- .../api/public/endpoints/Invoices/list.py | 3 +- .../api/public/endpoints/clients/create.py | 3 +- .../core/api/public/endpoints/clients/list.py | 4 -- backend/core/api/public/middleware.py | 47 ----------------- backend/core/models.py | 4 ++ settings/settings.py | 12 ++--- 7 files changed, 56 insertions(+), 67 deletions(-) diff --git a/backend/core/api/public/authentication.py b/backend/core/api/public/authentication.py index 82e075e4..77b52d59 100644 --- a/backend/core/api/public/authentication.py +++ b/backend/core/api/public/authentication.py @@ -1,11 +1,13 @@ from typing import Type -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, get_authorization_header from rest_framework.exceptions import AuthenticationFailed - +from django.utils.translation import gettext_lazy as _ from backend.core.api.public.models import APIAuthToken from backend.models import User, Organization +from rest_framework import exceptions + class CustomBearerAuthentication(TokenAuthentication): keyword = "Bearer" @@ -13,18 +15,52 @@ class CustomBearerAuthentication(TokenAuthentication): def get_model(self) -> Type[APIAuthToken]: return APIAuthToken + def authenticate(self, request): + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != self.keyword.lower().encode(): + return None + + if len(auth) == 1: + msg = _("Invalid token header. No credentials provided.") + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = _("Invalid token header. Token string should not contain spaces.") + raise exceptions.AuthenticationFailed(msg) + + try: + token = auth[1].decode() + except UnicodeError: + msg = _("Invalid token header. Token string should not contain invalid characters.") + raise exceptions.AuthenticationFailed(msg) + + user_or_org, token = self.authenticate_credentials(token) + + request.actor = user_or_org + + if isinstance(user_or_org, Organization): + request.team = user_or_org + request.team_id = user_or_org.id + else: + request.team = None + request.team_id = None + + return (user_or_org, token) + def authenticate_credentials(self, raw_key) -> tuple[User | Organization | None, APIAuthToken]: model = self.get_model() try: token = model.objects.get(hashed_key=model.hash_raw_key(raw_key), active=True) except model.DoesNotExist: - raise AuthenticationFailed("Invalid token.") + raise AuthenticationFailed(_("Invalid token.")) if token.has_expired: - raise AuthenticationFailed("Token has expired.") + raise AuthenticationFailed(_("Token has expired.")) + + user_or_org = token.user or token.organization - # todo: make sure this is safe to set request.user = obj - return token.user or token.organization, token + if user_or_org is None: + raise AuthenticationFailed(_("Associated user or organization not found.")) - # todo: override more methods + add hashing + return user_or_org, token diff --git a/backend/core/api/public/endpoints/Invoices/list.py b/backend/core/api/public/endpoints/Invoices/list.py index 524aa44a..adb81033 100644 --- a/backend/core/api/public/endpoints/Invoices/list.py +++ b/backend/core/api/public/endpoints/Invoices/list.py @@ -8,6 +8,7 @@ from rest_framework.response import Response from backend.core.api.public.decorators import require_scopes +from backend.core.api.public.helpers.response import APIResponse from backend.core.api.public.serializers.invoices import InvoiceSerializer from backend.core.api.public.swagger_ui import TEAM_PARAMETER from backend.core.api.public.types import APIRequest @@ -82,4 +83,4 @@ def list_invoices_endpoint(request: APIRequest) -> Response: serializer = InvoiceSerializer(invoices, many=True) - return Response({"success": True, "invoices": serializer.data}, status=status.HTTP_200_OK) + return APIResponse(True, {"invoices": serializer.data}, status=status.HTTP_200_OK) diff --git a/backend/core/api/public/endpoints/clients/create.py b/backend/core/api/public/endpoints/clients/create.py index 030cbcd8..7899eb3d 100644 --- a/backend/core/api/public/endpoints/clients/create.py +++ b/backend/core/api/public/endpoints/clients/create.py @@ -5,6 +5,7 @@ from rest_framework.response import Response from backend.core.api.public.decorators import require_scopes +from backend.core.api.public.helpers.response import APIResponse from backend.core.api.public.serializers.clients import ClientSerializer from backend.core.api.public.swagger_ui import TEAM_PARAMETER from backend.core.api.public.types import APIRequest @@ -57,4 +58,4 @@ def client_create_endpoint(request: APIRequest): else: client = serializer.save(user=request.user) - return Response({"client_id": client.id, "success": True}, status=status.HTTP_201_CREATED) + return APIResponse(True, {"client_id": client.id}, status=status.HTTP_201_CREATED) diff --git a/backend/core/api/public/endpoints/clients/list.py b/backend/core/api/public/endpoints/clients/list.py index af55c6b3..29c59f96 100644 --- a/backend/core/api/public/endpoints/clients/list.py +++ b/backend/core/api/public/endpoints/clients/list.py @@ -44,13 +44,9 @@ def list_clients_endpoint(request: APIRequest): search_text = request.data.get("search") - if not request.team and isinstance(request.auth.owner, Organization): - return APIResponse(False, "When using a team API Key the team_id field must be provided.") - clients: FetchClientServiceResponse = fetch_clients(request, search_text=search_text, team=request.team) # queryset = paginator.paginate_queryset(clients, request) serializer = ClientSerializer(clients.response, many=True) - return APIResponse(True, {"clients": serializer.data}) diff --git a/backend/core/api/public/middleware.py b/backend/core/api/public/middleware.py index d329a895..e69de29b 100644 --- a/backend/core/api/public/middleware.py +++ b/backend/core/api/public/middleware.py @@ -1,47 +0,0 @@ -from django.utils.deprecation import MiddlewareMixin - -from backend.core.api.public import APIAuthToken -from backend.models import Organization - - -class AttachTokenMiddleware(MiddlewareMixin): - def process_request(self, request): - if not request.path.startswith("/api/public/"): - return - - auth_header = request.headers.get("Authorization") - - if not (auth_header and auth_header.startswith("Bearer ")): - request.auth = None - return - - token_key = auth_header.split(" ")[1] - try: - token = APIAuthToken.objects.get(key=token_key, active=True) - if not token.has_expired: - request.auth = token - request.api_token = token - except APIAuthToken.DoesNotExist: - request.auth = None - - -class HandleTeamContextMiddleware(MiddlewareMixin): - def process_request(self, request): - if not request.path.startswith("/api/public/"): - return - - if hasattr(request, "query_params"): - team_id = request.query_params.get("team_id") - else: - team_id = request.GET.get("team_id") - request.team = None - request.team_id = team_id - - if not team_id: - # No team_id provided, proceed with user context - return - - team = Organization.objects.filter(id=team_id).first() - - request.team = team - request.actor = team diff --git a/backend/core/models.py b/backend/core/models.py index b1c1235d..6294713e 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -383,6 +383,10 @@ def filter_by_owner(cls: typing.Type[M], owner: Union[User, Organization]) -> Qu else: raise ValueError("Owner must be either a User or an Organization") + @property + def is_team(self): + return isinstance(self.owner, Organization) + class PasswordSecret(ExpiresBase): user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="password_secrets") diff --git a/settings/settings.py b/settings/settings.py index 5a8c8844..3819e8f1 100644 --- a/settings/settings.py +++ b/settings/settings.py @@ -234,7 +234,7 @@ "tz_detect.middleware.TimezoneMiddleware", "backend.middleware.HTMXPartialLoadMiddleware", # "backend.core.api.public.middleware.AttachTokenMiddleware", - "backend.core.api.public.middleware.HandleTeamContextMiddleware", + # "backend.core.api.public.middleware.HandleTeamContextMiddleware", ] if DEBUG: @@ -305,11 +305,9 @@ BILLING_ENABLED = get_var("BILLING_ENABLED", "").lower() == "true" if BILLING_ENABLED or TYPE_CHECKING: - print("BILLING MODULE IS ENABLED") + print("[BACKEND] BILLING MODULE IS ENABLED") INSTALLED_APPS.append("billing") MIDDLEWARE.extend(["billing.middleware.CheckUserSubScriptionMiddleware"]) - # TEMPLATES[0]["DIRS"].append(BASE_DIR / "billing/templates") - print(TEMPLATES) # endregion "Billing" @@ -414,7 +412,7 @@ class CustomPrivateMediaStorage(S3Storage): AWS_STATIC_ENABLED = get_var("AWS_STATIC_ENABLED", default="False").lower() == "true" AWS_STATIC_CDN_TYPE = get_var("AWS_STATIC_CDN_TYPE") -logging.info(f"{AWS_STATIC_ENABLED=} | {AWS_STATIC_CDN_TYPE=}") +logging.debug(f"{AWS_STATIC_ENABLED=} | {AWS_STATIC_CDN_TYPE=}") if AWS_STATIC_ENABLED or AWS_STATIC_CDN_TYPE.lower() == "aws": STATICFILES_STORAGE = "settings.settings.CustomStaticStorage" @@ -422,12 +420,12 @@ class CustomPrivateMediaStorage(S3Storage): STORAGES["staticfiles"] = { "BACKEND": "settings.settings.CustomStaticStorage", } - logging.info(f"{STATIC_LOCATION=} | {STATICFILES_STORAGE=}") + logging.debug(f"{STATIC_LOCATION=} | {STATICFILES_STORAGE=}") else: STATIC_URL = f"/static/" STATIC_ROOT = os.path.join(BASE_DIR, "static") STATICFILES_STORAGE = "django.contrib.staticfiles.storage.ManifestStaticFilesStorage" - logging.info(f"{STATIC_URL=} | {STATIC_ROOT=} | {STATICFILES_STORAGE=}") + logging.debug(f"{STATIC_URL=} | {STATIC_ROOT=} | {STATICFILES_STORAGE=}") AWS_MEDIA_PUBLIC_ENABLED = get_var("AWS_MEDIA_PUBLIC_ENABLED", default="False").lower() == "true"