Skip to content

Commit 15b50ae

Browse files
committed
Handle multiple messages in send and group_send
1 parent 13cef45 commit 15b50ae

File tree

2 files changed

+117
-30
lines changed

2 files changed

+117
-30
lines changed

channels_redis/core.py

+63-30
Original file line numberDiff line numberDiff line change
@@ -178,24 +178,37 @@ def _setup_encryption(self, symmetric_encryption_keys):
178178

179179
async def send(self, channel, message):
180180
"""
181-
Send a message onto a (general or specific) channel.
181+
Send one or multiple messages onto a (general or specific) channel.
182+
The `message` can be a single dict or an iterable of dicts.
182183
"""
184+
messages = self._parse_messages(message)
185+
183186
# Typecheck
184-
assert isinstance(message, dict), "message is not a dict"
185187
assert self.valid_channel_name(channel), "Channel name not valid"
186-
# Make sure the message does not contain reserved keys
187-
assert "__asgi_channel__" not in message
188+
188189
# If it's a process-local channel, strip off local part and stick full name in message
189190
channel_non_local_name = channel
190-
if "!" in channel:
191-
message = dict(message.items())
192-
message["__asgi_channel__"] = channel
191+
process_local = "!" in channel
192+
if process_local:
193193
channel_non_local_name = self.non_local_name(channel)
194+
195+
now = time.time()
196+
mapping = {}
197+
for message in messages:
198+
assert isinstance(message, dict), "message is not a dict"
199+
# Make sure the message does not contain reserved keys
200+
assert "__asgi_channel__" not in message
201+
if process_local:
202+
message = dict(message.items())
203+
message["__asgi_channel__"] = channel
204+
205+
mapping[self.serialize(message)] = now
206+
194207
# Write out message into expiring key (avoids big items in list)
195208
channel_key = self.prefix + channel_non_local_name
196209
# Pick a connection to the right server - consistent for specific
197210
# channels, random for general channels
198-
if "!" in channel:
211+
if process_local:
199212
index = self.consistent_hash(channel)
200213
else:
201214
index = next(self._send_index_generator)
@@ -207,15 +220,23 @@ async def send(self, channel, message):
207220

208221
# Check the length of the list before send
209222
# This can allow the list to leak slightly over capacity, but that's fine.
210-
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
211-
channel
212-
):
223+
current_length = await connection.zcount(channel_key, "-inf", "+inf")
224+
225+
if current_length + len(messages) > self.get_capacity(channel):
213226
raise ChannelFull()
214227

215228
# Push onto the list then set it to expire in case it's not consumed
216-
await connection.zadd(channel_key, {self.serialize(message): time.time()})
229+
await connection.zadd(channel_key, mapping)
217230
await connection.expire(channel_key, int(self.expiry))
218231

