diff --git a/ucm/sparse/gsa_on_device/gsa_on_device.py b/ucm/sparse/gsa_on_device/gsa_on_device.py index cfed6f428..b6a9a0984 100644 --- a/ucm/sparse/gsa_on_device/gsa_on_device.py +++ b/ucm/sparse/gsa_on_device/gsa_on_device.py @@ -218,6 +218,15 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.qk_rope_head_dim = getattr( vllm_config.model_config.hf_text_config, "qk_rope_head_dim", None ) + self.hash_encoder = HashEncoder( + input_dim=self.kv_lora_rank, + hash_bits=self.kv_lora_rank, + dtype=vllm_config.model_config.dtype, + device=self.device, + input_dim_rope=self.qk_rope_head_dim, + hash_bits_rope=self.qk_rope_head_dim, + is_mla=True, + ) self.hash_encoder_nope = HashEncoder( input_dim=self.kv_lora_rank, hash_bits=self.kv_lora_rank, @@ -426,14 +435,12 @@ def get_layer_state(self, layer_name: str): def cache_k_hash_mla_cuda( self, nope, rope, k_hash, attn_metadata, forward_context, layer_name ): - k_c_normed_hash, k_pe_hash = self.hash_code(nope=nope, rope=rope) - ops.concat_and_cache_mla( - k_c_normed_hash, - k_pe_hash.squeeze(1), - k_hash, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype="auto", - scale=self._k_scale, + self.hash_encoder.compute_hash_and_cache_mla( + x=nope, + x_rope=rope.squeeze(1), + slot_mapping=attn_metadata.slot_mapping.flatten(), + k_hash_cache=k_hash, + block_size=self.block_size, ) if self.has_pc_hit: ## kvcache -> nope + rope @@ -444,14 +451,13 @@ def cache_k_hash_mla_cuda( k_c_normed, k_pe = torch.split(k_cache, [512, 64], dim=-1) k_c_normed = k_c_normed.reshape(-1, k_c_normed.shape[2]) k_pe = k_pe.reshape(-1, k_pe.shape[2]) - k_c_normed_hash, k_pe_hash = self.hash_code(nope=k_c_normed, rope=k_pe) - ops.concat_and_cache_mla( - k_c_normed_hash, - k_pe_hash, - k_hash, - self.prefix_slot_mapping.flatten(), - kv_cache_dtype="auto", - scale=self._k_scale, + + self.hash_encoder.compute_hash_and_cache_mla( + x=k_c_normed, + x_rope=k_pe, + slot_mapping=self.prefix_slot_mapping.flatten(), + k_hash_cache=k_hash, + block_size=self.block_size, ) def cache_k_hash_mla_npu( diff --git a/ucm/sparse/gsa_on_device/hash_encoder.py b/ucm/sparse/gsa_on_device/hash_encoder.py index 743fd6346..397fa2892 100644 --- a/ucm/sparse/gsa_on_device/hash_encoder.py +++ b/ucm/sparse/gsa_on_device/hash_encoder.py @@ -269,9 +269,9 @@ def fused_hash_and_cache_kernel( k_cache_ptr, T, H, - K, - N_BITS, - N_BYTES, + K: tl.constexpr, + N_BITS: tl.constexpr, + N_BYTES: tl.constexpr, stride_xt, stride_xh, stride_xk, @@ -355,6 +355,119 @@ def fused_hash_and_cache_kernel( ) tl.store(out_ptrs, packed, mask=out_mask) + @triton.jit + def fused_hash_and_cache_mla_kernel( + x_nope_ptr, + x_rope_ptr, + code_nope_ptr, + code_rope_ptr, + pack_w_ptr, + slot_ptr, + k_cache_ptr, + T, + K_nope: tl.constexpr, + K_rope: tl.constexpr, + N_BITS_NOPE: tl.constexpr, + N_BITS_ROPE: tl.constexpr, + nope_hash_bytes: tl.constexpr, + stride_xt_nope, + stride_xk_nope, + stride_xt_rope, + stride_xk_rope, + stride_codek_nope, + stride_coden_nope, + stride_codek_rope, + stride_coden_rope, + stride_packw, + stride_cb, + stride_cs, + stride_cw, + block_size: tl.constexpr, + cache_num_slots: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + if pid_m * BLOCK_M >= T: + return + + blocks_for_nope = N_BITS_NOPE // BLOCK_N + is_nope = pid_n < blocks_for_nope + + if is_nope: + K = K_nope + N_BITS = N_BITS_NOPE + x_ptr = x_nope_ptr + code_ptr = code_nope_ptr + stride_xt = stride_xt_nope + stride_xk = stride_xk_nope + stride_codek = stride_codek_nope + stride_coden = stride_coden_nope + byte_offset_base = 0 + rel_pid_n = pid_n + else: + K = K_rope + N_BITS = N_BITS_ROPE + x_ptr = x_rope_ptr + code_ptr = code_rope_ptr + stride_xt = stride_xt_rope + stride_xk = stride_xk_rope + stride_codek = stride_codek_rope + stride_coden = stride_coden_rope + byte_offset_base = nope_hash_bytes + rel_pid_n = pid_n - blocks_for_nope + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = rel_pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + m_mask = offs_m < T + n_mask = offs_n < N_BITS + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(0, tl.cdiv(K, BLOCK_K)): + k_mask = offs_k < K + x_ptrs = x_ptr + offs_m[:, None] * stride_xt + offs_k[None, :] * stride_xk + x = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0) + + code_ptrs = ( + code_ptr + + offs_k[:, None] * stride_codek + + offs_n[None, :] * stride_coden + ) + code = tl.load(code_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0) + + acc += tl.dot(x, code) + offs_k += BLOCK_K + + bits = (acc > 0).to(tl.uint8) + bits_2d = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) + pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_packw) + packed = tl.sum(bits_2d * pack_w[None, None, :], axis=-1).to(tl.uint8) + + offs_byte = rel_pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8) + slot = tl.load(slot_ptr + offs_m, mask=m_mask, other=-1) + + valid_slot = (slot >= 0) & (slot < cache_num_slots) + valid_row = m_mask & valid_slot + valid_byte = offs_byte[None, :] < (N_BITS // 8) + out_mask = valid_row[:, None] & valid_byte + + safe_slot = tl.where(valid_row, slot, tl.zeros((BLOCK_M,), dtype=tl.int64)) + b = (safe_slot // block_size)[:, None] + s = (safe_slot % block_size)[:, None] + + out_ptrs = ( + k_cache_ptr + + b * stride_cb + + s * stride_cs + + (byte_offset_base + offs_byte[None, :]) * stride_cw + ) + tl.store(out_ptrs, packed, mask=out_mask) + @torch.compile() def torch_hash_code(x, code, pack_weight): @@ -377,7 +490,14 @@ class HashEncoder: """ def __init__( - self, input_dim: int, hash_bits: int, dtype: torch.dtype, device: torch.device + self, + input_dim: int, + hash_bits: int, + dtype: torch.dtype, + device: torch.device, + input_dim_rope: int = 0, + hash_bits_rope: int = 0, + is_mla: bool = False, ) -> None: self.input_dim = input_dim @@ -391,6 +511,7 @@ def __init__( self.dtype = dtype self.device = device + self.is_mla = is_mla if self.device.type == "npu": if dtype not in [torch.float16, torch.float32, torch.float64]: @@ -400,23 +521,40 @@ def __init__( logger.warning("automatically using float16 for hash_weights now") self.dtype = torch.float16 - if self.device.type == "cuda" and dtype == torch.bfloat16: - logger.warning("geqrf_cuda not implemented for BFloat16") - logger.warning("automatically using float32 for hash_weights now") - self.dtype = torch.float32 + # if self.device.type == "cuda" and dtype == torch.bfloat16: + # logger.warning("geqrf_cuda not implemented for BFloat16") + # logger.warning("automatically using float32 for hash_weights now") + # self.dtype = torch.float32 - self._init_hash_weights() + self.hash_weights = self._init_hash_weights(self.input_dim, self.hash_bits) if self.device.type == "cuda" or self.device.type == "cpu": self._init_bit_masks() - def _init_hash_weights(self): + if self.is_mla: + self.input_dim_rope = input_dim_rope + self.hash_bits_rope = hash_bits_rope + if hash_bits_rope % 8 != 0: + raise ValueError("hash_bits_rope must be a multiple of 8") + self.hash_numbers_rope = hash_bits_rope // 8 + self.hash_weights_rope = self._init_hash_weights( + self.input_dim_rope, self.hash_bits_rope + ) + else: + self.hash_numbers_rope = 0 + self.hash_numbers_total = self.hash_numbers + self.hash_numbers_rope + + def _init_hash_weights(self, input_dim, hash_bits): # Step 1: 随机高斯矩阵 random_weights = torch.normal( mean=0, std=2, - size=(self.input_dim, self.hash_bits), - dtype=self.dtype, + size=(input_dim, hash_bits), + dtype=( + torch.float32 + if (self.device.type == "cuda" and self.dtype == torch.bfloat16) + else self.dtype + ), device=self.device, ) # Step 2: QR分解 @@ -424,7 +562,10 @@ def _init_hash_weights(self): # Step 3: 调整符号,保证Haar 分布 d = torch.sign(torch.diag(R)) - self.hash_weights = Q * d + hash_weights = Q * d + if self.device.type == "cuda" and self.dtype == torch.bfloat16: + hash_weights = hash_weights.to(self.dtype) + return hash_weights def set_hash_weight(self, hash_weights: torch.Tensor) -> None: if hash_weights.shape != (self.input_dim, self.hash_bits): @@ -447,7 +588,7 @@ def _init_bit_masks(self) -> None: 2, torch.arange(8, dtype=torch.uint8, device=self.device) ) - def compute_hash(self, x: torch.Tensor) -> torch.Tensor: + def compute_hash(self, x: torch.Tensor, is_rope: bool = False) -> torch.Tensor: """ Compute the hash code for input tensor x. Args: @@ -456,9 +597,13 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: A tensor of shape (..., hash_numbers=hash_bits // 8) representing the hash codes. Each element is a uint8 number representing 8 bits of the hash code. """ - if x.shape[-1] != self.input_dim: + input_dim = self.input_dim_rope if is_rope else self.input_dim + hash_weights = self.hash_weights_rope if is_rope else self.hash_weights + hash_numbers = self.hash_numbers_rope if is_rope else self.hash_numbers + + if x.shape[-1] != input_dim: raise ValueError( - f"x must be of shape (..., {self.input_dim}), but got {x.shape}" + f"x must be of shape (..., {input_dim}), but got {x.shape}" ) if x.device != self.device: raise ValueError( @@ -470,14 +615,14 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: orig_shape = x.shape[:-1] # [N, input_dim], e.g., N = s1*s2*s3 - x_flat = x.reshape(-1, self.input_dim) + x_flat = x.reshape(-1, input_dim) if x_flat.dtype != self.dtype: x_flat = x_flat.to(self.dtype) if self.device.type == "npu": # [N, hash_bits] - xW = torch.matmul(x_flat, self.hash_weights) + xW = torch.matmul(x_flat, hash_weights) # [N * hash_bits] xW_flat = xW.view(-1) # [N*hash_numbers], where hash_numbers = hash_bits // 8 @@ -485,14 +630,14 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: elif self.device.type == "cuda": packed_codes_flat = triton_hash_code( - x_flat, self.hash_weights, self.bit_masks + x_flat, hash_weights, self.bit_masks ).view( -1 ) # [N * hash_numbers] elif self.device.type == "cpu": packed_codes_flat = torch_hash_code( - x_flat, self.hash_weights, self.bit_masks + x_flat, hash_weights, self.bit_masks ).view( -1 ) # [N * hash_numbers] @@ -501,7 +646,7 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unsupported device type: {self.device.type}") # e.g., [s1, s2, s3, hash_numbers] - out_shape = orig_shape + (self.hash_numbers,) + out_shape = orig_shape + (hash_numbers,) packed_codes = packed_codes_flat.view(out_shape) return packed_codes @@ -514,30 +659,30 @@ def _reinterpret_cache_as_u8(self, cache: torch.Tensor) -> torch.Tensor: 内部 view 成 uint8 后变成 [B, BS, H, hash_numbers] """ if cache.dtype == torch.uint8: - if cache.shape[-1] != self.hash_numbers: + if cache.shape[-1] != self.hash_numbers_total: raise ValueError( f"uint8 cache last dim mismatch: got {cache.shape[-1]}, " - f"expected {self.hash_numbers}" + f"expected {self.hash_numbers_total}" ) return cache if cache.dtype == torch.bfloat16: - if self.hash_numbers % 2 != 0: + if self.hash_numbers_total % 2 != 0: raise ValueError( - f"for bfloat16 cache, hash_numbers must be even, got {self.hash_numbers}" + f"for bfloat16 cache, hash_numbers must be even, got {self.hash_numbers_total}" ) - if cache.shape[-1] != self.hash_numbers // 2: + if cache.shape[-1] != self.hash_numbers_total // 2: raise ValueError( f"bfloat16 cache last dim mismatch: got {cache.shape[-1]}, " - f"expected {self.hash_numbers // 2}" + f"expected {self.hash_numbers_total // 2}" ) cache_u8 = cache.view(torch.uint8) - if cache_u8.shape[-1] != self.hash_numbers: + if cache_u8.shape[-1] != self.hash_numbers_total: raise ValueError( f"reinterpret bf16->u8 failed: got last dim {cache_u8.shape[-1]}, " - f"expected {self.hash_numbers}" + f"expected {self.hash_numbers_total}" ) return cache_u8 @@ -661,6 +806,77 @@ def compute_hash_and_cache( return k_hash_cache + def compute_hash_and_cache_mla( + self, + x: torch.Tensor, # [T, 1, K_nope] + x_rope: torch.Tensor, # [T, 1, K_rope] + slot_mapping: torch.Tensor, # [T] + k_hash_cache: torch.Tensor, # [B, BS, 1, N], uint8 or bf16 + block_size: int = 128, + BLOCK_M: int = 64, + BLOCK_K: int = 64, + BLOCK_N: int = 64, + num_warps: int = 4, + ): + T, K = x.shape + _, K_rope = x_rope.shape + B, _, _ = k_hash_cache.shape + + k_hash_cache_u8 = self._reinterpret_cache_as_u8(k_hash_cache) + + if x.dtype != self.dtype: + x = x.to(self.dtype) + if x_rope.dtype != self.dtype: + x_rope = x_rope.to(self.dtype) + + cache_num_slots = B * block_size + + stride_xt, stride_xk = x.stride() + stride_xt_rope, stride_xk_rope = x_rope.stride() + stride_codek, stride_coden = self.hash_weights.stride() + stride_codek_rope, stride_coden_rope = self.hash_weights_rope.stride() + (stride_packw,) = self.bit_masks.stride() + stride_cb, stride_cs, stride_cw = k_hash_cache_u8.stride() + + total_hash_bits = self.hash_bits + self.hash_bits_rope + grid = (triton.cdiv(T, BLOCK_M), triton.cdiv(total_hash_bits, BLOCK_N)) + + fused_hash_and_cache_mla_kernel[grid]( + x_nope_ptr=x, + x_rope_ptr=x_rope, + code_nope_ptr=self.hash_weights, + code_rope_ptr=self.hash_weights_rope, + pack_w_ptr=self.bit_masks, + slot_ptr=slot_mapping, + k_cache_ptr=k_hash_cache_u8, + T=T, + K_nope=K, + K_rope=K_rope, + N_BITS_NOPE=self.hash_bits, + N_BITS_ROPE=self.hash_bits_rope, + nope_hash_bytes=self.hash_numbers, + stride_xt_nope=stride_xt, + stride_xk_nope=stride_xk, + stride_xt_rope=stride_xt_rope, + stride_xk_rope=stride_xk_rope, + stride_codek_nope=stride_codek, + stride_coden_nope=stride_coden, + stride_codek_rope=stride_codek_rope, + stride_coden_rope=stride_coden_rope, + stride_packw=stride_packw, + stride_cb=stride_cb, + stride_cs=stride_cs, + stride_cw=stride_cw, + block_size=block_size, + cache_num_slots=cache_num_slots, + BLOCK_M=BLOCK_M, + BLOCK_K=BLOCK_K, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + + return k_hash_cache + def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: """ Unpack the hash codes to +1 or -1 bits. diff --git a/ucm/sparse/test/gsa/test_cuda_cache_and_hash.py b/ucm/sparse/test/gsa/test_cuda_cache_and_hash-gqa.py similarity index 100% rename from ucm/sparse/test/gsa/test_cuda_cache_and_hash.py rename to ucm/sparse/test/gsa/test_cuda_cache_and_hash-gqa.py diff --git a/ucm/sparse/test/gsa/test_cuda_cache_and_hash-mla.py b/ucm/sparse/test/gsa/test_cuda_cache_and_hash-mla.py new file mode 100644 index 000000000..d6615549e --- /dev/null +++ b/ucm/sparse/test/gsa/test_cuda_cache_and_hash-mla.py @@ -0,0 +1,303 @@ +import pytest +import torch +from vllm import _custom_ops as ops + +from ucm.sparse.gsa_on_device.csrc.cuda.hash_and_cache import fused_mla_hash +from ucm.sparse.gsa_on_device.hash_encoder import ( + HashEncoder, + reshape_and_cache_khash_triton, +) + +torch.manual_seed(42) + +warmup_iters = 5 +test_iters = 20 + +num_tokens = 128 * 300 # T +num_heads = 1 # H, MLA直接不构造该维度 +head_dim = 512 # K (input_dim) +head_dim_rope = 64 # K_rope (input_dim_rope) +hash_bits = 512 # N (hash_bits) +hash_bits_rope = 64 # N_rope +hash_numbers = hash_bits // 8 # W (hash_numbers) +hash_numbers_rope = hash_bits_rope // 8 # W (hash_numbers) +block_size = 128 # BS +num_blocks = 300 # B + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +class TestCudaHashAndCacheMLA: + def get_input_data(self): + device = torch.device("cuda:6") + torch.cuda.set_device(6) + dtype = torch.bfloat16 + + assert ( + num_tokens <= num_blocks * block_size + ), "num_blocks is not large enough to contain all tokens." + + # 初始化 HashEncoder + encoder = HashEncoder( + input_dim=head_dim, + hash_bits=hash_bits, + dtype=dtype, + device=device, + input_dim_rope=head_dim_rope, + hash_bits_rope=hash_bits_rope, + is_mla=True, + ) + + # key: [T, H, K] + key = torch.randn((num_tokens, head_dim), device=device, dtype=dtype) + key_rope = torch.randn((num_tokens, head_dim_rope), device=device, dtype=dtype) + + # slot_mapping: [T], 随机映射到 cache 中的位置 + slot_mapping = torch.randperm(num_blocks * block_size)[:num_tokens].to( + device, dtype=torch.int64 + ) + + # 初始化两个相同的 cache 用于对比 + # k_hash_cache: [B, BS, H, W] 其中 W = hash_bits // 8 + # bf16格式,相比u8的维度减半 + cache_fused = torch.zeros( + (num_blocks, block_size, (hash_numbers + hash_numbers_rope) // 2), + device=device, + dtype=torch.bfloat16, + ) + cache_ref = torch.zeros_like(cache_fused) + + k_scale = torch.tensor(1.0, dtype=torch.float32) + return ( + encoder, + key, + key_rope, + slot_mapping, + cache_fused, + cache_ref, + num_tokens, + hash_numbers, + hash_numbers_rope, + block_size, + k_scale, + ) + + def test_cuda_hash_and_cache_gqa_accuracy(self): + + ( + encoder, + key, + key_rope, + slot_mapping, + cache_fused, + cache_ref, + num_tokens, + hash_numbers, + hash_numbers_rope, + block_size, + k_scale, + ) = self.get_input_data() + + # 融合算子 + encoder.compute_hash_and_cache_mla( + key, key_rope, slot_mapping, cache_fused, block_size=block_size + ) + # 基准计算 + # 1. 计算 Hash Code [T, H, W] + k_hash_computed = encoder.compute_hash(key).view(torch.bfloat16) + k_rope_hash_computed = encoder.compute_hash(key_rope, is_rope=True).view( + torch.bfloat16 + ) + # 2. 写入 Cache + ops.concat_and_cache_mla( + k_hash_computed, + k_rope_hash_computed, + cache_ref, + slot_mapping.flatten(), + kv_cache_dtype="auto", + scale=k_scale, + ) + torch.save( + cache_ref.to("cpu"), "/home/externals/wangwenxin21/fl/data/cache_triton.npy" + ) + + cache_ref.zero_() + torch.cuda.synchronize() + block_size = 128 + fused_mla_hash.fused_hash_and_cache_mla( + key, + key_rope, + encoder.hash_weights, + encoder.hash_weights_rope, + encoder.bit_masks, + slot_mapping.flatten(), + cache_ref.view(torch.uint8), + block_size, + ) + torch.save( + cache_ref.to("cpu"), "/home/externals/wangwenxin21/fl/data/cache_cuda.npy" + ) + # 验证融合算子的结果与分步计算的结果是否一致 + diff = torch.abs(cache_fused.view(torch.uint8) - cache_ref.view(torch.uint8)) + print( + f"\nBit flip rate: {diff.nonzero().shape[0]}/{diff.numel()} = {diff.nonzero().shape[0] / diff.numel():.4f}" + ) + assert ( + diff.nonzero().shape[0] / diff.numel() < 0.01 + ), "More than 1% of the elements differ between fused and reference results." + + def test_cuda_hash_and_cache_gqa_baseline(self): + ( + encoder, + key, + key_rope, + slot_mapping, + cache_fused, + cache_ref, + num_tokens, + hash_numbers, + hash_numbers_rope, + block_size, + k_scale, + ) = self.get_input_data() + + # 原版:分步计算 + # 预热 + for _ in range(warmup_iters): + k_hash_computed = encoder.compute_hash(key).view(torch.bfloat16) + k_rope_hash_computed = encoder.compute_hash(key_rope, is_rope=True).view( + torch.bfloat16 + ) + # 2. 写入 Cache + ops.concat_and_cache_mla( + k_hash_computed, + k_rope_hash_computed, + cache_ref, + slot_mapping.flatten(), + kv_cache_dtype="auto", + scale=k_scale, + ) + torch.cuda.synchronize() + + # 性能测试 + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + total_time = 0 + torch.cuda.synchronize() + start_time.record() + for _ in range(test_iters): + k_hash_computed = encoder.compute_hash(key).view(torch.bfloat16) + k_rope_hash_computed = encoder.compute_hash(key_rope, is_rope=True).view( + torch.bfloat16 + ) + # 2. 写入 Cache + ops.concat_and_cache_mla( + k_hash_computed, + k_rope_hash_computed, + cache_ref, + slot_mapping.flatten(), + kv_cache_dtype="auto", + scale=k_scale, + ) + end_time.record() + torch.cuda.synchronize() + total_time += start_time.elapsed_time(end_time) + avg_time_ms_ref = total_time / test_iters + print(f"\nAverage time per iteration (Unfused): {avg_time_ms_ref:.2f} ms") + + def test_cuda_hash_and_cache_mla_tritonFused_performance(self): + + ( + encoder, + key, + key_rope, + slot_mapping, + cache_fused, + cache_ref, + num_tokens, + hash_numbers, + hash_numbers_rope, + block_size, + k_scale, + ) = self.get_input_data() + + # 融合算子 + # 预热 + for _ in range(warmup_iters): + encoder.compute_hash_and_cache_mla( + key, key_rope, slot_mapping, cache_fused, block_size=block_size + ) + torch.cuda.synchronize() + + # 性能测试 + total_time = 0 + torch.cuda.synchronize() + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + start_time.record() + for _ in range(test_iters): + encoder.compute_hash_and_cache_mla( + key, key_rope, slot_mapping, cache_fused, block_size=block_size + ) + end_time.record() + torch.cuda.synchronize() + total_time += start_time.elapsed_time(end_time) + avg_time_ms = total_time / test_iters + print(f"\nAverage time per iteration (Triton fused): {avg_time_ms:.2f} ms") + + def test_cuda_hash_and_cache_mla_cudaFused_performance(self): + + ( + encoder, + key, + key_rope, + slot_mapping, + cache_fused, + cache_ref, + num_tokens, + hash_numbers, + hash_numbers_rope, + block_size, + k_scale, + ) = self.get_input_data() + + # 融合算子 + # 预热 + for _ in range(warmup_iters): + fused_mla_hash.fused_hash_and_cache_mla( + key, + key_rope, + encoder.hash_weights, + encoder.hash_weights_rope, + encoder.bit_masks, + slot_mapping.flatten(), + cache_ref.view(torch.uint8), + block_size, + ) + torch.cuda.synchronize() + + # 性能测试 + total_time = 0 + torch.cuda.synchronize() + start_time = torch.cuda.Event(enable_timing=True) + end_time = torch.cuda.Event(enable_timing=True) + start_time.record() + for _ in range(test_iters): + fused_mla_hash.fused_hash_and_cache_mla( + key, + key_rope, + encoder.hash_weights, + encoder.hash_weights_rope, + encoder.bit_masks, + slot_mapping.flatten(), + cache_ref.view(torch.uint8), + block_size, + ) + end_time.record() + torch.cuda.synchronize() + total_time += start_time.elapsed_time(end_time) + avg_time_ms = total_time / test_iters + print(f"\nAverage time per iteration (Cuda fused): {avg_time_ms:.2f} ms") + + +if __name__ == "__main__": + pytest.main([__file__])