Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions push_notifications/apns_async.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import asyncio
import time

from dataclasses import asdict, dataclass
from typing import Awaitable, Callable, Dict, Optional, Union, Any, Tuple, List
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union

from aioapns import APNs, ConnectionError, NotificationRequest
from aioapns.common import NotificationResult

from . import models
from .conf import get_manager
from .exceptions import APNSServerError, APNSError
from .exceptions import APNSError, APNSServerError


ErrFunc = Optional[Callable[[NotificationRequest, NotificationResult], Awaitable[None]]]
"""function to proces errors from aioapns send_message"""
Expand All @@ -36,6 +37,15 @@ class CertificateCredentials(Credentials):
client_cert: str


@dataclass
class BulkNotificationResult:
results: dict[str, Any]
errors: list[dict[str, Any]]

@property
def has_errors(self) -> bool:
return len(self.errors) > 0

@dataclass
class Alert:
"""
Expand Down Expand Up @@ -305,7 +315,7 @@ def apns_send_bulk_message(
mutable_content: Optional[bool] = False,
category: Optional[str] = None,
err_func: Optional[ErrFunc] = None,
) -> Dict[str, str]:
) -> BulkNotificationResult:
"""
Sends an APNS notification to one or more registration_ids.
The registration_ids argument needs to be a list.
Expand Down Expand Up @@ -361,7 +371,13 @@ def apns_send_bulk_message(
"Success" if result.is_successful else result.description
)
if not result.is_successful:
errors.append(result.description)
error_obj = {
'registration_id': registration_id,
'error_type': result.description,
'error_message': result.description,
'timestamp': datetime.now().isoformat(),
}
errors.append(error_obj)
if result.description in [
"Unregistered",
"BadDeviceToken",
Expand All @@ -374,11 +390,7 @@ def apns_send_bulk_message(
registration_id__in=inactive_tokens
).update(active=False)

if len(errors) > 0:
msg = "One or more errors failed with errors: {}".format(", ".join(errors))
raise APNSError(msg)

return results
return BulkNotificationResult(results=results, errors=errors)

except ConnectionError as e:
raise APNSServerError(status=e.__class__.__name__)
Expand Down
150 changes: 149 additions & 1 deletion tests/test_apns_async_push_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
from aioapns.common import NotificationResult
from push_notifications.apns_async import TokenCredentials, apns_send_message, CertificateCredentials
from push_notifications.apns_async import TokenCredentials, apns_send_message, apns_send_bulk_message, CertificateCredentials, BulkNotificationResult
except ModuleNotFoundError:
# skipping because apns2 is not supported on python 3.10
# it uses hyper that imports from collections which were changed in 3.10
Expand Down Expand Up @@ -276,3 +276,151 @@ def test_push_payload_with_content_available_not_set(self, mock_apns):
req = args[0]

assert "content-available" not in req.message["aps"]



class APNSAsyncErrorHandlingTests(TestCase):

@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_returns_bulk_notification_result(self, mock_manager, mock_asyncio_run):
mock_manager.return_value.get_apns_topic.return_value = "com.example.app"

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "200")),
]

result = apns_send_bulk_message(
registration_ids=["token1"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)

self.assertIsInstance(result, BulkNotificationResult)
self.assertIsInstance(result.results, dict)
self.assertIsInstance(result.errors, list)

@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_all_success_returns_empty_errors(self, mock_manager, mock_asyncio_run):
# successful notifications return empty errors list

mock_manager.return_value.get_apns_topic.return_value = "com.example.app"

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "200")),
("token2", NotificationResult("124", "200")),
]

result = apns_send_bulk_message(
registration_ids=["token1", "token2"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)

self.assertEqual(len(result.errors), 0)
self.assertFalse(result.has_errors)
self.assertEqual(result.results["token1"], "Success")
self.assertEqual(result.results["token2"], "Success")

@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_partial_failure_returns_error_objects(self, mock_manager, mock_asyncio_run):
mock_manager.return_value.get_apns_topic.return_value = "com.example.app"

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "200")),
("token2", NotificationResult("124", "400", description="BadDeviceToken")),
]

result = apns_send_bulk_message(
registration_ids=["token1", "token2"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)

self.assertEqual(result.results["token1"], "Success")
self.assertEqual(result.results["token2"], "BadDeviceToken")

self.assertEqual(len(result.errors), 1)
self.assertTrue(result.has_errors)

error = result.errors[0]
self.assertEqual(error["registration_id"], "token2")
self.assertEqual(error["error_type"], "BadDeviceToken")
self.assertEqual(error["error_message"], "BadDeviceToken")
self.assertIn("timestamp", error)

@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_multiple_errors_all_captured(self, mock_manager, mock_asyncio_run):
mock_manager.return_value.get_apns_topic.return_value = "com.example.app"

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "400", description="Unregistered")),
("token2", NotificationResult("124", "400", description="BadDeviceToken")),
("token3", NotificationResult("125", "400", description="TimeoutError")),
]

result = apns_send_bulk_message(
registration_ids=["token1", "token2", "token3"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)

self.assertEqual(len(result.errors), 3)
self.assertTrue(result.has_errors)

error_types = [e["error_type"] for e in result.errors]
self.assertIn("Unregistered", error_types)
self.assertIn("BadDeviceToken", error_types)
self.assertIn("TimeoutError", error_types)

@mock.patch("push_notifications.apns_async.models.APNSDevice.objects.filter")
@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_unregistered_tokens_marked_inactive(self, mock_manager, mock_asyncio_run, mock_filter):
mock_manager.return_value.get_apns_topic.return_value = "com.example.app"
mock_update = mock.Mock()
mock_filter.return_value.update = mock_update

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "200")),
("token2", NotificationResult("124", "400", description="Unregistered")),
("token3", NotificationResult("125", "400", description="BadDeviceToken")),
]

_ = apns_send_bulk_message(
registration_ids=["token1", "token2", "token3"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)

mock_filter.assert_called_once()
call_args = mock_filter.call_args[1]
self.assertIn("token2", call_args["registration_id__in"])
self.assertIn("token3", call_args["registration_id__in"])
mock_update.assert_called_once_with(active=False)

@mock.patch("push_notifications.apns_async.asyncio.run")
@mock.patch("push_notifications.apns_async.get_manager")
def test_does_not_raise_exception_on_errors(self, mock_manager, mock_asyncio_run):
# Test that errors don't raise exceptions - they're returned in the result
mock_manager.return_value.get_apns_topic.return_value = "com.example.app"

mock_asyncio_run.return_value = [
("token1", NotificationResult("123", "400", description="Unregistered")),
]

try:
result = apns_send_bulk_message(
registration_ids=["token1"],
alert="Test",
creds=TokenCredentials(key="aaa", key_id="bbb", team_id="ccc"),
)
exception_raised = False
except Exception:
exception_raised = True

self.assertFalse(exception_raised, "apns_send_bulk_message should not raise exceptions on errors")
self.assertTrue(result.has_errors)
Loading