From 13debf7a3af363987b08209c6e398f71a95e617c Mon Sep 17 00:00:00 2001
From: olzhasar <o.arystanov@gmail.com>
Date: Sun, 29 Sep 2024 03:42:30 +0500
Subject: [PATCH] Handle multiple messages in send and group_send

---
 channels_redis/core.py | 179 +++++++++++++++++++++++++++++++++--------
 tests/test_core.py     |  54 +++++++++++++
 2 files changed, 199 insertions(+), 34 deletions(-)

diff --git a/channels_redis/core.py b/channels_redis/core.py
index a164059..6ed3865 100644
--- a/channels_redis/core.py
+++ b/channels_redis/core.py
@@ -180,22 +180,40 @@ async def send(self, channel, message):
         """
         Send a message onto a (general or specific) channel.
         """
+        await self.send_bulk(channel, (message,))
+
+    async def send_bulk(self, channel, messages):
+        """
+        Send multiple messages in bulk onto a (general or specific) channel.
+        The `messages` argument should be an iterable of dicts.
+        """
+
         # Typecheck
-        assert isinstance(message, dict), "message is not a dict"
         assert self.valid_channel_name(channel), "Channel name not valid"
-        # Make sure the message does not contain reserved keys
-        assert "__asgi_channel__" not in message
+
         # If it's a process-local channel, strip off local part and stick full name in message
         channel_non_local_name = channel
-        if "!" in channel:
-            message = dict(message.items())
-            message["__asgi_channel__"] = channel
+        process_local = "!" in channel
+        if process_local:
             channel_non_local_name = self.non_local_name(channel)
+
+        now = time.time()
+        mapping = {}
+        for message in messages:
+            assert isinstance(message, dict), "message is not a dict"
+            # Make sure the message does not contain reserved keys
+            assert "__asgi_channel__" not in message
+            if process_local:
+                message = dict(message.items())
+                message["__asgi_channel__"] = channel
+
+            mapping[self.serialize(message)] = now
+
         # Write out message into expiring key (avoids big items in list)
         channel_key = self.prefix + channel_non_local_name
         # Pick a connection to the right server - consistent for specific
         # channels, random for general channels
-        if "!" in channel:
+        if process_local:
             index = self.consistent_hash(channel)
         else:
             index = next(self._send_index_generator)
@@ -207,13 +225,13 @@ async def send(self, channel, message):
 
         # Check the length of the list before send
         # This can allow the list to leak slightly over capacity, but that's fine.
-        if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
-            channel
-        ):
+        current_length = await connection.zcount(channel_key, "-inf", "+inf")
+
+        if current_length + len(messages) > self.get_capacity(channel):
             raise ChannelFull()
 
         # Push onto the list then set it to expire in case it's not consumed
-        await connection.zadd(channel_key, {self.serialize(message): time.time()})
+        await connection.zadd(channel_key, mapping)
         await connection.expire(channel_key, int(self.expiry))
 
     def _backup_channel_name(self, channel):
@@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
         connection = self.connection(self.consistent_hash(group))
         await connection.zrem(key, channel)
 
-    async def group_send(self, group, message):
-        """
-        Sends a message to the entire group.
-        """
+    async def _get_group_connection_and_channels(self, group):
         assert self.valid_group_name(group), "Group name not valid"
         # Retrieve list of all channel names
         key = self._group_key(group)
@@ -532,11 +547,36 @@ async def group_send(self, group, message):
 
         channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]
 
+        return connection, channel_names
+
+    async def _exec_group_lua_script(
+        self, conn_idx, group, channel_redis_keys, channel_names, script, args
+    ):
+        # channel_keys does not contain a single redis key more than once
+        connection = self.connection(conn_idx)
+        channels_over_capacity = await connection.eval(
+            script, len(channel_redis_keys), *channel_redis_keys, *args
+        )
+        if channels_over_capacity > 0:
+            logger.info(
+                "%s of %s channels over capacity in group %s",
+                channels_over_capacity,
+                len(channel_names),
+                group,
+            )
+
+    async def group_send(self, group, message):
+        """
+        Sends a message to the entire group.
+        """
+
+        connection, channel_names = await self._get_group_connection_and_channels(group)
+
         (
             connection_to_channel_keys,
             channel_keys_to_message,
             channel_keys_to_capacity,
-        ) = self._map_channel_keys_to_connection(channel_names, message)
+        ) = self._map_channel_keys_to_connection(channel_names, (message,))
 
         for connection_index, channel_redis_keys in connection_to_channel_keys.items():
             # Discard old messages based on expiry
@@ -569,7 +609,7 @@ async def group_send(self, group, message):
 
             # We need to filter the messages to keep those related to the connection
             args = [
-                channel_keys_to_message[channel_key]
+                channel_keys_to_message[channel_key][0]
                 for channel_key in channel_redis_keys
             ]
 
@@ -581,20 +621,88 @@ async def group_send(self, group, message):
 
             args += [time.time(), self.expiry]
 
