@@ -250,32 +250,35 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_bat
250250 # filter rules for buckets
251251 # prompt
252252 def not_over_max_model_len (bs , query , ctx ):
253- if not bs * (query + ctx * block_size ) <= max_model_len :
253+ smaller_than_limit = bs * (query + ctx * block_size ) <= max_model_len
254+ if not smaller_than_limit :
254255 omitted_buckets .add (
255256 ("condition: bs * (query + ctx * block_size) <= max_model_len" , "-> bs, query, ctx: " , bs , query , ctx ))
256- return bs * ( query + ctx * block_size ) <= max_model_len
257+ return smaller_than_limit
257258
258259 def not_over_max_num_batched_tokens (bs , query , ctx ):
259- if not bs * query <= max_num_batched_tokens :
260+ smaller_than_limit = bs * query <= max_num_batched_tokens
261+ if not smaller_than_limit :
260262 omitted_buckets .add (
261263 ("condition: bs * query <= max_num_batched_tokens" , "-> bs, query, ctx: " , bs , query , ctx ))
262- return bs * query <= max_num_batched_tokens
264+ return smaller_than_limit
263265
264266 def ctx_not_over_max_ctx_for_merged_prefill (bs , query , ctx ):
265- if not ctx <= max_num_prefill_seqs * math .ceil (
266- (max_model_len - math .floor (query / max_num_prefill_seqs )) // block_size ):
267+ smaller_than_limit = ctx <= max_num_prefill_seqs * math .ceil (
268+ (max_model_len - math .floor (query / max_num_prefill_seqs )) // block_size )
269+ if not smaller_than_limit :
267270 omitted_buckets .add ((
268271 "ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)" ,
269272 "-> bs, query, ctx: " , bs , query , ctx ))
270- return ctx <= max_num_prefill_seqs * math .ceil (
271- (max_model_len - math .floor (query / max_num_prefill_seqs )) // block_size )
273+ return smaller_than_limit
272274
273275 # decode
274276 def block_not_greater_than_max_model_len (bs , query , ctx ):
275- if not ctx <= bs * math .ceil (max_model_len / block_size ):
277+ smaller_than_limit = ctx <= bs * math .ceil (max_model_len / block_size )
278+ if not smaller_than_limit :
276279 omitted_buckets .add (
277280 ("condition: ctx <= bs * math.ceil(max_model_len / block_size)" , "-> bs, query, ctx: " , bs , query , ctx ))
278- return ctx <= bs * math . ceil ( max_model_len / block_size )
281+ return smaller_than_limit
279282
280283 def batch_size_smaller_than_blocks (bs , query , ctx ):
281284 if not bs <= ctx :
0 commit comments