@@ -180,22 +180,40 @@ async def send(self, channel, message):
180
180
"""
181
181
Send a message onto a (general or specific) channel.
182
182
"""
183
+ await self .send_bulk (channel , (message ,))
184
+
185
+ async def send_bulk (self , channel , messages ):
186
+ """
187
+ Send multiple messages in bulk onto a (general or specific) channel.
188
+ The `messages` argument should be an iterable of dicts.
189
+ """
190
+
183
191
# Typecheck
184
- assert isinstance (message , dict ), "message is not a dict"
185
192
assert self .valid_channel_name (channel ), "Channel name not valid"
186
- # Make sure the message does not contain reserved keys
187
- assert "__asgi_channel__" not in message
193
+
188
194
# If it's a process-local channel, strip off local part and stick full name in message
189
195
channel_non_local_name = channel
190
- if "!" in channel :
191
- message = dict (message .items ())
192
- message ["__asgi_channel__" ] = channel
196
+ process_local = "!" in channel
197
+ if process_local :
193
198
channel_non_local_name = self .non_local_name (channel )
199
+
200
+ now = time .time ()
201
+ mapping = {}
202
+ for message in messages :
203
+ assert isinstance (message , dict ), "message is not a dict"
204
+ # Make sure the message does not contain reserved keys
205
+ assert "__asgi_channel__" not in message
206
+ if process_local :
207
+ message = dict (message .items ())
208
+ message ["__asgi_channel__" ] = channel
209
+
210
+ mapping [self .serialize (message )] = now
211
+
194
212
# Write out message into expiring key (avoids big items in list)
195
213
channel_key = self .prefix + channel_non_local_name
196
214
# Pick a connection to the right server - consistent for specific
197
215
# channels, random for general channels
198
- if "!" in channel :
216
+ if process_local :
199
217
index = self .consistent_hash (channel )
200
218
else :
201
219
index = next (self ._send_index_generator )
@@ -207,13 +225,13 @@ async def send(self, channel, message):
207
225
208
226
# Check the length of the list before send
209
227
# This can allow the list to leak slightly over capacity, but that's fine.
210
- if await connection .zcount (channel_key , "-inf" , "+inf" ) >= self . get_capacity (
211
- channel
212
- ):
228
+ current_length = await connection .zcount (channel_key , "-inf" , "+inf" )
229
+
230
+ if current_length + len ( messages ) > self . get_capacity ( channel ):
213
231
raise ChannelFull ()
214
232
215
233
# Push onto the list then set it to expire in case it's not consumed
216
- await connection .zadd (channel_key , { self . serialize ( message ): time . time ()} )
234
+ await connection .zadd (channel_key , mapping )
217
235
await connection .expire (channel_key , int (self .expiry ))
218
236
219
237
def _backup_channel_name (self , channel ):
@@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
517
535
connection = self .connection (self .consistent_hash (group ))
518
536
await connection .zrem (key , channel )
519
537
520
- async def group_send (self , group , message ):
521
- """
522
- Sends a message to the entire group.
523
- """
538
+ async def _get_group_connection_and_channels (self , group ):
524
539
assert self .valid_group_name (group ), "Group name not valid"
525
540
# Retrieve list of all channel names
526
541
key = self ._group_key (group )
@@ -532,11 +547,36 @@ async def group_send(self, group, message):
532
547
533
548
channel_names = [x .decode ("utf8" ) for x in await connection .zrange (key , 0 , - 1 )]
534
549
550
+ return connection , channel_names
551
+
552
+ async def _exec_group_lua_script (
553
+ self , conn_idx , group , channel_redis_keys , channel_names , script , args
554
+ ):
555
+ # channel_keys does not contain a single redis key more than once
556
+ connection = self .connection (conn_idx )
557
+ channels_over_capacity = await connection .eval (
558
+ script , len (channel_redis_keys ), * channel_redis_keys , * args
559
+ )
560
+ if channels_over_capacity > 0 :
561
+ logger .info (
562
+ "%s of %s channels over capacity in group %s" ,
563
+ channels_over_capacity ,
564
+ len (channel_names ),
565
+ group ,
566
+ )
567
+
568
+ async def group_send (self , group , message ):
569
+ """
570
+ Sends a message to the entire group.
571
+ """
572
+
573
+ connection , channel_names = await self ._get_group_connection_and_channels (group )
574
+
535
575
(
536
576
connection_to_channel_keys ,
537
577
channel_keys_to_message ,
538
578
channel_keys_to_capacity ,
539
- ) = self ._map_channel_keys_to_connection (channel_names , message )
579
+ ) = self ._map_channel_keys_to_connection (channel_names , ( message ,) )
540
580
541
581
for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
542
582
# Discard old messages based on expiry
@@ -569,7 +609,7 @@ async def group_send(self, group, message):
569
609
570
610
# We need to filter the messages to keep those related to the connection
571
611
args = [
572
- channel_keys_to_message [channel_key ]
612
+ channel_keys_to_message [channel_key ][ 0 ]
573
613
for channel_key in channel_redis_keys
574
614
]
575
615
@@ -581,20 +621,88 @@ async def group_send(self, group, message):
581
621
582
622
args += [time .time (), self .expiry ]
583
623
584
- # channel_keys does not contain a single redis key more than once
585
- connection = self .connection (connection_index )
586
- channels_over_capacity = await connection .eval (
587
- group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
624
+ await self ._exec_group_lua_script (
625
+ connection_index ,
626
+ group ,
627
+ channel_redis_keys ,
628
+ channel_names ,
629
+ group_send_lua ,
630
+ args ,
588
631
)
589
- if channels_over_capacity > 0 :
590
- logger .info (
591
- "%s of %s channels over capacity in group %s" ,
592
- channels_over_capacity ,
593
- len (channel_names ),
594
- group ,
632
+
633
+ async def group_send_bulk (self , group , messages ):
634
+ """
635
+ Sends multiple messages in bulk to the entire group.
636
+ The `messages` argument should be an iterable of dicts.
637
+ """
638
+
639
+ connection , channel_names = await self ._get_group_connection_and_channels (group )
640
+
641
+ (
642
+ connection_to_channel_keys ,
643
+ channel_keys_to_message ,
644
+ channel_keys_to_capacity ,
645
+ ) = self ._map_channel_keys_to_connection (channel_names , messages )
646
+
647
+ for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
648
+ # Discard old messages based on expiry
649
+ pipe = connection .pipeline ()
650
+ for key in channel_redis_keys :
651
+ pipe .zremrangebyscore (
652
+ key , min = 0 , max = int (time .time ()) - int (self .expiry )
595
653
)
654
+ await pipe .execute ()
655
+
656
+ # Create a LUA script specific for this connection.
657
+ # Make sure to use the message list specific to this channel, it is
658
+ # stored in channel_to_message dict and each message contains the
659
+ # __asgi_channel__ key.
660
+
661
+ group_send_lua = """
662
+ local over_capacity = 0
663
+ local num_messages = tonumber(ARGV[#ARGV - 2])
664
+ local current_time = ARGV[#ARGV - 1]
665
+ local expiry = ARGV[#ARGV]
666
+ for i=1,#KEYS do
667
+ if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
668
+ local messages = {}
669
+ for j=num_messages * (i - 1) + 1, num_messages * i do
670
+ table.insert(messages, current_time)
671
+ table.insert(messages, ARGV[j])
672
+ end
673
+ redis.call('ZADD', KEYS[i], unpack(messages))
674
+ redis.call('EXPIRE', KEYS[i], expiry)
675
+ else
676
+ over_capacity = over_capacity + 1
677
+ end
678
+ end
679
+ return over_capacity
680
+ """
681
+
682
+ # We need to filter the messages to keep those related to the connection
683
+ args = []
684
+
685
+ for channel_key in channel_redis_keys :
686
+ args += channel_keys_to_message [channel_key ]
687
+
688
+ # We need to send the capacity for each channel
689
+ args += [
690
+ channel_keys_to_capacity [channel_key ]
691
+ for channel_key in channel_redis_keys
692
+ ]
596
693
597
- def _map_channel_keys_to_connection (self , channel_names , message ):
694
+ args += [len (messages ), time .time (), self .expiry ]
695
+
696
+ await self ._exec_group_lua_script (
697
+ connection_index ,
698
+ group ,
699
+ channel_redis_keys ,
700
+ channel_names ,
701
+ group_send_lua ,
702
+ args ,
703
+ )
704
+
705
+ def _map_channel_keys_to_connection (self , channel_names , messages ):
598
706
"""
599
707
For a list of channel names, GET
600
708
@@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609
717
# Connection dict keyed by index to list of redis keys mapped on that index
610
718
connection_to_channel_keys = collections .defaultdict (list )
611
719
# Message dict maps redis key to the message that needs to be send on that key
612
- channel_key_to_message = dict ( )
720
+ channel_key_to_message = collections . defaultdict ( list )
613
721
# Channel key mapped to its capacity
614
722
channel_key_to_capacity = dict ()
615
723
@@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623
731
# Have we come across the same redis key?
624
732
if channel_key not in channel_key_to_message :
625
733
# If not, fill the corresponding dicts
626
- message = dict (message .items ())
627
- message ["__asgi_channel__" ] = [channel ]
628
- channel_key_to_message [channel_key ] = message
734
+ for message in messages :
735
+ message = dict (message .items ())
736
+ message ["__asgi_channel__" ] = [channel ]
737
+ channel_key_to_message [channel_key ].append (message )
629
738
channel_key_to_capacity [channel_key ] = self .get_capacity (channel )
630
739
idx = self .consistent_hash (channel_non_local_name )
631
740
connection_to_channel_keys [idx ].append (channel_key )
632
741
else :
633
742
# Yes, Append the channel in message dict
634
- channel_key_to_message [channel_key ]["__asgi_channel__" ].append (channel )
743
+ for message in channel_key_to_message [channel_key ]:
744
+ message ["__asgi_channel__" ].append (channel )
635
745
636
746
# Now that we know what message needs to be send on a redis key we serialize it
637
747
for key , value in channel_key_to_message .items ():
638
748
# Serialize the message stored for each redis key
639
- channel_key_to_message [key ] = self .serialize (value )
749
+ for idx , message in enumerate (value ):
750
+ channel_key_to_message [key ][idx ] = self .serialize (message )
640
751
641
752
return (
642
753
connection_to_channel_keys ,
0 commit comments