Skip to content

Commit 13debf7

Browse files
committed
Handle multiple messages in send and group_send
1 parent 13cef45 commit 13debf7

File tree

2 files changed

+199
-34
lines changed

2 files changed

+199
-34
lines changed

channels_redis/core.py

+145-34
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,40 @@ async def send(self, channel, message):
180180
"""
181181
Send a message onto a (general or specific) channel.
182182
"""
183+
await self.send_bulk(channel, (message,))
184+
185+
async def send_bulk(self, channel, messages):
186+
"""
187+
Send multiple messages in bulk onto a (general or specific) channel.
188+
The `messages` argument should be an iterable of dicts.
189+
"""
190+
183191
# Typecheck
184-
assert isinstance(message, dict), "message is not a dict"
185192
assert self.valid_channel_name(channel), "Channel name not valid"
186-
# Make sure the message does not contain reserved keys
187-
assert "__asgi_channel__" not in message
193+
188194
# If it's a process-local channel, strip off local part and stick full name in message
189195
channel_non_local_name = channel
190-
if "!" in channel:
191-
message = dict(message.items())
192-
message["__asgi_channel__"] = channel
196+
process_local = "!" in channel
197+
if process_local:
193198
channel_non_local_name = self.non_local_name(channel)
199+
200+
now = time.time()
201+
mapping = {}
202+
for message in messages:
203+
assert isinstance(message, dict), "message is not a dict"
204+
# Make sure the message does not contain reserved keys
205+
assert "__asgi_channel__" not in message
206+
if process_local:
207+
message = dict(message.items())
208+
message["__asgi_channel__"] = channel
209+
210+
mapping[self.serialize(message)] = now
211+
194212
# Write out message into expiring key (avoids big items in list)
195213
channel_key = self.prefix + channel_non_local_name
196214
# Pick a connection to the right server - consistent for specific
197215
# channels, random for general channels
198-
if "!" in channel:
216+
if process_local:
199217
index = self.consistent_hash(channel)
200218
else:
201219
index = next(self._send_index_generator)
@@ -207,13 +225,13 @@ async def send(self, channel, message):
207225

208226
# Check the length of the list before send
209227
# This can allow the list to leak slightly over capacity, but that's fine.
210-
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
211-
channel
212-
):
228+
current_length = await connection.zcount(channel_key, "-inf", "+inf")
229+
230+
if current_length + len(messages) > self.get_capacity(channel):
213231
raise ChannelFull()
214232

215233
# Push onto the list then set it to expire in case it's not consumed
216-
await connection.zadd(channel_key, {self.serialize(message): time.time()})
234+
await connection.zadd(channel_key, mapping)
217235
await connection.expire(channel_key, int(self.expiry))
218236

219237
def _backup_channel_name(self, channel):
@@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
517535
connection = self.connection(self.consistent_hash(group))
518536
await connection.zrem(key, channel)
519537

520-
async def group_send(self, group, message):
521-
"""
522-
Sends a message to the entire group.
523-
"""
538+
async def _get_group_connection_and_channels(self, group):
524539
assert self.valid_group_name(group), "Group name not valid"
525540
# Retrieve list of all channel names
526541
key = self._group_key(group)
@@ -532,11 +547,36 @@ async def group_send(self, group, message):
532547

533548
channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]
534549

550+
return connection, channel_names
551+
552+
async def _exec_group_lua_script(
553+
self, conn_idx, group, channel_redis_keys, channel_names, script, args
554+
):
555+
# channel_keys does not contain a single redis key more than once
556+
connection = self.connection(conn_idx)
557+
channels_over_capacity = await connection.eval(
558+
script, len(channel_redis_keys), *channel_redis_keys, *args
559+
)
560+
if channels_over_capacity > 0:
561+
logger.info(
562+
"%s of %s channels over capacity in group %s",
563+
channels_over_capacity,
564+
len(channel_names),
565+
group,
566+
)
567+
568+
async def group_send(self, group, message):
569+
"""
570+
Sends a message to the entire group.
571+
"""
572+
573+
connection, channel_names = await self._get_group_connection_and_channels(group)
574+
535575
(
536576
connection_to_channel_keys,
537577
channel_keys_to_message,
538578
channel_keys_to_capacity,
539-
) = self._map_channel_keys_to_connection(channel_names, message)
579+
) = self._map_channel_keys_to_connection(channel_names, (message,))
540580

