Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,127 @@ class BucketingFailedException(Exception):
pass


class HpuSampler(Sampler):
@staticmethod
def make_ndarray_with_pad(
x: list[list[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len: int | None = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.

The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x)) if x else 0

padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, : len(blocktb)] = blocktb

return padded_x

@staticmethod
def make_tensor_with_pad(
x: list[list[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: int | None = None,
device: str | torch.device | None = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.

The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
if max_len is None:
max_len = max(len(row) for row in x) if x else 0

padded_tensor = torch.full(
(len(x), max_len), fill_value=pad, dtype=dtype, device=device
)

for i, row in enumerate(x):
row_len = len(row)
if row_len > 0:
row_tensor = torch.as_tensor(row, dtype=dtype, device=device)
padded_tensor[i, :row_len] = row_tensor

if pin_memory and padded_tensor.device.type == "cpu":
return padded_tensor.pin_memory()

return padded_tensor

@staticmethod
def _convert_to_tensors(
output_token_ids: list[list[int]], vocab_size: int, device: torch.device
) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = HpuSampler.make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

@staticmethod
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = HpuSampler._convert_to_tensors(output_token_ids, vocab_size, logits.device)
from vllm.model_executor.layers.utils import apply_penalties as a_p
return a_p(
logits,
prompt_token_ids,
output_tokens_t,
presence_penalties,
frequency_penalties,
repetition_penalties,
)

@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits

return HpuSampler.apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)


# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncHPUModelRunnerOutput(AsyncModelRunnerOutput):

Expand Down Expand Up @@ -674,7 +795,7 @@ def __init__(
self.use_aux_hidden_state_outputs = False
self.supports_mm_inputs = False

self.sampler = Sampler()
self.sampler = HpuSampler()

# NOTE(kzawora) update_env is a hack to work around VLLMKVCache in
# hpu-extension which selects fetch_from_cache implementation based
Expand Down
Loading