Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 68 additions & 27 deletions benchmarks/benchmark_throughput_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,34 +1067,75 @@ def main() -> None:
oot_enabled = False
oot_whitelist = None

val = run_benchmark_multi(
label=label,
throughput_args=throughput_args,
num_runs=args.num_runs,
run_dir=run_dir,
log_counter=log_counter,
use_flaggems=use_flaggems,
gems_whitelist=gems_whitelist,
oot_whitelist=oot_whitelist,
oot_enabled=oot_enabled,
config_dir=round_config_dir,
config_payload=op_config_payload,
)
if backend == "reference":
best_ref_type = "FLASH_ATTN"

for attn_mode in ["FLASH_ATTN", "FLASHINFER"]:
logger.info(f" [Autotune Reference] Trying {attn_mode}...")

test_env = {"VLLM_ATTENTION_BACKEND": attn_mode}
os.environ["VLLM_ATTENTION_BACKEND"] = attn_mode

val = run_benchmark_multi(
label=f"op_{op}_reference_{attn_mode.lower()}",
throughput_args=throughput_args,
num_runs=args.num_runs,
run_dir=run_dir,
log_counter=log_counter,
use_flaggems=use_flaggems,
gems_whitelist=gems_whitelist,
oot_whitelist=oot_whitelist,
oot_enabled=oot_enabled,
config_dir=round_config_dir,
config_payload=op_config_payload,
)

per_op_backend_results.setdefault(op, {})[backend] = val
write_results_csv(
csv_path,
baseline_result,
per_op_results,
baseline_fake_result=baseline_fake_result,
baseline_enable_result=baseline_enable_result,
op_backends=tuned_op_backends,
per_op_backend_results=per_op_backend_results,
)
if val is not None and baseline_total is not None:
if best_result is None or val[0] > best_result[0]:
best_result = val
best_backend = backend
per_op_backend_results.setdefault(op, {})[backend] = val
write_results_csv(
csv_path,
baseline_result,
per_op_results,
baseline_fake_result=baseline_fake_result,
baseline_enable_result=baseline_enable_result,
op_backends=tuned_op_backends,
per_op_backend_results=per_op_backend_results,
)
if val is not None and baseline_total is not None:
if best_result is None or val[0] > best_result[0]:
best_result = val
best_backend = backend
best_ref_type = attn_mode

os.environ["VLLM_ATTENTION_BACKEND"] = best_ref_type
else:
val = run_benchmark_multi(
label=label,
throughput_args=throughput_args,
num_runs=args.num_runs,
run_dir=run_dir,
log_counter=log_counter,
use_flaggems=use_flaggems,
gems_whitelist=gems_whitelist,
oot_whitelist=oot_whitelist,
oot_enabled=oot_enabled,
config_dir=round_config_dir,
config_payload=op_config_payload,
)

per_op_backend_results.setdefault(op, {})[backend] = val
write_results_csv(
csv_path,
baseline_result,
per_op_results,
baseline_fake_result=baseline_fake_result,
baseline_enable_result=baseline_enable_result,
op_backends=tuned_op_backends,
per_op_backend_results=per_op_backend_results,
)
if val is not None and baseline_total is not None:
if best_result is None or val[0] > best_result[0]:
best_result = val
best_backend = backend

per_op_results[op] = best_result
if (
Expand Down
9 changes: 8 additions & 1 deletion vllm_fl/dispatch/backends/reference/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ def attention_backend(self, use_mla: bool = False) -> str:
# Return vLLM's native flash attention backend as reference
from vllm.attention.backends.registry import AttentionBackendEnum

import os

env_backend = os.environ.get("VLLM_ATTENTION_BACKEND", "FLASH_ATTN").upper()

if use_mla:
# vLLM native MLA backend
return AttentionBackendEnum.MLA.get_path()
return AttentionBackendEnum.FLASH_ATTN.get_path()
if env_backend == "FLASHINFER":
return AttentionBackendEnum.FLASHINFER.get_path()
else:
return AttentionBackendEnum.FLASH_ATTN.get_path()