541581
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
542582
# Discard old messages based on expiry
@@ -569,7 +609,7 @@ async def group_send(self, group, message):
569609

570610
# We need to filter the messages to keep those related to the connection
571611
args = [
572-
channel_keys_to_message[channel_key]
612+
channel_keys_to_message[channel_key][0]
573613
for channel_key in channel_redis_keys
574614
]
575615

@@ -581,20 +621,88 @@ async def group_send(self, group, message):
581621

582622
args += [time.time(), self.expiry]
583623

584-
# channel_keys does not contain a single redis key more than once
585-
connection = self.connection(connection_index)
586-
channels_over_capacity = await connection.eval(
587-
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
624+
await self._exec_group_lua_script(
625+
connection_index,
626+
group,
627+
channel_redis_keys,
628+
channel_names,
629+
group_send_lua,
630+
args,
588631
)
589-
if channels_over_capacity > 0:
590-
logger.info(
591-
"%s of %s channels over capacity in group %s",
592-
channels_over_capacity,
593-
len(channel_names),
594-
group,
632+
633+
async def group_send_bulk(self, group, messages):
634+
"""
635+
Sends multiple messages in bulk to the entire group.
636+
The `messages` argument should be an iterable of dicts.
637+
"""
638+
639+
connection, channel_names = await self._get_group_connection_and_channels(group)
640+
641+
(
642+
connection_to_channel_keys,
643+
channel_keys_to_message,
644+
channel_keys_to_capacity,
645+
) = self._map_channel_keys_to_connection(channel_names, messages)
646+
647+
for connection_index, channel_redis_keys in connection_to_channel_keys.items():
648+
# Discard old messages based on expiry
649+
pipe = connection.pipeline()
650+
for key in channel_redis_keys:
651+
pipe.zremrangebyscore(
652+
key, min=0, max=int(time.time()) - int(self.expiry)
595653
)
654+
await pipe.execute()
655+
656+
# Create a LUA script specific for this connection.
657+
# Make sure to use the message list specific to this channel, it is
658+
# stored in channel_to_message dict and each message contains the
659+
# __asgi_channel__ key.
660+
661+
group_send_lua = """
662+
local over_capacity = 0
663+
local num_messages = tonumber(ARGV[#ARGV - 2])
664+
local current_time = ARGV[#ARGV - 1]
665+
local expiry = ARGV[#ARGV]
666+
for i=1,#KEYS do
667+
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
668+
local messages = {}
669+
for j=num_messages * (i - 1) + 1, num_messages * i do
670+
table.insert(messages, current_time)
671+
table.insert(messages, ARGV[j])
672+
end
673+
redis.call('ZADD', KEYS[i], unpack(messages))
674+
redis.call('EXPIRE', KEYS[i], expiry)
675+
else
676+
over_capacity = over_capacity + 1
677+
end
678+
end
679+
return over_capacity
680+
"""
681+
682+
# We need to filter the messages to keep those related to the connection
683+
args = []
684+
685+
for channel_key in channel_redis_keys:
686+
args += channel_keys_to_message[channel_key]
687+
688+
# We need to send the capacity for each channel
689+
args += [
690+
channel_keys_to_capacity[channel_key]
691+
for channel_key in channel_redis_keys
692+
]
596693

597-
def _map_channel_keys_to_connection(self, channel_names, message):
694+
args += [len(messages), time.time(), self.expiry]
695+
696+
await self._exec_group_lua_script(
697+
connection_index,
698+
group,
699+
channel_redis_keys,
700+
channel_names,
701+
group_send_lua,
702+
args,
703+
)
704+
705+
def _map_channel_keys_to_connection(self, channel_names, messages):
598706
"""
599707
For a list of channel names, GET
600708
@@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609717
# Connection dict keyed by index to list of redis keys mapped on that index
610718
connection_to_channel_keys = collections.defaultdict(list)
611719
# Message dict maps redis key to the message that needs to be send on that key
612-
channel_key_to_message = dict()
720+
channel_key_to_message = collections.defaultdict(list)
613721
# Channel key mapped to its capacity
614722
channel_key_to_capacity = dict()
615723

@@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623731
# Have we come across the same redis key?
624732
if channel_key not in channel_key_to_message:
625733
# If not, fill the corresponding dicts
626-
message = dict(message.items())
627-
message["__asgi_channel__"] = [channel]
628-
channel_key_to_message[channel_key] = message
734+
for message in messages:
735+
message = dict(message.items())
736+
message["__asgi_channel__"] = [channel]
737+
channel_key_to_message[channel_key].append(message)
629738
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
630739
idx = self.consistent_hash(channel_non_local_name)
631740
connection_to_channel_keys[idx].append(channel_key)
632741
else:
633742
# Yes, Append the channel in message dict
634-
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
743+
for message in channel_key_to_message[channel_key]:
744+
message["__asgi_channel__"].append(channel)
635745

636746
# Now that we know what message needs to be send on a redis key we serialize it
637747
for key, value in channel_key_to_message.items():
638748
# Serialize the message stored for each redis key
639-
channel_key_to_message[key] = self.serialize(value)
749+
for idx, message in enumerate(value):
750+
channel_key_to_message[key][idx] = self.serialize(message)
640751

641752
return (
642753
connection_to_channel_keys,

tests/test_core.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import collections
23
import random
34

45
import async_timeout
@@ -125,6 +126,25 @@ async def listen2():
125126
async_to_sync(channel_layer.flush)()
126127

127128

129+
@pytest.mark.asyncio
130+
async def test_send_multiple(channel_layer):
131+
messsages = [
132+
{"type": "test.message.1"},
133+
{"type": "test.message.2"},
134+
{"type": "test.message.3"},
135+
]
136+
137+
await channel_layer.send_bulk("test-channel-1", messsages)
138+
139+
expected = {"test.message.1", "test.message.2", "test.message.3"}
140+
received = set()
141+
for _ in range(3):
142+
msg = await channel_layer.receive("test-channel-1")
143+
received.add(msg["type"])
144+
145+
assert received == expected
146+
147+
128148
@pytest.mark.asyncio
129149
async def test_send_capacity(channel_layer):
130150
"""
@@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
225245
await channel_layer.flush()
226246

227247

248+
@pytest.mark.asyncio
249+
async def test_groups_multiple(channel_layer):
250+
"""
251+
Tests basic group operation.
252+
"""
253+
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
254+
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
255+
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
256+
await channel_layer.group_add("test-group", channel_name1)
257+
await channel_layer.group_add("test-group", channel_name2)
258+
await channel_layer.group_add("test-group", channel_name3)
259+
260+
messages = [
261+
{"type": "message.1"},
262+
{"type": "message.2"},
263+
{"type": "message.3"},
264+
]
265+
266+
expected = {msg["type"] for msg in messages}
267+
268+
await channel_layer.group_send_bulk("test-group", messages)
269+
270+
received = collections.defaultdict(set)
271+
272+
for channel_name in (channel_name1, channel_name2, channel_name3):
273+
async with async_timeout.timeout(1):
274+
for _ in range(len(messages)):
275+
received[channel_name].add(
276+
(await channel_layer.receive(channel_name))["type"]
277+
)
278+
279+
assert received[channel_name] == expected
280+
281+
228282
@pytest.mark.asyncio
229283
async def test_groups_channel_full(channel_layer):
230284
"""

0 commit comments

Comments
 (0)