diff --git a/examples/bench_attn/bench_attn.py b/examples/bench_attn/bench_attn.py new file mode 100644 index 0000000..ae7e8a3 --- /dev/null +++ b/examples/bench_attn/bench_attn.py @@ -0,0 +1,136 @@ +import torch +import time +from transformers import AutoConfig + +# Assuming the import path works in your environment +from transformers.models.qwen3.modular_qwen3 import Qwen3Attention +from loguru import logger + +# --- 1. Setup Configuration --- +model_name = "/data/yiliu/models/Qwen/Qwen3-8B" +# model_name = "/data5/yliu7/HF_HOME/Qwen/Qwen3-8B/" +# Load config +config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) +config.use_cache = False # Disable KV cache for training/benchmarking logic + + +# Benchmarking Parameters +batch_size = 8 +seq_len = 2048 +hidden_size = config.hidden_size +num_heads = config.num_attention_heads +head_dim = hidden_size // num_heads + +dtype = torch.bfloat16 +device = "xpu" + + +def benchmark_implementation(impl_name, forward_only=False, num_steps=10): + mode_str = "Forward Only" if forward_only else "Forward + Backward" + logger.info(f"\n--- Benchmarking: {impl_name} [{mode_str}] ---") + + # Update config implementation + config._attn_implementation = impl_name + + try: + attn_mod = Qwen3Attention(config, layer_idx=0).to(device).to(dtype) + except Exception as e: + logger.info(f"Skipping {impl_name}: Could not initialize. Error: {e}") + return + + # 2. Prepare Data + # Only require gradients if we are testing backward pass + req_grad = not forward_only + + hidden_states = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype, requires_grad=req_grad) + + # Generate Position Embeddings (RoPE) - Cos/Sin tuple + cos = torch.randn(1, seq_len, head_dim, device=device, dtype=dtype) + sin = torch.randn(1, seq_len, head_dim, device=device, dtype=dtype) + position_embeddings = (cos, sin) + + attention_mask = None + + # Helper function to run one step + def run_step(): + if forward_only: + with torch.no_grad(): + output, _ = attn_mod(hidden_states, position_embeddings, attention_mask=attention_mask) + return output + else: + output, _ = attn_mod(hidden_states, position_embeddings, attention_mask=attention_mask) + loss = output.mean() + loss.backward() + # Zero gradients + hidden_states.grad = None + for param in attn_mod.parameters(): + param.grad = None + return output + + # 3. Warmup + logger.info(" Warming up...") + try: + for _ in range(10): + run_step() + except Exception as e: + logger.info(f" Error during warmup: {e}") + return + + torch.xpu.synchronize() + + # 4. Benchmark Loop + logger.info(" Running benchmark...") + start_event = torch.xpu.Event(enable_timing=True) + end_event = torch.xpu.Event(enable_timing=True) + + start_event.record() + from auto_round.compressors.profiler_wrapper import XPUTorchProfilerWrapper as Profiler + # profiler = Profiler(worker_name=f"attn_{impl_name}_{'fwd' if forward_only else 'fwd_bwd'}", local_rank=0) + # profiler.start() + for _ in range(num_steps): + + run_step() + # profiler.step() + # profiler.stop() + + end_event.record() + torch.xpu.synchronize() + + # 5. Stats + elapsed_time_ms = start_event.elapsed_time(end_event) + avg_time_per_step = elapsed_time_ms / num_steps + + logger.info(f" > Average Time: {avg_time_per_step:.2f} ms") + + # Approx TFLOPS Calculation + # Base FLOPs for Forward pass (Linear Projections + Attention) + # Projections (Q,K,V,O): 4 projections * 2 (mul+add) * B * S * H^2 -> 8 B S H^2 + # Attention (Matmuls): ~ 4 * B * S^2 * H + fwd_flops = (8 * batch_size * seq_len * hidden_size**2) + (4 * batch_size * seq_len**2 * hidden_size) + + # Backward pass is roughly 2x Forward pass cost + total_flops_per_step = fwd_flops if forward_only else (fwd_flops * 3) + + tflops = (total_flops_per_step / (avg_time_per_step / 1000)) / 1e12 + # logger.info(f" > Approx TFLOPS: {tflops:.2f}") + + +# --- Run Benchmarks --- +implementations = [ + # "flash_attention_2", + "sdpa", + "flex_attention", + "eager", +] + +# Test Forward Only +# logger.info("=== MODE: FORWARD ONLY ===") +# for impl in implementations: +# torch.xpu.empty_cache() +# benchmark_implementation(impl, forward_only=True) + +# Test Forward + Backward +logger.info("\n=== MODE: FORWARD + BACKWARD ===") +for impl in implementations: + torch.xpu.empty_cache() + benchmark_implementation(impl, forward_only=False) diff --git a/examples/bench_attn/profiler_wrapper.py b/examples/bench_attn/profiler_wrapper.py new file mode 100644 index 0000000..9c63027 --- /dev/null +++ b/examples/bench_attn/profiler_wrapper.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from contextlib import nullcontext + +import torch +from typing_extensions import override + +# import vllm.envs as envs +# from vllm.logger import init_logger +from loguru import logger +# logger = init_logger(__name__) + +from dataclasses import dataclass +@dataclass +class envs: + VLLM_PROFILER_DELAY_ITERS: int = 0 + VLLM_PROFILER_MAX_ITERS: int = 0 + VLLM_TORCH_PROFILER_DIR: str = "./ar" + VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = True + VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_TORCH_PROFILER_WITH_STACK: bool = True + VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False + + +class WorkerProfiler(ABC): + def __init__(self) -> None: + self._delay_iters = envs.VLLM_PROFILER_DELAY_ITERS + if self._delay_iters > 0: + logger.info_once( + "GPU profiling will start " + f"{self._delay_iters} steps after start_profile." + ) + + self._max_iters = envs.VLLM_PROFILER_MAX_ITERS + if self._max_iters > 0: + logger.info_once( + "GPU profiling will stop " + f"after {self._max_iters} worker steps, " + "or when stop_profile is received." + ) + + # Track when the profiler gets triggered by start_profile + self._active_iteration_count = 0 + self._active = False + + # Track when the profiler is actually running + self._profiling_for_iters = 0 + self._running = False + + @abstractmethod + def _start(self) -> None: + """Start the profiler.""" + pass + + @abstractmethod + def _stop(self) -> None: + """Stop the profiler.""" + pass + + def _call_start(self) -> None: + """Call _start with error handling but no safeguards.""" + try: + self._start() + self._running = True # Only mark as running if start succeeds + except Exception as e: + logger.warning("Failed to start profiler: %s", e) + + def _call_stop(self) -> None: + """Call _stop with error handling but no safeguards.""" + try: + self._stop() + logger.info("Profiler stopped successfully.") + except Exception as e: + logger.warning("Failed to stop profiler: %s", e) + self._running = False # Always mark as not running, assume stop worked + + def start(self) -> None: + """Attempt to start the profiler, accounting for delayed starts.""" + if self._active: + logger.debug( + "start_profile received when profiler is already active. " + "Ignoring request." + ) + return + self._active = True + if self._delay_iters == 0: + self._call_start() + + def step(self) -> None: + """Update the profiler state at each worker step, + to handle delayed starts and max iteration limits.""" + if not self._active: + return + + self._active_iteration_count += 1 + + if ( + not self._running + and self._delay_iters > 0 + and self._active_iteration_count == self._delay_iters + ): + logger.info("Starting profiler after delay...") + self._call_start() + + if self._running: + self._profiling_for_iters += 1 + + if ( + self._max_iters > 0 + and self._running + and self._profiling_for_iters > self._max_iters + ): + # Automatically stop the profiler after max iters + # will be marked as not running, but leave as active so that stop + # can clean up properly + logger.info("Max profiling iterations reached. Stopping profiler...") + self._call_stop() + return + + def stop(self) -> None: + """Attempt to stop the profiler, accounting for overlapped calls.""" + if not self._active: + logger.debug( + "stop_profile received when profiler is not active. Ignoring request." + ) + return + self._active = False + self._active_iteration_count = 0 + self._profiling_for_iters = 0 + + if self._running: + self._call_stop() + + def shutdown(self) -> None: + """Ensure profiler is stopped when shutting down.""" + logger.info_once("Shutting down profiler") + if self._running: + self.stop() + + def annotate_context_manager(self, name: str): + """Return a context manager to annotate profiler traces.""" + return nullcontext() + + +class TorchProfilerWrapper(WorkerProfiler): + def __init__(self, worker_name: str, local_rank: int) -> None: + super().__init__() + + self.local_rank = local_rank + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + if local_rank in (None, 0): + logger.info( + "Torch profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) + logger.debug( + "Profiler config: record_shapes=%s," + "profile_memory=%s,with_stack=%s,with_flops=%s", + envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + envs.VLLM_TORCH_PROFILER_WITH_STACK, + envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + ), + ) + + @override + def _start(self) -> None: + self.profiler.start() + + @override + def _stop(self) -> None: + self.profiler.stop() + + rank = self.local_rank + profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + sort_key = "self_cuda_time_total" + table = self.profiler.key_averages().table(sort_by=sort_key) + + with open(profiler_out_file, "w") as f: + print(table, file=f) + + # only print profiler results on rank 0 + if rank == 0: + print(table) + + @override + def annotate_context_manager(self, name: str): + return torch.profiler.record_function(name) + + +class CudaProfilerWrapper(WorkerProfiler): + def __init__(self) -> None: + super().__init__() + # Note: lazy import to avoid dependency issues if CUDA is not available. + import torch.cuda.profiler as cuda_profiler + + self._cuda_profiler = cuda_profiler + + @override + def _start(self) -> None: + self._cuda_profiler.start() + + @override + def _stop(self) -> None: + self._cuda_profiler.stop() + + @override + def annotate_context_manager(self, name: str): + return torch.cuda.nvtx.range(name) + +class XPUTorchProfilerWrapper(WorkerProfiler): + def __init__(self, worker_name: str, local_rank: int) -> None: + super().__init__() + + self.local_rank = local_rank + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + if local_rank in (None, 0): + logger.info( + "Torch profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) + logger.debug( + "Profiler config: record_shapes=%s," + "profile_memory=%s,with_stack=%s,with_flops=%s", + envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + envs.VLLM_TORCH_PROFILER_WITH_STACK, + envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU, + ], + record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + ), + ) + + @override + def _start(self) -> None: + self.profiler.start() + + @override + def _stop(self) -> None: + self.profiler.stop() + + rank = self.local_rank + profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + sort_key = "self_xpu_time_total" + table = self.profiler.key_averages().table(sort_by=sort_key) + + with open(profiler_out_file, "w") as f: + print(table, file=f) + + # only print profiler results on rank 0 + if rank == 0: + print(table) + + @override + def annotate_context_manager(self, name: str): + return torch.profiler.record_function(name) diff --git a/examples/bench_attn/run_bench.sh b/examples/bench_attn/run_bench.sh new file mode 100755 index 0000000..5f5532a --- /dev/null +++ b/examples/bench_attn/run_bench.sh @@ -0,0 +1 @@ +ZE_AFFINITY_MASK=3 python bench_attn.py \ No newline at end of file