diff --git a/backend/Dockerfile b/backend/Dockerfile index 82c1cb665..3613f22fe 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -20,8 +20,12 @@ RUN set -ex \ | sort -u)" \ && apk add --virtual rundeps $runDeps \ && apk del .build-deps +RUN apk add --no-cache curl ca-certificates + +# Add AWS RDS eu-central-1 CA bundle +ADD https://truststore.pki.rds.amazonaws.com/eu-central-1/eu-central-1-bundle.pem /etc/ssl/certs/rds-ca-bundle.pem +RUN chmod 644 /etc/ssl/certs/rds-ca-bundle.pem -RUN apk add --no-cache curl RUN addgroup -S app && adduser -S app -G app ADD . /app WORKDIR /app diff --git a/backend/api/auth.py b/backend/api/auth.py index 8421a4ec0..190d9308f 100644 --- a/backend/api/auth.py +++ b/backend/api/auth.py @@ -63,14 +63,20 @@ def authenticate(self, request): if secret_id: found = False try: - secret = Secret.objects.get(id=secret_id) + # Pre-fetch environment, app, and organisation + secret = Secret.objects.select_related( + "environment__app__organisation" + ).get(id=secret_id) env = secret.environment found = True except Secret.DoesNotExist: pass if not found: try: - dyn_secret = DynamicSecret.objects.get(id=secret_id) + # Pre-fetch environment, app, and organisation + dyn_secret = DynamicSecret.objects.select_related( + "environment__app__organisation" + ).get(id=secret_id) env = dyn_secret.environment found = True except DynamicSecret.DoesNotExist: @@ -84,7 +90,10 @@ def authenticate(self, request): # Try resolving env from header if env_id: try: - env = Environment.objects.get(id=env_id) + # Pre-fetch app and organisation + env = Environment.objects.select_related("app__organisation").get( + id=env_id + ) except Environment.DoesNotExist: raise exceptions.AuthenticationFailed("Environment not found") @@ -99,7 +108,10 @@ def authenticate(self, request): ) if not env_name: raise exceptions.AuthenticationFailed("Missing env parameter") - env = Environment.objects.get(app_id=app_id, name__iexact=env_name) + # Pre-fetch app and organisation + env = Environment.objects.select_related("app__organisation").get( + app_id=app_id, name__iexact=env_name + ) except Environment.DoesNotExist: # Check if the app exists to give a more specific error App = apps.get_model("api", "App") diff --git a/backend/api/throttling.py b/backend/api/throttling.py new file mode 100644 index 000000000..b5289fee3 --- /dev/null +++ b/backend/api/throttling.py @@ -0,0 +1,58 @@ +from rest_framework.throttling import SimpleRateThrottle +from django.conf import settings + +CLOUD_HOSTED = settings.APP_HOST == "cloud" + + +class PlanBasedRateThrottle(SimpleRateThrottle): + """ + Limits the rate of API calls based on the Organisation's plan. + Uses the pre-fetched organisation data from request.auth to avoid DB lookups. + """ + + scope = "plan_based" + + def get_cache_key(self, request, view): + # Identify the user or service account + ident = self.get_ident(request) + + if request.user.is_authenticated and request.auth: + if request.auth.get("org_member"): + ident = f"user_{request.auth['org_member'].id}" + elif request.auth.get("service_account"): + ident = f"sa_{request.auth['service_account'].id}" + elif request.auth.get("service_token"): + ident = f"st_{request.auth['service_token'].id}" + else: + ident = f"anon_{ident}" + + return self.cache_format % {"scope": self.scope, "ident": ident} + + def allow_request(self, request, view): + """ + Override allow_request to dynamically set the rate based on the request user's plan. + """ + # Default fallback (reads from REST_FRAMEWORK['DEFAULT_THROTTLE_RATES']['plan_based']) + new_rate = self.get_rate() + + if request.user.is_authenticated and request.auth: + env = request.auth.get("environment") + if env: + try: + plan = env.app.organisation.plan + new_rate = self.get_rate_for_plan(plan) + except AttributeError: + pass + + # Update the throttle configuration for this specific request + self.rate = new_rate + self.num_requests, self.duration = self.parse_rate(self.rate) + + return super().allow_request(request, view) + + @staticmethod + def get_rate_for_plan(plan): + # If self-hosted return the default rate limit. If not set, this will disable throttling + if not CLOUD_HOSTED: + return settings.PLAN_RATE_LIMITS["DEFAULT"] + return settings.PLAN_RATE_LIMITS.get(plan, settings.PLAN_RATE_LIMITS["DEFAULT"]) diff --git a/backend/api/views/identities/aws/iam.py b/backend/api/views/identities/aws/iam.py index f29add821..d318c2f3d 100644 --- a/backend/api/views/identities/aws/iam.py +++ b/backend/api/views/identities/aws/iam.py @@ -7,13 +7,14 @@ from defusedxml.ElementTree import parse from django.http import JsonResponse from django.utils import timezone -from rest_framework.decorators import api_view, permission_classes +from rest_framework.decorators import api_view, permission_classes, throttle_classes from rest_framework.permissions import AllowAny from api.utils.identity.common import ( resolve_service_account, mint_service_account_token, ) +from api.throttling import PlanBasedRateThrottle def get_normalized_host(uri): @@ -27,6 +28,7 @@ def get_normalized_host(uri): @api_view(["POST"]) @permission_classes([AllowAny]) +@throttle_classes([PlanBasedRateThrottle]) def aws_iam_auth(request): """Accepts SigV4-signed STS GetCallerIdentity request and issues a ServiceAccount token if trusted.""" try: diff --git a/backend/api/views/secrets.py b/backend/api/views/secrets.py index ec6217b5f..4a9981632 100644 --- a/backend/api/views/secrets.py +++ b/backend/api/views/secrets.py @@ -33,6 +33,7 @@ import json from api.content_negotiation import CamelCaseContentNegotiation from api.utils.access.middleware import IsIPAllowed +from api.throttling import PlanBasedRateThrottle from ee.integrations.secrets.dynamic.exceptions import ( DynamicSecretError, PlanRestrictionError, @@ -60,6 +61,7 @@ class E2EESecretsView(APIView): authentication_classes = [PhaseTokenAuthentication] permission_classes = [IsAuthenticated, IsIPAllowed] + throttle_classes = [PlanBasedRateThrottle] content_negotiation_class = CamelCaseContentNegotiation def initial(self, request, *args, **kwargs): @@ -489,6 +491,7 @@ def delete(self, request, *args, **kwargs): class PublicSecretsView(APIView): authentication_classes = [PhaseTokenAuthentication] permission_classes = [IsAuthenticated, IsIPAllowed] + throttle_classes = [PlanBasedRateThrottle] renderer_classes = [ CamelCaseJSONRenderer, ] diff --git a/backend/backend/settings.py b/backend/backend/settings.py index 0ab17492b..f60e110f8 100644 --- a/backend/backend/settings.py +++ b/backend/backend/settings.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from urllib.parse import quote as urlquote import logging.config from backend.utils.secrets import get_secret from ee.licensing.verifier import check_license @@ -50,16 +51,10 @@ def get_version(): VERSION = get_version() - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = get_secret("SECRET_KEY") SERVER_SECRET = get_secret("SERVER_SECRET") -# SECURITY WARNING: don't run with debug turned on in production! DEBUG = True if os.getenv("DEBUG") == "True" else False ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS", []).split(",") @@ -72,8 +67,6 @@ def get_version(): SESSION_COOKIE_AGE = 604800 # 1 week, in seconds -# Application definition - INSTALLED_APPS = [ "django.contrib.admin", "django.contrib.auth", @@ -241,6 +234,16 @@ def get_version(): "USER_DETAILS_SERIALIZER": "api.serializers.CustomUserSerializer" } +# Global rate limit +PLAN_RATE_LIMITS = { + # PHASE CLOUD + "FR": os.getenv("RATE_LIMIT_FREE"), + "PR": os.getenv("RATE_LIMIT_PRO"), + "EN": os.getenv("RATE_LIMIT_ENTERPRISE"), + # PHASE SELF-HOSTED + "DEFAULT": os.getenv("RATE_LIMIT_DEFAULT"), +} + REST_FRAMEWORK = { "DEFAULT_PERMISSION_CLASSES": [ "rest_framework.permissions.IsAuthenticated", @@ -248,6 +251,10 @@ def get_version(): "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework.authentication.SessionAuthentication", ), + "DEFAULT_THROTTLE_CLASSES": [], + "DEFAULT_THROTTLE_RATES": { + "plan_based": PLAN_RATE_LIMITS["DEFAULT"], + }, "EXCEPTION_HANDLER": "backend.exceptions.custom_exception_handler", "DEFAULT_RENDERER_CLASSES": [ "rest_framework.renderers.JSONRenderer", @@ -290,7 +297,80 @@ def get_version(): "PASSWORD": get_secret("DATABASE_PASSWORD"), "NAME": os.getenv("DATABASE_NAME"), "HOST": os.getenv("DATABASE_HOST"), - "PORT": os.getenv("DATABASE_PORT"), + "PORT": int(os.getenv("DATABASE_PORT", "5432")), + "OPTIONS": ( + { + "sslmode": "verify-full" + if os.getenv("DATABASE_SSL_CA_PATH") + else "require", + **( + {"sslrootcert": os.getenv("DATABASE_SSL_CA_PATH")} + if os.getenv("DATABASE_SSL_CA_PATH") + else {} + ), + } + if os.getenv("DATABASE_SSL", "False").lower() == "true" + else {} + ), + }, +} + +REDIS_HOST = os.getenv("REDIS_HOST") +REDIS_PORT = int(os.getenv("REDIS_PORT", "6379")) +REDIS_USER = os.getenv("REDIS_USER") or None +REDIS_PASSWORD = get_secret("REDIS_PASSWORD") +REDIS_SSL = os.getenv("REDIS_SSL", "False").lower() == "true" +REDIS_PROTOCOL = "rediss" if REDIS_SSL else "redis" + +if REDIS_USER and REDIS_PASSWORD: + REDIS_AUTH = f"{urlquote(REDIS_USER, safe='')}:{urlquote(REDIS_PASSWORD, safe='')}@" +elif REDIS_PASSWORD: + REDIS_AUTH = f":{urlquote(REDIS_PASSWORD, safe='')}@" +else: + REDIS_AUTH = "" + +CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.redis.RedisCache", + "LOCATION": f"{REDIS_PROTOCOL}://{REDIS_AUTH}{REDIS_HOST}:{REDIS_PORT}/1", + "OPTIONS": ( + { + "ssl_cert_reqs": "required", + "ssl_ca_certs": os.getenv("REDIS_SSL_CA_PATH"), + } + if REDIS_SSL + else {} + ), + } +} + +RQ_SSL_OPTIONS = ( + { + "ssl_cert_reqs": "required", + "ssl_ca_certs": os.getenv("REDIS_SSL_CA_PATH"), + } + if REDIS_SSL + else None +) + +RQ_QUEUES = { + "default": { + "HOST": REDIS_HOST, + "PORT": REDIS_PORT, + "USERNAME": REDIS_USER, + "PASSWORD": REDIS_PASSWORD, + "SSL": REDIS_SSL, + "SSL_OPTIONS": RQ_SSL_OPTIONS, + "DB": 0, + }, + "scheduled-jobs": { + "HOST": REDIS_HOST, + "PORT": REDIS_PORT, + "USERNAME": REDIS_USER, + "PASSWORD": REDIS_PASSWORD, + "SSL": REDIS_SSL, + "SSL_OPTIONS": RQ_SSL_OPTIONS, + "DB": 0, }, } @@ -358,22 +438,6 @@ def get_version(): except: APP_HOST = "self" -RQ_QUEUES = { - "default": { - "HOST": os.getenv("REDIS_HOST"), - "PORT": os.getenv("REDIS_PORT"), - "PASSWORD": get_secret("REDIS_PASSWORD"), - "SSL": os.getenv("REDIS_SSL", None), - "DB": 0, - }, - "scheduled-jobs": { - "HOST": os.getenv("REDIS_HOST"), - "PORT": os.getenv("REDIS_PORT"), - "PASSWORD": get_secret("REDIS_PASSWORD"), - "SSL": os.getenv("REDIS_SSL", None), - "DB": 0, - }, -} PHASE_LICENSE = check_license(get_secret("PHASE_LICENSE_OFFLINE")) diff --git a/backend/conftest.py b/backend/conftest.py new file mode 100644 index 000000000..a169614a2 --- /dev/null +++ b/backend/conftest.py @@ -0,0 +1,23 @@ +import os +import django + +# Set environment variables required for settings.py to import successfully +os.environ.setdefault("ALLOWED_HOSTS", "localhost") +os.environ.setdefault("ALLOWED_ORIGINS", "http://localhost") + +# Set dummy Redis values so settings.py generates a valid URL (e.g. redis://localhost:6379/1) +os.environ.setdefault("REDIS_HOST", "localhost") +os.environ.setdefault("REDIS_PORT", "6379") + +# Set dummy database config +os.environ.setdefault("DATABASE_HOST", "localhost") +os.environ.setdefault("DATABASE_PORT", "5432") +os.environ.setdefault("DATABASE_NAME", "dummy_db") +os.environ.setdefault("DATABASE_USER", "dummy_user") +os.environ.setdefault("DATABASE_PASSWORD", "dummy_password") + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings") + + +def pytest_configure(): + django.setup() diff --git a/backend/dev-requirements.txt b/backend/dev-requirements.txt index 06447a128..751f2fab0 100644 --- a/backend/dev-requirements.txt +++ b/backend/dev-requirements.txt @@ -1,4 +1,5 @@ pytest==8.3.4 +pytest-django==4.11.1 pytest-cov==7.0.0 Faker==37.4.0 colorama==0.4.6 diff --git a/backend/ee/integrations/secrets/dynamic/rest/views.py b/backend/ee/integrations/secrets/dynamic/rest/views.py index ba1635ef8..c3ba41760 100644 --- a/backend/ee/integrations/secrets/dynamic/rest/views.py +++ b/backend/ee/integrations/secrets/dynamic/rest/views.py @@ -18,6 +18,7 @@ ) from api.utils.access.middleware import IsIPAllowed +from api.throttling import PlanBasedRateThrottle from ee.integrations.secrets.dynamic.aws.utils import ( revoke_aws_dynamic_secret_lease, ) @@ -50,6 +51,7 @@ class DynamicSecretsView(APIView): authentication_classes = [PhaseTokenAuthentication] permission_classes = [IsAuthenticated, IsIPAllowed] + throttle_classes = [PlanBasedRateThrottle] renderer_classes = [ CamelCaseJSONRenderer, ] @@ -200,6 +202,7 @@ def get(self, request, *args, **kwargs): class DynamicSecretLeaseView(APIView): authentication_classes = [PhaseTokenAuthentication] permission_classes = [IsAuthenticated, IsIPAllowed] + throttle_classes = [PlanBasedRateThrottle] renderer_classes = [CamelCaseJSONRenderer] def _get_account_and_org(self, request): diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 000000000..f912902b3 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +python_files = tests.py test_*.py *_tests.py \ No newline at end of file diff --git a/backend/tests/api/test_throttling.py b/backend/tests/api/test_throttling.py new file mode 100644 index 000000000..084cd228a --- /dev/null +++ b/backend/tests/api/test_throttling.py @@ -0,0 +1,172 @@ +import pytest +from unittest.mock import Mock, patch +from rest_framework.test import APIRequestFactory +from api.throttling import PlanBasedRateThrottle + + +class TestPlanBasedRateThrottle: + + @pytest.fixture(autouse=True) + def setup_test_env(self, settings): + """ + Configures the test environment, overrides settings, and initializes objects. + Replaces setup_method to ensure correct ordering with settings overrides. + """ + # 1. Override cache backend to use LocMemCache + settings.CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + } + } + settings.DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + } + + # 2. Force Django to reload the cache backend + from django.core.cache import caches + + try: + if "default" in caches: + del caches["default"] + except (AttributeError, KeyError): + # Handle cases where asgiref/thread-local storage is inconsistent + pass + + # 3. Initialize test objects (formerly in setup_method) + self.factory = APIRequestFactory() + self.throttle = PlanBasedRateThrottle() + + # 4. Assign the fresh LocMemCache to the throttle and clear it + from django.core.cache import cache + + self.throttle.cache = cache + cache.clear() + + def test_get_cache_key_authenticated_user(self): + """Test cache key generation for standard authenticated user""" + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + request.auth = {"org_member": Mock(id=123)} + + key = self.throttle.get_cache_key(request, None) + assert "plan_based" in key + assert "user_123" in key + + def test_get_cache_key_service_account(self): + """Test cache key generation for service account""" + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + request.auth = {"service_account": Mock(id=456)} + + key = self.throttle.get_cache_key(request, None) + assert "plan_based" in key + assert "sa_456" in key + + def test_get_cache_key_anonymous(self): + """Test cache key generation for anonymous user""" + request = self.factory.get("/") + request.user = Mock(is_authenticated=False) + request.auth = None + + key = self.throttle.get_cache_key(request, None) + assert "plan_based" in key + assert "anon" in key + + @patch("api.throttling.CLOUD_HOSTED", True) + def test_rate_selection_cloud_hosted_free_plan(self, settings): + """Test that Free plan rate is applied in cloud mode""" + settings.PLAN_RATE_LIMITS = { + "FR": "10/min", + "PR": "100/min", + "DEFAULT": "5/min", + } + + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + + # Mock environment structure + mock_env = Mock() + mock_env.app.organisation.plan = "FR" + request.auth = {"environment": mock_env} + + # Mock get_rate to return default (simulating DRF settings) + with patch.object(PlanBasedRateThrottle, "get_rate", return_value="5/min"): + self.throttle.allow_request(request, None) + + assert self.throttle.rate == "10/min" + assert self.throttle.num_requests == 10 + assert self.throttle.duration == 60 + + @patch("api.throttling.CLOUD_HOSTED", True) + def test_rate_selection_cloud_hosted_pro_plan(self, settings): + """Test that Pro plan rate is applied in cloud mode""" + settings.PLAN_RATE_LIMITS = { + "FR": "10/min", + "PR": "100/min", + "DEFAULT": "5/min", + } + + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + + mock_env = Mock() + mock_env.app.organisation.plan = "PR" + request.auth = {"environment": mock_env} + + with patch.object(PlanBasedRateThrottle, "get_rate", return_value="5/min"): + self.throttle.allow_request(request, None) + + assert self.throttle.rate == "100/min" + + @patch("api.throttling.CLOUD_HOSTED", False) + def test_rate_selection_self_hosted_uses_default(self, settings): + """Test that self-hosted mode ignores plan and uses default""" + settings.PLAN_RATE_LIMITS = {"FR": "10/min", "DEFAULT": "5/min"} + + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + + mock_env = Mock() + mock_env.app.organisation.plan = "FR" # Should be ignored + request.auth = {"environment": mock_env} + + with patch.object(PlanBasedRateThrottle, "get_rate", return_value="5/min"): + self.throttle.allow_request(request, None) + + assert self.throttle.rate == "5/min" + + @patch("api.throttling.CLOUD_HOSTED", True) + def test_rate_selection_anonymous_uses_default(self, settings): + """Test that anonymous requests use the default rate""" + settings.PLAN_RATE_LIMITS = {"DEFAULT": "5/min"} + + request = self.factory.get("/") + request.user = Mock(is_authenticated=False) + request.auth = None + + with patch.object(PlanBasedRateThrottle, "get_rate", return_value="5/min"): + self.throttle.allow_request(request, None) + + assert self.throttle.rate == "5/min" + + @patch("api.throttling.CLOUD_HOSTED", False) + def test_self_hosted_no_default_disables_throttling(self, settings): + """Test that self-hosted mode with no default set disables throttling""" + settings.PLAN_RATE_LIMITS = {"DEFAULT": None} + + request = self.factory.get("/") + request.user = Mock(is_authenticated=True) + request.auth = {} + + # Reset rate to None to force allow_request to call get_rate() again + # This ensures we test the logic inside get_rate with the new settings + self.throttle.rate = None + + # We do NOT mock get_rate here. We want to verify that the REAL get_rate + # returns None when CLOUD_HOSTED is False and DEFAULT is None. + allowed = self.throttle.allow_request(request, None) + + assert allowed is True