Skip to content

Commit

Permalink
fix: cast env var values according to its type
Browse files Browse the repository at this point in the history
  • Loading branch information
Iasmini Gomes committed Jul 1, 2024
1 parent 8e6df7a commit 20d0178
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 25 deletions.
61 changes: 38 additions & 23 deletions django_cloud_tasks/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@ def __init__(self, *args, **kwargs):
self.on_demand_tasks = {}
self.periodic_tasks = {}
self.subscriber_tasks = {}
self.domain = self._fetch_config(name="ENDPOINT", default="http://localhost:8080")
self.app_name = self._fetch_config(name="APP_NAME", default=os.environ.get("APP_NAME", None))
self.delimiter = self._fetch_config(name="DELIMITER", default="--")
self.eager = self._fetch_config(name="EAGER", default=False)
self.tasks_url_name = self._fetch_config(name="URL_NAME", default="tasks-endpoint")
self.tasks_max_eta = self._fetch_config(name="MAXIMUM_ETA_TASK", default=None)
self.subscribers_url_name = self._fetch_config(name="SUBSCRIBERS_URL_NAME", default="subscriptions-endpoint")

self.subscribers_max_retries = self._fetch_config(name="SUBSCRIBER_MAX_RETRIES", default=None)
self.subscribers_min_backoff = self._fetch_config(name="SUBSCRIBER_MIN_BACKOFF", default=None)
self.subscribers_max_backoff = self._fetch_config(name="SUBSCRIBER_MAX_BACKOFF", default=None)
self.subscribers_expiration = self._fetch_config(name="SUBSCRIBER_EXPIRATION", default=None)

