Skip to content

Commit f2cec18

Browse files
committed
Handle multiple messages in send and group_send
1 parent 13cef45 commit f2cec18

File tree

2 files changed

+198
-34
lines changed

2 files changed

+198
-34
lines changed

channels_redis/core.py

+144-34
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,40 @@ async def send(self, channel, message):
180180
"""
181181
Send a message onto a (general or specific) channel.
182182
"""
183+
await self.send_bulk(channel, (message,))
184+
185+
async def send_bulk(self, channel, messages):
186+
"""
187+
Send one or multiple messages onto a (general or specific) channel.
188+
The `message` can be a single dict or an iterable of dicts.
189+
"""
190+
183191
# Typecheck
184-
assert isinstance(message, dict), "message is not a dict"
185192
assert self.valid_channel_name(channel), "Channel name not valid"
186-
# Make sure the message does not contain reserved keys
187-
assert "__asgi_channel__" not in message
193+
188194
# If it's a process-local channel, strip off local part and stick full name in message
189195
channel_non_local_name = channel
190-
if "!" in channel:
191-
message = dict(message.items())
192-
message["__asgi_channel__"] = channel
196+
process_local = "!" in channel
197+
if process_local:
193198
channel_non_local_name = self.non_local_name(channel)
199+
200+
now = time.time()
201+
mapping = {}
202+
for message in messages:
203+
assert isinstance(message, dict), "message is not a dict"
204+
# Make sure the message does not contain reserved keys
205+
assert "__asgi_channel__" not in messages
206+
if process_local:
207+
message = dict(message.items())
208+
message["__asgi_channel__"] = channel
209+
210+
mapping[self.serialize(message)] = now
211+
194212
# Write out message into expiring key (avoids big items in list)
195213
channel_key = self.prefix + channel_non_local_name
196214
# Pick a connection to the right server - consistent for specific
197215
# channels, random for general channels
198-
if "!" in channel:
216+
if process_local:
199217
index = self.consistent_hash(channel)
200218
else:
201219
index = next(self._send_index_generator)
@@ -207,13 +225,13 @@ async def send(self, channel, message):
207225

208226
# Check the length of the list before send
209227
# This can allow the list to leak slightly over capacity, but that's fine.
210-
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
211-
channel
212-
):
228+
current_length = await connection.zcount(channel_key, "-inf", "+inf")
229+
230+
if current_length + len(messages) > self.get_capacity(channel):
213231
raise ChannelFull()
214232

215233
# Push onto the list then set it to expire in case it's not consumed
216-
await connection.zadd(channel_key, {self.serialize(message): time.time()})
234+
await connection.zadd(channel_key, mapping)
217235
await connection.expire(channel_key, int(self.expiry))
218236

219237
def _backup_channel_name(self, channel):
@@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
517535
connection = self.connection(self.consistent_hash(group))
518536
await connection.zrem(key, channel)
519537

520-
async def group_send(self, group, message):
521-
"""
522-
Sends a message to the entire group.
523-
"""
538+
async def _get_group_connection_and_channels(self, group):
524539
assert self.valid_group_name(group), "Group name not valid"
525540
# Retrieve list of all channel names
526541
key = self._group_key(group)
@@ -532,11 +547,36 @@ async def group_send(self, group, message):
532547

533548
channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]
534549

550+
return connection, channel_names
551+
552+
async def _exec_group_lua_script(
553+
self, conn_idx, group, channel_redis_keys, channel_names, script, args
554+
):
555+
# channel_keys does not contain a single redis key more than once
556+
connection = self.connection(conn_idx)
557+
channels_over_capacity = await connection.eval(
558+
script, len(channel_redis_keys), *channel_redis_keys, *args
559+
)
560+
if channels_over_capacity > 0:
561+
logger.info(
562+
"%s of %s channels over capacity in group %s",
563+
channels_over_capacity,
564+
len(channel_names),
565+
group,
566+
)
567+
568+
async def group_send(self, group, message):
569+
"""
570+
Sends a message to the entire group.
571+
"""
572+
573+
connection, channel_names = await self._get_group_connection_and_channels(group)
574+
535575
(
536576
connection_to_channel_keys,
537577
channel_keys_to_message,
538578
channel_keys_to_capacity,
539-
) = self._map_channel_keys_to_connection(channel_names, message)
579+
) = self._map_channel_keys_to_connection(channel_names, (message,))
540580

