Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jul 2, 2024
1 parent cc7998e commit 8f7ecf3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,4 +573,4 @@ def dispatch_bgmv_low_level(
names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type
del names_and_values_to_update, names_and_values, v, k, fn_type
4 changes: 2 additions & 2 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
# assert is_cpu(), RuntimeError(
# "Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
Expand Down
6 changes: 3 additions & 3 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha
typical_acceptance_sampler_posterior_alpha,
cpu_draft_worker=speculative_config.cpu_draft_worker)

return spec_decode_worker
Expand Down Expand Up @@ -132,8 +132,8 @@ def create_worker(
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size

if draft_tp == 1:
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
# if draft_tp == 1:
# draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
if cpu_draft_worker:
cpu_draft_worker_kwargs = copy.deepcopy(draft_worker_kwargs)
from vllm.executor.cpu_executor import (
Expand Down

0 comments on commit 8f7ecf3

Please sign in to comment.