@@ -863,16 +863,8 @@ def unified_bucketing_fn(self, is_causal, query_len, shared_blocks, unique_block
863863 if not get_config ().use_bucketing :
864864 return query_len , shared_blocks , unique_blocks , logits
865865
866- def bucketize (x , buckets ):
867- if x < buckets [- 1 ]:
868- return next (b for b in buckets if b >= x )
869- else :
870- return round_up (x , buckets [- 1 ])
871-
872- logits_buckets = [self .max_num_seqs ]
873- logits = min (bucketize (logits , logits_buckets ), query_len )
874866 new_bucket = self .bucketing_manager .find_unified_bucket (query_len , shared_blocks , unique_blocks , is_causal )
875- return (new_bucket [0 ], new_bucket [1 ], new_bucket [2 ], logits )
867+ return (new_bucket [0 ], new_bucket [1 ], new_bucket [2 ], self . max_num_seqs )
876868
877869 def create_lora_mask (self , input_tokens : torch .Tensor , lora_ids : list [int ], is_prompt : bool ):
878870 '''
@@ -1491,7 +1483,7 @@ def _generate_req_id_output_token_ids_lst(self,
14911483 # Merged prefill case: remove requests without logits
14921484 req_id_output_token_ids_lst = [r for r in req_id_output_token_ids_lst if r [0 ] in logits_reqs ]
14931485 else :
1494- if pad_to is not None :
1486+ if pad_to is not None and len ( req_id_output_token_ids_lst ) > 0 :
14951487 while len (req_id_output_token_ids_lst ) < pad_to :
14961488 req_id_output_token_ids_lst .append (req_id_output_token_ids_lst [0 ])
14971489 return req_id_output_token_ids_lst
@@ -3858,12 +3850,10 @@ def _prepare_dummy_unified_scenario(self, unified_cfg):
38583850 for query , blocks in zip (prompt_reqs_query , prompt_reqs_blocks ):
38593851 self ._add_dummy_unified_request (requests , True , False , blocks , num_computed_tokens , query ,
38603852 scheduled_tokens )
3861-
38623853 else :
38633854 remaining_samples = query_len
38643855 base = shared_ctx_len // remaining_samples
38653856 remain = shared_ctx_len % remaining_samples
3866-
38673857 all_shared_blocks_ids = [block for block in range (shared_ctx_len )]
38683858 unique_block = unique_ctx_len - 1
38693859 # do not use unique block id
@@ -3887,8 +3877,16 @@ def _prepare_dummy_unified_scenario(self, unified_cfg):
38873877 split_shared_blocks_ids [target ].append (block )
38883878
38893879 # add unique id
3890- min_idx = min (range (remaining_samples ), key = lambda j : len (split_shared_blocks_ids [j ]))
3891- split_shared_blocks_ids [min_idx ].append (unique_block )
3880+ if unique_ctx_len > 0 :
3881+ min_idx = min (range (remaining_samples ), key = lambda j : len (split_shared_blocks_ids [j ]))
3882+ split_shared_blocks_ids [min_idx ].append (unique_block )
3883+
3884+ for i in range (len (split_shared_blocks_ids )):
3885+ if not split_shared_blocks_ids [i ]:
3886+ if unique_block - i >= 0 :
3887+ split_shared_blocks_ids [i ] = [unique_block - i ]
3888+ else :
3889+ split_shared_blocks_ids [i ] = [all_shared_blocks_ids [0 ]]
38923890
38933891 for request_blocks in split_shared_blocks_ids :
38943892 self ._add_dummy_unified_request (requests , False , False , request_blocks , num_computed_tokens , 1 ,
0 commit comments