-            # channel_keys does not contain a single redis key more than once
-            connection = self.connection(connection_index)
-            channels_over_capacity = await connection.eval(
-                group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
+            await self._exec_group_lua_script(
+                connection_index,
+                group,
+                channel_redis_keys,
+                channel_names,
+                group_send_lua,
+                args,
             )
-            if channels_over_capacity > 0:
-                logger.info(
-                    "%s of %s channels over capacity in group %s",
-                    channels_over_capacity,
-                    len(channel_names),
-                    group,
+
+    async def group_send_bulk(self, group, messages):
+        """
+        Sends multiple messages in bulk to the entire group.
+        The `messages` argument should be an iterable of dicts.
+        """
+
+        connection, channel_names = await self._get_group_connection_and_channels(group)
+
+        (
+            connection_to_channel_keys,
+            channel_keys_to_message,
+            channel_keys_to_capacity,
+        ) = self._map_channel_keys_to_connection(channel_names, messages)
+
+        for connection_index, channel_redis_keys in connection_to_channel_keys.items():
+            # Discard old messages based on expiry
+            pipe = connection.pipeline()
+            for key in channel_redis_keys:
+                pipe.zremrangebyscore(
+                    key, min=0, max=int(time.time()) - int(self.expiry)
                 )
+            await pipe.execute()
+
+            # Create a LUA script specific for this connection.
+            # Make sure to use the message list specific to this channel, it is
+            # stored in channel_to_message dict and each message contains the
+            # __asgi_channel__ key.
+
+            group_send_lua = """
+                local over_capacity = 0
+                local num_messages = tonumber(ARGV[#ARGV - 2])
+                local current_time = ARGV[#ARGV - 1]
+                local expiry = ARGV[#ARGV]
+                for i=1,#KEYS do
+                    if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
+                        local messages = {}
+                        for j=num_messages * (i - 1) + 1, num_messages * i do
+                            table.insert(messages, current_time)
+                            table.insert(messages, ARGV[j])
+                        end
+                        redis.call('ZADD', KEYS[i], unpack(messages))
+                        redis.call('EXPIRE', KEYS[i], expiry)
+                    else
+                        over_capacity = over_capacity + 1
+                    end
+                end
+                return over_capacity
+            """
+
+            # We need to filter the messages to keep those related to the connection
+            args = []
+
+            for channel_key in channel_redis_keys:
+                args += channel_keys_to_message[channel_key]
+
+            # We need to send the capacity for each channel
+            args += [
+                channel_keys_to_capacity[channel_key]
+                for channel_key in channel_redis_keys
+            ]
 
-    def _map_channel_keys_to_connection(self, channel_names, message):
+            args += [len(messages), time.time(), self.expiry]
+
+            await self._exec_group_lua_script(
+                connection_index,
+                group,
+                channel_redis_keys,
+                channel_names,
+                group_send_lua,
+                args,
+            )
+
+    def _map_channel_keys_to_connection(self, channel_names, messages):
         """
         For a list of channel names, GET
 
@@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
         # Connection dict keyed by index to list of redis keys mapped on that index
         connection_to_channel_keys = collections.defaultdict(list)
         # Message dict maps redis key to the message that needs to be send on that key
-        channel_key_to_message = dict()
+        channel_key_to_message = collections.defaultdict(list)
         # Channel key mapped to its capacity
         channel_key_to_capacity = dict()
 
@@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
             # Have we come across the same redis key?
             if channel_key not in channel_key_to_message:
                 # If not, fill the corresponding dicts
-                message = dict(message.items())
-                message["__asgi_channel__"] = [channel]
-                channel_key_to_message[channel_key] = message
+                for message in messages:
+                    message = dict(message.items())
+                    message["__asgi_channel__"] = [channel]
+                    channel_key_to_message[channel_key].append(message)
                 channel_key_to_capacity[channel_key] = self.get_capacity(channel)
                 idx = self.consistent_hash(channel_non_local_name)
                 connection_to_channel_keys[idx].append(channel_key)
             else:
                 # Yes, Append the channel in message dict
-                channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
+                for message in channel_key_to_message[channel_key]:
+                    message["__asgi_channel__"].append(channel)
 
         # Now that we know what message needs to be send on a redis key we serialize it
         for key, value in channel_key_to_message.items():
             # Serialize the message stored for each redis key
-            channel_key_to_message[key] = self.serialize(value)
+            for idx, message in enumerate(value):
+                channel_key_to_message[key][idx] = self.serialize(message)
 
         return (
             connection_to_channel_keys,
diff --git a/tests/test_core.py b/tests/test_core.py
index 2752040..a8e5bef 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -1,4 +1,5 @@
 import asyncio
+import collections
 import random
 
 import async_timeout
@@ -125,6 +126,25 @@ async def listen2():
     async_to_sync(channel_layer.flush)()
 
 
+@pytest.mark.asyncio
+async def test_send_multiple(channel_layer):
+    messsages = [
+        {"type": "test.message.1"},
+        {"type": "test.message.2"},
+        {"type": "test.message.3"},
+    ]
+
+    await channel_layer.send_bulk("test-channel-1", messsages)
+
+    expected = {"test.message.1", "test.message.2", "test.message.3"}
+    received = set()
+    for _ in range(3):
+        msg = await channel_layer.receive("test-channel-1")
+        received.add(msg["type"])
+
+    assert received == expected
+
+
 @pytest.mark.asyncio
 async def test_send_capacity(channel_layer):
     """
@@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
     await channel_layer.flush()
 
 
+@pytest.mark.asyncio
+async def test_groups_multiple(channel_layer):
+    """
+    Tests basic group operation.
+    """
+    channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
+    channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
+    channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
+    await channel_layer.group_add("test-group", channel_name1)
+    await channel_layer.group_add("test-group", channel_name2)
+    await channel_layer.group_add("test-group", channel_name3)
+
+    messages = [
+        {"type": "message.1"},
+        {"type": "message.2"},
+        {"type": "message.3"},
+    ]
+
+    expected = {msg["type"] for msg in messages}
+
+    await channel_layer.group_send_bulk("test-group", messages)
+
+    received = collections.defaultdict(set)
+
+    for channel_name in (channel_name1, channel_name2, channel_name3):
+        async with async_timeout.timeout(1):
+            for _ in range(len(messages)):
+                received[channel_name].add(
+                    (await channel_layer.receive(channel_name))["type"]
+                )
+
+        assert received[channel_name] == expected
+
+
 @pytest.mark.asyncio
 async def test_groups_channel_full(channel_layer):
     """