541581
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
542582
# Discard old messages based on expiry
@@ -569,7 +609,7 @@ async def group_send(self, group, message):
569609

570610
# We need to filter the messages to keep those related to the connection
571611
args = [
572-
channel_keys_to_message[channel_key]
612+
channel_keys_to_message[channel_key][0]
573613
for channel_key in channel_redis_keys
574614
]
575615

@@ -581,20 +621,87 @@ async def group_send(self, group, message):
581621

582622
args += [time.time(), self.expiry]
583623

584-
# channel_keys does not contain a single redis key more than once
585-
connection = self.connection(connection_index)
586-
channels_over_capacity = await connection.eval(
587-
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
624+
await self._exec_group_lua_script(
625+
connection_index,
626+
group,
627+
channel_redis_keys,
628+
channel_names,
629+
group_send_lua,
630+
args,
588631
)
589-
if channels_over_capacity > 0:
590-
logger.info(
591-
"%s of %s channels over capacity in group %s",
592-
channels_over_capacity,
593-
len(channel_names),
594-
group,
632+
633+
async def group_send_bulk(self, group, messages):
634+
"""
635+
Sends multiple messages in bulk to the entire group.
636+
"""
637+
638+
connection, channel_names = await self._get_group_connection_and_channels(group)
639+
640+
(
641+
connection_to_channel_keys,
642+
channel_keys_to_message,
643+
channel_keys_to_capacity,
644+
) = self._map_channel_keys_to_connection(channel_names, messages)
645+
646+
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
647+
# Discard old messages based on expiry
648+
pipe = connection.pipeline()
649+
for key in channel_redis_keys:
650+
pipe.zremrangebyscore(
651+
key, min=0, max=int(time.time()) - int(self.expiry)
595652
)
653+
await pipe.execute()
654+
655+
# Create a LUA script specific for this connection.
656+
# Make sure to use the message list specific to this channel, it is
657+
# stored in channel_to_message dict and each message contains the
658+
# __asgi_channel__ key.
659+
660+
group_send_lua = """
661+
local over_capacity = 0
662+
local num_messages = tonumber(ARGV[#ARGV - 2])
663+
local current_time = ARGV[#ARGV - 1]
664+
local expiry = ARGV[#ARGV]
665+
for i=1,#KEYS do
666+
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
667+
local messages = {}
668+
for j=num_messages * (i - 1) + 1, num_messages * i do
669+
table.insert(messages, current_time)
670+
table.insert(messages, ARGV[j])
671+
end
672+
redis.call('ZADD', KEYS[i], unpack(messages))
673+
redis.call('EXPIRE', KEYS[i], expiry)
674+
else
675+
over_capacity = over_capacity + 1
676+
end
677+
end
678+
return over_capacity
679+
"""
680+
681+
# We need to filter the messages to keep those related to the connection
682+
args = []
683+
684+
for channel_key in channel_redis_keys:
685+
args += channel_keys_to_message[channel_key]
686+
687+
# We need to send the capacity for each channel
688+
args += [
689+
channel_keys_to_capacity[channel_key]
690+
for channel_key in channel_redis_keys
691+
]
596692

597-
def _map_channel_keys_to_connection(self, channel_names, message):
693+
args += [len(messages), time.time(), self.expiry]
694+
695+
await self._exec_group_lua_script(
696+
connection_index,
697+
group,
698+
channel_redis_keys,
699+
channel_names,
700+
group_send_lua,
701+
args,
702+
)
703+
704+
def _map_channel_keys_to_connection(self, channel_names, messages):
598705
"""
599706
For a list of channel names, GET
600707
@@ -609,7 +716,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609716
# Connection dict keyed by index to list of redis keys mapped on that index
610717
connection_to_channel_keys = collections.defaultdict(list)
611718
# Message dict maps redis key to the message that needs to be send on that key
612-
channel_key_to_message = dict()
719+
channel_key_to_message = collections.defaultdict(list)
613720
# Channel key mapped to its capacity
614721
channel_key_to_capacity = dict()
615722

@@ -623,20 +730,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623730
# Have we come across the same redis key?
624731
if channel_key not in channel_key_to_message:
625732
# If not, fill the corresponding dicts
626-
message = dict(message.items())
627-
message["__asgi_channel__"] = [channel]
628-
channel_key_to_message[channel_key] = message
733+
for message in messages:
734+
message = dict(message.items())
735+
message["__asgi_channel__"] = [channel]
736+
channel_key_to_message[channel_key].append(message)
629737
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
630738
idx = self.consistent_hash(channel_non_local_name)
631739
connection_to_channel_keys[idx].append(channel_key)
632740
else:
633741
# Yes, Append the channel in message dict
634-
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
742+
for message in channel_key_to_message[channel_key]:
743+
message["__asgi_channel__"].append(channel)
635744

636745
# Now that we know what message needs to be send on a redis key we serialize it
637746
for key, value in channel_key_to_message.items():
638747
# Serialize the message stored for each redis key
639-
channel_key_to_message[key] = self.serialize(value)
748+
for idx, message in enumerate(value):
749+
channel_key_to_message[key][idx] = self.serialize(message)
640750

641751
return (
642752
connection_to_channel_keys,

tests/test_core.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import collections
23
import random
34

45
import async_timeout
@@ -125,6 +126,25 @@ async def listen2():
125126
async_to_sync(channel_layer.flush)()
126127

127128

129+
@pytest.mark.asyncio
130+
async def test_send_multiple(channel_layer):
131+
messsages = [
132+
{"type": "test.message.1"},
133+
{"type": "test.message.2"},
134+
{"type": "test.message.3"},
135+
]
136+
137+
await channel_layer.send_bulk("test-channel-1", messsages)
138+
139+
expected = {"test.message.1", "test.message.2", "test.message.3"}
140+
received = set()
141+
for _ in range(3):
142+
msg = await channel_layer.receive("test-channel-1")
143+
received.add(msg["type"])
144+
145+
assert received == expected
146+
147+
128148
@pytest.mark.asyncio
129149
async def test_send_capacity(channel_layer):
130150
"""
@@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
225245
await channel_layer.flush()
226246

227247

248+
@pytest.mark.asyncio
249+
async def test_groups_multiple(channel_layer):
250+
"""
251+
Tests basic group operation.
252+
"""
253+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
254+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
255+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
256+
await channel_layer.group_add("test-group", channel_name1)
257+
await channel_layer.group_add("test-group", channel_name2)
258+
await channel_layer.group_add("test-group", channel_name3)
259+
260+
messages = [
261+
{"type": "message.1"},
262+
{"type": "message.2"},
263+
{"type": "message.3"},
264+
]
265+
266+
expected = {msg["type"] for msg in messages}
267+
268+
await channel_layer.group_send_bulk("test-group", messages)
269+
270+
received = collections.defaultdict(set)
271+
272+
for channel_name in (channel_name1, channel_name2, channel_name3):
273+
async with async_timeout.timeout(1):
274+
for _ in range(len(messages)):
275+
received[channel_name].add(
276+
(await channel_layer.receive(channel_name))["type"]
277+
)
278+
279+
assert received[channel_name] == expected
280+
281+
228282
@pytest.mark.asyncio
229283
async def test_groups_channel_full(channel_layer):
230284
"""

0 commit comments

Comments
 (0)