Skip to content

Commit

Permalink
Fixed public API authentication for teams. Now works with new Python …
Browse files Browse the repository at this point in the history
…Client

Signed-off-by: Trey <[email protected]>
  • Loading branch information
TreyWW committed Nov 3, 2024
1 parent f618412 commit e3bfec5
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 67 deletions.
50 changes: 43 additions & 7 deletions backend/core/api/public/authentication.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,66 @@
from typing import Type

from rest_framework.authentication import TokenAuthentication
from rest_framework.authentication import TokenAuthentication, get_authorization_header
from rest_framework.exceptions import AuthenticationFailed

from django.utils.translation import gettext_lazy as _
from backend.core.api.public.models import APIAuthToken
from backend.models import User, Organization

from rest_framework import exceptions


class CustomBearerAuthentication(TokenAuthentication):
keyword = "Bearer"

def get_model(self) -> Type[APIAuthToken]:
return APIAuthToken

def authenticate(self, request):
auth = get_authorization_header(request).split()

if not auth or auth[0].lower() != self.keyword.lower().encode():
return None

if len(auth) == 1:
msg = _("Invalid token header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _("Invalid token header. Token string should not contain spaces.")
raise exceptions.AuthenticationFailed(msg)

try:
token = auth[1].decode()
except UnicodeError:
msg = _("Invalid token header. Token string should not contain invalid characters.")
raise exceptions.AuthenticationFailed(msg)

user_or_org, token = self.authenticate_credentials(token)

request.actor = user_or_org

if isinstance(user_or_org, Organization):
request.team = user_or_org
request.team_id = user_or_org.id
else:
request.team = None
request.team_id = None

return (user_or_org, token)

def authenticate_credentials(self, raw_key) -> tuple[User | Organization | None, APIAuthToken]:
model = self.get_model()

try:
token = model.objects.get(hashed_key=model.hash_raw_key(raw_key), active=True)
except model.DoesNotExist:
raise AuthenticationFailed("Invalid token.")
raise AuthenticationFailed(_("Invalid token."))

if token.has_expired:
raise AuthenticationFailed("Token has expired.")
raise AuthenticationFailed(_("Token has expired."))

user_or_org = token.user or token.organization

# todo: make sure this is safe to set request.user = <Team> obj
return token.user or token.organization, token
if user_or_org is None:
raise AuthenticationFailed(_("Associated user or organization not found."))

# todo: override more methods + add hashing
return user_or_org, token
3 changes: 2 additions & 1 deletion backend/core/api/public/endpoints/Invoices/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rest_framework.response import Response

from backend.core.api.public.decorators import require_scopes
from backend.core.api.public.helpers.response import APIResponse
from backend.core.api.public.serializers.invoices import InvoiceSerializer
from backend.core.api.public.swagger_ui import TEAM_PARAMETER
from backend.core.api.public.types import APIRequest
Expand Down Expand Up @@ -82,4 +83,4 @@ def list_invoices_endpoint(request: APIRequest) -> Response:

serializer = InvoiceSerializer(invoices, many=True)

return Response({"success": True, "invoices": serializer.data}, status=status.HTTP_200_OK)
return APIResponse(True, {"invoices": serializer.data}, status=status.HTTP_200_OK)
3 changes: 2 additions & 1 deletion backend/core/api/public/endpoints/clients/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rest_framework.response import Response

from backend.core.api.public.decorators import require_scopes
from backend.core.api.public.helpers.response import APIResponse
from backend.core.api.public.serializers.clients import ClientSerializer
from backend.core.api.public.swagger_ui import TEAM_PARAMETER
from backend.core.api.public.types import APIRequest
Expand Down Expand Up @@ -57,4 +58,4 @@ def client_create_endpoint(request: APIRequest):
else:
client = serializer.save(user=request.user)

return Response({"client_id": client.id, "success": True}, status=status.HTTP_201_CREATED)
return APIResponse(True, {"client_id": client.id}, status=status.HTTP_201_CREATED)
4 changes: 0 additions & 4 deletions backend/core/api/public/endpoints/clients/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ def list_clients_endpoint(request: APIRequest):

search_text = request.data.get("search")

if not request.team and isinstance(request.auth.owner, Organization):
return APIResponse(False, "When using a team API Key the team_id field must be provided.")

clients: FetchClientServiceResponse = fetch_clients(request, search_text=search_text, team=request.team)

# queryset = paginator.paginate_queryset(clients, request)

serializer = ClientSerializer(clients.response, many=True)

return APIResponse(True, {"clients": serializer.data})
47 changes: 0 additions & 47 deletions backend/core/api/public/middleware.py
Original file line number Diff line number Diff line change
@@ -1,47 +0,0 @@
from django.utils.deprecation import MiddlewareMixin

from backend.core.api.public import APIAuthToken
from backend.models import Organization


class AttachTokenMiddleware(MiddlewareMixin):
def process_request(self, request):
if not request.path.startswith("/api/public/"):
return

auth_header = request.headers.get("Authorization")

if not (auth_header and auth_header.startswith("Bearer ")):
request.auth = None
return

token_key = auth_header.split(" ")[1]
try:
token = APIAuthToken.objects.get(key=token_key, active=True)
if not token.has_expired:
request.auth = token
request.api_token = token
except APIAuthToken.DoesNotExist:
request.auth = None


class HandleTeamContextMiddleware(MiddlewareMixin):
def process_request(self, request):
if not request.path.startswith("/api/public/"):
return

if hasattr(request, "query_params"):
team_id = request.query_params.get("team_id")
else:
team_id = request.GET.get("team_id")
request.team = None
request.team_id = team_id

if not team_id:
# No team_id provided, proceed with user context
return

team = Organization.objects.filter(id=team_id).first()

request.team = team
request.actor = team
4 changes: 4 additions & 0 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,10 @@ def filter_by_owner(cls: typing.Type[M], owner: Union[User, Organization]) -> Qu
else:
raise ValueError("Owner must be either a User or an Organization")

@property
def is_team(self):
return isinstance(self.owner, Organization)


class PasswordSecret(ExpiresBase):
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="password_secrets")
Expand Down
12 changes: 5 additions & 7 deletions settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@
"tz_detect.middleware.TimezoneMiddleware",
"backend.middleware.HTMXPartialLoadMiddleware",
# "backend.core.api.public.middleware.AttachTokenMiddleware",
"backend.core.api.public.middleware.HandleTeamContextMiddleware",
# "backend.core.api.public.middleware.HandleTeamContextMiddleware",
]

