Skip to content

Commit a7b2a41

Browse files
adobrzynmichalkuligowskikzawora-intel
authored
Unified attention improvemets (#363)
- [x] warmup funcioning - [x] no recompiles --------- Signed-off-by: Agata Dobrzyniewicz <[email protected]> Co-authored-by: Michał Kuligowski <[email protected]> Co-authored-by: Konrad Zawora <[email protected]>
1 parent b0bd04b commit a7b2a41

File tree

3 files changed

+32
-29
lines changed

3 files changed

+32
-29
lines changed

vllm_gaudi/extension/bucketing/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,12 @@ def generate_unified_buckets(query_range, shared_ctx_range, unique_ctx_range, bs
340340
max_bs = min(bs, query)
341341
if math.ceil(shared_ctx * block_size // max_bs) <= max_model_len:
342342
buckets.add((query, shared_ctx, unique_ctx, causal))
343-
elif (query <= bs):
343+
elif query <= bs:
344344
# non causal query = current bs
345345
if shared_ctx > 0 or unique_ctx > 0:
346-
if shared_ctx == 0 or (query > 1 and \
347-
math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len):
348-
buckets.add((query, shared_ctx, unique_ctx, causal))
346+
if shared_ctx == 0 or (math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len):
347+
if shared_ctx > 0 or query <= unique_ctx:
348+
buckets.add((query, shared_ctx, unique_ctx, causal))
349349

350350
return sorted(buckets)
351351

vllm_gaudi/extension/bucketing/unified.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class UnifiedBucketingStrategy():
1414

1515
def get_unified_cfgs(self, bs, max_model_len, block_size, max_blocks, max_num_batched_tokens):
1616
# [min, max, turning_point]
17-
query_cfg = [block_size, max_num_batched_tokens, bs]
18-
max_shared_ctx = math.ceil(max_model_len // block_size) * bs
17+
query_cfg = [1, max_num_batched_tokens, bs]
18+
max_shared_ctx = min(math.ceil(max_model_len // block_size), max_blocks)
1919
shared_ctx_cfg = [0, max_shared_ctx, bs]
2020
max_unique_ctx = max_blocks
2121
unique_ctx_cfg = [0, max_unique_ctx, bs]
@@ -28,19 +28,24 @@ def get_range(self, cfg):
2828

2929
def warmup_unified_range(cfg):
3030
bmin, bmax, turning_point = cfg
31+
limit = 10
32+
round_up = 128
3133

3234
buckets: Set[Tuple[int, int]] = set()
3335

3436
if bmin == 0:
3537
buckets.add(bmin)
36-
37-
# alpha version: [bs/4, bs/2, bs, bt/4, bt/2, bt]
38-
39-
buckets.add(turning_point // 4)
40-
buckets.add(turning_point // 2)
41-
buckets.add(turning_point)
42-
buckets.add(bmax // 4)
43-
buckets.add(bmax // 2)
44-
buckets.add(bmax)
38+
bmin = 1
39+
40+
num_buckets_exp = limit
41+
first_step = bmax
42+
43+
for i in range(num_buckets_exp):
44+
power_unpadded = bmin * np.float_power(first_step / bmin, (1. / float(num_buckets_exp - 1)) * i)
45+
if i == limit - 1:
46+
bucket = bmax
47+
else:
48+
bucket = math.ceil(power_unpadded / round_up) * round_up
49+
buckets.add(bucket)
4550

4651
return list(sorted(buckets))

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -863,16 +863,8 @@ def unified_bucketing_fn(self, is_causal, query_len, shared_blocks, unique_block
863863
if not get_config().use_bucketing:
864864
return query_len, shared_blocks, unique_blocks, logits
865865

866-
def bucketize(x, buckets):
867-
if x < buckets[-1]:
868-
return next(b for b in buckets if b >= x)
869-
else:
870-
return round_up(x, buckets[-1])
871-
872-
logits_buckets = [self.max_num_seqs]
873-
logits = min(bucketize(logits, logits_buckets), query_len)
874866
new_bucket = self.bucketing_manager.find_unified_bucket(query_len, shared_blocks, unique_blocks, is_causal)
875-
return (new_bucket[0], new_bucket[1], new_bucket[2], logits)
867+
return (new_bucket[0], new_bucket[1], new_bucket[2], self.max_num_seqs)
876868

877869
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: list[int], is_prompt: bool):
878870
'''
@@ -1491,7 +1483,7 @@ def _generate_req_id_output_token_ids_lst(self,
14911483
# Merged prefill case: remove requests without logits
14921484
req_id_output_token_ids_lst = [r for r in req_id_output_token_ids_lst if r[0] in logits_reqs]
14931485
else:
1494-
if pad_to is not None:
1486+
if pad_to is not None and len(req_id_output_token_ids_lst) > 0:
14951487
while len(req_id_output_token_ids_lst) < pad_to:
14961488
req_id_output_token_ids_lst.append(req_id_output_token_ids_lst[0])
14971489
return req_id_output_token_ids_lst
@@ -3858,12 +3850,10 @@ def _prepare_dummy_unified_scenario(self, unified_cfg):
38583850
for query, blocks in zip(prompt_reqs_query, prompt_reqs_blocks):
38593851
self._add_dummy_unified_request(requests, True, False, blocks, num_computed_tokens, query,
38603852
scheduled_tokens)
3861-
38623853
else:
38633854
remaining_samples = query_len
38643855
base = shared_ctx_len // remaining_samples
38653856
remain = shared_ctx_len % remaining_samples
3866-
38673857
all_shared_blocks_ids = [block for block in range(shared_ctx_len)]
38683858
unique_block = unique_ctx_len - 1
38693859
# do not use unique block id
@@ -3887,8 +3877,16 @@ def _prepare_dummy_unified_scenario(self, unified_cfg):
38873877
split_shared_blocks_ids[target].append(block)
38883878

38893879
# add unique id
3890-
min_idx = min(range(remaining_samples), key=lambda j: len(split_shared_blocks_ids[j]))
3891-
split_shared_blocks_ids[min_idx].append(unique_block)
3880+
if unique_ctx_len > 0:
3881+
min_idx = min(range(remaining_samples), key=lambda j: len(split_shared_blocks_ids[j]))
3882+
split_shared_blocks_ids[min_idx].append(unique_block)
3883+
3884+
for i in range(len(split_shared_blocks_ids)):
3885+
if not split_shared_blocks_ids[i]:
3886+
if unique_block - i >= 0:
3887+
split_shared_blocks_ids[i] = [unique_block - i]
3888+
else:
3889+
split_shared_blocks_ids[i] = [all_shared_blocks_ids[0]]
38923890

38933891
for request_blocks in split_shared_blocks_ids:
38943892
self._add_dummy_unified_request(requests, False, False, request_blocks, num_computed_tokens, 1,

0 commit comments

Comments
 (0)