diff --git a/benchmarks/benchmark_throughput_autotune.py b/benchmarks/benchmark_throughput_autotune.py index 407393f..4bbd843 100644 --- a/benchmarks/benchmark_throughput_autotune.py +++ b/benchmarks/benchmark_throughput_autotune.py @@ -1067,34 +1067,75 @@ def main() -> None: oot_enabled = False oot_whitelist = None - val = run_benchmark_multi( - label=label, - throughput_args=throughput_args, - num_runs=args.num_runs, - run_dir=run_dir, - log_counter=log_counter, - use_flaggems=use_flaggems, - gems_whitelist=gems_whitelist, - oot_whitelist=oot_whitelist, - oot_enabled=oot_enabled, - config_dir=round_config_dir, - config_payload=op_config_payload, - ) + if backend == "reference": + best_ref_type = "FLASH_ATTN" + + for attn_mode in ["FLASH_ATTN", "FLASHINFER"]: + logger.info(f" [Autotune Reference] Trying {attn_mode}...") + + test_env = {"VLLM_ATTENTION_BACKEND": attn_mode} + os.environ["VLLM_ATTENTION_BACKEND"] = attn_mode + + val = run_benchmark_multi( + label=f"op_{op}_reference_{attn_mode.lower()}", + throughput_args=throughput_args, + num_runs=args.num_runs, + run_dir=run_dir, + log_counter=log_counter, + use_flaggems=use_flaggems, + gems_whitelist=gems_whitelist, + oot_whitelist=oot_whitelist, + oot_enabled=oot_enabled, + config_dir=round_config_dir, + config_payload=op_config_payload, + ) - per_op_backend_results.setdefault(op, {})[backend] = val - write_results_csv( - csv_path, - baseline_result, - per_op_results, - baseline_fake_result=baseline_fake_result, - baseline_enable_result=baseline_enable_result, - op_backends=tuned_op_backends, - per_op_backend_results=per_op_backend_results, - ) - if val is not None and baseline_total is not None: - if best_result is None or val[0] > best_result[0]: - best_result = val - best_backend = backend + per_op_backend_results.setdefault(op, {})[backend] = val + write_results_csv( + csv_path, + baseline_result, + per_op_results, + baseline_fake_result=baseline_fake_result, + baseline_enable_result=baseline_enable_result, + op_backends=tuned_op_backends, + per_op_backend_results=per_op_backend_results, + ) + if val is not None and baseline_total is not None: + if best_result is None or val[0] > best_result[0]: + best_result = val + best_backend = backend + best_ref_type = attn_mode + + os.environ["VLLM_ATTENTION_BACKEND"] = best_ref_type + else: + val = run_benchmark_multi( + label=label, + throughput_args=throughput_args, + num_runs=args.num_runs, + run_dir=run_dir, + log_counter=log_counter, + use_flaggems=use_flaggems, + gems_whitelist=gems_whitelist, + oot_whitelist=oot_whitelist, + oot_enabled=oot_enabled, + config_dir=round_config_dir, + config_payload=op_config_payload, + ) + + per_op_backend_results.setdefault(op, {})[backend] = val + write_results_csv( + csv_path, + baseline_result, + per_op_results, + baseline_fake_result=baseline_fake_result, + baseline_enable_result=baseline_enable_result, + op_backends=tuned_op_backends, + per_op_backend_results=per_op_backend_results, + ) + if val is not None and baseline_total is not None: + if best_result is None or val[0] > best_result[0]: + best_result = val + best_backend = backend per_op_results[op] = best_result if ( diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 78514d0..63520f5 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -134,7 +134,14 @@ def attention_backend(self, use_mla: bool = False) -> str: # Return vLLM's native flash attention backend as reference from vllm.attention.backends.registry import AttentionBackendEnum + import os + + env_backend = os.environ.get("VLLM_ATTENTION_BACKEND", "FLASH_ATTN").upper() + if use_mla: # vLLM native MLA backend return AttentionBackendEnum.MLA.get_path() - return AttentionBackendEnum.FLASH_ATTN.get_path() + if env_backend == "FLASHINFER": + return AttentionBackendEnum.FLASHINFER.get_path() + else: + return AttentionBackendEnum.FLASH_ATTN.get_path()