if DEBUG:
Expand Down Expand Up @@ -305,11 +305,9 @@
BILLING_ENABLED = get_var("BILLING_ENABLED", "").lower() == "true"

if BILLING_ENABLED or TYPE_CHECKING:
print("BILLING MODULE IS ENABLED")
print("[BACKEND] BILLING MODULE IS ENABLED")
INSTALLED_APPS.append("billing")
MIDDLEWARE.extend(["billing.middleware.CheckUserSubScriptionMiddleware"])
# TEMPLATES[0]["DIRS"].append(BASE_DIR / "billing/templates")
print(TEMPLATES)

# endregion "Billing"

Expand Down Expand Up @@ -414,20 +412,20 @@ class CustomPrivateMediaStorage(S3Storage):
AWS_STATIC_ENABLED = get_var("AWS_STATIC_ENABLED", default="False").lower() == "true"
AWS_STATIC_CDN_TYPE = get_var("AWS_STATIC_CDN_TYPE")

logging.info(f"{AWS_STATIC_ENABLED=} | {AWS_STATIC_CDN_TYPE=}")
logging.debug(f"{AWS_STATIC_ENABLED=} | {AWS_STATIC_CDN_TYPE=}")

if AWS_STATIC_ENABLED or AWS_STATIC_CDN_TYPE.lower() == "aws":
STATICFILES_STORAGE = "settings.settings.CustomStaticStorage"
STATIC_LOCATION = get_var("AWS_STATIC_LOCATION", default="static")
STORAGES["staticfiles"] = {
"BACKEND": "settings.settings.CustomStaticStorage",
}
logging.info(f"{STATIC_LOCATION=} | {STATICFILES_STORAGE=}")
logging.debug(f"{STATIC_LOCATION=} | {STATICFILES_STORAGE=}")
else:
STATIC_URL = f"/static/"
STATIC_ROOT = os.path.join(BASE_DIR, "static")
STATICFILES_STORAGE = "django.contrib.staticfiles.storage.ManifestStaticFilesStorage"
logging.info(f"{STATIC_URL=} | {STATIC_ROOT=} | {STATICFILES_STORAGE=}")
logging.debug(f"{STATIC_URL=} | {STATIC_ROOT=} | {STATICFILES_STORAGE=}")

AWS_MEDIA_PUBLIC_ENABLED = get_var("AWS_MEDIA_PUBLIC_ENABLED", default="False").lower() == "true"

Expand Down

0 comments on commit e3bfec5

Please sign in to comment.