Skip to content

Commit cac6291

Browse files
authored
Merge branch 'main' into dev/iirzynsk/test_prefix_caching
2 parents 8ae1dec + 5716c5d commit cac6291

File tree

5 files changed

+42
-14
lines changed

5 files changed

+42
-14
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ run_gsm8k_granite_test() {
158158
echo "✅ Test with granite-8b passed."
159159
}
160160

161+
# GSM8K on granite-8b (unified attn)
162+
run_gsm8k_granite_test_unified_attn() {
163+
echo "➡️ Testing GSM8K on granite-8b with unified attention..."
164+
VLLM_UNIFIED_ATTN=True VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
165+
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml"
166+
echo "✅ Test with granite-8b unified attention passed."
167+
}
168+
161169
# GSM8K on granite-8b with async scheduling
162170
run_gsm8k_granite_async_test() {
163171
echo "➡️ Testing GSM8K on granite-8b with async scheduling..."
@@ -230,6 +238,7 @@ launch_all_tests() {
230238
run_compressed_w4a16_channelwise_test
231239
run_compressed_w4a16_moe_gidx_test
232240
run_gsm8k_granite_test
241+
run_gsm8k_granite_test_unified_attn
233242
run_gsm8k_granite_async_test
234243
run_gsm8k_deepseek_test
235244
run_gsm8k_qwen3_30b_test

vllm_gaudi/extension/bucketing/common.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,32 +250,35 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, ctx_idx, ctx_range, max_num_bat
250250
# filter rules for buckets
251251
# prompt
252252
def not_over_max_model_len(bs, query, ctx):
253-
if not bs * (query + ctx * block_size) <= max_model_len:
253+
smaller_than_limit = bs * (query + ctx * block_size) <= max_model_len
254+
if not smaller_than_limit:
254255
omitted_buckets.add(
255256
("condition: bs * (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx))
256-
return bs * (query + ctx * block_size) <= max_model_len
257+
return smaller_than_limit
257258

258259
def not_over_max_num_batched_tokens(bs, query, ctx):
259-
if not bs * query <= max_num_batched_tokens:
260+
smaller_than_limit = bs * query <= max_num_batched_tokens
261+
if not smaller_than_limit:
260262
omitted_buckets.add(
261263
("condition: bs * query <= max_num_batched_tokens", "-> bs, query, ctx: ", bs, query, ctx))
262-
return bs * query <= max_num_batched_tokens
264+
return smaller_than_limit
263265

264266
def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx):
265-
if not ctx <= max_num_prefill_seqs * math.ceil(
266-
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size):
267+
smaller_than_limit = ctx <= max_num_prefill_seqs * math.ceil(
268+
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)
269+
if not smaller_than_limit:
267270
omitted_buckets.add((
268271
"ctx <= max_num_prefill_seqs * math.ceil((max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)",
269272
"-> bs, query, ctx: ", bs, query, ctx))
270-
return ctx <= max_num_prefill_seqs * math.ceil(
271-
(max_model_len - math.floor(query / max_num_prefill_seqs)) // block_size)
273+
return smaller_than_limit
272274

273275
# decode
274276
def block_not_greater_than_max_model_len(bs, query, ctx):
275-
if not ctx <= bs * math.ceil(max_model_len / block_size):
277+
smaller_than_limit = ctx <= bs * math.ceil(max_model_len / block_size)
278+
if not smaller_than_limit:
276279
omitted_buckets.add(
277280
("condition: ctx <= bs * math.ceil(max_model_len / block_size)", "-> bs, query, ctx: ", bs, query, ctx))
278-
return ctx <= bs * math.ceil(max_model_len / block_size)
281+
return smaller_than_limit
279282

280283
def batch_size_smaller_than_blocks(bs, query, ctx):
281284
if not bs <= ctx:

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def warmup_range(config: Tuple[int, int, int]):
109109
"""
110110
bmin, bstep, bmax = config
111111
add_zero_bucket = bmin == 0
112-
if add_zero_bucket:
113-
bmin = bstep
114112
assert bmin <= bmax, ("Min. batch size cannot be greater than max. "
115113
"batch size. If you want to skip warmup, "
116114
"set VLLM_SKIP_WARMUP=true")
115+
if add_zero_bucket:
116+
bmin = bstep
117117
base = itertools.repeat(2)
118118
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
119119
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \

vllm_gaudi/extension/unified.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,12 @@ def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: in
357357

358358
group_ids, group_offsets = indices_and_offsets(num_ctx_blocks)
359359
block_ids = fetch_2d(block_table, group_ids, group_offsets)
360-
block_usages = torch.clamp(
361-
total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size)
360+
#NOTE(kzawora): Originally, we were clamping
361+
# total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1
362+
# I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios
363+
# it made causal mask not cover the first unused token.
364+
# (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29)
365+
block_usages = torch.clamp(total_tokens.index_select(0, group_ids) - group_offsets * block_size, 1, block_size)
362366

363367
ctx = Context(group_ids, group_offsets, block_ids, block_usages)
364368
all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)]

vllm_gaudi/v1/worker/hpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,18 @@ def unified_execute_model(
28252825
self.input_batch.token_ids_cpu_tensor.index_put_((batch.logits_groups_cpu, batch.new_token_positions_cpu),
28262826
sampled_token_ids_cpu)
28272827

2828+
######### UPDATE REQUEST STATE WITH GENERATED TOKENS #########
2829+
num_reqs = len(selected_req_ids)
2830+
for req_id in self.input_batch.req_ids[:num_reqs]:
2831+
req_state = self.requests[req_id]
2832+
i = self.input_batch.req_id_to_index[req_id]
2833+
seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id])
2834+
token_ids = sampled_token_ids[i]
2835+
num_tokens = len(token_ids)
2836+
self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids
2837+
self.input_batch.num_tokens[i] += len(token_ids)
2838+
req_state.output_token_ids.extend(token_ids)
2839+
28282840
model_runner_output = ModelRunnerOutput(
28292841
req_ids=batch.req_ids_cpu,
28302842
req_id_to_index=self.input_batch.req_id_to_index,

0 commit comments

Comments
 (0)