diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index 811b0f90..00bce480 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -1,15 +1,16 @@ import asyncio import time - from dataclasses import asdict, dataclass -from typing import Awaitable, Callable, Dict, Optional, Union, Any, Tuple, List +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union from aioapns import APNs, ConnectionError, NotificationRequest from aioapns.common import NotificationResult from . import models from .conf import get_manager -from .exceptions import APNSServerError, APNSError +from .exceptions import APNSError, APNSServerError + ErrFunc = Optional[Callable[[NotificationRequest, NotificationResult], Awaitable[None]]] """function to proces errors from aioapns send_message""" @@ -36,6 +37,15 @@ class CertificateCredentials(Credentials): client_cert: str +@dataclass +class BulkNotificationResult: + results: dict[str, Any] + errors: list[dict[str, Any]] + + @property + def has_errors(self) -> bool: + return len(self.errors) > 0 + @dataclass class Alert: """ @@ -305,7 +315,7 @@ def apns_send_bulk_message( mutable_content: Optional[bool] = False, category: Optional[str] = None, err_func: Optional[ErrFunc] = None, -) -> Dict[str, str]: +) -> BulkNotificationResult: """ Sends an APNS notification to one or more registration_ids. The registration_ids argument needs to be a list. @@ -361,7 +371,13 @@ def apns_send_bulk_message( "Success" if result.is_successful else result.description ) if not result.is_successful: - errors.append(result.description) + error_obj = { + 'registration_id': registration_id, + 'error_type': result.description, + 'error_message': result.description, + 'timestamp': datetime.now().isoformat(), + } + errors.append(error_obj) if result.description in [ "Unregistered", "BadDeviceToken", @@ -374,11 +390,7 @@ def apns_send_bulk_message( registration_id__in=inactive_tokens ).update(active=False) - if len(errors) > 0: - msg = "One or more errors failed with errors: {}".format(", ".join(errors)) - raise APNSError(msg) - - return results + return BulkNotificationResult(results=results, errors=errors) except ConnectionError as e: raise APNSServerError(status=e.__class__.__name__) diff --git a/tests/test_apns_async_push_payload.py b/tests/test_apns_async_push_payload.py index ee808156..5eff5f0e 100644 --- a/tests/test_apns_async_push_payload.py +++ b/tests/test_apns_async_push_payload.py @@ -8,7 +8,7 @@ try: from aioapns.common import NotificationResult - from push_notifications.apns_async import TokenCredentials, apns_send_message, CertificateCredentials + from push_notifications.apns_async import TokenCredentials, apns_send_message, apns_send_bulk_message, CertificateCredentials, BulkNotificationResult except ModuleNotFoundError: # skipping because apns2 is not supported on python 3.10 # it uses hyper that imports from collections which were changed in 3.10 @@ -276,3 +276,151 @@ def test_push_payload_with_content_available_not_set(self, mock_apns): req = args[0] assert "content-available" not in req.message["aps"] + + + +class APNSAsyncErrorHandlingTests(TestCase): + + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_returns_bulk_notification_result(self, mock_manager, mock_asyncio_run): + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "200")), + ] + + result = apns_send_bulk_message( + registration_ids=["token1"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + + self.assertIsInstance(result, BulkNotificationResult) + self.assertIsInstance(result.results, dict) + self.assertIsInstance(result.errors, list) + + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_all_success_returns_empty_errors(self, mock_manager, mock_asyncio_run): + # successful notifications return empty errors list + + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "200")), + ("token2", NotificationResult("124", "200")), + ] + + result = apns_send_bulk_message( + registration_ids=["token1", "token2"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + + self.assertEqual(len(result.errors), 0) + self.assertFalse(result.has_errors) + self.assertEqual(result.results["token1"], "Success") + self.assertEqual(result.results["token2"], "Success") + + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_partial_failure_returns_error_objects(self, mock_manager, mock_asyncio_run): + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "200")), + ("token2", NotificationResult("124", "400", description="BadDeviceToken")), + ] + + result = apns_send_bulk_message( + registration_ids=["token1", "token2"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + + self.assertEqual(result.results["token1"], "Success") + self.assertEqual(result.results["token2"], "BadDeviceToken") + + self.assertEqual(len(result.errors), 1) + self.assertTrue(result.has_errors) + + error = result.errors[0] + self.assertEqual(error["registration_id"], "token2") + self.assertEqual(error["error_type"], "BadDeviceToken") + self.assertEqual(error["error_message"], "BadDeviceToken") + self.assertIn("timestamp", error) + + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_multiple_errors_all_captured(self, mock_manager, mock_asyncio_run): + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "400", description="Unregistered")), + ("token2", NotificationResult("124", "400", description="BadDeviceToken")), + ("token3", NotificationResult("125", "400", description="TimeoutError")), + ] + + result = apns_send_bulk_message( + registration_ids=["token1", "token2", "token3"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + + self.assertEqual(len(result.errors), 3) + self.assertTrue(result.has_errors) + + error_types = [e["error_type"] for e in result.errors] + self.assertIn("Unregistered", error_types) + self.assertIn("BadDeviceToken", error_types) + self.assertIn("TimeoutError", error_types) + + @mock.patch("push_notifications.apns_async.models.APNSDevice.objects.filter") + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_unregistered_tokens_marked_inactive(self, mock_manager, mock_asyncio_run, mock_filter): + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + mock_update = mock.Mock() + mock_filter.return_value.update = mock_update + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "200")), + ("token2", NotificationResult("124", "400", description="Unregistered")), + ("token3", NotificationResult("125", "400", description="BadDeviceToken")), + ] + + _ = apns_send_bulk_message( + registration_ids=["token1", "token2", "token3"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + + mock_filter.assert_called_once() + call_args = mock_filter.call_args[1] + self.assertIn("token2", call_args["registration_id__in"]) + self.assertIn("token3", call_args["registration_id__in"]) + mock_update.assert_called_once_with(active=False) + + @mock.patch("push_notifications.apns_async.asyncio.run") + @mock.patch("push_notifications.apns_async.get_manager") + def test_does_not_raise_exception_on_errors(self, mock_manager, mock_asyncio_run): + # Test that errors don't raise exceptions - they're returned in the result + mock_manager.return_value.get_apns_topic.return_value = "com.example.app" + + mock_asyncio_run.return_value = [ + ("token1", NotificationResult("123", "400", description="Unregistered")), + ] + + try: + result = apns_send_bulk_message( + registration_ids=["token1"], + alert="Test", + creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"), + ) + exception_raised = False + except Exception: + exception_raised = True + + self.assertFalse(exception_raised, "apns_send_bulk_message should not raise exceptions on errors") + self.assertTrue(result.has_errors)