self.propagated_headers = self._fetch_config(
self.domain = self.get_str(name="ENDPOINT", default="http://localhost:8080")
self.app_name = self.get_str(name="APP_NAME", default=os.environ.get("APP_NAME", None))
self.delimiter = self.get_str(name="DELIMITER", default="--")
self.eager = self.get_bool(name="EAGER", default=False)
self.tasks_url_name = self.get_str(name="URL_NAME", default="tasks-endpoint")
self.tasks_max_eta = self.get_int(name="MAXIMUM_ETA_TASK", default=None)
self.subscribers_url_name = self.get_str(name="SUBSCRIBERS_URL_NAME", default="subscriptions-endpoint")

self.subscribers_max_retries = self.get_int(name="SUBSCRIBER_MAX_RETRIES", default=None)
self.subscribers_min_backoff = self.get_int(name="SUBSCRIBER_MIN_BACKOFF", default=None)
self.subscribers_max_backoff = self.get_int(name="SUBSCRIBER_MAX_BACKOFF", default=None)
self.subscribers_expiration = self.get_str(name="SUBSCRIBER_EXPIRATION", default=None)

self.propagated_headers = self.get_list(
name="PROPAGATED_HEADERS",
default=DEFAULT_PROPAGATION_HEADERS,
as_list=True,
)
self.propagated_headers_key = self._fetch_config(
self.propagated_headers_key = self.get_str(
name="PROPAGATED_HEADERS_KEY", default=DEFAULT_PROPAGATION_HEADERS_KEY
)

Expand Down Expand Up @@ -79,15 +78,15 @@ def get_task(self, name: str):
raise exceptions.TaskNotFound(name=name)

def get_backup_queue_name(self, original_name: str) -> str:
return self._fetch_config(
return self.get_str(
name="BACKUP_QUEUE_NAME",
default=f"{original_name}{self.delimiter}temp",
)

def get_task_metadata_class(self):
from django_cloud_tasks.tasks import TaskMetadata

metadata_class_name = self._fetch_config(
metadata_class_name = self.get_str(
name="TASK_METADATA_CLASS",
default="django_cloud_tasks.tasks.task.TaskMetadata",
)
Expand All @@ -104,13 +103,29 @@ def get_task_metadata_class(self):

return metadata_class

def _fetch_config(self, name: str, default: Any, as_list: bool = False) -> Any:
def _fetch_config(self, name: str, default: Any) -> Any:
config_name = f"{PREFIX}{name.upper()}"
return getattr(settings, config_name, os.environ.get(config_name, default))

value = getattr(settings, config_name, os.environ.get(config_name, default))
if as_list and not isinstance(value, list):
value = value.split(",")
return value
def get_str(self, name: str, default: Any) -> str:
value = self._fetch_config(name=name, default=default)
return str(value) if value is not None else default

def get_bool(self, name: str, default: Any) -> bool:
value = self._fetch_config(name=name, default=default)
return str(value).lower() in ("true", "1", "t", "y", "yes") if value is not None else default

def get_int(self, name: str, default: Any) -> int:
value = self._fetch_config(name=name, default=default)
return int(value) if value is not None else default

def get_float(self, name: str, default: Any) -> float:
value = self._fetch_config(name=name, default=default)
return float(value) if value is not None else default

def get_list(self, name: str, default: Any) -> list:
value = self._fetch_config(name=name, default=default)
return value.split(",") if value is not None and isinstance(value, str) else default

def register_task(self, task_class):
from django_cloud_tasks.tasks.periodic_task import PeriodicTask
Expand Down
4 changes: 2 additions & 2 deletions django_cloud_tasks/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def asap(cls, **kwargs):

@classmethod
def later(cls, task_kwargs: dict, eta: int | timedelta | datetime, queue: str = None, headers: dict | None = None):
delay_in_seconds = cls._calculate_delay_in_seconds(eta)
cls._validate_delay(delay_in_seconds)
delay_in_seconds = cls._calculate_delay_in_seconds(eta=eta)
cls._validate_delay(delay_in_seconds=delay_in_seconds)
return cls.push(
task_kwargs=task_kwargs,
queue=queue,
Expand Down
67 changes: 67 additions & 0 deletions sample_project/sample_app/tests/tests_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest
from unittest.mock import patch
from django_cloud_tasks.apps import DjangoCloudTasksAppConfig
import os


class TestAppConfig(DjangoCloudTasksAppConfig):
path = os.path.dirname(__file__)


class DjangoCloudTasksAppConfigTest(unittest.TestCase):
def setUp(self):
self.config = TestAppConfig("django_cloud_tasks", "django_cloud_tasks")

@patch("django_cloud_tasks.apps.DjangoCloudTasksAppConfig._fetch_config")
def test_get_str(self, mock_fetch_config):
mock_fetch_config.return_value = "test_string"
result = self.config.get_str("name", "default_value")
self.assertEqual(result, "test_string")

mock_fetch_config.return_value = None
result = self.config.get_str("name", "default_value")
self.assertEqual(result, "default_value")

@patch("django_cloud_tasks.apps.DjangoCloudTasksAppConfig._fetch_config")
def test_get_bool(self, mock_fetch_config):
mock_fetch_config.return_value = "true"
result = self.config.get_bool("name", False)
self.assertEqual(result, True)

mock_fetch_config.return_value = "false"
result = self.config.get_bool("name", True)
self.assertEqual(result, False)

mock_fetch_config.return_value = None
result = self.config.get_bool("name", True)
self.assertEqual(result, True)

@patch("django_cloud_tasks.apps.DjangoCloudTasksAppConfig._fetch_config")
def test_get_int(self, mock_fetch_config):
mock_fetch_config.return_value = "10"
result = self.config.get_int("name", 0)
self.assertEqual(result, 10)

mock_fetch_config.return_value = None
result = self.config.get_int("name", 5)
self.assertEqual(result, 5)

@patch("django_cloud_tasks.apps.DjangoCloudTasksAppConfig._fetch_config")
def test_get_float(self, mock_fetch_config):
mock_fetch_config.return_value = "3.14"
result = self.config.get_float("name", 0.0)
self.assertEqual(result, 3.14)

mock_fetch_config.return_value = None
result = self.config.get_float("name", 2.71)
self.assertEqual(result, 2.71)

@patch("django_cloud_tasks.apps.DjangoCloudTasksAppConfig._fetch_config")
def test_get_list(self, mock_fetch_config):
mock_fetch_config.return_value = "item1,item2,item3"
result = self.config.get_list("name", [])
self.assertEqual(result, ["item1", "item2", "item3"])

mock_fetch_config.return_value = None
result = self.config.get_list("name", ["default"])
self.assertEqual(result, ["default"])

0 comments on commit 20d0178

Please sign in to comment.