diff --git a/app/persistence/azure_queue_storage.py b/app/persistence/azure_queue_storage.py index d677006..4a07e90 100644 --- a/app/persistence/azure_queue_storage.py +++ b/app/persistence/azure_queue_storage.py @@ -1,3 +1,5 @@ +from base64 import b64decode, b64encode +from binascii import Error as BinasciiError from typing import Any, AsyncGenerator from azure.core.exceptions import ( @@ -5,7 +7,6 @@ ResourceNotFoundError, ServiceRequestError, ) -from azure.storage.queue import TextBase64DecodePolicy, TextBase64EncodePolicy from azure.storage.queue.aio import QueueClient, QueueServiceClient from pydantic import BaseModel from tenacity import ( @@ -47,7 +48,7 @@ async def send_message( self, message: str, ) -> None: - await self._client.send_message(message) + await self._client.send_message(self._escape(message)) @retry( reraise=True, @@ -66,7 +67,7 @@ async def receive_messages( ) async for message in messages: yield Message( - content=message.content, + content=self._unescape(message.content), delete_token=message.pop_receipt, dequeue_count=message.dequeue_count, message_id=message.id, @@ -99,13 +100,28 @@ async def delete_queue( await self._client.delete_queue() logger.info('Deleted Queue Storage "%s"', self._config.name) + def _escape(self, value: str) -> str: + """ + Escape value to base64 encoding. + """ + return b64encode(value.encode(self.encoding)).decode(self.encoding) + + def _unescape(self, value: str) -> str: + """ + Unescape value from base64 encoding. + + If the value is not base64 encoded, return the original value as string. This will handle retro-compatibility with old messages. + """ + try: + return b64decode(value.encode(self.encoding)).decode(self.encoding) + except BinasciiError: + return value + async def __aenter__(self) -> "AzureQueueStorage": self._service = QueueServiceClient.from_connection_string( self._config.connection_string ) self._client = self._service.get_queue_client( - message_decode_policy=TextBase64DecodePolicy(), - message_encode_policy=TextBase64EncodePolicy(), queue=self._config.name, ) # Create if it does not exist diff --git a/app/persistence/iqueue.py b/app/persistence/iqueue.py index dc02f0a..d28ba6d 100644 --- a/app/persistence/iqueue.py +++ b/app/persistence/iqueue.py @@ -16,6 +16,8 @@ class Provider(str, Enum): class IQueue: + encoding = "utf-8" + @abstractmethod async def send_message( self, diff --git a/tests/blob.py b/tests/blob.py index fd1045f..c696b6c 100644 --- a/tests/blob.py +++ b/tests/blob.py @@ -297,12 +297,20 @@ async def _validate( def _random_name() -> str: + """ + Generate a random name with 32 characters. + + All lowercase letters and digits are used. + """ return "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(32) ) def _random_content() -> str: - return "".join( - random.choice(string.printable) for _ in range(random.randint(1, 512)) - ) + """ + Generate a random content with a length of 512 characters. + + All printable ASCII characters are used. + """ + return "".join(random.choice(string.printable) for _ in range(512)) diff --git a/tests/queue.py b/tests/queue.py index 56967f9..e133445 100644 --- a/tests/queue.py +++ b/tests/queue.py @@ -169,12 +169,20 @@ async def test_send_many(provider: QueueProvider) -> None: def _random_name() -> str: + """ + Generate a random name with 32 characters. + + All lowercase letters and digits are used. + """ return "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(32) ) def _random_content() -> str: - return "".join( - random.choice(string.printable) for _ in range(random.randint(1, 512)) - ) + """ + Generate a random content with a length of 512 characters. + + All printable ASCII characters are used. + """ + return "".join(random.choice(string.printable) for _ in range(512))