Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/full_tests/ci_gsm8k_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ run_gsm8k_granite_test() {
echo "✅ Test with granite-8b passed."
}

# GSM8K on granite-8b (unified attn)
run_gsm8k_granite_test_unified_attn() {
echo "➡️ Testing GSM8K on granite-8b with unified attention..."
VLLM_UNIFIED_ATTN=True VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 \
pytest -v -s "${VLLM_GAUDI_PREFIX}/tests/models/language/generation/test_common.py" --model_card_path "${VLLM_GAUDI_PREFIX}/tests/full_tests/model_cards/granite-8b.yaml"
echo "✅ Test with granite-8b unified attention passed."
}

# GSM8K on granite-8b with async scheduling
run_gsm8k_granite_async_test() {
echo "➡️ Testing GSM8K on granite-8b with async scheduling..."
Expand Down Expand Up @@ -230,6 +238,7 @@ launch_all_tests() {
run_compressed_w4a16_channelwise_test
run_compressed_w4a16_moe_gidx_test
run_gsm8k_granite_test
run_gsm8k_granite_test_unified_attn
run_gsm8k_granite_async_test
run_gsm8k_deepseek_test
run_gsm8k_qwen3_30b_test
Expand Down
8 changes: 4 additions & 4 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,12 @@ def generate_unified_buckets(query_range, shared_ctx_range, unique_ctx_range, bs
max_bs = min(bs, query)
if math.ceil(shared_ctx * block_size // max_bs) <= max_model_len:
buckets.add((query, shared_ctx, unique_ctx, causal))
elif (query <= bs):
elif query <= bs:
# non causal query = current bs
if shared_ctx > 0 or unique_ctx > 0:
if shared_ctx == 0 or (query > 1 and \
math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len):
buckets.add((query, shared_ctx, unique_ctx, causal))
if shared_ctx == 0 or (math.ceil(shared_ctx * block_size // (query // 2)) <= max_model_len):
if shared_ctx > 0 or query <= unique_ctx:
buckets.add((query, shared_ctx, unique_ctx, causal))

return sorted(buckets)

Expand Down
27 changes: 16 additions & 11 deletions vllm_gaudi/extension/bucketing/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class UnifiedBucketingStrategy():

def get_unified_cfgs(self, bs, max_model_len, block_size, max_blocks, max_num_batched_tokens):
# [min, max, turning_point]
query_cfg = [block_size, max_num_batched_tokens, bs]
max_shared_ctx = math.ceil(max_model_len // block_size) * bs
query_cfg = [1, max_num_batched_tokens, bs]
max_shared_ctx = min(math.ceil(max_model_len // block_size), max_blocks)
shared_ctx_cfg = [0, max_shared_ctx, bs]
max_unique_ctx = max_blocks
unique_ctx_cfg = [0, max_unique_ctx, bs]
Expand All @@ -28,19 +28,24 @@ def get_range(self, cfg):

def warmup_unified_range(cfg):
bmin, bmax, turning_point = cfg
limit = 10
round_up = 128

buckets: Set[Tuple[int, int]] = set()

if bmin == 0:
buckets.add(bmin)

# alpha version: [bs/4, bs/2, bs, bt/4, bt/2, bt]

buckets.add(turning_point // 4)
buckets.add(turning_point // 2)
buckets.add(turning_point)
buckets.add(bmax // 4)
buckets.add(bmax // 2)
buckets.add(bmax)
bmin = 1

num_buckets_exp = limit
first_step = bmax

for i in range(num_buckets_exp):
power_unpadded = bmin * np.float_power(first_step / bmin, (1. / float(num_buckets_exp - 1)) * i)
if i == limit - 1:
bucket = bmax
else:
bucket = math.ceil(power_unpadded / round_up) * round_up
buckets.add(bucket)

return list(sorted(buckets))
9 changes: 9 additions & 0 deletions vllm_gaudi/extension/defragmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def initialize(self, kv_caches: tuple[tuple[torch.tensor, torch.tensor]], block_
if self.debug:
self.debug('initialized')

def clear_state(self):
""" Clear internal state (e.g. after warmup) """
self.used_blocks.clear()
self.req_blocks.clear()
self.fwd_mapping_table.clear()
self.bwd_mapping_table.clear()
if self.debug:
self.debug('state cleared')

def _extend_mapping_table(self, block_id: int):
""" Make sure mapping_tables are big enough to hold block_id """
if len(self.fwd_mapping_table) <= block_id:
Expand Down
4 changes: 4 additions & 0 deletions vllm_gaudi/extension/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,9 @@ def get_features():
Value('unified_attn', False),
Value('scale_adjustment', True, env_var='VLLM_SCALE_ADJUSTMENT', env_var_type=boolean),
Value('flatten_input', Any(ModelType('qwen3_moe'), ModelType('granitemoe'), ModelType('glm4_moe'))),
Value('unified_attn_shared_cache_ratio',
1.,
env_var='VLLM_UNIFIED_ATTENTION_SHARED_CACHE_RATIO',
env_var_type=float),
]
return split_values_and_flags(features)
87 changes: 61 additions & 26 deletions vllm_gaudi/extension/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,25 @@ def to_hpu(data: Optional[Union[torch.tensor, list]], dtype: Optional[torch.dtyp
return to_hpu(torch.tensor(data, dtype=dtype, device='cpu'))


def mask_to_bias(mask: torch.tensor, dtype: torch.dtype) -> torch.tensor:
def mask_to_bias(mask: torch.tensor, dtype: torch.dtype, bias_placeholder: torch.tensor = None) -> torch.tensor:
"""Convert attn mask to attn bias"""
if bias_placeholder is not None:
# CRITICAL: Must clone to avoid corrupting persistent array across batches
bias = bias_placeholder[:mask.shape[0], :mask.shape[1]].clone()
assert bias.shape == mask.shape
bias.fill_(0)
bias.masked_fill_(mask, -math.inf)
return bias
return torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)


def create_causal_bias(groups: torch.tensor, positions: torch.tensor, dtype: torch.dtype) -> torch.tensor:
def create_causal_bias(groups: torch.tensor, positions: torch.tensor, dtype: torch.dtype,
bias_placeholder: torch.tensor) -> torch.tensor:
"""Create causal bias from groups and positions"""
group_mask = groups.unsqueeze(-1) != groups.unsqueeze(0)
position_mask = positions.unsqueeze(-1) < positions.unsqueeze(0)
causal_mask = (group_mask | position_mask)
return mask_to_bias(causal_mask, dtype)
return mask_to_bias(causal_mask, dtype, bias_placeholder)


def indices_and_offsets(counts: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
Expand Down Expand Up @@ -322,11 +330,11 @@ def group_sum(groups: torch.tensor, values: torch.tensor):
return tmp.index_select(0, groups)


def generate_bias(block_usages: torch.tensor, block_size: torch.tensor, dtype: torch.dtype) -> torch.tensor:
def generate_bias(block_usages: torch.tensor, block_size: torch.tensor, dtype: torch.dtype, block_len_range,
bias_placeholder: torch.tensor) -> torch.tensor:
""" Generate block bias based on block_usage """
block_len_range = torch.arange(1, block_size + 1, dtype=block_usages.dtype, device=block_usages.device)
block_mask = block_len_range.unsqueeze(0) > block_usages.unsqueeze(-1)
return mask_to_bias(block_mask, dtype=dtype)
return mask_to_bias(block_mask, dtype=dtype, bias_placeholder=bias_placeholder)


@dataclass
Expand Down Expand Up @@ -357,8 +365,12 @@ def create(total_tokens: torch.tensor, block_table: torch.tensor, block_size: in

group_ids, group_offsets = indices_and_offsets(num_ctx_blocks)
block_ids = fetch_2d(block_table, group_ids, group_offsets)
block_usages = torch.clamp(
total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1, 1, block_size)
#NOTE(kzawora): Originally, we were clamping
# total_tokens.index_select(0, group_ids) - group_offsets * block_size + 1
# I'm not sure why +1 was there originally, but in non-block-aligned prefix-prefill scenarios
# it made causal mask not cover the first unused token.
# (e.g. with context 28, the 28th slot was unmasked, causing the effective context length to be 29)
block_usages = torch.clamp(total_tokens.index_select(0, group_ids) - group_offsets * block_size, 1, block_size)

ctx = Context(group_ids, group_offsets, block_ids, block_usages)
all_shapes = [v.shape for v in ctx._values() if torch.is_tensor(v)]
Expand Down Expand Up @@ -400,11 +412,35 @@ def hpu_tensor(tensor: torch.tensor, shape: tuple, pad_value: Union[int, float])
return to_hpu(tensor)


class UnifiedBatchPersistentContext:

def __init__(self, max_num_batched_tokens, max_shared_blocks, max_unique_blocks, block_size, dtype):
self.shared_bias = torch.full((max_num_batched_tokens, max_shared_blocks, block_size),
-math.inf,
dtype=dtype,
device='cpu')

# NOTE(kzawora): shared block bias is a weird entity - it maps block usage to each individual token in the context -
# so the upper bound should be max_shared_blocks*block_size
self.shared_block_bias = torch.full((max_shared_blocks * block_size, block_size),
-math.inf,
dtype=dtype,
device='cpu') # ?

self.unique_bias = torch.full((max_unique_blocks, block_size), -math.inf, dtype=dtype, device='cpu')
self.unique_block_bias = torch.full((max_unique_blocks, block_size), -math.inf, dtype=dtype, device='cpu') # ?
self.unique_block_mapping = torch.full((max_unique_blocks, ), -1, dtype=torch.int64, device='cpu')
self.block_len_range = torch.arange(1, block_size + 1, dtype=torch.int32, device='cpu')
self.causal_bias = torch.full((max_num_batched_tokens, max_num_batched_tokens),
-math.inf,
dtype=dtype,
device='cpu')


def create_unified_batch(req_ids: list[str], all_token_ids: torch.tensor, num_computed_tokens: torch.tensor,
num_scheduled_tokens: torch.tensor, num_prompt_tokens: torch.tensor, block_table: torch.tensor,
block_size: int, dtype: torch.dtype, bucketing_fn: Callable[[bool, int, int, int, int],
tuple[int, int, int,
int]]) -> UnifiedBatch:
block_size: int, dtype: torch.dtype, persistent_ctx: UnifiedBatchPersistentContext,
bucketing_fn: Callable[[bool, int, int, int, int], tuple[int, int, int, int]]) -> UnifiedBatch:
""" Calculate all necessary tensors needed for batch scheduling """
total_tokens = num_computed_tokens + num_scheduled_tokens
query_len = num_scheduled_tokens.sum().item()
Expand Down Expand Up @@ -440,7 +476,7 @@ def first_dim(t: Optional[torch.tensor]) -> int:
unique_bias = None

if contains_prompts:
causal_bias = create_causal_bias(token_groups, token_positions, dtype)
causal_bias = create_causal_bias(token_groups, token_positions, dtype, persistent_ctx.causal_bias)

ctx = Context.create(cached_tokens, block_table, block_size)
if ctx:
Expand All @@ -456,27 +492,26 @@ def first_dim(t: Optional[torch.tensor]) -> int:
shared_token_idx = shared_group_starts.index_select(0, shared_token_indices) + shared_token_offsets
shared_block_idx = orig_shared_blocks.index_select(0, shared_token_indices)
shared_block_usage = shared_ctx.block_usages.index_select(0, shared_token_indices)
shared_block_bias = generate_bias(shared_block_usage, block_size, dtype)
shared_block_bias = generate_bias(shared_block_usage, block_size, dtype, persistent_ctx.block_len_range,
persistent_ctx.shared_block_bias)

shared_bias = torch.full((query_len, shared_blocks.size(0), block_size),
-math.inf,
dtype=dtype,
device=shared_blocks.device)
# CRITICAL: Must clone to avoid corrupting persistent array - slicing creates a view!
shared_bias = persistent_ctx.shared_bias[:query_len, :shared_blocks.size(0), :block_size].clone()
shared_bias.fill_(-math.inf)
shared_bias.index_put_((shared_token_idx, shared_block_idx), shared_block_bias)

if unique_ctx:
unique_blocks = torch.amax(unique_ctx.block_ids).item() + 1
unique_bias = torch.full((unique_blocks, block_size),
-math.inf,
dtype=dtype,
device=unique_ctx.block_ids.device)
unique_block_bias = generate_bias(unique_ctx.block_usages, block_size, dtype)
# CRITICAL: Must clone to avoid corrupting persistent array - slicing creates a view!
unique_bias = persistent_ctx.unique_bias[:unique_blocks, :block_size].clone()
unique_bias.fill_(-math.inf)
unique_block_bias = generate_bias(unique_ctx.block_usages, block_size, dtype,
persistent_ctx.block_len_range, persistent_ctx.unique_block_bias)
unique_bias.index_copy_(0, unique_ctx.block_ids.to(torch.int64), unique_block_bias)
unique_group_starts = group_starts.index_select(0, unique_ctx.group_ids)
unique_block_mapping = torch.full((unique_blocks, ),
-1,
dtype=torch.int64,
device=unique_ctx.block_ids.device)
# CRITICAL: Must clone to avoid corrupting persistent array - slicing creates a view!
unique_block_mapping = persistent_ctx.unique_block_mapping[:unique_blocks].clone()
unique_block_mapping.fill_(-1)
unique_block_mapping.index_copy_(0, unique_ctx.block_ids.to(torch.int64), unique_group_starts)

bucket = bucketing_fn(contains_prompts, first_dim(token_ids), first_dim(shared_blocks), unique_blocks,
Expand Down
Loading
Loading