232+
def _parse_messages(self, message):
233+
"""
234+
Convert a passed message arg to a tuple of messages.
235+
"""
236+
if not isinstance(message, dict) and hasattr(message, "__iter__"):
237+
return tuple(message)
238+
return (message,)
239+
219240
def _backup_channel_name(self, channel):
220241
"""
221242
Construct the key used as a backup queue for the given channel.
@@ -519,8 +540,11 @@ async def group_discard(self, group, channel):
519540

520541
async def group_send(self, group, message):
521542
"""
522-
Sends a message to the entire group.
543+
Sends one or multiple messages to the entire group.
544+
The `message` can be a single dict or an iterable of dicts.
523545
"""
546+
messages = self._parse_messages(message)
547+
524548
assert self.valid_group_name(group), "Group name not valid"
525549
# Retrieve list of all channel names
526550
key = self._group_key(group)
@@ -536,7 +560,7 @@ async def group_send(self, group, message):
536560
connection_to_channel_keys,
537561
channel_keys_to_message,
538562
channel_keys_to_capacity,
539-
) = self._map_channel_keys_to_connection(channel_names, message)
563+
) = self._map_channel_keys_to_connection(channel_names, messages)
540564

541565
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
542566
# Discard old messages based on expiry
@@ -548,17 +572,23 @@ async def group_send(self, group, message):
548572
await pipe.execute()
549573

550574
# Create a LUA script specific for this connection.
551-
# Make sure to use the message specific to this channel, it is
552-
# stored in channel_to_message dict and contains the
575+
# Make sure to use the message list specific to this channel, it is
576+
# stored in channel_to_message dict and each message contains the
553577
# __asgi_channel__ key.
554578

555579
group_send_lua = """
556580
local over_capacity = 0
581+
local num_messages = tonumber(ARGV[#ARGV - 2])
557582
local current_time = ARGV[#ARGV - 1]
558583
local expiry = ARGV[#ARGV]
559584
for i=1,#KEYS do
560-
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
561-
redis.call('ZADD', KEYS[i], current_time, ARGV[i])
585+
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
586+
local messages = {}
587+
for j=num_messages * (i - 1) + 1, num_messages * i do
588+
table.insert(messages, current_time)
589+
table.insert(messages, ARGV[j])
590+
end
591+
redis.call('ZADD', KEYS[i], unpack(messages))
562592
redis.call('EXPIRE', KEYS[i], expiry)
563593
else
564594
over_capacity = over_capacity + 1
@@ -568,18 +598,18 @@ async def group_send(self, group, message):
568598
"""
569599

570600
# We need to filter the messages to keep those related to the connection
571-
args = [
572-
channel_keys_to_message[channel_key]
573-
for channel_key in channel_redis_keys
574-
]
601+
args = []
602+
603+
for channel_key in channel_redis_keys:
604+
args += channel_keys_to_message[channel_key]
575605

576606
# We need to send the capacity for each channel
577607
args += [
578608
channel_keys_to_capacity[channel_key]
579609
for channel_key in channel_redis_keys
580610
]
581611

582-
args += [time.time(), self.expiry]
612+
args += [len(messages), time.time(), self.expiry]
583613

584614
# channel_keys does not contain a single redis key more than once
585615
connection = self.connection(connection_index)
@@ -594,7 +624,7 @@ async def group_send(self, group, message):
594624
group,
595625
)
596626

597-
def _map_channel_keys_to_connection(self, channel_names, message):
627+
def _map_channel_keys_to_connection(self, channel_names, messages):
598628
"""
599629
For a list of channel names, GET
600630
@@ -609,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609639
# Connection dict keyed by index to list of redis keys mapped on that index
610640
connection_to_channel_keys = collections.defaultdict(list)
611641
# Message dict maps redis key to the message that needs to be send on that key
612-
channel_key_to_message = dict()
642+
channel_key_to_message = collections.defaultdict(list)
613643
# Channel key mapped to its capacity
614644
channel_key_to_capacity = dict()
615645

@@ -623,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623653
# Have we come across the same redis key?
624654
if channel_key not in channel_key_to_message:
625655
# If not, fill the corresponding dicts
626-
message = dict(message.items())
627-
message["__asgi_channel__"] = [channel]
628-
channel_key_to_message[channel_key] = message
656+
for message in messages:
657+
message = dict(message.items())
658+
message["__asgi_channel__"] = [channel]
659+
channel_key_to_message[channel_key].append(message)
629660
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
630661
idx = self.consistent_hash(channel_non_local_name)
631662
connection_to_channel_keys[idx].append(channel_key)
632663
else:
633664
# Yes, Append the channel in message dict
634-
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
665+
for message in channel_key_to_message[channel_key]:
666+
message["__asgi_channel__"].append(channel)
635667

636668
# Now that we know what message needs to be send on a redis key we serialize it
637669
for key, value in channel_key_to_message.items():
638670
# Serialize the message stored for each redis key
639-
channel_key_to_message[key] = self.serialize(value)
671+
for idx, message in enumerate(value):
672+
channel_key_to_message[key][idx] = self.serialize(message)
640673

641674
return (
642675
connection_to_channel_keys,

tests/test_core.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import collections
23
import random
34

45
import async_timeout
@@ -125,6 +126,25 @@ async def listen2():
125126
async_to_sync(channel_layer.flush)()
126127

127128

129+
@pytest.mark.asyncio
130+
async def test_send_multiple(channel_layer):
131+
messsages = [
132+
{"type": "test.message.1"},
133+
{"type": "test.message.2"},
134+
{"type": "test.message.3"},
135+
]
136+
137+
await channel_layer.send("test-channel-1", messsages)
138+
139+
expected = {"test.message.1", "test.message.2", "test.message.3"}
140+
received = set()
141+
for _ in range(3):
142+
msg = await channel_layer.receive("test-channel-1")
143+
received.add(msg["type"])
144+
145+
assert received == expected
146+
147+
128148
@pytest.mark.asyncio
129149
async def test_send_capacity(channel_layer):
130150
"""
@@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
225245
await channel_layer.flush()
226246

227247

248+
@pytest.mark.asyncio
249+
async def test_groups_multiple(channel_layer):
250+
"""
251+
Tests basic group operation.
252+
"""
253+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
254+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
255+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
256+
await channel_layer.group_add("test-group", channel_name1)
257+
await channel_layer.group_add("test-group", channel_name2)
258+
await channel_layer.group_add("test-group", channel_name3)
259+
260+
messages = [
261+
{"type": "message.1"},
262+
{"type": "message.2"},
263+
{"type": "message.3"},
264+
]
265+
266+
expected = {msg["type"] for msg in messages}
267+
268+
await channel_layer.group_send("test-group", messages)
269+
270+
received = collections.defaultdict(set)
271+
272+
for channel_name in (channel_name1, channel_name2, channel_name3):
273+
async with async_timeout.timeout(1):
274+
for _ in range(len(messages)):
275+
received[channel_name].add(
276+
(await channel_layer.receive(channel_name))["type"]
277+
)
278+
279+
assert received[channel_name] == expected
280+
281+
228282
@pytest.mark.asyncio
229283
async def test_groups_channel_full(channel_layer):
230284
"""

0 commit comments

Comments
 (0)