@@ -178,24 +178,37 @@ def _setup_encryption(self, symmetric_encryption_keys):
178
178
179
179
async def send (self , channel , message ):
180
180
"""
181
- Send a message onto a (general or specific) channel.
181
+ Send one or multiple messages onto a (general or specific) channel.
182
+ The `message` can be a single dict or an iterable of dicts.
182
183
"""
184
+ messages = self ._parse_messages (message )
185
+
183
186
# Typecheck
184
- assert isinstance (message , dict ), "message is not a dict"
185
187
assert self .valid_channel_name (channel ), "Channel name not valid"
186
- # Make sure the message does not contain reserved keys
187
- assert "__asgi_channel__" not in message
188
+
188
189
# If it's a process-local channel, strip off local part and stick full name in message
189
190
channel_non_local_name = channel
190
- if "!" in channel :
191
- message = dict (message .items ())
192
- message ["__asgi_channel__" ] = channel
191
+ process_local = "!" in channel
192
+ if process_local :
193
193
channel_non_local_name = self .non_local_name (channel )
194
+
195
+ now = time .time ()
196
+ mapping = {}
197
+ for message in messages :
198
+ assert isinstance (message , dict ), "message is not a dict"
199
+ # Make sure the message does not contain reserved keys
200
+ assert "__asgi_channel__" not in message
201
+ if process_local :
202
+ message = dict (message .items ())
203
+ message ["__asgi_channel__" ] = channel
204
+
205
+ mapping [self .serialize (message )] = now
206
+
194
207
# Write out message into expiring key (avoids big items in list)
195
208
channel_key = self .prefix + channel_non_local_name
196
209
# Pick a connection to the right server - consistent for specific
197
210
# channels, random for general channels
198
- if "!" in channel :
211
+ if process_local :
199
212
index = self .consistent_hash (channel )
200
213
else :
201
214
index = next (self ._send_index_generator )
@@ -207,15 +220,23 @@ async def send(self, channel, message):
207
220
208
221
# Check the length of the list before send
209
222
# This can allow the list to leak slightly over capacity, but that's fine.
210
- if await connection .zcount (channel_key , "-inf" , "+inf" ) >= self . get_capacity (
211
- channel
212
- ):
223
+ current_length = await connection .zcount (channel_key , "-inf" , "+inf" )
224
+
225
+ if current_length + len ( messages ) > self . get_capacity ( channel ):
213
226
raise ChannelFull ()
214
227
215
228
# Push onto the list then set it to expire in case it's not consumed
216
- await connection .zadd (channel_key , { self . serialize ( message ): time . time ()} )
229
+ await connection .zadd (channel_key , mapping )
217
230
await connection .expire (channel_key , int (self .expiry ))
218
231
232
+ def _parse_messages (self , message ):
233
+ """
234
+ Convert a passed message arg to a tuple of messages.
235
+ """
236
+ if not isinstance (message , dict ) and hasattr (message , "__iter__" ):
237
+ return tuple (message )
238
+ return (message ,)
239
+
219
240
def _backup_channel_name (self , channel ):
220
241
"""
221
242
Construct the key used as a backup queue for the given channel.
@@ -519,8 +540,11 @@ async def group_discard(self, group, channel):
519
540
520
541
async def group_send (self , group , message ):
521
542
"""
522
- Sends a message to the entire group.
543
+ Sends one or multiple messages to the entire group.
544
+ The `message` can be a single dict or an iterable of dicts.
523
545
"""
546
+ messages = self ._parse_messages (message )
547
+
524
548
assert self .valid_group_name (group ), "Group name not valid"
525
549
# Retrieve list of all channel names
526
550
key = self ._group_key (group )
@@ -536,7 +560,7 @@ async def group_send(self, group, message):
536
560
connection_to_channel_keys ,
537
561
channel_keys_to_message ,
538
562
channel_keys_to_capacity ,
539
- ) = self ._map_channel_keys_to_connection (channel_names , message )
563
+ ) = self ._map_channel_keys_to_connection (channel_names , messages )
540
564
541
565
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
542
566
# Discard old messages based on expiry
@@ -548,17 +572,23 @@ async def group_send(self, group, message):
548
572
await pipe .execute ()
549
573
550
574
# Create a LUA script specific for this connection.
551
- # Make sure to use the message specific to this channel, it is
552
- # stored in channel_to_message dict and contains the
575
+ # Make sure to use the message list specific to this channel, it is
576
+ # stored in channel_to_message dict and each message contains the
553
577
# __asgi_channel__ key.
554
578
555
579
group_send_lua = """
556
580
local over_capacity = 0
581
+ local num_messages = tonumber(ARGV[#ARGV - 2])
557
582
local current_time = ARGV[#ARGV - 1]
558
583
local expiry = ARGV[#ARGV]
559
584
for i=1,#KEYS do
560
- if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
561
- redis.call('ZADD', KEYS[i], current_time, ARGV[i])
585
+ if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
586
+ local messages = {}
587
+ for j=num_messages * (i - 1) + 1, num_messages * i do
588
+ table.insert(messages, current_time)
589
+ table.insert(messages, ARGV[j])
590
+ end
591
+ redis.call('ZADD', KEYS[i], unpack(messages))
562
592
redis.call('EXPIRE', KEYS[i], expiry)
563
593
else
564
594
over_capacity = over_capacity + 1
@@ -568,18 +598,18 @@ async def group_send(self, group, message):
568
598
"""
569
599
570
600
# We need to filter the messages to keep those related to the connection
571
- args = [
572
- channel_keys_to_message [ channel_key ]
573
- for channel_key in channel_redis_keys
574
- ]
601
+ args = []
602
+
603
+ for channel_key in channel_redis_keys :
604
+ args += channel_keys_to_message [ channel_key ]
575
605
576
606
# We need to send the capacity for each channel
577
607
args += [
578
608
channel_keys_to_capacity [channel_key ]
579
609
for channel_key in channel_redis_keys
580
610
]
581
611
582
- args += [time .time (), self .expiry ]
612
+ args += [len ( messages ), time .time (), self .expiry ]
583
613
584
614
# channel_keys does not contain a single redis key more than once
585
615
connection = self .connection (connection_index )
@@ -594,7 +624,7 @@ async def group_send(self, group, message):
594
624
group ,
595
625
)
596
626
597
- def _map_channel_keys_to_connection (self , channel_names , message ):
627
+ def _map_channel_keys_to_connection (self , channel_names , messages ):
598
628
"""
599
629
For a list of channel names, GET
600
630
@@ -609,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609
639
# Connection dict keyed by index to list of redis keys mapped on that index
610
640
connection_to_channel_keys = collections .defaultdict (list )
611
641
# Message dict maps redis key to the message that needs to be send on that key
612
- channel_key_to_message = dict ( )
642
+ channel_key_to_message = collections . defaultdict ( list )
613
643
# Channel key mapped to its capacity
614
644
channel_key_to_capacity = dict ()
615
645
@@ -623,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623
653
# Have we come across the same redis key?
624
654
if channel_key not in channel_key_to_message :
625
655
# If not, fill the corresponding dicts
626
- message = dict (message .items ())
627
- message ["__asgi_channel__" ] = [channel ]
628
- channel_key_to_message [channel_key ] = message
656
+ for message in messages :
657
+ message = dict (message .items ())
658
+ message ["__asgi_channel__" ] = [channel ]
659
+ channel_key_to_message [channel_key ].append (message )
629
660
channel_key_to_capacity [channel_key ] = self .get_capacity (channel )
630
661
idx = self .consistent_hash (channel_non_local_name )
631
662
connection_to_channel_keys [idx ].append (channel_key )
632
663
else :
633
664
# Yes, Append the channel in message dict
634
- channel_key_to_message [channel_key ]["__asgi_channel__" ].append (channel )
665
+ for message in channel_key_to_message [channel_key ]:
666
+ message ["__asgi_channel__" ].append (channel )
635
667
636
668
# Now that we know what message needs to be send on a redis key we serialize it
637
669
for key , value in channel_key_to_message .items ():
638
670
# Serialize the message stored for each redis key
639
- channel_key_to_message [key ] = self .serialize (value )
671
+ for idx , message in enumerate (value ):
672
+ channel_key_to_message [key ][idx ] = self .serialize (message )
640
673
641
674
return (
642
675
connection_to_channel_keys ,
0 commit comments