From 9b2fe7ed634c0a35124e049f828022f6745627e5 Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Fri, 26 Jun 2026 13:45:50 +0000 Subject: [PATCH 01/11] Add paged attention highperf JIT example --- .../paged_attention_highperf/.gitignore | 6 + .../paged_attention_highperf/README.md | 35 + .../paged_attention_highperf/jit_util_pa.py | 136 ++ .../paged_attention_highperf/pa_benchmark.py | 135 ++ .../pa_compile_and_run.py | 142 ++ .../paged_attention_highperf/pa_entry.hpp | 147 ++ .../paged_attention_highperf/pa_kernel.cpp | 66 + .../pa_kernel_impl.hpp | 1797 +++++++++++++++++ .../paged_attention_highperf/pa_tiling.py | 484 +++++ .../pa_tiling_struct.hpp | 97 + 10 files changed, 3045 insertions(+) create mode 100644 examples/jit_cpp/paged_attention_highperf/.gitignore create mode 100644 examples/jit_cpp/paged_attention_highperf/README.md create mode 100644 examples/jit_cpp/paged_attention_highperf/jit_util_pa.py create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_benchmark.py create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_entry.hpp create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_tiling.py create mode 100644 examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp diff --git a/examples/jit_cpp/paged_attention_highperf/.gitignore b/examples/jit_cpp/paged_attention_highperf/.gitignore new file mode 100644 index 00000000..885a34f5 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/.gitignore @@ -0,0 +1,6 @@ +outputs/ +*.so +pa_highperf_jit.so +pa_highperf_jit_bench.csv +__pycache__/ +.pytest_cache/ diff --git a/examples/jit_cpp/paged_attention_highperf/README.md b/examples/jit_cpp/paged_attention_highperf/README.md new file mode 100644 index 00000000..b5518ca9 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/README.md @@ -0,0 +1,35 @@ +# Paged Attention HighPerf + +JIT demo for a PTO-ISA paged-attention decode kernel using PyTorch/NPU tensors. + +The example includes: + +| File | Purpose | +|---|---| +| pa_kernel.cpp | host-callable JIT entry point | +| pa_entry.hpp | AIC/AIV dispatch wrapper | +| pa_kernel_impl.hpp | PTO kernel implementation | +| pa_tiling_struct.hpp | tiling type definitions | +| pa_tiling.py | Python tiling/workspace construction | +| pa_compile_and_run.py | correctness smoke test | +| pa_benchmark.py | benchmark driver | +| jit_util_pa.py | JIT compile and ctypes wrapper | + +## Requirements + +Set the PTO-ISA include root and CANN toolkit path if they are not already in the environment: + + export PTO_LIB_PATH=/path/to/pto-isa + export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/cann-9.0.0 + +## Run + + cd examples/jit_cpp/paged_attention_highperf + python3 pa_compile_and_run.py + python3 pa_benchmark.py --device npu:0 --shape b=8,s=4096 --check --warmup 1 --iters 1 + +The full benchmark sweep can be run with: + + python3 pa_benchmark.py --device npu:0 + +Use --check for correctness validation against the Python/PyTorch reference. For very large shapes, the reference can dominate runtime and memory use. diff --git a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py new file mode 100644 index 00000000..759f0041 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py @@ -0,0 +1,136 @@ +#!/usr/bin/python3 +# coding=utf-8 +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the +# terms and conditions of CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance +# with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, +# OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# -------------------------------------------------------------------------------- + +import ctypes +import os +import subprocess +from pathlib import Path + +import torch + +ASCEND_TOOLKIT_HOME = os.environ.get("ASCEND_TOOLKIT_HOME", "/usr/local/Ascend/cann-9.0.0") +PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", str(Path(__file__).resolve().parents[3])) + + +def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def _npu_arch_flag() -> str: + return os.environ.get("NPU_ARCH", "dav-2201").strip() + + +def compile_paged_attention(kernel_cpp: str, verbose: bool = False, timeout: int = 300) -> str: + lib_path = os.path.join(os.path.dirname(kernel_cpp), "pa_highperf_jit.so") + example_dir = os.path.dirname(kernel_cpp) + flags = [ + "-fPIC", + "-shared", + "-xcce", + f"--npu-arch={_npu_arch_flag()}", + "-O2", + "-std=c++17", + "-Wno-ignored-attributes", + "-Wno-macro-redefined", + f"-I{PTO_LIB_PATH}/include", + f"-I{example_dir}", + f"-I{ASCEND_TOOLKIT_HOME}/include", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/runtime", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/profiling", + ] + cmd = ["bisheng", *flags, kernel_cpp, "-o", lib_path] + if verbose: + print("compile command:\n", " ".join(cmd)) + subprocess.run(cmd, check=True, timeout=timeout) + return lib_path + + +def load_paged_attention_lib(lib_path: str, check_type: bool = True): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + if check_type: + lib.call_kernel.argtypes = [ + ctypes.c_void_p, # stream + ctypes.c_void_p, # q + ctypes.c_void_p, # k + ctypes.c_void_p, # v + ctypes.c_void_p, # block table + ctypes.c_void_p, # out + ctypes.c_void_p, # s + ctypes.c_void_p, # p + ctypes.c_void_p, # o_tmp + ctypes.c_void_p, # go + ctypes.c_void_p, # o_core_tmp + ctypes.c_void_p, # l + ctypes.c_void_p, # gm_k16 + ctypes.c_void_p, # gm_v16 + ctypes.c_void_p, # tiling + ctypes.c_void_p, # null + ctypes.c_uint32, # block dim + ] + lib.call_kernel.restype = None + + workspace = {} + default_stream_ptr = torch.npu.current_stream()._as_parameter_ + + def _alloc(device, workspace_sizes, tiling): + tiling_cpu = tuple(int(x) for x in tiling.detach().cpu().tolist()) + sizes_key = tuple(sorted((name, int(size)) for name, size in workspace_sizes.items())) + key = (str(device), sizes_key, tiling_cpu) + if workspace.get("key") == key: + return + workspace.clear() + workspace["key"] = key + for name, size in workspace_sizes.items(): + workspace[name] = torch.empty((int(size),), device=device, dtype=torch.uint8) + workspace["null"] = torch.zeros((1,), device=device, dtype=torch.uint8) + workspace["tiling"] = tiling.to(device=device, dtype=torch.int32) + + def paged_attention(q, k, v, block_table, workspace_sizes, tiling, stream_ptr=default_stream_ptr, block_dim: int = 24): + _alloc(q.device, workspace_sizes, tiling) + out = torch.empty_like(q) + lib.call_kernel( + stream_ptr, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(block_table), + torch_to_ctypes(out), + torch_to_ctypes(workspace["s"]), + torch_to_ctypes(workspace["p"]), + torch_to_ctypes(workspace["o_tmp"]), + torch_to_ctypes(workspace["go"]), + torch_to_ctypes(workspace["o_core_tmp"]), + torch_to_ctypes(workspace["l"]), + torch_to_ctypes(workspace["k16"]), + torch_to_ctypes(workspace["v16"]), + torch_to_ctypes(workspace["tiling"]), + torch_to_ctypes(workspace["null"]), + block_dim, + ) + return out + + return paged_attention + + +def jit_compile_paged_attention(verbose: bool = False, clean_up: bool = True, kernel_cpp: str = "pa_kernel.cpp"): + kernel_path = str((Path(__file__).resolve().parent / kernel_cpp).resolve()) + lib_path = compile_paged_attention(kernel_path, verbose=verbose) + fn = load_paged_attention_lib(lib_path) + if clean_up: + try: + os.remove(lib_path) + except OSError: + pass + return fn diff --git a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py new file mode 100644 index 00000000..4b638406 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py @@ -0,0 +1,135 @@ +#!/usr/bin/python3 +# coding=utf-8 +import argparse +import csv +import gc + +import torch + +from jit_util_pa import jit_compile_paged_attention +from pa_compile_and_run import PaShape, golden_attention, make_inputs, make_launch_config + +NUM_ITERATIONS = 50 +WARMUP = 10 +BATCHES = [1, 2, 4, 8, 32, 64] +SEQ_LENS = [128, 512, 4096, 8192, 16384, 32768, 65536, 131072] +DEFAULT_SHAPES = [PaShape(batch=batch, seq_len=seq_len) for batch in BATCHES for seq_len in SEQ_LENS] + + +def paged_attention_flops(shape: PaShape): + qk_and_pv = 4 * shape.batch * shape.num_heads * shape.seq_len * shape.head_dim + scale = shape.batch * shape.num_heads * shape.seq_len + rows = shape.batch * shape.num_heads + softmax = rows * ((shape.seq_len - 1) + shape.seq_len + shape.seq_len + (shape.seq_len - 1) + shape.seq_len) + return qk_and_pv + scale + softmax + + +def tensor_bytes(shape: PaShape): + dtype_bytes = 2 + q_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes + k_bytes = shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes + v_bytes = shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes + out_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes + blocks_per_batch = (shape.seq_len + shape.block_size - 1) // shape.block_size + block_table_bytes = shape.batch * blocks_per_batch * 4 + return q_bytes + k_bytes + v_bytes + out_bytes + block_table_bytes + + +def tflops(flops, ms): + return flops / (ms * 1e-3) / 1e12 + + +def tb_per_second(num_bytes, ms): + return num_bytes / (ms * 1e-3) / 1e12 + + +def time_npu(fn, iters, warmup): + for _ in range(warmup): + _ = fn() + torch.npu.synchronize() + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = fn() + end.record() + torch.npu.synchronize() + return start.elapsed_time(end) / iters + + +def parse_shape(text): + values = {} + for item in text.split(","): + key, value = item.split("=", 1) + values[key.strip()] = int(value) + return PaShape( + batch=values.get("b", values.get("batch", 1)), + seq_len=values.get("s", values.get("seq", 128)), + num_heads=values.get("h", values.get("heads", 32)), + num_kv_heads=values.get("kv", values.get("kv_heads", 8)), + head_dim=values.get("d", values.get("head_dim", 128)), + block_size=values.get("bs", values.get("block_size", 128)), + block_dim=values.get("bd", values.get("block_dim", 24)), + ) + + +def run_shape(pa, shape, device, iters, warmup, check): + q, k, v, block_table = make_inputs(shape, device, deterministic=check) + ws, tiling, _ = make_launch_config(shape) + if check: + out = pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim) + torch.npu.synchronize() + torch.testing.assert_close(out.float(), golden_attention(q, k, v, block_table, shape), rtol=5e-3, atol=2e-2) + ms = time_npu(lambda: pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim), iters, warmup) + flops = paged_attention_flops(shape) + bytes_total = tensor_bytes(shape) + perf = tflops(flops, ms) + norm_perf = perf * shape.block_dim + return { + "shape": shape.name, + "batch": shape.batch, + "seq_len": shape.seq_len, + "block_dim": shape.block_dim, + "jit_time_us": f"{ms * 1000:.3f}", + "jit_tflops": f"{perf:.6f}", + "jit_tflops_normalized": f"{norm_perf:.6f}", + "jit_bandwidth_tb_s": f"{tb_per_second(bytes_total, ms):.6f}", + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--csv", default="pa_highperf_jit_bench.csv") + parser.add_argument("--iters", type=int, default=NUM_ITERATIONS) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--device", default="npu:0") + parser.add_argument("--shape", action="append", help="Shape override, e.g. b=2,s=8192 or batch=4,seq=512") + parser.add_argument("--check", action="store_true", help="Run correctness check before timing each shape.") + parser.add_argument("--no-check", action="store_true", help=argparse.SUPPRESS) + args = parser.parse_args() + + torch.npu.set_device(args.device) + shapes = [parse_shape(item) for item in args.shape] if args.shape else DEFAULT_SHAPES + pa = jit_compile_paged_attention(verbose=False) + rows = [] + for shape in shapes: + row = run_shape(pa, shape, args.device, args.iters, args.warmup, args.check and not args.no_check) + rows.append(row) + print( + f"paged_attention_highperf_jit {row['shape']}: {row['jit_time_us']} us/iter, " + f"{row['jit_tflops']} TFLOPS logical, {row['jit_tflops_normalized']} TFLOPS normalized, " + f"{row['jit_bandwidth_tb_s']} TB/s, block_dim={row['block_dim']}" + ) + + fieldnames = [ + "shape", "batch", "seq_len", "block_dim", "jit_time_us", "jit_tflops", + "jit_tflops_normalized", "jit_bandwidth_tb_s" + ] + with open(args.csv, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py new file mode 100644 index 00000000..306df052 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py @@ -0,0 +1,142 @@ +#!/usr/bin/python3 +# coding=utf-8 +import math +import os +from dataclasses import dataclass + +import torch +import torch_npu + +from jit_util_pa import jit_compile_paged_attention +from pa_tiling import make_pa_nd_decode_tiling, workspace_sizes + + +@dataclass(frozen=True) +class PaShape: + batch: int + num_heads: int = 32 + num_kv_heads: int = 8 + seq_len: int = 128 + head_dim: int = 128 + block_size: int = 128 + block_dim: int = 24 + dtype: torch.dtype = torch.float16 + + @property + def name(self): + return f"b{self.batch}_h{self.num_heads}_kv{self.num_kv_heads}_s{self.seq_len}_bs{self.block_size}_fp16" + + +def pack_kv_to_paged(k_dense, v_dense, shape: PaShape): + num_blocks = shape.seq_len // shape.block_size + k_page = ( + k_dense.view(shape.batch, shape.seq_len, shape.num_kv_heads, shape.head_dim) + .view(shape.batch, num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .reshape(shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .contiguous() + ) + v_page = ( + v_dense.view(shape.batch, shape.seq_len, shape.num_kv_heads, shape.head_dim) + .view(shape.batch, num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .reshape(shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .contiguous() + ) + block_table = ( + torch.arange(num_blocks, device=k_dense.device, dtype=torch.int32).unsqueeze(0).expand(shape.batch, -1).clone() + + torch.arange(shape.batch, device=k_dense.device, dtype=torch.int32).unsqueeze(1) * num_blocks + ) + return k_page, v_page, block_table + + +def make_inputs(shape: PaShape = PaShape(batch=1), device="npu:0", deterministic=True): + q = torch.zeros((shape.batch, shape.num_heads, shape.head_dim), device=device, dtype=shape.dtype) + k_dense = torch.zeros( + (shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim), device=device, dtype=shape.dtype + ) + if deterministic: + token = torch.arange(shape.seq_len, device=device, dtype=torch.float32).view(1, shape.seq_len, 1, 1) + kv_head = torch.arange(shape.num_kv_heads, device=device, dtype=torch.float32).view(1, 1, shape.num_kv_heads, 1) + dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, 1, shape.head_dim) + batch = torch.arange(shape.batch, device=device, dtype=torch.float32).view(shape.batch, 1, 1, 1) + q_head = torch.arange(shape.num_heads, device=device, dtype=torch.float32).view(1, shape.num_heads, 1) + q_dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, shape.head_dim) + q_values = (((batch[:, 0, 0, 0].view(shape.batch, 1, 1) * 3 + q_head * 5 + q_dim * 7) + .remainder(251) / 125.0) - 1.0) * 0.02 + k_values = (((batch * 11 + token * 13 + kv_head * 17 + dim * 19).remainder(257) / 128.0) - 1.0) * 0.02 + v_values = (((batch * 13 + token * 17 + kv_head * 31 + dim * 7).remainder(257) / 128.0) - 1.0) * 0.25 + q.copy_(q_values.to(shape.dtype)) + k_dense.copy_(k_values.reshape(shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim).to(shape.dtype)) + v_dense = v_values.reshape(shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim).to(shape.dtype) + else: + v_dense = torch.zeros_like(k_dense) + k_page, v_page, block_table = pack_kv_to_paged(k_dense, v_dense, shape) + return q, k_page, v_page, block_table + + +def make_launch_config(shape: PaShape, device="cpu"): + scale = 1.0 / math.sqrt(float(shape.head_dim)) + num_blocks = shape.batch * (shape.seq_len // shape.block_size) + max_blocks_per_query = shape.seq_len // shape.block_size + tiling, effective_block_dim = make_pa_nd_decode_tiling( + batch=shape.batch, + kv_seq_lens=[shape.seq_len] * shape.batch, + num_heads=shape.num_heads, + kv_heads=shape.num_kv_heads, + head_dim=shape.head_dim, + head_dim_v=shape.head_dim, + num_blocks=num_blocks, + block_size=shape.block_size, + max_blocks_per_query=max_blocks_per_query, + scale=scale, + block_dim=shape.block_dim, + device=device, + dtype=shape.dtype, + ) + ws = workspace_sizes(shape.batch, shape.num_heads, shape.head_dim, shape.head_dim, shape.block_dim) + return ws, tiling, effective_block_dim + + +def golden_attention(q, k_page, v_page, block_table, shape: PaShape): + heads_per_kv = shape.num_heads // shape.num_kv_heads + scale = 1.0 / math.sqrt(float(shape.head_dim)) + out = torch.empty((shape.batch, shape.num_heads, shape.head_dim), device=v_page.device, dtype=torch.float32) + for batch_idx in range(shape.batch): + blocks = block_table[batch_idx] + keys = k_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() + values = v_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() + for head in range(shape.num_heads): + kv_head = head // heads_per_kv + scores = torch.mv(keys[:, kv_head, :], q[batch_idx, head].float()) * scale + probs = torch.softmax(scores, dim=0) + out[batch_idx, head] = torch.mv(values[:, kv_head, :].t(), probs) + return out + + +def golden_uniform(v_page, block_table, shape: PaShape): + heads_per_kv = shape.num_heads // shape.num_kv_heads + out = torch.empty((shape.batch, shape.num_heads, shape.head_dim), device=v_page.device, dtype=torch.float32) + for batch_idx in range(shape.batch): + blocks = block_table[batch_idx] + values = v_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() + kv_avg = values.mean(dim=0) + for head in range(shape.num_heads): + out[batch_idx, head] = kv_avg[head // heads_per_kv] + return out + + +def main(): + device = os.environ.get("PA_DEVICE", "npu:0") + torch.npu.set_device(device) + shape = PaShape(batch=8, seq_len=4096) + q, k, v, block_table = make_inputs(shape, device) + ws, tiling, _ = make_launch_config(shape) + pa = jit_compile_paged_attention(verbose=False) + out = pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim) + torch.npu.synchronize() + ref = golden_attention(q, k, v, block_table, shape) + torch.testing.assert_close(out.float(), ref, rtol=5e-3, atol=2e-2) + print("PTO-ISA paged attention JIT: PASSED") + + +if __name__ == "__main__": + main() diff --git a/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp new file mode 100644 index 00000000..8d0c29cc --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp @@ -0,0 +1,147 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef PTO_PAGED_ATTENTION_HIGHPERF_ENTRY_HPP +#define PTO_PAGED_ATTENTION_HIGHPERF_ENTRY_HPP + +#include "pa_kernel_impl.hpp" + +static AICORE __attribute__((always_inline)) void paged_attention_mask_body( + __gm__ uint8_t *__restrict__ sync, + uint32_t ptoBlockIdx, + uint32_t ptoBlockNum, + uint32_t ptoSubBlockId, + __gm__ uint8_t *__restrict__ qGm, + __gm__ uint8_t *__restrict__ kGm, + __gm__ uint8_t *__restrict__ vGm, + __gm__ uint8_t *__restrict__ blockTablesGm, + __gm__ uint8_t *__restrict__ maskGm, + __gm__ uint8_t *__restrict__ deqScale1Gm, + __gm__ uint8_t *__restrict__ offset1Gm, + __gm__ uint8_t *__restrict__ deqScale2Gm, + __gm__ uint8_t *__restrict__ offset2Gm, + __gm__ uint8_t *__restrict__ razorOffset, + __gm__ uint8_t *__restrict__ scaleGm, + __gm__ uint8_t *__restrict__ logNGm, + __gm__ uint8_t *__restrict__ eyeGm, + __gm__ uint8_t *__restrict__ oGm, + __gm__ uint8_t *__restrict__ sGm, + __gm__ uint8_t *__restrict__ pGm, + __gm__ uint8_t *__restrict__ oTmpGm, + __gm__ uint8_t *__restrict__ goGm, + __gm__ uint8_t *__restrict__ oCoreTmpGm, + __gm__ uint8_t *__restrict__ lGm, + __gm__ uint8_t *__restrict__ gmK16, + __gm__ uint8_t *__restrict__ gmV16, + __gm__ uint8_t *__restrict__ tilingParaGm) +{ + (void)maskGm; + (void)deqScale1Gm; + (void)offset1Gm; + (void)deqScale2Gm; + (void)offset2Gm; + (void)razorOffset; + (void)scaleGm; + (void)logNGm; + (void)eyeGm; + (void)sGm; + (void)pGm; + (void)oTmpGm; + (void)goGm; + (void)gmK16; + (void)gmV16; + + if (sync != nullptr) { + set_ffts_base_addr(reinterpret_cast(sync)); + } + set_atomic_none(); + set_mask_norm(); + +#ifdef __DAV_C220_CUBE__ + const int64_t workerIdx = static_cast(ptoBlockIdx); + const int64_t workerNum = static_cast(ptoBlockNum); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (SupportsPtoPagedAttentionHighPerf(tilingParaGm)) { + RunPtoPagedAttentionCubePipeline(qGm, kGm, vGm, blockTablesGm, sGm, pGm, oTmpGm, tilingParaGm, workerIdx, workerNum); + } else if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + RunPtoPagedAttentionCubePipelineSplitKV(qGm, kGm, vGm, blockTablesGm, sGm, pGm, oTmpGm, tilingParaGm, + workerIdx, workerNum); + } else if (ctx.kvSplitCoreNum > 1) { + pipe_barrier(PIPE_ALL); + } else { + pipe_barrier(PIPE_ALL); + } +#elif defined(__DAV_C220_VEC__) + if (SupportsPtoPagedAttentionHighPerf(tilingParaGm)) { + const int64_t workerIdx = static_cast(ptoBlockIdx); + const int64_t workerNum = static_cast(ptoBlockNum); + RunPtoPagedAttentionVecPipeline(oGm, sGm, pGm, oTmpGm, tilingParaGm, workerIdx, workerNum, ptoSubBlockId); + } else { + const int64_t workerIdx = static_cast(ptoBlockIdx) * 2 + static_cast(ptoSubBlockId); + const int64_t workerNum = static_cast(ptoBlockNum) * 2; + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + RunPtoPagedAttentionVecPipelineSplitKV(oGm, sGm, pGm, oTmpGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); + } else if (ctx.kvSplitCoreNum > 1) { + RunPtoPagedAttentionDecodeSplitKV(qGm, kGm, vGm, blockTablesGm, oGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); + } else { + RunPtoPagedAttentionDecode(qGm, kGm, vGm, blockTablesGm, oGm, tilingParaGm, workerIdx, workerNum); + } + } +#else + pipe_barrier(PIPE_ALL); +#endif +} + +#ifndef PTO_PA_NO_GLOBAL_ENTRY +extern "C" __global__ AICORE void paged_attention_mask( + __gm__ uint8_t *__restrict__ sync, + __gm__ uint8_t *__restrict__ qGm, + __gm__ uint8_t *__restrict__ kGm, + __gm__ uint8_t *__restrict__ vGm, + __gm__ uint8_t *__restrict__ blockTablesGm, + __gm__ uint8_t *__restrict__ maskGm, + __gm__ uint8_t *__restrict__ deqScale1Gm, + __gm__ uint8_t *__restrict__ offset1Gm, + __gm__ uint8_t *__restrict__ deqScale2Gm, + __gm__ uint8_t *__restrict__ offset2Gm, + __gm__ uint8_t *__restrict__ razorOffset, + __gm__ uint8_t *__restrict__ scaleGm, + __gm__ uint8_t *__restrict__ logNGm, + __gm__ uint8_t *__restrict__ eyeGm, + __gm__ uint8_t *__restrict__ oGm, + __gm__ uint8_t *__restrict__ sGm, + __gm__ uint8_t *__restrict__ pGm, + __gm__ uint8_t *__restrict__ oTmpGm, + __gm__ uint8_t *__restrict__ goGm, + __gm__ uint8_t *__restrict__ oCoreTmpGm, + __gm__ uint8_t *__restrict__ lGm, + __gm__ uint8_t *__restrict__ gmK16, + __gm__ uint8_t *__restrict__ gmV16, + __gm__ uint8_t *__restrict__ tilingParaGm) +{ + const uint32_t ptoBlockIdx = static_cast(get_block_idx()); + const uint32_t ptoBlockNum = static_cast(get_block_num()); +#ifdef __DAV_C220_VEC__ + const uint32_t ptoSubBlockId = static_cast(get_subblockid()); +#else + const uint32_t ptoSubBlockId = 0; +#endif + + paged_attention_mask_body( + sync, ptoBlockIdx, ptoBlockNum, ptoSubBlockId, qGm, kGm, vGm, blockTablesGm, maskGm, deqScale1Gm, offset1Gm, + deqScale2Gm, offset2Gm, razorOffset, scaleGm, logNGm, eyeGm, oGm, sGm, pGm, oTmpGm, goGm, oCoreTmpGm, + lGm, gmK16, gmV16, tilingParaGm); +} +#endif + +#endif diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp new file mode 100644 index 00000000..65b905ce --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp @@ -0,0 +1,66 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include + +#include +#include "runtime/rt.h" + +#include "pa_entry.hpp" + +extern "C" void call_kernel( + void *stream, + uint8_t *qGm, + uint8_t *kGm, + uint8_t *vGm, + uint8_t *blockTablesGm, + uint8_t *oGm, + uint8_t *sGm, + uint8_t *pGm, + uint8_t *oTmpGm, + uint8_t *goGm, + uint8_t *oCoreTmpGm, + uint8_t *lGm, + uint8_t *gmK16, + uint8_t *gmV16, + uint8_t *tilingParaGm, + uint8_t *nullGm, + uint32_t blockDim) +{ + uint64_t ffts = 0; + uint32_t fftsLen = 0; + rtGetC2cCtrlAddr(&ffts, &fftsLen); + + paged_attention_mask<<>>( + reinterpret_cast<__gm__ uint8_t *>(ffts), + qGm, + kGm, + vGm, + blockTablesGm, + nullGm, + nullGm, + nullGm, + nullGm, + nullGm, + nullGm, + nullGm, + nullGm, + nullGm, + oGm, + sGm, + pGm, + oTmpGm, + goGm, + oCoreTmpGm, + lGm, + gmK16, + gmV16, + tilingParaGm); +} diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp new file mode 100644 index 00000000..9c401817 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -0,0 +1,1797 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef PTO_PAGED_ATTENTION_HIGHPERF_IMPL_HPP +#define PTO_PAGED_ATTENTION_HIGHPERF_IMPL_HPP + +#include +#include + +#include "pa_tiling_struct.hpp" + +using namespace pto; + +constexpr int32_t TILING_BATCH = 0; +constexpr int32_t TILING_NUMHEADS = 1; +constexpr int32_t TILING_HEADDIM = 2; +constexpr int32_t TILING_BLOCKSIZE = 4; +constexpr int32_t TILING_MAXBLOCKS = 5; +constexpr int32_t TILING_KVHEADS = 7; +constexpr int32_t TILING_FORMER_BATCH = 8; +constexpr int32_t TILING_FORMER_HEAD = 9; +constexpr int32_t TILING_TAIL_BATCH = 10; +constexpr int32_t TILING_TAIL_HEAD = 11; +constexpr int32_t TILING_HEADNUM_MOVE = 12; +constexpr int32_t TILING_KEY = 16; +constexpr int32_t TILING_HEADSIZE = 17; +constexpr int32_t TILING_PARASIZE = 18; +constexpr int32_t TILING_GROUPNUM = 19; +constexpr int32_t TILING_FORMER_GROUP_MOVE = 20; +constexpr int32_t TILING_TAIL_GROUP_MOVE = 21; +constexpr int32_t TILING_MAX_KVSEQLEN = 22; +constexpr int32_t TILING_KVSPLIT = 23; +constexpr int32_t TILING_KVCORENUM = 24; +constexpr int32_t TILING_BLOCKSIZE_CALC = 25; +constexpr int32_t TILING_DECODER_BS = 28; +constexpr int32_t TILING_HEADDIM_V = 29; + +constexpr int32_t kParaKvSeqLen = 1; +constexpr int32_t kParaBatchIndex = 13; + +AICORE inline int32_t LoadTilingI32(__gm__ uint8_t *tiling, int32_t index) +{ + return *(reinterpret_cast<__gm__ int32_t *>(tiling) + index); +} + +AICORE inline int32_t LoadBlockTable(__gm__ uint8_t *blockTablesGm, int64_t offset) +{ + return *(reinterpret_cast<__gm__ int32_t *>(blockTablesGm) + offset); +} + +AICORE inline float LoadFp16(__gm__ uint8_t *gm, int64_t offset) +{ + __gm__ half *ptr = reinterpret_cast<__gm__ half *>(gm); + return static_cast(ptr[offset]); +} + +AICORE inline void StoreOutputFp16(__gm__ uint8_t *oGm, int64_t offset, float value) +{ + __gm__ half *out = reinterpret_cast<__gm__ half *>(oGm); + out[offset] = static_cast(value); +} + +AICORE inline float LoadScale(__gm__ uint8_t *tiling) +{ + union { + int32_t i; + float f; + } scale; + scale.i = LoadTilingI32(tiling, 6); + return scale.f; +} + +struct PaTilingContext { + int32_t batch; + int32_t decoderBatch; + int32_t numHeads; + int32_t kvHeads; + int32_t headDim; + int32_t headDimV; + int32_t blockSize; + int32_t maxBlocksPerQuery; + int32_t maxKvSeqLen; + int32_t formerBatch; + int32_t formerHeadSplit; + int32_t tailBatch; + int32_t tailHeadSplit; + int32_t headNumMove; + int32_t groupNum; + int32_t formerGroupMove; + int32_t tailGroupMove; + int32_t kvSplitPerCore; + int32_t kvSplitCoreNum; + int32_t blockSizeCalc; + int32_t headSize; + int32_t paraSize; + float scale; +}; + +AICORE inline PaTilingContext LoadPaTilingContext(__gm__ uint8_t *tiling) +{ + PaTilingContext ctx{}; + ctx.batch = LoadTilingI32(tiling, TILING_BATCH); + ctx.decoderBatch = LoadTilingI32(tiling, TILING_DECODER_BS); + ctx.numHeads = LoadTilingI32(tiling, TILING_NUMHEADS); + ctx.kvHeads = LoadTilingI32(tiling, TILING_KVHEADS); + ctx.headDim = LoadTilingI32(tiling, TILING_HEADDIM); + ctx.headDimV = LoadTilingI32(tiling, TILING_HEADDIM_V); + ctx.blockSize = LoadTilingI32(tiling, TILING_BLOCKSIZE); + ctx.maxBlocksPerQuery = LoadTilingI32(tiling, TILING_MAXBLOCKS); + ctx.maxKvSeqLen = LoadTilingI32(tiling, TILING_MAX_KVSEQLEN); + ctx.formerBatch = LoadTilingI32(tiling, TILING_FORMER_BATCH); + ctx.formerHeadSplit = LoadTilingI32(tiling, TILING_FORMER_HEAD); + ctx.tailBatch = LoadTilingI32(tiling, TILING_TAIL_BATCH); + ctx.tailHeadSplit = LoadTilingI32(tiling, TILING_TAIL_HEAD); + ctx.headNumMove = LoadTilingI32(tiling, TILING_HEADNUM_MOVE); + ctx.groupNum = LoadTilingI32(tiling, TILING_GROUPNUM); + ctx.formerGroupMove = LoadTilingI32(tiling, TILING_FORMER_GROUP_MOVE); + ctx.tailGroupMove = LoadTilingI32(tiling, TILING_TAIL_GROUP_MOVE); + ctx.kvSplitPerCore = LoadTilingI32(tiling, TILING_KVSPLIT); + ctx.kvSplitCoreNum = LoadTilingI32(tiling, TILING_KVCORENUM); + ctx.blockSizeCalc = LoadTilingI32(tiling, TILING_BLOCKSIZE_CALC); + ctx.headSize = LoadTilingI32(tiling, TILING_HEADSIZE); + ctx.paraSize = LoadTilingI32(tiling, TILING_PARASIZE); + ctx.scale = LoadScale(tiling); + return ctx; +} + +template +AICORE inline float PtoExpScalar(ScalarTile &tile, float value) +{ + tile.data()[0] = value; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TEXP(tile, tile); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + return tile.data()[0]; +} + +template +AICORE inline float PtoLogScalar(ScalarTile &tile, float value) +{ + if (value <= 0.0f) { + return -3.4028234663852886e38f; + } + tile.data()[0] = value; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TLOG(tile, tile); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + return tile.data()[0]; +} + +AICORE inline float LoadPagedKByBlock( + __gm__ uint8_t *kGm, + int32_t blockId, + int32_t offsetInBlock, + int32_t blockSize, + int32_t kvHeads, + int32_t kvHead, + int32_t headDim, + int32_t dim) +{ + const int64_t offset = (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + kvHead) * headDim + dim); + return LoadFp16(kGm, offset); +} + +AICORE inline float LoadPagedVByBlock( + __gm__ uint8_t *vGm, + int32_t blockId, + int32_t offsetInBlock, + int32_t blockSize, + int32_t kvHeads, + int32_t kvHead, + int32_t headDim, + int32_t dim) +{ + const int64_t offset = (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + kvHead) * headDim + dim); + return LoadFp16(vGm, offset); +} + +AICORE inline void ResolvePagedPosition( + __gm__ uint8_t *blockTablesGm, + int32_t batchIndex, + int32_t maxBlocksPerQuery, + int32_t pos, + int32_t blockSize, + int32_t &blockId, + int32_t &offsetInBlock) +{ + const int32_t tableCol = pos / blockSize; + offsetInBlock = pos - tableCol * blockSize; + blockId = LoadBlockTable(blockTablesGm, static_cast(batchIndex) * maxBlocksPerQuery + tableCol); +} + +AICORE inline float ComputeScoreByBlock( + const float *qValues, + __gm__ uint8_t *kGm, + int32_t blockId, + int32_t offsetInBlock, + int32_t blockSize, + int32_t kvHead, + int32_t headDim, + int32_t kvHeads, + float scale) +{ + float score = 0.0f; + for (int32_t dim = 0; dim < headDim; ++dim) { + const float k = LoadPagedKByBlock(kGm, blockId, offsetInBlock, blockSize, kvHeads, kvHead, headDim, dim); + score += qValues[dim] * k; + } + return score * scale; +} + + + +constexpr int32_t PA_TILE_TOKENS = 128; +constexpr uint8_t PA_QK_FIFO_FLAG = 0; +constexpr uint8_t PA_P_FIFO_FLAG = 2; +constexpr uint8_t PA_PV_FIFO_FLAG = 4; +constexpr uint32_t PA_FIFO_DEPTH = 2; +constexpr uint8_t PTO_PA_REDUCE_READY_DECODER = static_cast(SYNC_AIV_ONLY_ALL); +constexpr uint8_t PTO_PA_RAW_QK_READY = 0; +constexpr uint8_t PTO_PA_RAW_QK_FREE = 2; +constexpr uint8_t PTO_PA_RAW_P_READY = 4; +constexpr uint8_t PTO_PA_RAW_P_FREE = 6; +constexpr uint8_t PTO_PA_RAW_PV_READY = 8; +constexpr uint8_t PTO_PA_RAW_PV_FREE = 10; + +AICORE inline uint8_t PtoPaSlotFlag(uint8_t baseFlag, uint8_t slot) +{ + return static_cast(baseFlag + slot); +} + +AICORE inline uint16_t PtoPaGetFftsMsg(uint16_t mode, uint16_t eventId, uint16_t baseConst = 0x1) +{ + return ((baseConst & 0xf) + ((mode & 0x3) << 4) + ((eventId & 0xf) << 8)); +} + +AICORE inline void PtoPaSignalFromCube(uint8_t flagId) +{ + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); +} + +AICORE inline void PtoPaSignalFromVec(uint8_t flagId) +{ + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); +} + +AICORE inline void PtoPaSignalFreeFromVec(uint8_t flagId) +{ + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); +} + +AICORE inline void PtoPaSignalFreeFromCube(uint8_t flagId) +{ + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); +} + +// Real (un-sorted) batch index lives at +8 inside the per-batch para block; +13 (kParaBatchIndex) +// holds the sorted/remap slot. The CCE reference always double-indirects: read the sorted slot at +// +13, re-derive the para base, then read the real batch at +8. See pa_kernel.cce:523-525. +constexpr int32_t kParaRealBatchIndex = 8; + +AICORE inline int32_t ResolveSortedParaBase(__gm__ uint8_t *tiling, const PaTilingContext &ctx, int32_t batchSlot) +{ + const int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tiling, paraBase + kParaBatchIndex); + return ctx.headSize + sortedBatch * ctx.paraSize; +} + +// One-time sticky SPR setup mirroring the CCE reference SetArgs (pa_kernel.cce:441-444). +AICORE inline void PtoPaInitCoreState() +{ +#if defined(__DAV_C220_CUBE__) + set_padding(0); + set_nd_para(1ULL); +#endif + set_atomic_none(); + set_mask_norm(); +} + +AICORE inline bool SupportsPtoPagedAttentionHighPerf(__gm__ uint8_t *tilingParaGm) +{ + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (ctx.headDim != PA_TILE_TOKENS || ctx.headDimV != PA_TILE_TOKENS || ctx.blockSize != PA_TILE_TOKENS) { + return false; + } + if (ctx.batch <= 0 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || ctx.numHeads % ctx.kvHeads != 0) { + return false; + } + if (ctx.kvSplitCoreNum > 1) { + return false; + } + return true; +} + +AICORE inline bool SupportsPtoPagedAttentionRawSplitKV(__gm__ uint8_t *tilingParaGm) +{ + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (ctx.headDim != PA_TILE_TOKENS || ctx.headDimV != PA_TILE_TOKENS || ctx.blockSize != PA_TILE_TOKENS) { + return false; + } + if (ctx.batch <= 0 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || ctx.numHeads % ctx.kvHeads != 0) { + return false; + } + if (ctx.kvSplitCoreNum <= 1) { + return false; + } + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + if (headsPerKv != 4 || formerHeadSplit % headsPerKv != 0 || formerHeadSplit < 16 || formerHeadSplit % 16 != 0) { + return false; + } + return ctx.kvSplitPerCore <= 8192; +} + +AICORE inline void DdrBarrierBeforePtoFfts() +{ +#if defined(__CPU_SIM) + dsb(0); +#else + dsb(DSB_DDR); +#endif + pipe_barrier(PIPE_ALL); +} + +AICORE inline void DdrFenceBeforePtoAivReduce() +{ +#if defined(__CPU_SIM) + dsb(0); +#else + dsb(DSB_DDR); +#endif +} + +AICORE inline void PtoPaSetFloatVectorMask(uint32_t len) +{ + set_mask_norm(); + constexpr uint32_t kFloatVectorSize = 64; + if (len >= kFloatVectorSize) { + set_vector_mask(static_cast(-1), static_cast(-1)); + return; + } + uint64_t mask = 0; + for (uint32_t i = 0; i < len; ++i) { + mask |= 1ULL << i; + } + set_vector_mask(0, mask); +} + +AICORE inline void PtoPaStageSync() +{ + SYNCALL(); +} + +#ifdef __DAV_C220_VEC__ +template +__tf__ AICORE void PtoPaConvF32ToF16Raw(typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint8_t repeat) +{ + __ubuf__ half *dstAddr = reinterpret_cast<__ubuf__ half *>(__cce_get_tile_ptr(dst)); + __ubuf__ float *srcAddr = reinterpret_cast<__ubuf__ float *>(__cce_get_tile_ptr(src)); + vconv_f322f16(dstAddr, srcAddr, repeat, 1, 1, 4, 8); +} + +template +AICORE inline void PtoPaConvF32ToF16(DstTileData &dst, SrcTileData &src, uint8_t repeat) +{ + PtoPaConvF32ToF16Raw(dst.data(), src.data(), repeat); +} +#endif + +#ifdef __DAV_C220_CUBE__ +template +__tf__ AICORE void PtoPaLoadNzHeadGroupToCaRaw(typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint32_t headGroupBase, uint16_t repeatTimes) +{ + using DataType = typename SrcTileData::DType; + static constexpr uint32_t kC0 = 16; + __ca__ DataType *dstAddr = reinterpret_cast<__ca__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + headGroupBase * kC0; + load_cbuf_to_ca(dstAddr, srcAddr, 0, repeatTimes, 1, 0, 0, false, false, addr_cal_mode_t(0)); +} + +template +AICORE inline void PtoPaLoadNzHeadGroupToCa(DstTileData &dst, SrcTileData &src, uint32_t headGroupBase, + uint16_t repeatTimes) +{ + PtoPaLoadNzHeadGroupToCaRaw(dst.data(), src.data(), headGroupBase, repeatTimes); +} + +template +__tf__ AICORE void PtoPaLoadCbufToCbRaw(typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint32_t srcElementOffset, uint16_t repeatTimes, + uint16_t srcStride) +{ + using DataType = typename SrcTileData::DType; + __cb__ DataType *dstAddr = reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + srcElementOffset; + load_cbuf_to_cb(dstAddr, srcAddr, 0, repeatTimes, srcStride, 0, 0, false, addr_cal_mode_t(0)); +} + +template +AICORE inline void PtoPaLoadCbufToCbRaw(DstTileData &dst, SrcTileData &src, uint32_t srcElementOffset, + uint16_t repeatTimes, uint16_t srcStride) +{ + PtoPaLoadCbufToCbRaw(dst.data(), src.data(), srcElementOffset, repeatTimes, srcStride); +} + +template +__tf__ AICORE void PtoPaLoadCbufToCbTranspose128Raw(typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src) +{ + using DataType = typename SrcTileData::DType; + __cb__ DataType *dstAddr = reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)); + constexpr uint32_t kBlock = 16; + constexpr uint32_t kRows = 128; + constexpr uint32_t kCols = 128; + for (uint32_t idx = 0; idx < kCols / kBlock; ++idx) { + load_cbuf_to_cb_transpose(dstAddr + idx * kRows * kBlock, srcAddr + idx * kBlock * kBlock, 0, + kRows / kBlock, kCols / kBlock, 0, addr_cal_mode_t(0), 0); + } +} + +template +AICORE inline void PtoPaLoadCbufToCbTranspose128Raw(DstTileData &dst, SrcTileData &src) +{ + PtoPaLoadCbufToCbTranspose128Raw(dst.data(), src.data()); +} +#endif + +#ifdef __DAV_C220_CUBE__ +AICORE inline void RunPtoPagedAttentionCubePipeline(__gm__ uint8_t *qGm, __gm__ uint8_t *kGm, + __gm__ uint8_t *vGm, __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, + __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum) +{ + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + constexpr int32_t kM = 16; + constexpr int32_t kN = kTileTokens; + constexpr int32_t kK = 256; + + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (!SupportsPtoPagedAttentionHighPerf(tilingParaGm) || workerIdx < 0 || workerNum <= 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : + (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + + using QKPipe = TPipe; + using PPipe = TPipe; + using PVPipe = TPipe; + using QGlobal = GlobalTensor, Stride>; + using KGlobal = GlobalTensor, + Stride<1, 1, 1, 1, 8 * kHeadDim>, Layout::DN>; + using VGlobal = GlobalTensor, + Stride>; + using PGlobal = GlobalTensor, + Stride>; + using ScoreGlobal = GlobalTensor, + Stride>; + using OTmpGlobal = GlobalTensor, + Stride>; + + using QMatTile = Tile; + using KMatTile = Tile; + using PMatTile = Tile; + using VMatTile = Tile; + using LeftQTile = TileLeft; + using LeftPTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + QMatTile qMatTile; + KMatTile kMatTile; + PMatTile pMatTile; + VMatTile vMatTile; + LeftQTile qLeftTile; + LeftPTile pLeftTile; + RightTile rightTile; + AccTile accTile; + TASSIGN(qMatTile, 0x00000); + TASSIGN(kMatTile, 0x20000); + TASSIGN(pMatTile, 0x00000); + TASSIGN(vMatTile, 0x20000); + TASSIGN(qLeftTile, 0x00000); + TASSIGN(pLeftTile, 0x00000); + TASSIGN(rightTile, 0x00000); + TASSIGN(accTile, 0x00000); + + __gm__ uint8_t *scoreBase = sGm + workerIdx * QKPipe::RingFiFo::SLOT_SIZE * QKPipe::RingFiFo::SLOT_NUM; + __gm__ uint8_t *probBase = pGm + workerIdx * PPipe::RingFiFo::SLOT_SIZE * PPipe::RingFiFo::SLOT_NUM; + __gm__ uint8_t *outBase = oTmpGm + workerIdx * PVPipe::RingFiFo::SLOT_SIZE * PVPipe::RingFiFo::SLOT_NUM; + QKPipe qkPipe(reinterpret_cast<__gm__ void *>(scoreBase), 0, 0); + PPipe pPipe(reinterpret_cast<__gm__ void *>(probBase), 0, 0); + PVPipe pvPipe(reinterpret_cast<__gm__ void *>(outBase), 0, 0); + + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t kvHead = head / headsPerKv; + const int32_t tileCount = (kvSeqLen + kTileTokens - 1) / kTileTokens; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qMatTile, qGlobal); + pipe_barrier(PIPE_ALL); + TEXTRACT(qLeftTile, qMatTile, 0, 0); + pipe_barrier(PIPE_ALL); + + for (int32_t tile = 0; tile < tileCount; ++tile) { + const int32_t blockId = LoadBlockTable(blockTablesGm, static_cast(batchIndex) * maxBlocksPerQuery + tile); + const int64_t kvBase = (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; + + + KGlobal kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvBase); + TLOAD(kMatTile, kGlobal); + pipe_barrier(PIPE_ALL); + TMOV(rightTile, kMatTile); + pipe_barrier(PIPE_ALL); + TGEMV(accTile, qLeftTile, rightTile); + pipe_barrier(PIPE_ALL); + DdrBarrierBeforePtoFfts(); + TPUSH(qkPipe, accTile); + + TPOP(pPipe, pMatTile); + VGlobal vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvBase); + + TLOAD(vMatTile, vGlobal); + pipe_barrier(PIPE_ALL); + TEXTRACT(pLeftTile, pMatTile, 0, 0); + TMOV(rightTile, vMatTile); + pipe_barrier(PIPE_ALL); + TGEMV(accTile, pLeftTile, rightTile); + pipe_barrier(PIPE_ALL); + DdrBarrierBeforePtoFfts(); + TPUSH(pvPipe, accTile); + } + } + pipe_barrier(PIPE_ALL); +} +#endif + +#ifdef __DAV_C220_VEC__ +AICORE inline void RunPtoPagedAttentionVecPipeline(__gm__ uint8_t *oGm, __gm__ uint8_t *sGm, + __gm__ uint8_t *pGm, __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, + uint32_t subBlockId) +{ + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + if (!SupportsPtoPagedAttentionHighPerf(tilingParaGm) || workerIdx < 0 || workerNum <= 0) { + pipe_barrier(PIPE_ALL); + return; + } + + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + + using QKPipe = TPipe; + using PPipe = TPipe; + using PVPipe = TPipe; + using VecFloat128 = Tile; + using VecHalf128 = Tile; + using VecHalf256 = Tile; + using VecFloat8 = Tile; + using GlobalFloat128 = GlobalTensor, Stride>; + using GlobalHalf128 = GlobalTensor, Stride>; + + VecFloat128 weightedTile; + VecFloat128 scoreTile; + VecFloat128 pvTile; + VecHalf256 probTile; + VecHalf128 outHalfTile; + VecFloat8 scalarMathTile; + TASSIGN(weightedTile, 0x0000); + TASSIGN(scoreTile, 0x0800); + TASSIGN(pvTile, 0x1000); + TASSIGN(probTile, 0x1800); + TASSIGN(outHalfTile, 0x2000); + TASSIGN(scalarMathTile, 0x2800); + + __gm__ uint8_t *scoreBase = sGm + workerIdx * QKPipe::RingFiFo::SLOT_SIZE * QKPipe::RingFiFo::SLOT_NUM; + __gm__ uint8_t *probBase = pGm + workerIdx * PPipe::RingFiFo::SLOT_SIZE * PPipe::RingFiFo::SLOT_NUM; + __gm__ uint8_t *outTmpBase = oTmpGm + workerIdx * PVPipe::RingFiFo::SLOT_SIZE * PVPipe::RingFiFo::SLOT_NUM; + QKPipe qkPipe(reinterpret_cast<__gm__ void *>(scoreBase), 0, 0); + PPipe pPipe(reinterpret_cast<__gm__ void *>(probBase), 0, 0); + PVPipe pvPipe(reinterpret_cast<__gm__ void *>(outTmpBase), 0, 0); + + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t tileCount = (kvSeqLen + kTileTokens - 1) / kTileTokens; + const bool doWork = subBlockId == 0; + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_ALL); + + for (int32_t tile = 0; tile < tileCount; ++tile) { + const int32_t validTokens = ((tile + 1) * kTileTokens <= kvSeqLen) ? kTileTokens : (kvSeqLen - tile * kTileTokens); + TPOP(qkPipe, scoreTile); + float tileMax = -3.4028234663852886e38f; + for (int32_t pos = 0; pos < validTokens; ++pos) { + const float score = scoreTile.data()[pos] * ctx.scale; + tileMax = score > tileMax ? score : tileMax; + } + const float newMax = tileMax > maxScore ? tileMax : maxScore; + const float oldScale = (tile == 0) ? 0.0f : PtoExpScalar(scalarMathTile, maxScore - newMax); + float tileSum = 0.0f; + TEXPANDS(probTile, static_cast(0.0)); + for (int32_t pos = 0; pos < kTileTokens; ++pos) { + float prob = 0.0f; + if (pos < validTokens) { + prob = PtoExpScalar(scalarMathTile, scoreTile.data()[pos] * ctx.scale - newMax); + tileSum += prob; + } + probTile.data()[pos] = static_cast(prob); + } + sumExp = sumExp * oldScale + tileSum; + TMULS(weightedTile, weightedTile, oldScale); + pipe_barrier(PIPE_ALL); + maxScore = newMax; + DdrBarrierBeforePtoFfts(); + TPUSH(pPipe, probTile); + + TPOP(pvPipe, pvTile); + if (doWork) { + pipe_barrier(PIPE_ALL); + TAXPY(weightedTile, pvTile, 1.0f); + pipe_barrier(PIPE_ALL); + } + } + + if (!doWork) { + continue; + } + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + GlobalHalf128 outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); + TMULS(weightedTile, weightedTile, invSum); + pipe_barrier(PIPE_ALL); + TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + pipe_barrier(PIPE_ALL); + TSTORE(outGlobal, outHalfTile); + pipe_barrier(PIPE_ALL); + } + pipe_barrier(PIPE_ALL); +} +#endif + + +AICORE inline uint64_t LoadTilingOffset64(__gm__ uint8_t *tiling, int32_t base, int32_t highIdx, int32_t lowIdx) +{ + const uint32_t high = static_cast(LoadTilingI32(tiling, base + highIdx)); + const uint32_t low = static_cast(LoadTilingI32(tiling, base + lowIdx)); + return (static_cast(high) << 32) | static_cast(low); +} + +#ifdef __DAV_C220_CUBE__ +AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV(__gm__ uint8_t *qGm, __gm__ uint8_t *kGm, + __gm__ uint8_t *vGm, __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, + __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum) +{ + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + constexpr int32_t kM = 16; + constexpr int32_t kMValid = 4; + constexpr int32_t kN = kTileTokens; + constexpr int32_t kK = 256; + constexpr int32_t kHeadGroup = 16; + + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || + ctx.blockSize != kTileTokens || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : + (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; + const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + + using QGlobal = GlobalTensor, + Stride, Layout::ND>; + using KGlobal = GlobalTensor, + Stride<1, 1, 1, 1, 8 * kHeadDim>, Layout::DN>; + using VGlobal = GlobalTensor, + Stride>; + + using QMatTile = Tile; + using KMatTile = Tile; + using PMatTile = Tile; + using VMatTile = Tile; + using LeftQTile = TileLeft; + using LeftPTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + QMatTile qMatTile; + KMatTile kMatTile; + PMatTile pMatTile; + VMatTile vMatTile; + LeftQTile qLeftTile; + LeftPTile pLeftTile; + RightTile rightTile; + AccTile accTile; + TASSIGN(qMatTile, 0x00000); + TASSIGN(kMatTile, 0x20000); + TASSIGN(pMatTile, 0x00000); + TASSIGN(vMatTile, 0x20000); + TASSIGN(qLeftTile, 0x00000); + TASSIGN(pLeftTile, 0x00000); + TASSIGN(rightTile, 0x00000); + TASSIGN(accTile, 0x00000); + + using ScoreGlobal = GlobalTensor, + Stride>; + using ProbGlobal = GlobalTensor, + Stride, Layout::ND>; + using OutGlobal = GlobalTensor, + Stride>; + + constexpr int64_t scoreHeadBytes = kMValid * kTileTokens * sizeof(float); + constexpr int64_t probHeadBytes = 256 * sizeof(half); + constexpr int64_t outHeadBytes = kMValid * kHeadDim * sizeof(float); + constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; + constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; + constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; + const int64_t scoreSlotBytes = static_cast(maxHeadGroups) * scoreGroupBytes; + const int64_t probSlotBytes = static_cast(maxHeadGroups) * probGroupBytes; + const int64_t outSlotBytes = static_cast(maxHeadGroups) * outGroupBytes; + __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; + __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; + __gm__ uint8_t *outBase = oTmpGm + workerIdx * outSlotBytes * 2; + + const int64_t processRounds = (processNum + workerNum - 1) / workerNum; + const int32_t stageTileCount = (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; + for (int64_t processRound = 0; processRound < processRounds; ++processRound) { + const int64_t process = processRound * workerNum + workerIdx; + bool validProcess = process < processNum; + int32_t batchIndex = 0; + int32_t curHeadNum = 0; + int32_t startHead = 0; + int32_t startTile = 0; + int32_t tileCount = 0; + int32_t curKvSeqLen = 0; + if (validProcess) { + int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const int32_t curSplit = static_cast(process % ctx.kvSplitCoreNum); + validProcess = kvSeqLen > 0 && curSplit < kvLoop; + if (validProcess) { + const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + startHead = curHeadBlock * formerHeadSplit; + curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; + startTile = startKv / kTileTokens; + } + } + + for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; tilePairBase += 2) { + const bool hasStage2 = (tilePairBase + 1) < stageTileCount; + const bool activeTile0 = validProcess && tilePairBase < tileCount; + const bool activeTile1 = validProcess && hasStage2 && (tilePairBase + 1) < tileCount; + + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!activeGroup) { + continue; + } + const int32_t firstHead = startHead + groupHeadBase; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + firstHead) * ctx.headDim; + QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qMatTile, qGlobal); + pipe_barrier(PIPE_ALL); + const int32_t blockId = LoadBlockTable(blockTablesGm, + static_cast(batchIndex) * maxBlocksPerQuery + startTile + tile); + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { + const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; + if (baseHeadLocal >= curHeadNum) { + break; + } + const int32_t baseHead = startHead + baseHeadLocal; + const int32_t kvHead = baseHead / headsPerKv; + const int64_t kvBase = + (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; + PtoPaLoadNzHeadGroupToCa(qLeftTile, qMatTile, static_cast(headInGroupBase), + static_cast(kHeadDim / 16)); + KGlobal kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvBase); + TLOAD(kMatTile, kGlobal); + auto matmulEvent = EVENT_ID1; + set_flag(PIPE_FIX, PIPE_M, matmulEvent); + wait_flag(PIPE_FIX, PIPE_M, matmulEvent); + set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + set_flag(PIPE_M, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); + PtoPaLoadCbufToCbRaw(rightTile, kMatTile, 0, + static_cast((kHeadDim * kTileTokens) / 256), 1); + set_flag(PIPE_MTE1, PIPE_M, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); + TMATMUL(accTile, qLeftTile, rightTile); + set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + set_flag(PIPE_M, PIPE_FIX, matmulEvent); + wait_flag(PIPE_M, PIPE_FIX, matmulEvent); + ScoreGlobal scoreGlobal(reinterpret_cast<__gm__ float *>(scoreBase + + static_cast(slot) * scoreSlotBytes + + static_cast(headGroup) * scoreGroupBytes + + static_cast(headInGroupBase) * scoreHeadBytes)); + TSTORE(scoreGlobal, accTile); + } + } + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); + } + + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!activeGroup) { + continue; + } + const int32_t blockId = LoadBlockTable(blockTablesGm, + static_cast(batchIndex) * maxBlocksPerQuery + startTile + tile); + ProbGlobal probGlobal(reinterpret_cast<__gm__ half *>(probBase + + static_cast(slot) * probSlotBytes + + static_cast(headGroup) * probGroupBytes)); + TLOAD(pMatTile, probGlobal); + pipe_barrier(PIPE_ALL); + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { + const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; + if (baseHeadLocal >= curHeadNum) { + break; + } + const int32_t baseHead = startHead + baseHeadLocal; + const int32_t kvHead = baseHead / headsPerKv; + const int64_t kvBase = + (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; + PtoPaLoadNzHeadGroupToCa(pLeftTile, pMatTile, static_cast(headInGroupBase), + static_cast(kTileTokens / 16)); + VGlobal vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvBase); + TLOAD(vMatTile, vGlobal); + auto matmulEvent = EVENT_ID1; + set_flag(PIPE_FIX, PIPE_M, matmulEvent); + wait_flag(PIPE_FIX, PIPE_M, matmulEvent); + set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + set_flag(PIPE_M, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); + PtoPaLoadCbufToCbTranspose128Raw(rightTile, vMatTile); + set_flag(PIPE_MTE1, PIPE_M, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); + TMATMUL(accTile, pLeftTile, rightTile); + set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + set_flag(PIPE_M, PIPE_FIX, matmulEvent); + wait_flag(PIPE_M, PIPE_FIX, matmulEvent); + OutGlobal outGlobal(reinterpret_cast<__gm__ float *>(outBase + + static_cast(slot) * outSlotBytes + + static_cast(headGroup) * outGroupBytes + + static_cast(headInGroupBase) * outHeadBytes)); + TSTORE(outGlobal, accTile); + } + } + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); + } + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 0)); + if (hasStage2) { + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 1)); + } + } + + } + pipe_barrier(PIPE_ALL); +} +#endif + +#ifdef __DAV_C220_VEC__ +AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV(__gm__ uint8_t *oGm, __gm__ uint8_t *sGm, + __gm__ uint8_t *pGm, __gm__ uint8_t *oTmpGm, __gm__ uint8_t *oCoreTmpGm, __gm__ uint8_t *lGm, + __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, uint32_t subBlockId) +{ + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + constexpr int32_t kHeadGroup = 16; + constexpr int32_t kMaxHeadsPerProcess = 32; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || + ctx.blockSize != kTileTokens || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const bool activeSubBlock = subBlockId < 2; + const bool combineSubBlock = subBlockId == 0; + const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); + __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); + + using VecFloat128 = Tile; + using VecHalf128 = Tile; + using VecHalf256 = Tile; + using VecFloat8 = Tile; + using VecFloat4x128 = Tile; + using VecFloat4x1 = Tile; + using VecFloat1x8 = Tile; + using ScoreGlobal = GlobalTensor, + Stride>; + using ScoreRowsGlobal = GlobalTensor, + Stride<1, 1, 1, kHeadDim, 1>>; + using ProbGlobal = GlobalTensor, Stride<256, 256, 256, 256, 1>>; + using ProbRowGlobal = GlobalTensor, + Stride>; + using OutGlobal = GlobalTensor, + Stride>; + using OutputGlobal = GlobalTensor, + Stride>; + using OutRowsGlobal = GlobalTensor, + Stride<1, 1, 1, kHeadDim, 1>>; + + VecFloat128 weightedTile; + VecFloat128 scoreTile; + VecFloat128 scoreWorkTile; + VecFloat128 pvTile; + VecHalf128 probHalfTile; + VecHalf128 outHalfTile; + VecHalf256 probTile; + VecFloat8 rowMaxTile; + VecFloat8 rowSumTile; + VecFloat8 scalarMathTile; + VecFloat4x128 scoreRowsTile; + VecFloat4x128 scoreRowsWorkTile; + VecFloat128 probRowView; // 1x128 view aliasing one row of scoreRowsWorkTile for TCVT + VecFloat4x128 pvRowsTile; + VecFloat4x1 rowMaxRowsTile; + VecFloat4x1 maxStateRowsTile; + VecFloat4x1 newMaxRowsTile; + VecFloat4x1 oldScaleRowsTile; + VecFloat4x1 rowSumRowsTile; + VecFloat4x1 sumStateRowsTile; + VecFloat1x8 rowMaxRowsView; + VecFloat1x8 maxStateRowsView; + VecFloat1x8 newMaxRowsView; + VecFloat1x8 oldScaleRowsView; + VecFloat1x8 rowSumRowsView; + VecFloat1x8 sumStateRowsView; + TASSIGN(weightedTile, 0x0000); + TASSIGN(scoreTile, 0x0800); + TASSIGN(scoreWorkTile, 0x1000); + TASSIGN(pvTile, 0x1800); + TASSIGN(probHalfTile, 0x2000); + TASSIGN(outHalfTile, 0x2000); + TASSIGN(probTile, 0x2800); + TASSIGN(rowMaxTile, 0x3000); + TASSIGN(rowSumTile, 0x3040); + TASSIGN(scalarMathTile, 0x3080); + TASSIGN(scoreRowsTile, 0x0800); + constexpr uint32_t kScoreRowsWorkUb = 0x1000; + TASSIGN(scoreRowsWorkTile, kScoreRowsWorkUb); + TASSIGN(pvRowsTile, 0x1800); + TASSIGN(rowMaxRowsTile, 0x3000); + TASSIGN(maxStateRowsTile, 0x3020); + TASSIGN(newMaxRowsTile, 0x3040); + TASSIGN(oldScaleRowsTile, 0x3060); + TASSIGN(rowSumRowsTile, 0x3080); + TASSIGN(sumStateRowsTile, 0x30a0); + TRESHAPE(rowMaxRowsView, rowMaxRowsTile); + TRESHAPE(maxStateRowsView, maxStateRowsTile); + TRESHAPE(newMaxRowsView, newMaxRowsTile); + TRESHAPE(oldScaleRowsView, oldScaleRowsTile); + TRESHAPE(rowSumRowsView, rowSumRowsTile); + TRESHAPE(sumStateRowsView, sumStateRowsTile); + + constexpr uint32_t kAccumUbBase = 0x4000; + constexpr uint32_t kAccumHeadBytes = kHeadDim * sizeof(float); + constexpr int64_t scoreHeadBytes = 4 * kTileTokens * sizeof(float); + constexpr int64_t probHeadBytes = 256 * sizeof(half); + constexpr int64_t outHeadBytes = 4 * kHeadDim * sizeof(float); + constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; + constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; + constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; + const int64_t scoreSlotBytes = static_cast(maxHeadGroups) * scoreGroupBytes; + const int64_t probSlotBytes = static_cast(maxHeadGroups) * probGroupBytes; + const int64_t outSlotBytes = static_cast(maxHeadGroups) * outGroupBytes; + __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; + __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; + __gm__ uint8_t *outTmpBase = oTmpGm + workerIdx * outSlotBytes * 2; + + const int64_t processRounds = (processNum + workerNum - 1) / workerNum; + const int32_t stageTileCount = (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; + for (int64_t processRound = 0; processRound < processRounds; ++processRound) { + const int64_t process = processRound * workerNum + workerIdx; + bool validProcess = process < processNum; + int32_t batchIndex = 0; + int32_t curHeadNum = 0; + int32_t startHead = 0; + int32_t tileCount = 0; + int32_t curKvSeqLen = 0; + int32_t curSplit = 0; + uint64_t lBase = 0; + uint64_t oFdBase = 0; + if (validProcess) { + int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + curSplit = static_cast(process % ctx.kvSplitCoreNum); + validProcess = kvSeqLen > 0 && curSplit < kvLoop; + if (validProcess) { + const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + startHead = curHeadBlock * formerHeadSplit; + curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; + lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + } + } + + const int32_t subHeadBegin = subBlockId == 0 ? 0 : curHeadNum / 2; + const int32_t subHeadEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; + float maxScore[kMaxHeadsPerProcess]; + float sumExp[kMaxHeadsPerProcess]; + float oldScaleByHead[kMaxHeadsPerProcess]; + if (activeSubBlock && validProcess) { + for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; ++headLocal) { + maxScore[headLocal] = -3.4028234663852886e38f; + sumExp[headLocal] = 0.0f; + oldScaleByHead[headLocal] = 0.0f; + TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + } + } + + for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; tilePairBase += 2) { + const bool hasStage2 = (tilePairBase + 1) < stageTileCount; + const bool activeTile0 = validProcess && tilePairBase < tileCount; + const bool activeTile1 = validProcess && hasStage2 && (tilePairBase + 1) < tileCount; + + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!(activeSubBlock && activeGroup)) { + continue; + } + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { + const int32_t headLocal = groupHeadBase + headInGroupBase; + if (headLocal >= curHeadNum) { + break; + } + if (headsPerKv != 4 || headLocal < subHeadBegin || headLocal + 4 > subHeadEnd) { + continue; + } + ScoreRowsGlobal scoreGlobal(reinterpret_cast<__gm__ float *>(scoreBase + + static_cast(slot) * scoreSlotBytes + + static_cast(headGroup) * scoreGroupBytes + + static_cast(headInGroupBase) * scoreHeadBytes)); + TLOAD(scoreRowsTile, scoreGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(scoreRowsTile, scoreRowsTile, ctx.scale); + pipe_barrier(PIPE_V); + for (int32_t row = 0; row < 4; ++row) { + maxStateRowsTile.data()[row] = maxScore[headLocal + row]; + sumStateRowsTile.data()[row] = sumExp[headLocal + row]; + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TROWMAX(rowMaxRowsTile, scoreRowsTile, scoreRowsWorkTile); + pipe_barrier(PIPE_V); + TMAX(newMaxRowsView, rowMaxRowsView, maxStateRowsView); + pipe_barrier(PIPE_V); + TSUB(oldScaleRowsView, maxStateRowsView, newMaxRowsView); + pipe_barrier(PIPE_V); + TEXP(oldScaleRowsView, oldScaleRowsView); + pipe_barrier(PIPE_V); + if (tile == 0) { + TEXPANDS(oldScaleRowsView, 0.0f); + pipe_barrier(PIPE_V); + } + TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile); + pipe_barrier(PIPE_V); + TEXP(scoreRowsWorkTile, scoreRowsWorkTile); + pipe_barrier(PIPE_V); + TROWSUM(rowSumRowsTile, scoreRowsWorkTile, scoreRowsTile); + pipe_barrier(PIPE_V); + TMUL(sumStateRowsView, sumStateRowsView, oldScaleRowsView); + pipe_barrier(PIPE_V); + TADD(sumStateRowsView, sumStateRowsView, rowSumRowsView); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + __gm__ half *probScratch = reinterpret_cast<__gm__ half *>(probBase + + static_cast(slot) * probSlotBytes + + static_cast(headGroup) * probGroupBytes); + for (int32_t row = 0; row < 4; ++row) { + maxScore[headLocal + row] = newMaxRowsTile.data()[row]; + sumExp[headLocal + row] = sumStateRowsTile.data()[row]; + oldScaleByHead[headLocal + row] = oldScaleRowsTile.data()[row]; + TASSIGN(probRowView, kScoreRowsWorkUb + + static_cast(row) * kTileTokens * sizeof(float)); + PtoPaConvF32ToF16(probHalfTile, probRowView, 2); + pipe_barrier(PIPE_V); + ProbRowGlobal probRowGlobal(probScratch + + static_cast(headInGroupBase + row) * kTileTokens); + TSTORE(probRowGlobal, probHalfTile); + } + } + } + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromVec(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); + } + + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!(activeSubBlock && activeGroup)) { + continue; + } + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { + const int32_t headLocal = groupHeadBase + headInGroupBase; + if (headLocal >= curHeadNum) { + break; + } + if (headsPerKv != 4 || headLocal < subHeadBegin || headLocal + 4 > subHeadEnd) { + continue; + } + for (int32_t row = 0; row < 4; ++row) { + oldScaleRowsTile.data()[row] = oldScaleByHead[headLocal + row]; + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + OutRowsGlobal outGlobal(reinterpret_cast<__gm__ float *>(outTmpBase + + static_cast(slot) * outSlotBytes + + static_cast(headGroup) * outGroupBytes + + static_cast(headInGroupBase) * outHeadBytes)); + TLOAD(pvRowsTile, outGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + VecFloat4x128 weightedRowsTile; + TASSIGN(weightedRowsTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); + TROWEXPANDMUL(weightedRowsTile, weightedRowsTile, oldScaleRowsTile); + pipe_barrier(PIPE_V); + TADD(weightedRowsTile, weightedRowsTile, pvRowsTile); + pipe_barrier(PIPE_V); + } + } + PtoPaSignalFreeFromVec(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, slot)); + } + } + + + if (activeSubBlock && validProcess) { + for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; ++headLocal) { + const int32_t head = startHead + headLocal; + const float invSum = sumExp[headLocal] > 0.0f ? 1.0f / sumExp[headLocal] : 0.0f; + const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(curSplit) * ctx.headDim; + const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; + partialL[lOffset] = maxScore[headLocal] + PtoLogScalar(scalarMathTile, sumExp[headLocal]); + TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); + TMULS(weightedTile, weightedTile, invSum); + pipe_barrier(PIPE_V); + OutGlobal weightedGlobal(reinterpret_cast<__gm__ float *>(partialOut + outOffset)); + TSTORE(weightedGlobal, weightedTile); + } + } + } + + DdrBarrierBeforePtoFfts(); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); + wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); + if (!combineSubBlock) { + pipe_barrier(PIPE_ALL); + return; + } + const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + const int64_t combineWorkerIdx = workerIdx; + const int64_t combineWorkerNum = workerNum; + for (int64_t row = combineWorkerIdx; row < totalRows; row += combineWorkerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; + } + const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + __ubuf__ float *splitScale = reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3200); + __ubuf__ float *splitReduce = reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3400); + const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum; + const uint32_t lRemain = static_cast(kvLoop % 8); + copy_gm_to_ubuf_align_b32(splitScale, partialL + lOffset, 0, 1, static_cast(kvLoop * 4), 0, + lRemain == 0 ? 0 : 8 - lRemain, 0, 0); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + PtoPaSetFloatVectorMask(static_cast(kvLoop)); + vcmax(splitReduce, splitScale, 1, 1, 1, 8, static_cast(ONLY_VALUE)); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float lMax = splitReduce[0]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + vadds(splitScale, splitScale, -lMax, 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + vexp(splitScale, splitScale, 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + vcadd(splitReduce, splitScale, 1, 1, 1, 8, 0); + pipe_barrier(PIPE_V); + set_vector_mask(static_cast(-1), static_cast(-1)); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float denom = splitReduce[0]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; + PtoPaSetFloatVectorMask(static_cast(kvLoop)); + vmuls(splitScale, splitScale, static_cast(invDenom), 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + set_vector_mask(static_cast(-1), static_cast(-1)); + TASSIGN(weightedTile, 0x0000); + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + constexpr uint32_t kCombineOutUb = 0x4000; + __ubuf__ float *splitOut = reinterpret_cast<__ubuf__ float *>((uintptr_t)kCombineOutUb); + const uint64_t firstOutOffset = oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum; + copy_gm_to_ubuf_align_b32(splitOut, partialOut + firstOutOffset, 0, 1, + static_cast(kvLoop * ctx.headDim * 4), 0, 0, 0, 0); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + for (int32_t split = 0; split < kvLoop; ++split) { + TASSIGN(pvTile, kCombineOutUb + static_cast(split) * kHeadDim * sizeof(float)); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float splitWeight = splitScale[split]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TMULS(pvTile, pvTile, splitWeight); + pipe_barrier(PIPE_V); + TADD(weightedTile, weightedTile, pvTile); + pipe_barrier(PIPE_V); + } + const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + OutputGlobal outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); + TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + TSTORE(outGlobal, outHalfTile); + pipe_barrier(PIPE_ALL); + } + pipe_barrier(PIPE_ALL); +} +#endif + +#ifdef __DAV_C220_VEC__ +AICORE inline void RunPtoPagedAttentionDecodeSplitKV( + __gm__ uint8_t *qGm, + __gm__ uint8_t *kGm, + __gm__ uint8_t *vGm, + __gm__ uint8_t *blockTablesGm, + __gm__ uint8_t *oGm, + __gm__ uint8_t *oCoreTmpGm, + __gm__ uint8_t *lGm, + __gm__ uint8_t *tilingParaGm, + int64_t workerIdx, + int64_t workerNum, + uint32_t subBlockId) +{ + constexpr int32_t kHeadDim = 128; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || + ctx.blockSize != PA_TILE_TOKENS || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : + (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); + __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); + + using VecHalf128 = Tile; + using VecFloat128 = Tile; + using VecFloat8 = Tile; + using GlobalHalf128 = + GlobalTensor, Stride<1, 1, 1, kHeadDim, 1>>; + + VecHalf128 qHalfTile; + VecFloat128 qFloatTile; + VecHalf128 kHalfTile; + VecFloat128 kFloatTile; + VecFloat128 qkProductTile; + VecFloat8 scoreTile; + VecFloat128 reduceTmpTile; + VecHalf128 vHalfTile; + VecFloat128 vFloatTile; + VecFloat128 weightedTile; + VecFloat8 scalarMathTile; + TASSIGN(qHalfTile, 0x0800); + TASSIGN(qFloatTile, 0x1000); + TASSIGN(kHalfTile, 0x1800); + TASSIGN(kFloatTile, 0x2000); + TASSIGN(qkProductTile, 0x2800); + TASSIGN(scoreTile, 0x3000); + TASSIGN(reduceTmpTile, 0x3800); + TASSIGN(vHalfTile, 0x4000); + TASSIGN(vFloatTile, 0x4800); + TASSIGN(weightedTile, 0x5000); + TASSIGN(scalarMathTile, 0x5800); + + for (int64_t process = workerIdx; process < processNum; process += workerNum) { + int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; + } + + const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const int32_t curSplit = static_cast(process % ctx.kvSplitCoreNum); + if (curSplit >= kvLoop) { + continue; + } + + const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + const int32_t startHead = curHeadBlock * formerHeadSplit; + int32_t curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + int32_t curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + const int32_t headBegin = subBlockId == 0 ? 0 : curHeadNum / 2; + const int32_t headEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; + for (int32_t headLocal = headBegin; headLocal < headEnd; ++headLocal) { + const int32_t head = startHead + headLocal; + const int32_t kvHead = head / headsPerKv; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qHalfTile, qGlobal); + pipe_barrier(PIPE_ALL); + TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + for (int32_t relPos = 0; relPos < curKvSeqLen; ++relPos) { + const int32_t pos = startKv + relPos; + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, + offsetInBlock); + const int64_t kvOffset = (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * + ctx.kvHeads + kvHead) * ctx.headDim); + + GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); + TLOAD(kHalfTile, kGlobal); + pipe_barrier(PIPE_ALL); + TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + TMUL(qkProductTile, qFloatTile, kFloatTile); + pipe_barrier(PIPE_V); + TROWSUM(scoreTile, qkProductTile, reduceTmpTile); + pipe_barrier(PIPE_V); + const float score = scoreTile.data()[0] * ctx.scale; + const float newMax = score > maxScore ? score : maxScore; + float oldScale = 0.0f; + if (relPos != 0) { + scalarMathTile.data()[0] = maxScore - newMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + oldScale = scalarMathTile.data()[0]; + } + scalarMathTile.data()[0] = score - newMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + const float probUnnorm = scalarMathTile.data()[0]; + sumExp = sumExp * oldScale + probUnnorm; + + GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); + TMULS(weightedTile, weightedTile, oldScale); + pipe_barrier(PIPE_V); + TLOAD(vHalfTile, vGlobal); + pipe_barrier(PIPE_ALL); + TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + TAXPY(weightedTile, vFloatTile, probUnnorm); + pipe_barrier(PIPE_V); + maxScore = newMax; + } + + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(curSplit) * ctx.headDim; + const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; + float logSumExp = -3.4028234663852886e38f; + if (sumExp > 0.0f) { + scalarMathTile.data()[0] = sumExp; + TLOG(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + logSumExp = scalarMathTile.data()[0]; + } + partialL[lOffset] = maxScore + logSumExp; + for (int32_t dim = 0; dim < kHeadDim; ++dim) { + partialOut[outOffset + dim] = weightedTile.data()[dim] * invSum; + } + } + } + + DdrFenceBeforePtoAivReduce(); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); + wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); + + const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + const int64_t combineWorkerIdx = workerIdx * 2 + static_cast(subBlockId); + const int64_t combineWorkerNum = workerNum * 2; + for (int64_t row = combineWorkerIdx; row < totalRows; row += combineWorkerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; + } + + const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + float lMax = -3.4028234663852886e38f; + for (int32_t split = 0; split < kvLoop; ++split) { + const float lValue = partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + split]; + lMax = lValue > lMax ? lValue : lMax; + } + float denom = 0.0f; + float splitScale[64]; + for (int32_t split = 0; split < kvLoop; ++split) { + const float lValue = partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + split]; + scalarMathTile.data()[0] = lValue - lMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + const float scale = scalarMathTile.data()[0]; + splitScale[split] = scale; + denom += scale; + } + const float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; + for (int32_t split = 0; split < kvLoop; ++split) { + splitScale[split] *= invDenom; + } + const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < kHeadDim; ++dim) { + float value = 0.0f; + for (int32_t split = 0; split < kvLoop; ++split) { + const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(split) * ctx.headDim; + value += partialOut[outOffset + dim] * splitScale[split]; + } + StoreOutputFp16(oGm, outBase + dim, value); + } + } + pipe_barrier(PIPE_ALL); +} +#endif + +AICORE inline void RunPtoPagedAttentionDecode( + __gm__ uint8_t *qGm, + __gm__ uint8_t *kGm, + __gm__ uint8_t *vGm, + __gm__ uint8_t *blockTablesGm, + __gm__ uint8_t *oGm, + __gm__ uint8_t *tilingParaGm, + int64_t workerIdx, + int64_t workerNum) +{ + constexpr int32_t kMaxHeadDim = 128; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : + (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + if (workerIdx < 0 || workerNum <= 0) { + pipe_barrier(PIPE_ALL); + return; + } + + if (ctx.headDim > kMaxHeadDim) { + pipe_barrier(PIPE_ALL); + return; + } + + using DecodeScalarTile = Tile; + DecodeScalarTile decodeScalarMathTile; + TASSIGN(decodeScalarMathTile, 0x5800); + + if (ctx.headDim == kMaxHeadDim) { + constexpr uint64_t kWeightedUb = 0x0000; + constexpr uint64_t kQHalfUb = 0x0800; + constexpr uint64_t kQFloatUb = 0x1000; + constexpr uint64_t kKHalfUb = 0x1800; + constexpr uint64_t kKFloatUb = 0x2000; + constexpr uint64_t kQKProductUb = 0x2800; + constexpr uint64_t kScoreUb = 0x3000; + constexpr uint64_t kReduceTmpUb = 0x3800; + constexpr uint64_t kVHalfUb = 0x4000; + constexpr uint64_t kVFloatUb = 0x4800; + constexpr uint64_t kOutHalfUb = 0x5000; + + using VecHalf128 = Tile; + using VecFloat128 = Tile; + using VecFloat8 = Tile; + using GlobalHalf128 = + GlobalTensor, Stride<1, 1, 1, kMaxHeadDim, 1>>; + + VecFloat128 weightedTile; + VecHalf128 qHalfTile; + VecFloat128 qFloatTile; + VecHalf128 kHalfTile; + VecFloat128 kFloatTile; + VecFloat128 qkProductTile; + VecFloat8 scoreTile; + VecFloat128 reduceTmpTile; + VecHalf128 vHalfTile; + VecFloat128 vFloatTile; + VecHalf128 outHalfTile; + + TASSIGN(weightedTile, kWeightedUb); + TASSIGN(qHalfTile, kQHalfUb); + TASSIGN(qFloatTile, kQFloatUb); + TASSIGN(kHalfTile, kKHalfUb); + TASSIGN(kFloatTile, kKFloatUb); + TASSIGN(qkProductTile, kQKProductUb); + TASSIGN(scoreTile, kScoreUb); + TASSIGN(reduceTmpTile, kReduceTmpUb); + TASSIGN(vHalfTile, kVHalfUb); + TASSIGN(vFloatTile, kVFloatUb); + TASSIGN(outHalfTile, kOutHalfUb); + + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t kvHead = head / headsPerKv; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qHalfTile, qGlobal); + pipe_barrier(PIPE_ALL); + TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_ALL); + + for (int32_t pos = 0; pos < kvSeqLen; ++pos) { + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, + offsetInBlock); + const int64_t kvOffset = (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * + ctx.kvHeads + kvHead) * ctx.headDim); + + GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); + TLOAD(kHalfTile, kGlobal); + pipe_barrier(PIPE_ALL); + TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + TMUL(qkProductTile, qFloatTile, kFloatTile); + pipe_barrier(PIPE_ALL); + TROWSUM(scoreTile, qkProductTile, reduceTmpTile); + pipe_barrier(PIPE_ALL); + const float rawScore = scoreTile.data()[0]; + const float score = rawScore * ctx.scale; + + const float newMax = score > maxScore ? score : maxScore; + const float oldScale = (pos == 0) ? 0.0f : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); + const float probUnnorm = PtoExpScalar(decodeScalarMathTile, score - newMax); + sumExp = sumExp * oldScale + probUnnorm; + + GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); + TMULS(weightedTile, weightedTile, oldScale); + pipe_barrier(PIPE_ALL); + TLOAD(vHalfTile, vGlobal); + pipe_barrier(PIPE_ALL); + TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + TAXPY(weightedTile, vFloatTile, probUnnorm); + pipe_barrier(PIPE_ALL); + maxScore = newMax; + } + + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + GlobalHalf128 outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); + TMULS(weightedTile, weightedTile, invSum); + pipe_barrier(PIPE_ALL); + TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + pipe_barrier(PIPE_ALL); + TSTORE(outGlobal, outHalfTile); + pipe_barrier(PIPE_ALL); + } + + pipe_barrier(PIPE_ALL); + return; + } + + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t kvHead = head / headsPerKv; + + float qValues[kMaxHeadDim]; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + qValues[dim] = LoadFp16(qGm, qBase + dim); + } + + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + float weighted[kMaxHeadDim]; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + weighted[dim] = 0.0f; + } + + for (int32_t pos = 0; pos < kvSeqLen; ++pos) { + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, offsetInBlock); + const float score = ComputeScoreByBlock(qValues, kGm, blockId, offsetInBlock, ctx.blockSize, kvHead, + ctx.headDim, ctx.kvHeads, ctx.scale); + const bool updateMax = score > maxScore; + const float newMax = updateMax ? score : maxScore; + const float oldScale = (pos == 0) ? 0.0f : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); + const float probUnnorm = PtoExpScalar(decodeScalarMathTile, score - newMax); + sumExp = sumExp * oldScale + probUnnorm; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + const float value = LoadPagedVByBlock(vGm, blockId, offsetInBlock, ctx.blockSize, ctx.kvHeads, kvHead, ctx.headDim, dim); + weighted[dim] = weighted[dim] * oldScale + probUnnorm * value; + } + maxScore = newMax; + } + + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + StoreOutputFp16(oGm, outBase + dim, weighted[dim] * invSum); + } + } + + pipe_barrier(PIPE_ALL); +} + +#endif diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py new file mode 100644 index 00000000..5daf9c13 --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py @@ -0,0 +1,484 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +""" +Python port of the PagedAttention ND tiling logic from ascend-transformer. +""" + +from __future__ import annotations + +import struct + +import torch + +TILING_HEAD_SIZE = 44 +TILING_PARA_SIZE = 17 + + +TILING_BATCH = 0 +TILING_NUMHEADS = 1 +TILING_HEADDIM = 2 +TILING_NUMBLOKS = 3 +TILING_BLOCKSIZE = 4 +TILING_MAXBLOCKS = 5 +TILING_TOR = 6 +TILING_KVHEADS = 7 +TILING_FORMER_BATCH = 8 +TILING_FORMER_HEAD = 9 +TILING_TAIL_BATCH = 10 +TILING_TAIL_HEAD = 11 +TILING_HEADNUM_MOVE = 12 +TILING_MASK_MAX_LEN = 13 +TILING_BATCH_STRIDE = 14 +TILING_HEAD_STRIDE = 15 +TILING_KEY = 16 +TILING_HEADSIZE = 17 +TILING_PARASIZE = 18 +TILING_GROUPNUM = 19 +TILING_FORMER_GROUP_MOVE = 20 +TILING_TAIL_GROUP_MOVE = 21 +TILING_MAX_KVSEQLEN = 22 +TILING_KVSPLIT = 23 +TILING_KVCORENUM = 24 +TILING_BLOCKSIZE_CALC = 25 +TILING_TOTAL_BLOCK_NUM = 26 +TILING_PREFILL_BS = 27 +TILING_DECODER_BS = 28 +TILING_HEADDIM_V = 29 +TILING_MODCOEF = 30 +TILING_DIVCOEF = 31 +TILING_QHEADORIGINAL = 32 +TILING_COMPRESSHEAD = 33 +TILING_QUANTYPE = 34 +TILING_DATA_SHAPE_TYPE = 35 +TILING_SCALETYPE = 36 +TILING_MASK_TYPE_ND = 37 +TILING_HEADDIM_K_SPLIT = 38 +TILING_HEADDIM_V_SPLIT = 39 +TILING_HEADDIM_V_SPLIT_VECTOR_FORMER = 40 +TILING_HEADDIM_V_SPLIT_VECTOR_TAIL = 41 + + +WORKSPACE_BLOCK_SIZE_DB = 65536 +BLOCK_SIZE_ALIGN = 16 +SPLITKV_RATIO = 0.8 +SPLITHEAD_RATIO = 0.9 +HEADNUM_LIMIT = 128 +HEADNUM_LIMIT_REGU = 32 +EMBEDDING_LIMIT = 128 +MLA_THRESHOLD = 256 +KV_SEQLEN_SLICE = 128 +KV_SEQLEN_SLICE_256 = 256 +KV_SEQLEN_SLICE_512 = 512 +BLOCK_LIMIT = 128 * 128 +BLOCK_LIMIT_NO_PINGPONG_UINT8 = 128 * 256 * 2 +PP_MM = [16, 32, 48, 64, 80, 96, 112, 128] +PP_BLOCK_BUFFER_SIZE = 128 * 128 +SPECIAL_NUM_TOKENS = 16 +SPECIAL_NUM_HEADS = 32 + + +def _round_up(v: int, align: int) -> int: + return ((v + align - 1) // align) * align + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def _f32_bits(f: float) -> int: + return struct.unpack("I", struct.pack("f", float(f)))[0] + + +def _hi32(v: int) -> int: + return (v >> 32) & 0xFFFFFFFF + + +def _lo32(v: int) -> int: + return v & 0xFFFFFFFF + + +def _u32_to_i32(v: int) -> int: + v &= 0xFFFFFFFF + return v - 0x100000000 if v & 0x80000000 else v + + +def _calcu_head_nd(num_heads: int, kv_heads: int, former_head_split: int, tail_head_split: int): + """CalcuHeadNd: compute group move factors.""" + kv_real = kv_heads if kv_heads > 0 else num_heads + group_num = num_heads // kv_real + + former_group_move = 1 + if former_head_split % group_num == 0: + former_group_move = group_num + elif former_head_split < group_num and (kv_real == 1 or group_num % former_head_split == 0): + former_group_move = former_head_split + + tail_group_move = 1 + if tail_head_split > 0: + if tail_head_split % group_num == 0: + tail_group_move = group_num + elif tail_head_split < group_num and (kv_real == 1 or group_num % tail_head_split == 0): + tail_group_move = tail_head_split + + return group_num, former_group_move, tail_group_move + + +def _split_core_bn_nd( + num_heads: int, + kv_heads: int, + decoder_batch: int, + block_dim: int, + max_kv_seq_len: int, + block_size: int, + is_mla: bool, + is_quant: bool, +): + """SplitCoreBNND: split by (Batch, Head) dimensions.""" + kv_real = kv_heads if kv_heads > 0 else num_heads + core_per_batch = _ceil_div(block_dim, decoder_batch) + + if block_dim * SPLITKV_RATIO <= decoder_batch <= block_dim and is_quant and kv_real == 1: + core_per_batch = 1 + + head_split = _ceil_div(num_heads, core_per_batch) + head_split = min(head_split, HEADNUM_LIMIT_REGU) + + if decoder_batch == SPECIAL_NUM_TOKENS and num_heads == SPECIAL_NUM_HEADS: + head_split = 8 + + loop_len = _ceil_div(num_heads, head_split) + block = loop_len * decoder_batch + + former_batch = decoder_batch + tail_batch = 0 + former_head_split = head_split + tail_head_split = 0 + + if block > block_dim: + process_loop = block // block_dim + former_batch = process_loop * block_dim // loop_len + tail_batch = decoder_batch - former_batch + process_remain = tail_batch * loop_len + adj_last_head = (process_remain < SPECIAL_NUM_TOKENS) and (tail_batch > 0) + if (num_heads != kv_real) and not (kv_real == 1): + adj_last_head = adj_last_head and (tail_batch <= block_dim // 2) + if adj_last_head: + if is_mla and is_quant: + core_per_batch2 = block_dim // tail_batch + else: + core_per_batch2 = _ceil_div(block_dim, tail_batch) + tail_head_split = _ceil_div(num_heads, core_per_batch2) + tail_head_split = min(tail_head_split, HEADNUM_LIMIT_REGU) + else: + former_batch = decoder_batch + tail_batch = 0 + + eff_block_dim = min(block_dim, block) + kv_split_per_core = _round_up(max_kv_seq_len, block_size) + kv_split_core_num = 1 + + group_num, former_gm, tail_gm = _calcu_head_nd(num_heads, kv_real, former_head_split, tail_head_split) + return ( + eff_block_dim, + former_batch, + former_head_split, + tail_batch, + tail_head_split, + kv_split_per_core, + kv_split_core_num, + group_num, + former_gm, + tail_gm, + ) + + +def _split_core_bns_nd( + num_heads: int, + kv_heads: int, + decoder_batch: int, + block_dim: int, + max_kv_seq_len: int, + block_size: int, + is_long_seq: bool, +): + """SplitCoreBNSND: split by (Batch, Head, KVseq) dimensions.""" + kv_real = kv_heads if kv_heads > 0 else num_heads + kv_seq_aligned = _round_up(max_kv_seq_len, block_size) + kv_seq_block_num = kv_seq_aligned // block_size + + if is_long_seq: + kv_block_per_core = _ceil_div(kv_seq_block_num, block_dim) + else: + core_per_batch = _ceil_div(block_dim, decoder_batch) + kv_block_per_core = _ceil_div(kv_seq_block_num, core_per_batch) + + kv_split_per_core = kv_block_per_core * block_size + kv_split_core_num = _ceil_div(kv_seq_aligned, kv_split_per_core) + + core_per_kv = 1 + if decoder_batch * kv_split_core_num < block_dim: + core_per_kv = _ceil_div(block_dim, decoder_batch * kv_split_core_num) + + head_split = _ceil_div(num_heads, core_per_kv) + head_split = min(head_split, HEADNUM_LIMIT_REGU) + + head_core_num = _ceil_div(num_heads, head_split) + block = head_core_num * decoder_batch * kv_split_core_num + eff_block_dim = min(block_dim, block) + + former_batch = decoder_batch + tail_batch = 0 + former_head_split = head_split + tail_head_split = 0 + + group_num, former_gm, tail_gm = _calcu_head_nd(num_heads, kv_real, former_head_split, tail_head_split) + return ( + eff_block_dim, + former_batch, + former_head_split, + tail_batch, + tail_head_split, + kv_split_per_core, + kv_split_core_num, + group_num, + former_gm, + tail_gm, + ) + + +def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 + batch: int, + kv_seq_lens: list[int], + num_heads: int, + kv_heads: int, + head_dim: int, + head_dim_v: int, + num_blocks: int, + block_size: int, + max_blocks_per_query: int, + scale: float, + block_dim: int, + device: str = "npu", + dtype: torch.dtype = torch.float16, +) -> tuple[torch.Tensor, int]: + """ + Build PAGED_ATTENTION_MASK_ND tiling for decode-only GQA. + + Args: + batch: number of sequences + kv_seq_lens: KV context length per sequence + num_heads: number of Q attention heads + kv_heads: number of KV heads (GQA), 0 means = num_heads + head_dim: head dimension for QK + head_dim_v: head dimension for V (== head_dim for standard GQA) + num_blocks: total number of KV cache blocks + block_size: tokens per KV cache block + max_blocks_per_query: max blocks in block_table row + scale: softmax scale (1/sqrt(head_dim) typically) + block_dim: number of cube cores (from device properties) + device: torch device string + dtype: fp16 or bf16 (selects tiling key 0 or 1) + + Returns: + (tiling_tensor, effective_block_dim) + """ + kv_real = kv_heads if kv_heads > 0 else num_heads + max_kv = max(kv_seq_lens) + is_mla = head_dim > MLA_THRESHOLD or head_dim_v > MLA_THRESHOLD or head_dim != head_dim_v + is_quant = False # fp16/bf16 only + + indices: list[int] = sorted(range(batch), key=lambda i: kv_seq_lens[i]) + + decoder_batch = batch + is_long_seq = max_kv >= KV_SEQLEN_SLICE_512 * 8 + + use_bn = is_mla or (decoder_batch * num_heads >= block_dim * SPLITKV_RATIO and not is_long_seq) + + if use_bn: + (eff_bd, fB, fH, tB, tH, kvSplit, kvCN, gN, fGM, tGM) = _split_core_bn_nd( + num_heads, + kv_real, + decoder_batch, + block_dim, + max_kv, + block_size, + is_mla, + is_quant, + ) + else: + (eff_bd, fB, fH, tB, tH, kvSplit, kvCN, gN, fGM, tGM) = _split_core_bns_nd( + num_heads, + kv_real, + decoder_batch, + block_dim, + max_kv, + block_size, + is_long_seq, + ) + + if ( + head_dim % 16 == 0 + and head_dim <= EMBEDDING_LIMIT + and head_dim_v % 16 == 0 + and head_dim_v <= EMBEDDING_LIMIT + and kv_real == num_heads + and not is_quant + ): + head_num_move = 2 + else: + head_num_move = 1 + + head_dim_k_split = min(head_dim, MLA_THRESHOLD) + head_dim_v_split = min(head_dim_v, MLA_THRESHOLD) + head_dim_v_split_former = min(head_dim_v, MLA_THRESHOLD) if fGM <= 64 else min(head_dim_v, EMBEDDING_LIMIT) + head_dim_v_split_tail = min(head_dim_v, MLA_THRESHOLD) if tGM <= 64 else min(head_dim_v, EMBEDDING_LIMIT) + + if ( + block_size <= KV_SEQLEN_SLICE // 2 + and block_size * 2 * head_dim_k_split <= BLOCK_LIMIT + and block_size * 2 * head_dim_v_split <= BLOCK_LIMIT + ): + block_size_calc = block_size * 2 + elif block_size >= KV_SEQLEN_SLICE and head_dim == KV_SEQLEN_SLICE_256 and head_dim_v == KV_SEQLEN_SLICE_256: + block_size_calc = KV_SEQLEN_SLICE + else: + block_size_calc = block_size + + is_split_key = int(kvCN > 1) + is_split_block = int( + block_size >= KV_SEQLEN_SLICE and head_dim == KV_SEQLEN_SLICE_256 and head_dim_v == KV_SEQLEN_SLICE_256 + ) + type_key = 0 if dtype == torch.float16 else 1 + tiling_key = (is_split_block << 7) + (is_split_key << 4) + type_key + + total_words = TILING_HEAD_SIZE + batch * TILING_PARA_SIZE + tiling = [0] * total_words + + tiling[TILING_BATCH] = batch + tiling[TILING_NUMHEADS] = num_heads + tiling[TILING_HEADDIM] = head_dim + tiling[TILING_NUMBLOKS] = num_blocks + tiling[TILING_BLOCKSIZE] = block_size + tiling[TILING_MAXBLOCKS] = max_blocks_per_query + tiling[TILING_TOR] = _f32_bits(scale) + tiling[TILING_KVHEADS] = kv_real + tiling[TILING_FORMER_BATCH] = fB + tiling[TILING_FORMER_HEAD] = fH + tiling[TILING_TAIL_BATCH] = tB + tiling[TILING_TAIL_HEAD] = tH + tiling[TILING_HEADNUM_MOVE] = head_num_move + tiling[TILING_MASK_MAX_LEN] = 0 + tiling[TILING_BATCH_STRIDE] = 0 + tiling[TILING_HEAD_STRIDE] = 0 + tiling[TILING_KEY] = tiling_key + tiling[TILING_HEADSIZE] = TILING_HEAD_SIZE + tiling[TILING_PARASIZE] = TILING_PARA_SIZE + tiling[TILING_GROUPNUM] = gN + tiling[TILING_FORMER_GROUP_MOVE] = fGM + tiling[TILING_TAIL_GROUP_MOVE] = tGM + tiling[TILING_MAX_KVSEQLEN] = max_kv + tiling[TILING_KVSPLIT] = kvSplit + tiling[TILING_KVCORENUM] = kvCN + tiling[TILING_BLOCKSIZE_CALC] = block_size_calc + tiling[TILING_TOTAL_BLOCK_NUM] = 0 + tiling[TILING_PREFILL_BS] = 0 + tiling[TILING_DECODER_BS] = batch + tiling[TILING_HEADDIM_V] = head_dim_v + tiling[TILING_MODCOEF] = 0xFFFFFFFF + tiling[TILING_DIVCOEF] = 1 + tiling[TILING_QHEADORIGINAL] = num_heads + tiling[TILING_COMPRESSHEAD] = 0 + tiling[TILING_QUANTYPE] = 0 + tiling[TILING_DATA_SHAPE_TYPE] = 0 + tiling[TILING_SCALETYPE] = 0 + tiling[TILING_MASK_TYPE_ND] = 0 + tiling[TILING_HEADDIM_K_SPLIT] = head_dim_k_split + tiling[TILING_HEADDIM_V_SPLIT] = head_dim_v_split + tiling[TILING_HEADDIM_V_SPLIT_VECTOR_FORMER] = head_dim_v_split_former + tiling[TILING_HEADDIM_V_SPLIT_VECTOR_TAIL] = head_dim_v_split_tail + + addr_q = 0 + addr_o = 0 + total_q_blk = 0 + + for seq_idx in range(batch): + kv_seqlen = kv_seq_lens[seq_idx] + q_seqlen = 1 + + q_aligned = _round_up(q_seqlen, BLOCK_SIZE_ALIGN) + m_raw = (PP_BLOCK_BUFFER_SIZE // max(head_dim, block_size) // BLOCK_SIZE_ALIGN) * BLOCK_SIZE_ALIGN + m_ubd = min(m_raw, q_aligned) + m_ubd = max(m_ubd, BLOCK_SIZE_ALIGN) + m_idx = min(7, max(0, m_ubd // 16 - 1)) + m_ubd = PP_MM[m_idx] + + base = TILING_HEAD_SIZE + seq_idx * TILING_PARA_SIZE + tiling[base + 0] = q_seqlen + tiling[base + 1] = kv_seqlen + tiling[base + 2] = m_ubd + tiling[base + 3] = block_size + tiling[base + 4] = _hi32(addr_q) + tiling[base + 5] = _lo32(addr_q) + tiling[base + 6] = _hi32(addr_o) + tiling[base + 7] = _lo32(addr_o) + tiling[base + 8] = seq_idx + tiling[base + 9] = total_q_blk + tiling[base + 10] = 0 + tiling[base + 13] = indices[seq_idx] + tiling[base + 14] = 0 + + addr_q += num_heads * head_dim * q_seqlen + addr_o += num_heads * head_dim_v * q_seqlen + + addr_l = 0 + addr_ofd = 0 + + for seq_idx in range(batch): + kv_seqlen = kv_seq_lens[seq_idx] + if kv_seqlen == 0: + continue + q_seqlen = 1 + base = TILING_HEAD_SIZE + seq_idx * TILING_PARA_SIZE + tiling[base + 11] = _hi32(addr_l) + tiling[base + 12] = _lo32(addr_l) + tiling[base + 15] = _hi32(addr_ofd) + tiling[base + 16] = _lo32(addr_ofd) + addr_l += kvCN * num_heads * q_seqlen + addr_ofd += num_heads * head_dim * q_seqlen + + tiling_i32 = [_u32_to_i32(word) for word in tiling] + tiling_tensor = torch.tensor(tiling_i32, dtype=torch.int32, device=device) + + return tiling_tensor, eff_bd + + +def workspace_sizes( + batch: int, + num_heads: int, + head_dim: int, + head_dim_v: int, + block_dim: int, +) -> dict[str, int]: + """Return byte sizes for each workspace tensor (from PagedAttentionTiling scratch sizes).""" + basic_half = block_dim * WORKSPACE_BLOCK_SIZE_DB * 2 + basic_float = block_dim * WORKSPACE_BLOCK_SIZE_DB * 4 + o_core = batch * num_heads * block_dim * head_dim * 4 + l_size = batch * num_heads * block_dim * 4 + k16 = 2 * block_dim * 256 * num_heads * head_dim * 2 + v16 = 2 * block_dim * 256 * num_heads * head_dim_v * 2 + return { + "s": basic_float, + "p": basic_half, + "o_tmp": basic_float * 2, + "go": basic_float, + "o_core_tmp": max(16, o_core), + "l": max(16, l_size), + "k16": max(16, k16), + "v16": max(16, v16), + } diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp b/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp new file mode 100644 index 00000000..4938fafb --- /dev/null +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +#ifndef PAGED_HATTENTION_H +#define PAGED_HATTENTION_H + +#include + +namespace AtbOps { +constexpr int32_t BLOCK_SIZE = 16; +constexpr int32_t BLOCK_SIZE_32 = 32; +constexpr int32_t TILING_PARA_SIZE = 17; +constexpr int32_t TILING_HEAD_SIZE = 44; +constexpr int32_t TILING_HEAD_SIZE_NZ = 128; +constexpr int32_t TILING_HEAD_SIZE_910A = 192; +constexpr int32_t TILING_PARA_SIZE_NZ = 8; +constexpr int32_t M_LIMIT = 128; +constexpr int32_t FLOAT_LIMIT = 64; +constexpr int32_t MAX_EMBEDDING = 576; +constexpr int32_t ND_BATCH_LIMIT = INT32_MAX; +constexpr int32_t BLOCK_LIMIT = 128 * 128; +constexpr int32_t BLOCK_LIMIT_NO_PINGPONG = 128 * 256; +constexpr int32_t BLOCK_LIMIT_NO_PINGPONG_UINT8 = 128 * 256 * 2; +constexpr int32_t NZ_BLOCK_SIZE = 16; +constexpr int32_t TILING_KEY_ID = 16; +constexpr int32_t MLA_BLOCK_SIZE_LIMIT = 128; +constexpr int32_t MLA_THRESHOLD = 256; +constexpr int32_t PREFILL_BATCH = 27; +constexpr int32_t PARALLEL_MAX_HEAD = 256; +constexpr int32_t PARALLEL_MAX_BLK_SIZE = 128; +constexpr int32_t PARALLEL_MAX_BATCH = 2000; +constexpr int32_t WORKSPACE_BLOCK_SIZE_DB = 65536; // 128 * 256 * 2 + +enum class TilingKeyType { + TILING_HALF_DATA = 0, + TILING_BF16_DATA = 1, + TILING_INT8_DATA = 2, + TILING_INT8_CUBE_QUANT = 4, + TILING_INT8_VEC_QUANT = 8, + TILING_INT8_VEC_QUANTBF16 = 9, + TILING_QUANT_FP16OUT = 12, + TILING_QUANT_BF16OUT = 14 +}; + +enum class CalcType { CALC_TYPE_DEFAULT = 0, CALC_TYPE_MIX = 1, CALC_TYPE_PREFILL = 2 }; + +enum class DataShapeType { BSND = 0, BNSD = 1 }; + +enum class CompressType { COMPRESS_TYPE_UNDEFINED = 0, COMPRESS_TYPE_KVHEAD = 1 }; + +enum class PagedAttnVariant { DEFAULT = 0, MULTI_LATENT = 1 }; + +using PagedAttentionInfo = struct PagedAttentionTilingParams { + int32_t numTokens = 0; + int32_t numHeads = 0; + int32_t embeddingSize = 0; + int32_t embeddingSizeV = 0; + int32_t numBlocks = 0; + int32_t blockSize = 0; + int32_t maxNumBlocksPerQuery = 0; + float tor = 0; + int32_t kvHeads = 0; + int32_t maxPromptLen = 0; + int32_t batchStride = 0; + int32_t headStride = 0; + TilingKeyType type = TilingKeyType::TILING_HALF_DATA; + int32_t batch = 0; + int32_t isMaskSquare = 0; + int32_t *batchRunStatus{nullptr}; + int32_t *kvSeqLen{nullptr}; + int32_t modCoef{-1}; + int32_t divCoef{1}; + int32_t *qSeqLen{nullptr}; + int32_t qHeadOriginal = 0; + int32_t compressHead = 0; + int32_t tBlockAlign = 16; // L1 tile alignment: 16 for fp16, 32 for int8 + int32_t dataShapeType = 0; +}; + +using AddrOffsets = struct AddressOffsetInfo { + uint64_t addrQSeqOffset = 0; + uint64_t addrOSeqOffset = 0; + uint64_t addrOFdSeqOffset = 0; + uint64_t addrLSeqOffset = 0; +}; + +} // namespace AtbOps + +#endif +// PAGED_HATTENTION_H From 19186c22d5144717d137a3667fd7fc45ed0861cb Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Fri, 26 Jun 2026 15:24:25 +0000 Subject: [PATCH 02/11] Optimize paged attention split-KV sync path --- .../paged_attention_highperf/pa_entry.hpp | 33 +- .../pa_kernel_impl.hpp | 286 ++---------------- 2 files changed, 33 insertions(+), 286 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp index 8d0c29cc..f6e4f40f 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp @@ -67,35 +67,24 @@ static AICORE __attribute__((always_inline)) void paged_attention_mask_body( #ifdef __DAV_C220_CUBE__ const int64_t workerIdx = static_cast(ptoBlockIdx); const int64_t workerNum = static_cast(ptoBlockNum); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (SupportsPtoPagedAttentionHighPerf(tilingParaGm)) { - RunPtoPagedAttentionCubePipeline(qGm, kGm, vGm, blockTablesGm, sGm, pGm, oTmpGm, tilingParaGm, workerIdx, workerNum); - } else if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { RunPtoPagedAttentionCubePipelineSplitKV(qGm, kGm, vGm, blockTablesGm, sGm, pGm, oTmpGm, tilingParaGm, workerIdx, workerNum); - } else if (ctx.kvSplitCoreNum > 1) { - pipe_barrier(PIPE_ALL); } else { pipe_barrier(PIPE_ALL); } #elif defined(__DAV_C220_VEC__) - if (SupportsPtoPagedAttentionHighPerf(tilingParaGm)) { - const int64_t workerIdx = static_cast(ptoBlockIdx); - const int64_t workerNum = static_cast(ptoBlockNum); - RunPtoPagedAttentionVecPipeline(oGm, sGm, pGm, oTmpGm, tilingParaGm, workerIdx, workerNum, ptoSubBlockId); + const int64_t workerIdx = static_cast(ptoBlockIdx) * 2 + static_cast(ptoSubBlockId); + const int64_t workerNum = static_cast(ptoBlockNum) * 2; + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + RunPtoPagedAttentionVecPipelineSplitKV(oGm, sGm, pGm, oTmpGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); + } else if (ctx.kvSplitCoreNum > 1) { + RunPtoPagedAttentionDecodeSplitKV(qGm, kGm, vGm, blockTablesGm, oGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); } else { - const int64_t workerIdx = static_cast(ptoBlockIdx) * 2 + static_cast(ptoSubBlockId); - const int64_t workerNum = static_cast(ptoBlockNum) * 2; - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { - RunPtoPagedAttentionVecPipelineSplitKV(oGm, sGm, pGm, oTmpGm, oCoreTmpGm, lGm, tilingParaGm, - static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); - } else if (ctx.kvSplitCoreNum > 1) { - RunPtoPagedAttentionDecodeSplitKV(qGm, kGm, vGm, blockTablesGm, oGm, oCoreTmpGm, lGm, tilingParaGm, - static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); - } else { - RunPtoPagedAttentionDecode(qGm, kGm, vGm, blockTablesGm, oGm, tilingParaGm, workerIdx, workerNum); - } + RunPtoPagedAttentionDecode(qGm, kGm, vGm, blockTablesGm, oGm, tilingParaGm, workerIdx, workerNum); } #else pipe_barrier(PIPE_ALL); diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index 9c401817..c18f0df4 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -225,11 +225,7 @@ AICORE inline float ComputeScoreByBlock( constexpr int32_t PA_TILE_TOKENS = 128; -constexpr uint8_t PA_QK_FIFO_FLAG = 0; -constexpr uint8_t PA_P_FIFO_FLAG = 2; -constexpr uint8_t PA_PV_FIFO_FLAG = 4; -constexpr uint32_t PA_FIFO_DEPTH = 2; -constexpr uint8_t PTO_PA_REDUCE_READY_DECODER = static_cast(SYNC_AIV_ONLY_ALL); +constexpr uint8_t PTO_PA_REDUCE_READY_DECODER = 14; constexpr uint8_t PTO_PA_RAW_QK_READY = 0; constexpr uint8_t PTO_PA_RAW_QK_FREE = 2; constexpr uint8_t PTO_PA_RAW_P_READY = 4; @@ -294,21 +290,6 @@ AICORE inline void PtoPaInitCoreState() set_mask_norm(); } -AICORE inline bool SupportsPtoPagedAttentionHighPerf(__gm__ uint8_t *tilingParaGm) -{ - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (ctx.headDim != PA_TILE_TOKENS || ctx.headDimV != PA_TILE_TOKENS || ctx.blockSize != PA_TILE_TOKENS) { - return false; - } - if (ctx.batch <= 0 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || ctx.numHeads % ctx.kvHeads != 0) { - return false; - } - if (ctx.kvSplitCoreNum > 1) { - return false; - } - return true; -} - AICORE inline bool SupportsPtoPagedAttentionRawSplitKV(__gm__ uint8_t *tilingParaGm) { const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); @@ -363,10 +344,6 @@ AICORE inline void PtoPaSetFloatVectorMask(uint32_t len) set_vector_mask(0, mask); } -AICORE inline void PtoPaStageSync() -{ - SYNCALL(); -} #ifdef __DAV_C220_VEC__ template @@ -445,239 +422,6 @@ AICORE inline void PtoPaLoadCbufToCbTranspose128Raw(DstTileData &dst, SrcTileDat } #endif -#ifdef __DAV_C220_CUBE__ -AICORE inline void RunPtoPagedAttentionCubePipeline(__gm__ uint8_t *qGm, __gm__ uint8_t *kGm, - __gm__ uint8_t *vGm, __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, - __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum) -{ - constexpr int32_t kHeadDim = PA_TILE_TOKENS; - constexpr int32_t kTileTokens = PA_TILE_TOKENS; - constexpr int32_t kM = 16; - constexpr int32_t kN = kTileTokens; - constexpr int32_t kK = 256; - - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (!SupportsPtoPagedAttentionHighPerf(tilingParaGm) || workerIdx < 0 || workerNum <= 0) { - pipe_barrier(PIPE_ALL); - return; - } - - const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; - const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : - (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - - using QKPipe = TPipe; - using PPipe = TPipe; - using PVPipe = TPipe; - using QGlobal = GlobalTensor, Stride>; - using KGlobal = GlobalTensor, - Stride<1, 1, 1, 1, 8 * kHeadDim>, Layout::DN>; - using VGlobal = GlobalTensor, - Stride>; - using PGlobal = GlobalTensor, - Stride>; - using ScoreGlobal = GlobalTensor, - Stride>; - using OTmpGlobal = GlobalTensor, - Stride>; - - using QMatTile = Tile; - using KMatTile = Tile; - using PMatTile = Tile; - using VMatTile = Tile; - using LeftQTile = TileLeft; - using LeftPTile = TileLeft; - using RightTile = TileRight; - using AccTile = TileAcc; - - QMatTile qMatTile; - KMatTile kMatTile; - PMatTile pMatTile; - VMatTile vMatTile; - LeftQTile qLeftTile; - LeftPTile pLeftTile; - RightTile rightTile; - AccTile accTile; - TASSIGN(qMatTile, 0x00000); - TASSIGN(kMatTile, 0x20000); - TASSIGN(pMatTile, 0x00000); - TASSIGN(vMatTile, 0x20000); - TASSIGN(qLeftTile, 0x00000); - TASSIGN(pLeftTile, 0x00000); - TASSIGN(rightTile, 0x00000); - TASSIGN(accTile, 0x00000); - - __gm__ uint8_t *scoreBase = sGm + workerIdx * QKPipe::RingFiFo::SLOT_SIZE * QKPipe::RingFiFo::SLOT_NUM; - __gm__ uint8_t *probBase = pGm + workerIdx * PPipe::RingFiFo::SLOT_SIZE * PPipe::RingFiFo::SLOT_NUM; - __gm__ uint8_t *outBase = oTmpGm + workerIdx * PVPipe::RingFiFo::SLOT_SIZE * PVPipe::RingFiFo::SLOT_NUM; - QKPipe qkPipe(reinterpret_cast<__gm__ void *>(scoreBase), 0, 0); - PPipe pPipe(reinterpret_cast<__gm__ void *>(probBase), 0, 0); - PVPipe pvPipe(reinterpret_cast<__gm__ void *>(outBase), 0, 0); - - for (int64_t row = workerIdx; row < totalRows; row += workerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); - const int32_t kvHead = head / headsPerKv; - const int32_t tileCount = (kvSeqLen + kTileTokens - 1) / kTileTokens; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); - TLOAD(qMatTile, qGlobal); - pipe_barrier(PIPE_ALL); - TEXTRACT(qLeftTile, qMatTile, 0, 0); - pipe_barrier(PIPE_ALL); - - for (int32_t tile = 0; tile < tileCount; ++tile) { - const int32_t blockId = LoadBlockTable(blockTablesGm, static_cast(batchIndex) * maxBlocksPerQuery + tile); - const int64_t kvBase = (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; - - - KGlobal kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvBase); - TLOAD(kMatTile, kGlobal); - pipe_barrier(PIPE_ALL); - TMOV(rightTile, kMatTile); - pipe_barrier(PIPE_ALL); - TGEMV(accTile, qLeftTile, rightTile); - pipe_barrier(PIPE_ALL); - DdrBarrierBeforePtoFfts(); - TPUSH(qkPipe, accTile); - - TPOP(pPipe, pMatTile); - VGlobal vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvBase); - - TLOAD(vMatTile, vGlobal); - pipe_barrier(PIPE_ALL); - TEXTRACT(pLeftTile, pMatTile, 0, 0); - TMOV(rightTile, vMatTile); - pipe_barrier(PIPE_ALL); - TGEMV(accTile, pLeftTile, rightTile); - pipe_barrier(PIPE_ALL); - DdrBarrierBeforePtoFfts(); - TPUSH(pvPipe, accTile); - } - } - pipe_barrier(PIPE_ALL); -} -#endif - -#ifdef __DAV_C220_VEC__ -AICORE inline void RunPtoPagedAttentionVecPipeline(__gm__ uint8_t *oGm, __gm__ uint8_t *sGm, - __gm__ uint8_t *pGm, __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, - uint32_t subBlockId) -{ - constexpr int32_t kHeadDim = PA_TILE_TOKENS; - constexpr int32_t kTileTokens = PA_TILE_TOKENS; - if (!SupportsPtoPagedAttentionHighPerf(tilingParaGm) || workerIdx < 0 || workerNum <= 0) { - pipe_barrier(PIPE_ALL); - return; - } - - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; - const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - - using QKPipe = TPipe; - using PPipe = TPipe; - using PVPipe = TPipe; - using VecFloat128 = Tile; - using VecHalf128 = Tile; - using VecHalf256 = Tile; - using VecFloat8 = Tile; - using GlobalFloat128 = GlobalTensor, Stride>; - using GlobalHalf128 = GlobalTensor, Stride>; - - VecFloat128 weightedTile; - VecFloat128 scoreTile; - VecFloat128 pvTile; - VecHalf256 probTile; - VecHalf128 outHalfTile; - VecFloat8 scalarMathTile; - TASSIGN(weightedTile, 0x0000); - TASSIGN(scoreTile, 0x0800); - TASSIGN(pvTile, 0x1000); - TASSIGN(probTile, 0x1800); - TASSIGN(outHalfTile, 0x2000); - TASSIGN(scalarMathTile, 0x2800); - - __gm__ uint8_t *scoreBase = sGm + workerIdx * QKPipe::RingFiFo::SLOT_SIZE * QKPipe::RingFiFo::SLOT_NUM; - __gm__ uint8_t *probBase = pGm + workerIdx * PPipe::RingFiFo::SLOT_SIZE * PPipe::RingFiFo::SLOT_NUM; - __gm__ uint8_t *outTmpBase = oTmpGm + workerIdx * PVPipe::RingFiFo::SLOT_SIZE * PVPipe::RingFiFo::SLOT_NUM; - QKPipe qkPipe(reinterpret_cast<__gm__ void *>(scoreBase), 0, 0); - PPipe pPipe(reinterpret_cast<__gm__ void *>(probBase), 0, 0); - PVPipe pvPipe(reinterpret_cast<__gm__ void *>(outTmpBase), 0, 0); - - for (int64_t row = workerIdx; row < totalRows; row += workerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); - const int32_t tileCount = (kvSeqLen + kTileTokens - 1) / kTileTokens; - const bool doWork = subBlockId == 0; - float maxScore = -3.4028234663852886e38f; - float sumExp = 0.0f; - TEXPANDS(weightedTile, 0.0f); - pipe_barrier(PIPE_ALL); - - for (int32_t tile = 0; tile < tileCount; ++tile) { - const int32_t validTokens = ((tile + 1) * kTileTokens <= kvSeqLen) ? kTileTokens : (kvSeqLen - tile * kTileTokens); - TPOP(qkPipe, scoreTile); - float tileMax = -3.4028234663852886e38f; - for (int32_t pos = 0; pos < validTokens; ++pos) { - const float score = scoreTile.data()[pos] * ctx.scale; - tileMax = score > tileMax ? score : tileMax; - } - const float newMax = tileMax > maxScore ? tileMax : maxScore; - const float oldScale = (tile == 0) ? 0.0f : PtoExpScalar(scalarMathTile, maxScore - newMax); - float tileSum = 0.0f; - TEXPANDS(probTile, static_cast(0.0)); - for (int32_t pos = 0; pos < kTileTokens; ++pos) { - float prob = 0.0f; - if (pos < validTokens) { - prob = PtoExpScalar(scalarMathTile, scoreTile.data()[pos] * ctx.scale - newMax); - tileSum += prob; - } - probTile.data()[pos] = static_cast(prob); - } - sumExp = sumExp * oldScale + tileSum; - TMULS(weightedTile, weightedTile, oldScale); - pipe_barrier(PIPE_ALL); - maxScore = newMax; - DdrBarrierBeforePtoFfts(); - TPUSH(pPipe, probTile); - - TPOP(pvPipe, pvTile); - if (doWork) { - pipe_barrier(PIPE_ALL); - TAXPY(weightedTile, pvTile, 1.0f); - pipe_barrier(PIPE_ALL); - } - } - - if (!doWork) { - continue; - } - const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; - const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - GlobalHalf128 outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); - TMULS(weightedTile, weightedTile, invSum); - pipe_barrier(PIPE_ALL); - TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); - pipe_barrier(PIPE_ALL); - TSTORE(outGlobal, outHalfTile); - pipe_barrier(PIPE_ALL); - } - pipe_barrier(PIPE_ALL); -} -#endif - - AICORE inline uint64_t LoadTilingOffset64(__gm__ uint8_t *tiling, int32_t base, int32_t highIdx, int32_t lowIdx) { const uint32_t high = static_cast(LoadTilingI32(tiling, base + highIdx)); @@ -739,9 +483,12 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV(__gm__ uint8_t *qGm, LeftPTile pLeftTile; RightTile rightTile; AccTile accTile; - TASSIGN(qMatTile, 0x00000); + constexpr uint32_t kQCacheBase = 0x00000; + constexpr uint32_t kQGroupBytes = kM * kHeadDim * sizeof(half); + constexpr uint32_t kPmatBase = 0x10000; + TASSIGN(qMatTile, kQCacheBase); TASSIGN(kMatTile, 0x20000); - TASSIGN(pMatTile, 0x00000); + TASSIGN(pMatTile, kPmatBase); TASSIGN(vMatTile, 0x20000); TASSIGN(qLeftTile, 0x00000); TASSIGN(pLeftTile, 0x00000); @@ -807,6 +554,21 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV(__gm__ uint8_t *qGm, } } + if (validProcess) { + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + if (groupHeadBase >= curHeadNum) { + break; + } + const int32_t firstHead = startHead + groupHeadBase; + const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + firstHead) * ctx.headDim; + QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * kQGroupBytes); + TLOAD(qMatTile, qGlobal); + pipe_barrier(PIPE_ALL); + } + } + for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; tilePairBase += 2) { const bool hasStage2 = (tilePairBase + 1) < stageTileCount; const bool activeTile0 = validProcess && tilePairBase < tileCount; @@ -826,11 +588,7 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV(__gm__ uint8_t *qGm, if (!activeGroup) { continue; } - const int32_t firstHead = startHead + groupHeadBase; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + firstHead) * ctx.headDim; - QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); - TLOAD(qMatTile, qGlobal); - pipe_barrier(PIPE_ALL); + TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * kQGroupBytes); const int32_t blockId = LoadBlockTable(blockTablesGm, static_cast(batchIndex) * maxBlocksPerQuery + startTile + tile); for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { From 4b3c4d708934acdb6ea15b73ec27abf90e96ec9f Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 12:13:31 +0200 Subject: [PATCH 03/11] solved numerical accuracy issues --- .../paged_attention_highperf/jit_util_pa.py | 2 + .../paged_attention_highperf/pa_benchmark.py | 11 ++++- .../pa_compile_and_run.py | 46 ++++++++++++------- .../pa_kernel_impl.hpp | 3 ++ 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py index 759f0041..df60a9c6 100644 --- a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py +++ b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py @@ -94,6 +94,8 @@ def _alloc(device, workspace_sizes, tiling): workspace["key"] = key for name, size in workspace_sizes.items(): workspace[name] = torch.empty((int(size),), device=device, dtype=torch.uint8) + if name in {"s", "p", "o_tmp", "o_core_tmp", "l"}: + workspace[name].zero_() workspace["null"] = torch.zeros((1,), device=device, dtype=torch.uint8) workspace["tiling"] = tiling.to(device=device, dtype=torch.int32) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py index 4b638406..70eeb914 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py @@ -3,6 +3,7 @@ import argparse import csv import gc +import time import torch @@ -11,6 +12,7 @@ NUM_ITERATIONS = 50 WARMUP = 10 +RUN_DELAY_SECONDS = 2.0 BATCHES = [1, 2, 4, 8, 32, 64] SEQ_LENS = [128, 512, 4096, 8192, 16384, 32768, 65536, 131072] DEFAULT_SHAPES = [PaShape(batch=batch, seq_len=seq_len) for batch in BATCHES for seq_len in SEQ_LENS] @@ -102,6 +104,8 @@ def main(): parser.add_argument("--csv", default="pa_highperf_jit_bench.csv") parser.add_argument("--iters", type=int, default=NUM_ITERATIONS) parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--run-delay", type=float, default=RUN_DELAY_SECONDS, + help="Seconds to wait between benchmark shapes; set to 0 to disable.") parser.add_argument("--device", default="npu:0") parser.add_argument("--shape", action="append", help="Shape override, e.g. b=2,s=8192 or batch=4,seq=512") parser.add_argument("--check", action="store_true", help="Run correctness check before timing each shape.") @@ -112,7 +116,7 @@ def main(): shapes = [parse_shape(item) for item in args.shape] if args.shape else DEFAULT_SHAPES pa = jit_compile_paged_attention(verbose=False) rows = [] - for shape in shapes: + for idx, shape in enumerate(shapes): row = run_shape(pa, shape, args.device, args.iters, args.warmup, args.check and not args.no_check) rows.append(row) print( @@ -120,6 +124,11 @@ def main(): f"{row['jit_tflops']} TFLOPS logical, {row['jit_tflops_normalized']} TFLOPS normalized, " f"{row['jit_bandwidth_tb_s']} TB/s, block_dim={row['block_dim']}" ) + torch.npu.synchronize() + gc.collect() + torch.npu.empty_cache() + if args.run_delay > 0 and idx + 1 < len(shapes): + time.sleep(args.run_delay) fieldnames = [ "shape", "batch", "seq_len", "block_dim", "jit_time_us", "jit_tflops", diff --git a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py index 306df052..31198ed5 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py @@ -50,26 +50,38 @@ def pack_kv_to_paged(k_dense, v_dense, shape: PaShape): def make_inputs(shape: PaShape = PaShape(batch=1), device="npu:0", deterministic=True): q = torch.zeros((shape.batch, shape.num_heads, shape.head_dim), device=device, dtype=shape.dtype) - k_dense = torch.zeros( - (shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim), device=device, dtype=shape.dtype + num_blocks = shape.seq_len // shape.block_size + k_page = torch.empty( + (shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim), + device=device, + dtype=shape.dtype, ) - if deterministic: - token = torch.arange(shape.seq_len, device=device, dtype=torch.float32).view(1, shape.seq_len, 1, 1) - kv_head = torch.arange(shape.num_kv_heads, device=device, dtype=torch.float32).view(1, 1, shape.num_kv_heads, 1) - dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, 1, shape.head_dim) - batch = torch.arange(shape.batch, device=device, dtype=torch.float32).view(shape.batch, 1, 1, 1) - q_head = torch.arange(shape.num_heads, device=device, dtype=torch.float32).view(1, shape.num_heads, 1) - q_dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, shape.head_dim) - q_values = (((batch[:, 0, 0, 0].view(shape.batch, 1, 1) * 3 + q_head * 5 + q_dim * 7) - .remainder(251) / 125.0) - 1.0) * 0.02 + v_page = torch.empty_like(k_page) + block_table = ( + torch.arange(num_blocks, device=device, dtype=torch.int32).unsqueeze(0).expand(shape.batch, -1).clone() + + torch.arange(shape.batch, device=device, dtype=torch.int32).unsqueeze(1) * num_blocks + ) + if not deterministic: + k_page.zero_() + v_page.zero_() + return q, k_page, v_page, block_table + + token = torch.arange(shape.seq_len, device=device, dtype=torch.float32).view(num_blocks, shape.block_size, 1, 1) + kv_head = torch.arange(shape.num_kv_heads, device=device, dtype=torch.float32).view(1, 1, shape.num_kv_heads, 1) + dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, 1, shape.head_dim) + q_head = torch.arange(shape.num_heads, device=device, dtype=torch.float32).view(1, shape.num_heads, 1) + q_dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, shape.head_dim) + for batch_idx in range(shape.batch): + batch = torch.tensor(float(batch_idx), device=device, dtype=torch.float32) + q_values = (((batch * 3 + q_head * 5 + q_dim * 7).remainder(251) / 125.0) - 1.0) * 0.02 + q[batch_idx].copy_(q_values[0].to(shape.dtype)) + block_offset = batch_idx * num_blocks k_values = (((batch * 11 + token * 13 + kv_head * 17 + dim * 19).remainder(257) / 128.0) - 1.0) * 0.02 + k_page[block_offset:block_offset + num_blocks].copy_(k_values.to(shape.dtype)) + del k_values v_values = (((batch * 13 + token * 17 + kv_head * 31 + dim * 7).remainder(257) / 128.0) - 1.0) * 0.25 - q.copy_(q_values.to(shape.dtype)) - k_dense.copy_(k_values.reshape(shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim).to(shape.dtype)) - v_dense = v_values.reshape(shape.batch, shape.seq_len, shape.num_kv_heads * shape.head_dim).to(shape.dtype) - else: - v_dense = torch.zeros_like(k_dense) - k_page, v_page, block_table = pack_kv_to_paged(k_dense, v_dense, shape) + v_page[block_offset:block_offset + num_blocks].copy_(v_values.to(shape.dtype)) + del v_values return q, k_page, v_page, block_table diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index c18f0df4..a5a39942 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -288,6 +288,9 @@ AICORE inline void PtoPaInitCoreState() #endif set_atomic_none(); set_mask_norm(); +#if defined(__DAV_C220_VEC__) + set_vector_mask(static_cast(-1), static_cast(-1)); +#endif } AICORE inline bool SupportsPtoPagedAttentionRawSplitKV(__gm__ uint8_t *tilingParaGm) From 204a8adaab440d7de32fc4b0611024b881eeb788 Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 12:22:37 +0200 Subject: [PATCH 04/11] linting --- .../paged_attention_highperf/jit_util_pa.py | 31 ++++- .../paged_attention_highperf/pa_benchmark.py | 83 +++++++++-- .../pa_compile_and_run.py | 131 ++++++++++++++---- .../paged_attention_highperf/pa_tiling.py | 60 ++++++-- 4 files changed, 245 insertions(+), 60 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py index df60a9c6..b382147e 100644 --- a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py +++ b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py @@ -19,7 +19,9 @@ import torch -ASCEND_TOOLKIT_HOME = os.environ.get("ASCEND_TOOLKIT_HOME", "/usr/local/Ascend/cann-9.0.0") +ASCEND_TOOLKIT_HOME = os.environ.get( + "ASCEND_TOOLKIT_HOME", "/usr/local/Ascend/cann-9.0.0" +) PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", str(Path(__file__).resolve().parents[3])) @@ -31,7 +33,9 @@ def _npu_arch_flag() -> str: return os.environ.get("NPU_ARCH", "dav-2201").strip() -def compile_paged_attention(kernel_cpp: str, verbose: bool = False, timeout: int = 300) -> str: +def compile_paged_attention( + kernel_cpp: str, verbose: bool = False, timeout: int = 300 +) -> str: lib_path = os.path.join(os.path.dirname(kernel_cpp), "pa_highperf_jit.so") example_dir = os.path.dirname(kernel_cpp) flags = [ @@ -86,20 +90,33 @@ def load_paged_attention_lib(lib_path: str, check_type: bool = True): def _alloc(device, workspace_sizes, tiling): tiling_cpu = tuple(int(x) for x in tiling.detach().cpu().tolist()) - sizes_key = tuple(sorted((name, int(size)) for name, size in workspace_sizes.items())) + sizes_key = tuple( + sorted((name, int(size)) for name, size in workspace_sizes.items()) + ) key = (str(device), sizes_key, tiling_cpu) if workspace.get("key") == key: return workspace.clear() workspace["key"] = key for name, size in workspace_sizes.items(): - workspace[name] = torch.empty((int(size),), device=device, dtype=torch.uint8) + workspace[name] = torch.empty( + (int(size),), device=device, dtype=torch.uint8 + ) if name in {"s", "p", "o_tmp", "o_core_tmp", "l"}: workspace[name].zero_() workspace["null"] = torch.zeros((1,), device=device, dtype=torch.uint8) workspace["tiling"] = tiling.to(device=device, dtype=torch.int32) - def paged_attention(q, k, v, block_table, workspace_sizes, tiling, stream_ptr=default_stream_ptr, block_dim: int = 24): + def paged_attention( + q, + k, + v, + block_table, + workspace_sizes, + tiling, + stream_ptr=default_stream_ptr, + block_dim: int = 24, + ): _alloc(q.device, workspace_sizes, tiling) out = torch.empty_like(q) lib.call_kernel( @@ -126,7 +143,9 @@ def paged_attention(q, k, v, block_table, workspace_sizes, tiling, stream_ptr=de return paged_attention -def jit_compile_paged_attention(verbose: bool = False, clean_up: bool = True, kernel_cpp: str = "pa_kernel.cpp"): +def jit_compile_paged_attention( + verbose: bool = False, clean_up: bool = True, kernel_cpp: str = "pa_kernel.cpp" +): kernel_path = str((Path(__file__).resolve().parent / kernel_cpp).resolve()) lib_path = compile_paged_attention(kernel_path, verbose=verbose) fn = load_paged_attention_lib(lib_path) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py index 70eeb914..c0a0401f 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py @@ -8,29 +8,46 @@ import torch from jit_util_pa import jit_compile_paged_attention -from pa_compile_and_run import PaShape, golden_attention, make_inputs, make_launch_config +from pa_compile_and_run import ( + PaShape, + golden_attention, + make_inputs, + make_launch_config, +) NUM_ITERATIONS = 50 WARMUP = 10 RUN_DELAY_SECONDS = 2.0 BATCHES = [1, 2, 4, 8, 32, 64] SEQ_LENS = [128, 512, 4096, 8192, 16384, 32768, 65536, 131072] -DEFAULT_SHAPES = [PaShape(batch=batch, seq_len=seq_len) for batch in BATCHES for seq_len in SEQ_LENS] +DEFAULT_SHAPES = [ + PaShape(batch=batch, seq_len=seq_len) for batch in BATCHES for seq_len in SEQ_LENS +] def paged_attention_flops(shape: PaShape): qk_and_pv = 4 * shape.batch * shape.num_heads * shape.seq_len * shape.head_dim scale = shape.batch * shape.num_heads * shape.seq_len rows = shape.batch * shape.num_heads - softmax = rows * ((shape.seq_len - 1) + shape.seq_len + shape.seq_len + (shape.seq_len - 1) + shape.seq_len) + softmax = rows * ( + (shape.seq_len - 1) + + shape.seq_len + + shape.seq_len + + (shape.seq_len - 1) + + shape.seq_len + ) return qk_and_pv + scale + softmax def tensor_bytes(shape: PaShape): dtype_bytes = 2 q_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes - k_bytes = shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes - v_bytes = shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes + k_bytes = ( + shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes + ) + v_bytes = ( + shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes + ) out_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes blocks_per_batch = (shape.seq_len + shape.block_size - 1) // shape.block_size block_table_bytes = shape.batch * blocks_per_batch * 4 @@ -81,8 +98,17 @@ def run_shape(pa, shape, device, iters, warmup, check): if check: out = pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim) torch.npu.synchronize() - torch.testing.assert_close(out.float(), golden_attention(q, k, v, block_table, shape), rtol=5e-3, atol=2e-2) - ms = time_npu(lambda: pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim), iters, warmup) + torch.testing.assert_close( + out.float(), + golden_attention(q, k, v, block_table, shape), + rtol=5e-3, + atol=2e-2, + ) + ms = time_npu( + lambda: pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim), + iters, + warmup, + ) flops = paged_attention_flops(shape) bytes_total = tensor_bytes(shape) perf = tflops(flops, ms) @@ -104,20 +130,41 @@ def main(): parser.add_argument("--csv", default="pa_highperf_jit_bench.csv") parser.add_argument("--iters", type=int, default=NUM_ITERATIONS) parser.add_argument("--warmup", type=int, default=WARMUP) - parser.add_argument("--run-delay", type=float, default=RUN_DELAY_SECONDS, - help="Seconds to wait between benchmark shapes; set to 0 to disable.") + parser.add_argument( + "--run-delay", + type=float, + default=RUN_DELAY_SECONDS, + help="Seconds to wait between benchmark shapes; set to 0 to disable.", + ) parser.add_argument("--device", default="npu:0") - parser.add_argument("--shape", action="append", help="Shape override, e.g. b=2,s=8192 or batch=4,seq=512") - parser.add_argument("--check", action="store_true", help="Run correctness check before timing each shape.") + parser.add_argument( + "--shape", + action="append", + help="Shape override, e.g. b=2,s=8192 or batch=4,seq=512", + ) + parser.add_argument( + "--check", + action="store_true", + help="Run correctness check before timing each shape.", + ) parser.add_argument("--no-check", action="store_true", help=argparse.SUPPRESS) args = parser.parse_args() torch.npu.set_device(args.device) - shapes = [parse_shape(item) for item in args.shape] if args.shape else DEFAULT_SHAPES + shapes = ( + [parse_shape(item) for item in args.shape] if args.shape else DEFAULT_SHAPES + ) pa = jit_compile_paged_attention(verbose=False) rows = [] for idx, shape in enumerate(shapes): - row = run_shape(pa, shape, args.device, args.iters, args.warmup, args.check and not args.no_check) + row = run_shape( + pa, + shape, + args.device, + args.iters, + args.warmup, + args.check and not args.no_check, + ) rows.append(row) print( f"paged_attention_highperf_jit {row['shape']}: {row['jit_time_us']} us/iter, " @@ -131,8 +178,14 @@ def main(): time.sleep(args.run_delay) fieldnames = [ - "shape", "batch", "seq_len", "block_dim", "jit_time_us", "jit_tflops", - "jit_tflops_normalized", "jit_bandwidth_tb_s" + "shape", + "batch", + "seq_len", + "block_dim", + "jit_time_us", + "jit_tflops", + "jit_tflops_normalized", + "jit_bandwidth_tb_s", ] with open(args.csv, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py index 31198ed5..6f5ca22b 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py @@ -31,56 +31,113 @@ def pack_kv_to_paged(k_dense, v_dense, shape: PaShape): num_blocks = shape.seq_len // shape.block_size k_page = ( k_dense.view(shape.batch, shape.seq_len, shape.num_kv_heads, shape.head_dim) - .view(shape.batch, num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) - .reshape(shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .view( + shape.batch, + num_blocks, + shape.block_size, + shape.num_kv_heads, + shape.head_dim, + ) + .reshape( + shape.batch * num_blocks, + shape.block_size, + shape.num_kv_heads, + shape.head_dim, + ) .contiguous() ) v_page = ( v_dense.view(shape.batch, shape.seq_len, shape.num_kv_heads, shape.head_dim) - .view(shape.batch, num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) - .reshape(shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim) + .view( + shape.batch, + num_blocks, + shape.block_size, + shape.num_kv_heads, + shape.head_dim, + ) + .reshape( + shape.batch * num_blocks, + shape.block_size, + shape.num_kv_heads, + shape.head_dim, + ) .contiguous() ) block_table = ( - torch.arange(num_blocks, device=k_dense.device, dtype=torch.int32).unsqueeze(0).expand(shape.batch, -1).clone() - + torch.arange(shape.batch, device=k_dense.device, dtype=torch.int32).unsqueeze(1) * num_blocks + torch.arange(num_blocks, device=k_dense.device, dtype=torch.int32) + .unsqueeze(0) + .expand(shape.batch, -1) + .clone() + + torch.arange(shape.batch, device=k_dense.device, dtype=torch.int32).unsqueeze( + 1 + ) + * num_blocks ) return k_page, v_page, block_table def make_inputs(shape: PaShape = PaShape(batch=1), device="npu:0", deterministic=True): - q = torch.zeros((shape.batch, shape.num_heads, shape.head_dim), device=device, dtype=shape.dtype) + q = torch.zeros( + (shape.batch, shape.num_heads, shape.head_dim), device=device, dtype=shape.dtype + ) num_blocks = shape.seq_len // shape.block_size k_page = torch.empty( - (shape.batch * num_blocks, shape.block_size, shape.num_kv_heads, shape.head_dim), + ( + shape.batch * num_blocks, + shape.block_size, + shape.num_kv_heads, + shape.head_dim, + ), device=device, dtype=shape.dtype, ) v_page = torch.empty_like(k_page) block_table = ( - torch.arange(num_blocks, device=device, dtype=torch.int32).unsqueeze(0).expand(shape.batch, -1).clone() - + torch.arange(shape.batch, device=device, dtype=torch.int32).unsqueeze(1) * num_blocks + torch.arange(num_blocks, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(shape.batch, -1) + .clone() + + torch.arange(shape.batch, device=device, dtype=torch.int32).unsqueeze(1) + * num_blocks ) if not deterministic: k_page.zero_() v_page.zero_() return q, k_page, v_page, block_table - token = torch.arange(shape.seq_len, device=device, dtype=torch.float32).view(num_blocks, shape.block_size, 1, 1) - kv_head = torch.arange(shape.num_kv_heads, device=device, dtype=torch.float32).view(1, 1, shape.num_kv_heads, 1) - dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, 1, shape.head_dim) - q_head = torch.arange(shape.num_heads, device=device, dtype=torch.float32).view(1, shape.num_heads, 1) - q_dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view(1, 1, shape.head_dim) + token = torch.arange(shape.seq_len, device=device, dtype=torch.float32).view( + num_blocks, shape.block_size, 1, 1 + ) + kv_head = torch.arange(shape.num_kv_heads, device=device, dtype=torch.float32).view( + 1, 1, shape.num_kv_heads, 1 + ) + dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view( + 1, 1, 1, shape.head_dim + ) + q_head = torch.arange(shape.num_heads, device=device, dtype=torch.float32).view( + 1, shape.num_heads, 1 + ) + q_dim = torch.arange(shape.head_dim, device=device, dtype=torch.float32).view( + 1, 1, shape.head_dim + ) for batch_idx in range(shape.batch): batch = torch.tensor(float(batch_idx), device=device, dtype=torch.float32) - q_values = (((batch * 3 + q_head * 5 + q_dim * 7).remainder(251) / 125.0) - 1.0) * 0.02 + q_values = ( + ((batch * 3 + q_head * 5 + q_dim * 7).remainder(251) / 125.0) - 1.0 + ) * 0.02 q[batch_idx].copy_(q_values[0].to(shape.dtype)) block_offset = batch_idx * num_blocks - k_values = (((batch * 11 + token * 13 + kv_head * 17 + dim * 19).remainder(257) / 128.0) - 1.0) * 0.02 - k_page[block_offset:block_offset + num_blocks].copy_(k_values.to(shape.dtype)) + k_values = ( + ((batch * 11 + token * 13 + kv_head * 17 + dim * 19).remainder(257) / 128.0) + - 1.0 + ) * 0.02 + k_page[block_offset : block_offset + num_blocks].copy_(k_values.to(shape.dtype)) del k_values - v_values = (((batch * 13 + token * 17 + kv_head * 31 + dim * 7).remainder(257) / 128.0) - 1.0) * 0.25 - v_page[block_offset:block_offset + num_blocks].copy_(v_values.to(shape.dtype)) + v_values = ( + ((batch * 13 + token * 17 + kv_head * 31 + dim * 7).remainder(257) / 128.0) + - 1.0 + ) * 0.25 + v_page[block_offset : block_offset + num_blocks].copy_(v_values.to(shape.dtype)) del v_values return q, k_page, v_page, block_table @@ -104,18 +161,32 @@ def make_launch_config(shape: PaShape, device="cpu"): device=device, dtype=shape.dtype, ) - ws = workspace_sizes(shape.batch, shape.num_heads, shape.head_dim, shape.head_dim, shape.block_dim) + ws = workspace_sizes( + shape.batch, shape.num_heads, shape.head_dim, shape.head_dim, shape.block_dim + ) return ws, tiling, effective_block_dim def golden_attention(q, k_page, v_page, block_table, shape: PaShape): heads_per_kv = shape.num_heads // shape.num_kv_heads scale = 1.0 / math.sqrt(float(shape.head_dim)) - out = torch.empty((shape.batch, shape.num_heads, shape.head_dim), device=v_page.device, dtype=torch.float32) + out = torch.empty( + (shape.batch, shape.num_heads, shape.head_dim), + device=v_page.device, + dtype=torch.float32, + ) for batch_idx in range(shape.batch): blocks = block_table[batch_idx] - keys = k_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() - values = v_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() + keys = ( + k_page[blocks.long()] + .reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim) + .float() + ) + values = ( + v_page[blocks.long()] + .reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim) + .float() + ) for head in range(shape.num_heads): kv_head = head // heads_per_kv scores = torch.mv(keys[:, kv_head, :], q[batch_idx, head].float()) * scale @@ -126,10 +197,18 @@ def golden_attention(q, k_page, v_page, block_table, shape: PaShape): def golden_uniform(v_page, block_table, shape: PaShape): heads_per_kv = shape.num_heads // shape.num_kv_heads - out = torch.empty((shape.batch, shape.num_heads, shape.head_dim), device=v_page.device, dtype=torch.float32) + out = torch.empty( + (shape.batch, shape.num_heads, shape.head_dim), + device=v_page.device, + dtype=torch.float32, + ) for batch_idx in range(shape.batch): blocks = block_table[batch_idx] - values = v_page[blocks.long()].reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim).float() + values = ( + v_page[blocks.long()] + .reshape(shape.seq_len, shape.num_kv_heads, shape.head_dim) + .float() + ) kv_avg = values.mean(dim=0) for head in range(shape.num_heads): out[batch_idx, head] = kv_avg[head // heads_per_kv] diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py index 5daf9c13..884f4e2e 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py @@ -108,7 +108,9 @@ def _u32_to_i32(v: int) -> int: return v - 0x100000000 if v & 0x80000000 else v -def _calcu_head_nd(num_heads: int, kv_heads: int, former_head_split: int, tail_head_split: int): +def _calcu_head_nd( + num_heads: int, kv_heads: int, former_head_split: int, tail_head_split: int +): """CalcuHeadNd: compute group move factors.""" kv_real = kv_heads if kv_heads > 0 else num_heads group_num = num_heads // kv_real @@ -116,14 +118,18 @@ def _calcu_head_nd(num_heads: int, kv_heads: int, former_head_split: int, tail_h former_group_move = 1 if former_head_split % group_num == 0: former_group_move = group_num - elif former_head_split < group_num and (kv_real == 1 or group_num % former_head_split == 0): + elif former_head_split < group_num and ( + kv_real == 1 or group_num % former_head_split == 0 + ): former_group_move = former_head_split tail_group_move = 1 if tail_head_split > 0: if tail_head_split % group_num == 0: tail_group_move = group_num - elif tail_head_split < group_num and (kv_real == 1 or group_num % tail_head_split == 0): + elif tail_head_split < group_num and ( + kv_real == 1 or group_num % tail_head_split == 0 + ): tail_group_move = tail_head_split return group_num, former_group_move, tail_group_move @@ -143,7 +149,11 @@ def _split_core_bn_nd( kv_real = kv_heads if kv_heads > 0 else num_heads core_per_batch = _ceil_div(block_dim, decoder_batch) - if block_dim * SPLITKV_RATIO <= decoder_batch <= block_dim and is_quant and kv_real == 1: + if ( + block_dim * SPLITKV_RATIO <= decoder_batch <= block_dim + and is_quant + and kv_real == 1 + ): core_per_batch = 1 head_split = _ceil_div(num_heads, core_per_batch) @@ -183,7 +193,9 @@ def _split_core_bn_nd( kv_split_per_core = _round_up(max_kv_seq_len, block_size) kv_split_core_num = 1 - group_num, former_gm, tail_gm = _calcu_head_nd(num_heads, kv_real, former_head_split, tail_head_split) + group_num, former_gm, tail_gm = _calcu_head_nd( + num_heads, kv_real, former_head_split, tail_head_split + ) return ( eff_block_dim, former_batch, @@ -237,7 +249,9 @@ def _split_core_bns_nd( former_head_split = head_split tail_head_split = 0 - group_num, former_gm, tail_gm = _calcu_head_nd(num_heads, kv_real, former_head_split, tail_head_split) + group_num, former_gm, tail_gm = _calcu_head_nd( + num_heads, kv_real, former_head_split, tail_head_split + ) return ( eff_block_dim, former_batch, @@ -290,7 +304,9 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 """ kv_real = kv_heads if kv_heads > 0 else num_heads max_kv = max(kv_seq_lens) - is_mla = head_dim > MLA_THRESHOLD or head_dim_v > MLA_THRESHOLD or head_dim != head_dim_v + is_mla = ( + head_dim > MLA_THRESHOLD or head_dim_v > MLA_THRESHOLD or head_dim != head_dim_v + ) is_quant = False # fp16/bf16 only indices: list[int] = sorted(range(batch), key=lambda i: kv_seq_lens[i]) @@ -298,7 +314,9 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 decoder_batch = batch is_long_seq = max_kv >= KV_SEQLEN_SLICE_512 * 8 - use_bn = is_mla or (decoder_batch * num_heads >= block_dim * SPLITKV_RATIO and not is_long_seq) + use_bn = is_mla or ( + decoder_batch * num_heads >= block_dim * SPLITKV_RATIO and not is_long_seq + ) if use_bn: (eff_bd, fB, fH, tB, tH, kvSplit, kvCN, gN, fGM, tGM) = _split_core_bn_nd( @@ -336,8 +354,16 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 head_dim_k_split = min(head_dim, MLA_THRESHOLD) head_dim_v_split = min(head_dim_v, MLA_THRESHOLD) - head_dim_v_split_former = min(head_dim_v, MLA_THRESHOLD) if fGM <= 64 else min(head_dim_v, EMBEDDING_LIMIT) - head_dim_v_split_tail = min(head_dim_v, MLA_THRESHOLD) if tGM <= 64 else min(head_dim_v, EMBEDDING_LIMIT) + head_dim_v_split_former = ( + min(head_dim_v, MLA_THRESHOLD) + if fGM <= 64 + else min(head_dim_v, EMBEDDING_LIMIT) + ) + head_dim_v_split_tail = ( + min(head_dim_v, MLA_THRESHOLD) + if tGM <= 64 + else min(head_dim_v, EMBEDDING_LIMIT) + ) if ( block_size <= KV_SEQLEN_SLICE // 2 @@ -345,14 +371,20 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 and block_size * 2 * head_dim_v_split <= BLOCK_LIMIT ): block_size_calc = block_size * 2 - elif block_size >= KV_SEQLEN_SLICE and head_dim == KV_SEQLEN_SLICE_256 and head_dim_v == KV_SEQLEN_SLICE_256: + elif ( + block_size >= KV_SEQLEN_SLICE + and head_dim == KV_SEQLEN_SLICE_256 + and head_dim_v == KV_SEQLEN_SLICE_256 + ): block_size_calc = KV_SEQLEN_SLICE else: block_size_calc = block_size is_split_key = int(kvCN > 1) is_split_block = int( - block_size >= KV_SEQLEN_SLICE and head_dim == KV_SEQLEN_SLICE_256 and head_dim_v == KV_SEQLEN_SLICE_256 + block_size >= KV_SEQLEN_SLICE + and head_dim == KV_SEQLEN_SLICE_256 + and head_dim_v == KV_SEQLEN_SLICE_256 ) type_key = 0 if dtype == torch.float16 else 1 tiling_key = (is_split_block << 7) + (is_split_key << 4) + type_key @@ -412,7 +444,9 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 q_seqlen = 1 q_aligned = _round_up(q_seqlen, BLOCK_SIZE_ALIGN) - m_raw = (PP_BLOCK_BUFFER_SIZE // max(head_dim, block_size) // BLOCK_SIZE_ALIGN) * BLOCK_SIZE_ALIGN + m_raw = ( + PP_BLOCK_BUFFER_SIZE // max(head_dim, block_size) // BLOCK_SIZE_ALIGN + ) * BLOCK_SIZE_ALIGN m_ubd = min(m_raw, q_aligned) m_ubd = max(m_ubd, BLOCK_SIZE_ALIGN) m_idx = min(7, max(0, m_ubd // 16 - 1)) From 3a30572af2aaf4d7eca49293b4f273c8370bf429 Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 12:33:14 +0200 Subject: [PATCH 05/11] make CI happy --- .../paged_attention_highperf/jit_util_pa.py | 1 + .../paged_attention_highperf/pa_benchmark.py | 2 +- .../pa_compile_and_run.py | 2 +- .../pa_kernel_impl.hpp | 2852 +++++++++-------- .../paged_attention_highperf/pa_tiling.py | 6 +- 5 files changed, 1515 insertions(+), 1348 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py index b382147e..ceaaa4b1 100644 --- a/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py +++ b/examples/jit_cpp/paged_attention_highperf/jit_util_pa.py @@ -86,6 +86,7 @@ def load_paged_attention_lib(lib_path: str, check_type: bool = True): lib.call_kernel.restype = None workspace = {} + # pylint: disable-next=protected-access default_stream_ptr = torch.npu.current_stream()._as_parameter_ def _alloc(device, workspace_sizes, tiling): diff --git a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py index c0a0401f..e8649bfa 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_benchmark.py @@ -187,7 +187,7 @@ def main(): "jit_tflops_normalized", "jit_bandwidth_tb_s", ] - with open(args.csv, "w", newline="") as f: + with open(args.csv, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py index 6f5ca22b..f8ac9be4 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_compile_and_run.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import torch -import torch_npu +import torch_npu # noqa: F401 # pylint: disable=unused-import from jit_util_pa import jit_compile_paged_attention from pa_tiling import make_pa_nd_decode_tiling, workspace_sizes diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index a5a39942..b7e00b9a 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -1,11 +1,13 @@ /** Copyright (c) 2026 Huawei Technologies Co., Ltd. -This program is free software, you can redistribute it and/or modify it under the terms and conditions of -CANN Open Software License Agreement Version 2.0 (the "License"). -Please refer to the License for details. You may not use this file except in compliance with the License. -THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -See LICENSE in the root of the software repository for the full text of the License. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN "AS +IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A +PARTICULAR PURPOSE. See LICENSE in the root of the software repository for the +full text of the License. */ #ifndef PTO_PAGED_ATTENTION_HIGHPERF_IMPL_HPP @@ -45,185 +47,167 @@ constexpr int32_t TILING_HEADDIM_V = 29; constexpr int32_t kParaKvSeqLen = 1; constexpr int32_t kParaBatchIndex = 13; -AICORE inline int32_t LoadTilingI32(__gm__ uint8_t *tiling, int32_t index) -{ - return *(reinterpret_cast<__gm__ int32_t *>(tiling) + index); +AICORE inline int32_t LoadTilingI32(__gm__ uint8_t *tiling, int32_t index) { + return *(reinterpret_cast<__gm__ int32_t *>(tiling) + index); } -AICORE inline int32_t LoadBlockTable(__gm__ uint8_t *blockTablesGm, int64_t offset) -{ - return *(reinterpret_cast<__gm__ int32_t *>(blockTablesGm) + offset); +AICORE inline int32_t LoadBlockTable(__gm__ uint8_t *blockTablesGm, + int64_t offset) { + return *(reinterpret_cast<__gm__ int32_t *>(blockTablesGm) + offset); } -AICORE inline float LoadFp16(__gm__ uint8_t *gm, int64_t offset) -{ - __gm__ half *ptr = reinterpret_cast<__gm__ half *>(gm); - return static_cast(ptr[offset]); +AICORE inline float LoadFp16(__gm__ uint8_t *gm, int64_t offset) { + __gm__ half *ptr = reinterpret_cast<__gm__ half *>(gm); + return static_cast(ptr[offset]); } -AICORE inline void StoreOutputFp16(__gm__ uint8_t *oGm, int64_t offset, float value) -{ - __gm__ half *out = reinterpret_cast<__gm__ half *>(oGm); - out[offset] = static_cast(value); +AICORE inline void StoreOutputFp16(__gm__ uint8_t *oGm, int64_t offset, + float value) { + __gm__ half *out = reinterpret_cast<__gm__ half *>(oGm); + out[offset] = static_cast(value); } -AICORE inline float LoadScale(__gm__ uint8_t *tiling) -{ - union { - int32_t i; - float f; - } scale; - scale.i = LoadTilingI32(tiling, 6); - return scale.f; +AICORE inline float LoadScale(__gm__ uint8_t *tiling) { + union { + int32_t i; + float f; + } scale; + scale.i = LoadTilingI32(tiling, 6); + return scale.f; } struct PaTilingContext { - int32_t batch; - int32_t decoderBatch; - int32_t numHeads; - int32_t kvHeads; - int32_t headDim; - int32_t headDimV; - int32_t blockSize; - int32_t maxBlocksPerQuery; - int32_t maxKvSeqLen; - int32_t formerBatch; - int32_t formerHeadSplit; - int32_t tailBatch; - int32_t tailHeadSplit; - int32_t headNumMove; - int32_t groupNum; - int32_t formerGroupMove; - int32_t tailGroupMove; - int32_t kvSplitPerCore; - int32_t kvSplitCoreNum; - int32_t blockSizeCalc; - int32_t headSize; - int32_t paraSize; - float scale; + int32_t batch; + int32_t decoderBatch; + int32_t numHeads; + int32_t kvHeads; + int32_t headDim; + int32_t headDimV; + int32_t blockSize; + int32_t maxBlocksPerQuery; + int32_t maxKvSeqLen; + int32_t formerBatch; + int32_t formerHeadSplit; + int32_t tailBatch; + int32_t tailHeadSplit; + int32_t headNumMove; + int32_t groupNum; + int32_t formerGroupMove; + int32_t tailGroupMove; + int32_t kvSplitPerCore; + int32_t kvSplitCoreNum; + int32_t blockSizeCalc; + int32_t headSize; + int32_t paraSize; + float scale; }; -AICORE inline PaTilingContext LoadPaTilingContext(__gm__ uint8_t *tiling) -{ - PaTilingContext ctx{}; - ctx.batch = LoadTilingI32(tiling, TILING_BATCH); - ctx.decoderBatch = LoadTilingI32(tiling, TILING_DECODER_BS); - ctx.numHeads = LoadTilingI32(tiling, TILING_NUMHEADS); - ctx.kvHeads = LoadTilingI32(tiling, TILING_KVHEADS); - ctx.headDim = LoadTilingI32(tiling, TILING_HEADDIM); - ctx.headDimV = LoadTilingI32(tiling, TILING_HEADDIM_V); - ctx.blockSize = LoadTilingI32(tiling, TILING_BLOCKSIZE); - ctx.maxBlocksPerQuery = LoadTilingI32(tiling, TILING_MAXBLOCKS); - ctx.maxKvSeqLen = LoadTilingI32(tiling, TILING_MAX_KVSEQLEN); - ctx.formerBatch = LoadTilingI32(tiling, TILING_FORMER_BATCH); - ctx.formerHeadSplit = LoadTilingI32(tiling, TILING_FORMER_HEAD); - ctx.tailBatch = LoadTilingI32(tiling, TILING_TAIL_BATCH); - ctx.tailHeadSplit = LoadTilingI32(tiling, TILING_TAIL_HEAD); - ctx.headNumMove = LoadTilingI32(tiling, TILING_HEADNUM_MOVE); - ctx.groupNum = LoadTilingI32(tiling, TILING_GROUPNUM); - ctx.formerGroupMove = LoadTilingI32(tiling, TILING_FORMER_GROUP_MOVE); - ctx.tailGroupMove = LoadTilingI32(tiling, TILING_TAIL_GROUP_MOVE); - ctx.kvSplitPerCore = LoadTilingI32(tiling, TILING_KVSPLIT); - ctx.kvSplitCoreNum = LoadTilingI32(tiling, TILING_KVCORENUM); - ctx.blockSizeCalc = LoadTilingI32(tiling, TILING_BLOCKSIZE_CALC); - ctx.headSize = LoadTilingI32(tiling, TILING_HEADSIZE); - ctx.paraSize = LoadTilingI32(tiling, TILING_PARASIZE); - ctx.scale = LoadScale(tiling); - return ctx; +AICORE inline PaTilingContext LoadPaTilingContext(__gm__ uint8_t *tiling) { + PaTilingContext ctx{}; + ctx.batch = LoadTilingI32(tiling, TILING_BATCH); + ctx.decoderBatch = LoadTilingI32(tiling, TILING_DECODER_BS); + ctx.numHeads = LoadTilingI32(tiling, TILING_NUMHEADS); + ctx.kvHeads = LoadTilingI32(tiling, TILING_KVHEADS); + ctx.headDim = LoadTilingI32(tiling, TILING_HEADDIM); + ctx.headDimV = LoadTilingI32(tiling, TILING_HEADDIM_V); + ctx.blockSize = LoadTilingI32(tiling, TILING_BLOCKSIZE); + ctx.maxBlocksPerQuery = LoadTilingI32(tiling, TILING_MAXBLOCKS); + ctx.maxKvSeqLen = LoadTilingI32(tiling, TILING_MAX_KVSEQLEN); + ctx.formerBatch = LoadTilingI32(tiling, TILING_FORMER_BATCH); + ctx.formerHeadSplit = LoadTilingI32(tiling, TILING_FORMER_HEAD); + ctx.tailBatch = LoadTilingI32(tiling, TILING_TAIL_BATCH); + ctx.tailHeadSplit = LoadTilingI32(tiling, TILING_TAIL_HEAD); + ctx.headNumMove = LoadTilingI32(tiling, TILING_HEADNUM_MOVE); + ctx.groupNum = LoadTilingI32(tiling, TILING_GROUPNUM); + ctx.formerGroupMove = LoadTilingI32(tiling, TILING_FORMER_GROUP_MOVE); + ctx.tailGroupMove = LoadTilingI32(tiling, TILING_TAIL_GROUP_MOVE); + ctx.kvSplitPerCore = LoadTilingI32(tiling, TILING_KVSPLIT); + ctx.kvSplitCoreNum = LoadTilingI32(tiling, TILING_KVCORENUM); + ctx.blockSizeCalc = LoadTilingI32(tiling, TILING_BLOCKSIZE_CALC); + ctx.headSize = LoadTilingI32(tiling, TILING_HEADSIZE); + ctx.paraSize = LoadTilingI32(tiling, TILING_PARASIZE); + ctx.scale = LoadScale(tiling); + return ctx; } template -AICORE inline float PtoExpScalar(ScalarTile &tile, float value) -{ - tile.data()[0] = value; - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - TEXP(tile, tile); - pipe_barrier(PIPE_V); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - return tile.data()[0]; +AICORE inline float PtoExpScalar(ScalarTile &tile, float value) { + tile.data()[0] = value; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TEXP(tile, tile); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + return tile.data()[0]; } template -AICORE inline float PtoLogScalar(ScalarTile &tile, float value) -{ - if (value <= 0.0f) { - return -3.4028234663852886e38f; - } - tile.data()[0] = value; - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - TLOG(tile, tile); - pipe_barrier(PIPE_V); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - return tile.data()[0]; +AICORE inline float PtoLogScalar(ScalarTile &tile, float value) { + if (value <= 0.0f) { + return -3.4028234663852886e38f; + } + tile.data()[0] = value; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TLOG(tile, tile); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + return tile.data()[0]; } -AICORE inline float LoadPagedKByBlock( - __gm__ uint8_t *kGm, - int32_t blockId, - int32_t offsetInBlock, - int32_t blockSize, - int32_t kvHeads, - int32_t kvHead, - int32_t headDim, - int32_t dim) -{ - const int64_t offset = (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + kvHead) * headDim + dim); - return LoadFp16(kGm, offset); +AICORE inline float LoadPagedKByBlock(__gm__ uint8_t *kGm, int32_t blockId, + int32_t offsetInBlock, int32_t blockSize, + int32_t kvHeads, int32_t kvHead, + int32_t headDim, int32_t dim) { + const int64_t offset = + (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + + kvHead) * + headDim + + dim); + return LoadFp16(kGm, offset); } -AICORE inline float LoadPagedVByBlock( - __gm__ uint8_t *vGm, - int32_t blockId, - int32_t offsetInBlock, - int32_t blockSize, - int32_t kvHeads, - int32_t kvHead, - int32_t headDim, - int32_t dim) -{ - const int64_t offset = (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + kvHead) * headDim + dim); - return LoadFp16(vGm, offset); +AICORE inline float LoadPagedVByBlock(__gm__ uint8_t *vGm, int32_t blockId, + int32_t offsetInBlock, int32_t blockSize, + int32_t kvHeads, int32_t kvHead, + int32_t headDim, int32_t dim) { + const int64_t offset = + (((static_cast(blockId) * blockSize + offsetInBlock) * kvHeads + + kvHead) * + headDim + + dim); + return LoadFp16(vGm, offset); } -AICORE inline void ResolvePagedPosition( - __gm__ uint8_t *blockTablesGm, - int32_t batchIndex, - int32_t maxBlocksPerQuery, - int32_t pos, - int32_t blockSize, - int32_t &blockId, - int32_t &offsetInBlock) -{ - const int32_t tableCol = pos / blockSize; - offsetInBlock = pos - tableCol * blockSize; - blockId = LoadBlockTable(blockTablesGm, static_cast(batchIndex) * maxBlocksPerQuery + tableCol); +AICORE inline void ResolvePagedPosition(__gm__ uint8_t *blockTablesGm, + int32_t batchIndex, + int32_t maxBlocksPerQuery, int32_t pos, + int32_t blockSize, int32_t &blockId, + int32_t &offsetInBlock) { + const int32_t tableCol = pos / blockSize; + offsetInBlock = pos - tableCol * blockSize; + blockId = LoadBlockTable( + blockTablesGm, + static_cast(batchIndex) * maxBlocksPerQuery + tableCol); } -AICORE inline float ComputeScoreByBlock( - const float *qValues, - __gm__ uint8_t *kGm, - int32_t blockId, - int32_t offsetInBlock, - int32_t blockSize, - int32_t kvHead, - int32_t headDim, - int32_t kvHeads, - float scale) -{ - float score = 0.0f; - for (int32_t dim = 0; dim < headDim; ++dim) { - const float k = LoadPagedKByBlock(kGm, blockId, offsetInBlock, blockSize, kvHeads, kvHead, headDim, dim); - score += qValues[dim] * k; - } - return score * scale; +AICORE inline float ComputeScoreByBlock(const float *qValues, + __gm__ uint8_t *kGm, int32_t blockId, + int32_t offsetInBlock, + int32_t blockSize, int32_t kvHead, + int32_t headDim, int32_t kvHeads, + float scale) { + float score = 0.0f; + for (int32_t dim = 0; dim < headDim; ++dim) { + const float k = LoadPagedKByBlock(kGm, blockId, offsetInBlock, blockSize, + kvHeads, kvHead, headDim, dim); + score += qValues[dim] * k; + } + return score * scale; } - - constexpr int32_t PA_TILE_TOKENS = 128; constexpr uint8_t PTO_PA_REDUCE_READY_DECODER = 14; constexpr uint8_t PTO_PA_RAW_QK_READY = 0; @@ -233,938 +217,1340 @@ constexpr uint8_t PTO_PA_RAW_P_FREE = 6; constexpr uint8_t PTO_PA_RAW_PV_READY = 8; constexpr uint8_t PTO_PA_RAW_PV_FREE = 10; -AICORE inline uint8_t PtoPaSlotFlag(uint8_t baseFlag, uint8_t slot) -{ - return static_cast(baseFlag + slot); +AICORE inline uint8_t PtoPaSlotFlag(uint8_t baseFlag, uint8_t slot) { + return static_cast(baseFlag + slot); } -AICORE inline uint16_t PtoPaGetFftsMsg(uint16_t mode, uint16_t eventId, uint16_t baseConst = 0x1) -{ - return ((baseConst & 0xf) + ((mode & 0x3) << 4) + ((eventId & 0xf) << 8)); +AICORE inline uint16_t PtoPaGetFftsMsg(uint16_t mode, uint16_t eventId, + uint16_t baseConst = 0x1) { + return ((baseConst & 0xf) + ((mode & 0x3) << 4) + ((eventId & 0xf) << 8)); } -AICORE inline void PtoPaSignalFromCube(uint8_t flagId) -{ - pipe_barrier(PIPE_ALL); - ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); +AICORE inline void PtoPaSignalFromCube(uint8_t flagId) { + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); } -AICORE inline void PtoPaSignalFromVec(uint8_t flagId) -{ - pipe_barrier(PIPE_ALL); - ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); +AICORE inline void PtoPaSignalFromVec(uint8_t flagId) { + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); } -AICORE inline void PtoPaSignalFreeFromVec(uint8_t flagId) -{ - pipe_barrier(PIPE_ALL); - ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); +AICORE inline void PtoPaSignalFreeFromVec(uint8_t flagId) { + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x2, flagId)); } -AICORE inline void PtoPaSignalFreeFromCube(uint8_t flagId) -{ - pipe_barrier(PIPE_ALL); - ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); +AICORE inline void PtoPaSignalFreeFromCube(uint8_t flagId) { + pipe_barrier(PIPE_ALL); + ffts_cross_core_sync(PIPE_FIX, PtoPaGetFftsMsg(0x2, flagId)); } -// Real (un-sorted) batch index lives at +8 inside the per-batch para block; +13 (kParaBatchIndex) -// holds the sorted/remap slot. The CCE reference always double-indirects: read the sorted slot at -// +13, re-derive the para base, then read the real batch at +8. See pa_kernel.cce:523-525. +// Real (un-sorted) batch index lives at +8 inside the per-batch para block; +13 +// (kParaBatchIndex) holds the sorted/remap slot. The CCE reference always +// double-indirects: read the sorted slot at +13, re-derive the para base, then +// read the real batch at +8. See pa_kernel.cce:523-525. constexpr int32_t kParaRealBatchIndex = 8; -AICORE inline int32_t ResolveSortedParaBase(__gm__ uint8_t *tiling, const PaTilingContext &ctx, int32_t batchSlot) -{ - const int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tiling, paraBase + kParaBatchIndex); - return ctx.headSize + sortedBatch * ctx.paraSize; +AICORE inline int32_t ResolveSortedParaBase(__gm__ uint8_t *tiling, + const PaTilingContext &ctx, + int32_t batchSlot) { + const int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = LoadTilingI32(tiling, paraBase + kParaBatchIndex); + return ctx.headSize + sortedBatch * ctx.paraSize; } -// One-time sticky SPR setup mirroring the CCE reference SetArgs (pa_kernel.cce:441-444). -AICORE inline void PtoPaInitCoreState() -{ +// One-time sticky SPR setup mirroring the CCE reference SetArgs +// (pa_kernel.cce:441-444). +AICORE inline void PtoPaInitCoreState() { #if defined(__DAV_C220_CUBE__) - set_padding(0); - set_nd_para(1ULL); + set_padding(0); + set_nd_para(1ULL); #endif - set_atomic_none(); - set_mask_norm(); + set_atomic_none(); + set_mask_norm(); #if defined(__DAV_C220_VEC__) - set_vector_mask(static_cast(-1), static_cast(-1)); + set_vector_mask(static_cast(-1), static_cast(-1)); #endif } -AICORE inline bool SupportsPtoPagedAttentionRawSplitKV(__gm__ uint8_t *tilingParaGm) -{ - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (ctx.headDim != PA_TILE_TOKENS || ctx.headDimV != PA_TILE_TOKENS || ctx.blockSize != PA_TILE_TOKENS) { - return false; - } - if (ctx.batch <= 0 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || ctx.numHeads % ctx.kvHeads != 0) { - return false; - } - if (ctx.kvSplitCoreNum <= 1) { - return false; - } - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; - if (headsPerKv != 4 || formerHeadSplit % headsPerKv != 0 || formerHeadSplit < 16 || formerHeadSplit % 16 != 0) { - return false; - } - return ctx.kvSplitPerCore <= 8192; +AICORE inline bool SupportsPtoPagedAttentionRawSplitKV( + __gm__ uint8_t *tilingParaGm) { + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (ctx.headDim != PA_TILE_TOKENS || ctx.headDimV != PA_TILE_TOKENS || + ctx.blockSize != PA_TILE_TOKENS) { + return false; + } + if (ctx.batch <= 0 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + return false; + } + if (ctx.kvSplitCoreNum <= 1) { + return false; + } + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = + ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + if (headsPerKv != 4 || formerHeadSplit % headsPerKv != 0 || + formerHeadSplit < 16 || formerHeadSplit % 16 != 0) { + return false; + } + return ctx.kvSplitPerCore <= 8192; } -AICORE inline void DdrBarrierBeforePtoFfts() -{ +AICORE inline void DdrBarrierBeforePtoFfts() { #if defined(__CPU_SIM) - dsb(0); + dsb(0); #else - dsb(DSB_DDR); + dsb(DSB_DDR); #endif - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_ALL); } -AICORE inline void DdrFenceBeforePtoAivReduce() -{ +AICORE inline void DdrFenceBeforePtoAivReduce() { #if defined(__CPU_SIM) - dsb(0); + dsb(0); #else - dsb(DSB_DDR); + dsb(DSB_DDR); #endif } -AICORE inline void PtoPaSetFloatVectorMask(uint32_t len) -{ - set_mask_norm(); - constexpr uint32_t kFloatVectorSize = 64; - if (len >= kFloatVectorSize) { - set_vector_mask(static_cast(-1), static_cast(-1)); - return; - } - uint64_t mask = 0; - for (uint32_t i = 0; i < len; ++i) { - mask |= 1ULL << i; - } - set_vector_mask(0, mask); +AICORE inline void PtoPaSetFloatVectorMask(uint32_t len) { + set_mask_norm(); + constexpr uint32_t kFloatVectorSize = 64; + if (len >= kFloatVectorSize) { + set_vector_mask(static_cast(-1), static_cast(-1)); + return; + } + uint64_t mask = 0; + for (uint32_t i = 0; i < len; ++i) { + mask |= 1ULL << i; + } + set_vector_mask(0, mask); } - #ifdef __DAV_C220_VEC__ template -__tf__ AICORE void PtoPaConvF32ToF16Raw(typename DstTileData::TileDType __out__ dst, - typename SrcTileData::TileDType __in__ src, uint8_t repeat) -{ - __ubuf__ half *dstAddr = reinterpret_cast<__ubuf__ half *>(__cce_get_tile_ptr(dst)); - __ubuf__ float *srcAddr = reinterpret_cast<__ubuf__ float *>(__cce_get_tile_ptr(src)); - vconv_f322f16(dstAddr, srcAddr, repeat, 1, 1, 4, 8); +__tf__ AICORE void PtoPaConvF32ToF16Raw( + typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint8_t repeat) { + __ubuf__ half *dstAddr = + reinterpret_cast<__ubuf__ half *>(__cce_get_tile_ptr(dst)); + __ubuf__ float *srcAddr = + reinterpret_cast<__ubuf__ float *>(__cce_get_tile_ptr(src)); + vconv_f322f16(dstAddr, srcAddr, repeat, 1, 1, 4, 8); } template -AICORE inline void PtoPaConvF32ToF16(DstTileData &dst, SrcTileData &src, uint8_t repeat) -{ - PtoPaConvF32ToF16Raw(dst.data(), src.data(), repeat); +AICORE inline void PtoPaConvF32ToF16(DstTileData &dst, SrcTileData &src, + uint8_t repeat) { + PtoPaConvF32ToF16Raw(dst.data(), src.data(), + repeat); } #endif #ifdef __DAV_C220_CUBE__ template -__tf__ AICORE void PtoPaLoadNzHeadGroupToCaRaw(typename DstTileData::TileDType __out__ dst, - typename SrcTileData::TileDType __in__ src, uint32_t headGroupBase, uint16_t repeatTimes) -{ - using DataType = typename SrcTileData::DType; - static constexpr uint32_t kC0 = 16; - __ca__ DataType *dstAddr = reinterpret_cast<__ca__ DataType *>(__cce_get_tile_ptr(dst)); - __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + headGroupBase * kC0; - load_cbuf_to_ca(dstAddr, srcAddr, 0, repeatTimes, 1, 0, 0, false, false, addr_cal_mode_t(0)); +__tf__ AICORE void PtoPaLoadNzHeadGroupToCaRaw( + typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint32_t headGroupBase, + uint16_t repeatTimes) { + using DataType = typename SrcTileData::DType; + static constexpr uint32_t kC0 = 16; + __ca__ DataType *dstAddr = + reinterpret_cast<__ca__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = + reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + + headGroupBase * kC0; + load_cbuf_to_ca(dstAddr, srcAddr, 0, repeatTimes, 1, 0, 0, false, false, + addr_cal_mode_t(0)); } template -AICORE inline void PtoPaLoadNzHeadGroupToCa(DstTileData &dst, SrcTileData &src, uint32_t headGroupBase, - uint16_t repeatTimes) -{ - PtoPaLoadNzHeadGroupToCaRaw(dst.data(), src.data(), headGroupBase, repeatTimes); +AICORE inline void PtoPaLoadNzHeadGroupToCa(DstTileData &dst, SrcTileData &src, + uint32_t headGroupBase, + uint16_t repeatTimes) { + PtoPaLoadNzHeadGroupToCaRaw( + dst.data(), src.data(), headGroupBase, repeatTimes); } template -__tf__ AICORE void PtoPaLoadCbufToCbRaw(typename DstTileData::TileDType __out__ dst, - typename SrcTileData::TileDType __in__ src, uint32_t srcElementOffset, uint16_t repeatTimes, - uint16_t srcStride) -{ - using DataType = typename SrcTileData::DType; - __cb__ DataType *dstAddr = reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); - __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + srcElementOffset; - load_cbuf_to_cb(dstAddr, srcAddr, 0, repeatTimes, srcStride, 0, 0, false, addr_cal_mode_t(0)); +__tf__ AICORE void PtoPaLoadCbufToCbRaw( + typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src, uint32_t srcElementOffset, + uint16_t repeatTimes, uint16_t srcStride) { + using DataType = typename SrcTileData::DType; + __cb__ DataType *dstAddr = + reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = + reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)) + + srcElementOffset; + load_cbuf_to_cb(dstAddr, srcAddr, 0, repeatTimes, srcStride, 0, 0, false, + addr_cal_mode_t(0)); } template -AICORE inline void PtoPaLoadCbufToCbRaw(DstTileData &dst, SrcTileData &src, uint32_t srcElementOffset, - uint16_t repeatTimes, uint16_t srcStride) -{ - PtoPaLoadCbufToCbRaw(dst.data(), src.data(), srcElementOffset, repeatTimes, srcStride); +AICORE inline void PtoPaLoadCbufToCbRaw(DstTileData &dst, SrcTileData &src, + uint32_t srcElementOffset, + uint16_t repeatTimes, + uint16_t srcStride) { + PtoPaLoadCbufToCbRaw( + dst.data(), src.data(), srcElementOffset, repeatTimes, srcStride); } template -__tf__ AICORE void PtoPaLoadCbufToCbTranspose128Raw(typename DstTileData::TileDType __out__ dst, - typename SrcTileData::TileDType __in__ src) -{ - using DataType = typename SrcTileData::DType; - __cb__ DataType *dstAddr = reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); - __cbuf__ DataType *srcAddr = reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)); - constexpr uint32_t kBlock = 16; - constexpr uint32_t kRows = 128; - constexpr uint32_t kCols = 128; - for (uint32_t idx = 0; idx < kCols / kBlock; ++idx) { - load_cbuf_to_cb_transpose(dstAddr + idx * kRows * kBlock, srcAddr + idx * kBlock * kBlock, 0, - kRows / kBlock, kCols / kBlock, 0, addr_cal_mode_t(0), 0); - } +__tf__ AICORE void PtoPaLoadCbufToCbTranspose128Raw( + typename DstTileData::TileDType __out__ dst, + typename SrcTileData::TileDType __in__ src) { + using DataType = typename SrcTileData::DType; + __cb__ DataType *dstAddr = + reinterpret_cast<__cb__ DataType *>(__cce_get_tile_ptr(dst)); + __cbuf__ DataType *srcAddr = + reinterpret_cast<__cbuf__ DataType *>(__cce_get_tile_ptr(src)); + constexpr uint32_t kBlock = 16; + constexpr uint32_t kRows = 128; + constexpr uint32_t kCols = 128; + for (uint32_t idx = 0; idx < kCols / kBlock; ++idx) { + load_cbuf_to_cb_transpose( + dstAddr + idx * kRows * kBlock, srcAddr + idx * kBlock * kBlock, 0, + kRows / kBlock, kCols / kBlock, 0, addr_cal_mode_t(0), 0); + } } template -AICORE inline void PtoPaLoadCbufToCbTranspose128Raw(DstTileData &dst, SrcTileData &src) -{ - PtoPaLoadCbufToCbTranspose128Raw(dst.data(), src.data()); +AICORE inline void PtoPaLoadCbufToCbTranspose128Raw(DstTileData &dst, + SrcTileData &src) { + PtoPaLoadCbufToCbTranspose128Raw(dst.data(), + src.data()); } #endif -AICORE inline uint64_t LoadTilingOffset64(__gm__ uint8_t *tiling, int32_t base, int32_t highIdx, int32_t lowIdx) -{ - const uint32_t high = static_cast(LoadTilingI32(tiling, base + highIdx)); - const uint32_t low = static_cast(LoadTilingI32(tiling, base + lowIdx)); - return (static_cast(high) << 32) | static_cast(low); +AICORE inline uint64_t LoadTilingOffset64(__gm__ uint8_t *tiling, int32_t base, + int32_t highIdx, int32_t lowIdx) { + const uint32_t high = + static_cast(LoadTilingI32(tiling, base + highIdx)); + const uint32_t low = + static_cast(LoadTilingI32(tiling, base + lowIdx)); + return (static_cast(high) << 32) | static_cast(low); } #ifdef __DAV_C220_CUBE__ -AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV(__gm__ uint8_t *qGm, __gm__ uint8_t *kGm, - __gm__ uint8_t *vGm, __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, - __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum) -{ - constexpr int32_t kHeadDim = PA_TILE_TOKENS; - constexpr int32_t kTileTokens = PA_TILE_TOKENS; - constexpr int32_t kM = 16; - constexpr int32_t kMValid = 4; - constexpr int32_t kN = kTileTokens; - constexpr int32_t kK = 256; - constexpr int32_t kHeadGroup = 16; - - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || - ctx.blockSize != kTileTokens || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || - ctx.numHeads % ctx.kvHeads != 0) { - pipe_barrier(PIPE_ALL); - return; +AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( + __gm__ uint8_t *qGm, __gm__ uint8_t *kGm, __gm__ uint8_t *vGm, + __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, + __gm__ uint8_t *oTmpGm, __gm__ uint8_t *tilingParaGm, int64_t workerIdx, + int64_t workerNum) { + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + constexpr int32_t kM = 16; + constexpr int32_t kMValid = 4; + constexpr int32_t kN = kTileTokens; + constexpr int32_t kK = 256; + constexpr int32_t kHeadGroup = 16; + + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || + ctx.headDimV != kHeadDim || ctx.blockSize != kTileTokens || + ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const int32_t maxBlocksPerQuery = + ctx.maxBlocksPerQuery > 0 + ? ctx.maxBlocksPerQuery + : (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = + ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; + const int32_t corePerBatch = + (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = + static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + + using QGlobal = GlobalTensor< + half, Shape<1, 1, 1, kM, kHeadDim>, + Stride, + Layout::ND>; + using KGlobal = GlobalTensor, + Stride<1, 1, 1, 1, 8 * kHeadDim>, Layout::DN>; + using VGlobal = + GlobalTensor, + Stride>; + + using QMatTile = Tile; + using KMatTile = Tile; + using PMatTile = Tile; + using VMatTile = Tile; + using LeftQTile = TileLeft; + using LeftPTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + QMatTile qMatTile; + KMatTile kMatTile; + PMatTile pMatTile; + VMatTile vMatTile; + LeftQTile qLeftTile; + LeftPTile pLeftTile; + RightTile rightTile; + AccTile accTile; + constexpr uint32_t kQCacheBase = 0x00000; + constexpr uint32_t kQGroupBytes = kM * kHeadDim * sizeof(half); + constexpr uint32_t kPmatBase = 0x10000; + TASSIGN(qMatTile, kQCacheBase); + TASSIGN(kMatTile, 0x20000); + TASSIGN(pMatTile, kPmatBase); + TASSIGN(vMatTile, 0x20000); + TASSIGN(qLeftTile, 0x00000); + TASSIGN(pLeftTile, 0x00000); + TASSIGN(rightTile, 0x00000); + TASSIGN(accTile, 0x00000); + + using ScoreGlobal = + GlobalTensor, + Stride>; + using ProbGlobal = GlobalTensor, + Stride, + Layout::ND>; + using OutGlobal = + GlobalTensor, + Stride>; + + constexpr int64_t scoreHeadBytes = kMValid * kTileTokens * sizeof(float); + constexpr int64_t probHeadBytes = 256 * sizeof(half); + constexpr int64_t outHeadBytes = kMValid * kHeadDim * sizeof(float); + constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; + constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; + constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; + const int64_t scoreSlotBytes = + static_cast(maxHeadGroups) * scoreGroupBytes; + const int64_t probSlotBytes = + static_cast(maxHeadGroups) * probGroupBytes; + const int64_t outSlotBytes = + static_cast(maxHeadGroups) * outGroupBytes; + __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; + __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; + __gm__ uint8_t *outBase = oTmpGm + workerIdx * outSlotBytes * 2; + + const int64_t processRounds = (processNum + workerNum - 1) / workerNum; + const int32_t stageTileCount = + (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; + for (int64_t processRound = 0; processRound < processRounds; ++processRound) { + const int64_t process = processRound * workerNum + workerIdx; + bool validProcess = process < processNum; + int32_t batchIndex = 0; + int32_t curHeadNum = 0; + int32_t startHead = 0; + int32_t startTile = 0; + int32_t tileCount = 0; + int32_t curKvSeqLen = 0; + if (validProcess) { + int32_t curBatchSlot = + static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = + LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t kvSeqLenAlign = + ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = + (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const int32_t curSplit = + static_cast(process % ctx.kvSplitCoreNum); + validProcess = kvSeqLen > 0 && curSplit < kvLoop; + if (validProcess) { + const int32_t curHeadBlock = + static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + startHead = curHeadBlock * formerHeadSplit; + curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; + startTile = startKv / kTileTokens; + } } - const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : - (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; - const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; - const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; - const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; - - using QGlobal = GlobalTensor, - Stride, Layout::ND>; - using KGlobal = GlobalTensor, - Stride<1, 1, 1, 1, 8 * kHeadDim>, Layout::DN>; - using VGlobal = GlobalTensor, - Stride>; - - using QMatTile = Tile; - using KMatTile = Tile; - using PMatTile = Tile; - using VMatTile = Tile; - using LeftQTile = TileLeft; - using LeftPTile = TileLeft; - using RightTile = TileRight; - using AccTile = TileAcc; - - QMatTile qMatTile; - KMatTile kMatTile; - PMatTile pMatTile; - VMatTile vMatTile; - LeftQTile qLeftTile; - LeftPTile pLeftTile; - RightTile rightTile; - AccTile accTile; - constexpr uint32_t kQCacheBase = 0x00000; - constexpr uint32_t kQGroupBytes = kM * kHeadDim * sizeof(half); - constexpr uint32_t kPmatBase = 0x10000; - TASSIGN(qMatTile, kQCacheBase); - TASSIGN(kMatTile, 0x20000); - TASSIGN(pMatTile, kPmatBase); - TASSIGN(vMatTile, 0x20000); - TASSIGN(qLeftTile, 0x00000); - TASSIGN(pLeftTile, 0x00000); - TASSIGN(rightTile, 0x00000); - TASSIGN(accTile, 0x00000); - - using ScoreGlobal = GlobalTensor, - Stride>; - using ProbGlobal = GlobalTensor, - Stride, Layout::ND>; - using OutGlobal = GlobalTensor, - Stride>; - - constexpr int64_t scoreHeadBytes = kMValid * kTileTokens * sizeof(float); - constexpr int64_t probHeadBytes = 256 * sizeof(half); - constexpr int64_t outHeadBytes = kMValid * kHeadDim * sizeof(float); - constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; - constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; - constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; - const int64_t scoreSlotBytes = static_cast(maxHeadGroups) * scoreGroupBytes; - const int64_t probSlotBytes = static_cast(maxHeadGroups) * probGroupBytes; - const int64_t outSlotBytes = static_cast(maxHeadGroups) * outGroupBytes; - __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; - __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; - __gm__ uint8_t *outBase = oTmpGm + workerIdx * outSlotBytes * 2; - - const int64_t processRounds = (processNum + workerNum - 1) / workerNum; - const int32_t stageTileCount = (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; - for (int64_t processRound = 0; processRound < processRounds; ++processRound) { - const int64_t process = processRound * workerNum + workerIdx; - bool validProcess = process < processNum; - int32_t batchIndex = 0; - int32_t curHeadNum = 0; - int32_t startHead = 0; - int32_t startTile = 0; - int32_t tileCount = 0; - int32_t curKvSeqLen = 0; - if (validProcess) { - int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); - int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); - paraBase = ctx.headSize + sortedBatch * ctx.paraSize; - batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; - const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; - const int32_t curSplit = static_cast(process % ctx.kvSplitCoreNum); - validProcess = kvSeqLen > 0 && curSplit < kvLoop; - if (validProcess) { - const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); - startHead = curHeadBlock * formerHeadSplit; - curHeadNum = formerHeadSplit; - if (curHeadBlock == corePerBatch - 1) { - curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; - } - const int32_t startKv = curSplit * ctx.kvSplitPerCore; - curKvSeqLen = ctx.kvSplitPerCore; - if (curSplit == kvLoop - 1) { - curKvSeqLen = kvSeqLen - startKv; - } - tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; - startTile = startKv / kTileTokens; - } + if (validProcess) { + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + if (groupHeadBase >= curHeadNum) { + break; } + const int32_t firstHead = startHead + groupHeadBase; + const int64_t qBase = + (static_cast(batchIndex) * ctx.numHeads + firstHead) * + ctx.headDim; + QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TASSIGN(qMatTile, + kQCacheBase + static_cast(headGroup) * kQGroupBytes); + TLOAD(qMatTile, qGlobal); + pipe_barrier(PIPE_ALL); + } + } - if (validProcess) { - for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { - const int32_t groupHeadBase = headGroup * kHeadGroup; - if (groupHeadBase >= curHeadNum) { - break; - } - const int32_t firstHead = startHead + groupHeadBase; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + firstHead) * ctx.headDim; - QGlobal qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); - TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * kQGroupBytes); - TLOAD(qMatTile, qGlobal); - pipe_barrier(PIPE_ALL); - } - } + for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; + tilePairBase += 2) { + const bool hasStage2 = (tilePairBase + 1) < stageTileCount; + const bool activeTile0 = validProcess && tilePairBase < tileCount; + const bool activeTile1 = + validProcess && hasStage2 && (tilePairBase + 1) < tileCount; - for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; tilePairBase += 2) { - const bool hasStage2 = (tilePairBase + 1) < stageTileCount; - const bool activeTile0 = validProcess && tilePairBase < tileCount; - const bool activeTile1 = validProcess && hasStage2 && (tilePairBase + 1) < tileCount; - - for (uint32_t stage = 0; stage < 2; ++stage) { - if (stage == 1 && !hasStage2) { - break; - } - const int32_t tile = tilePairBase + static_cast(stage); - const uint8_t slot = static_cast(stage); - const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; - for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { - const int32_t groupHeadBase = headGroup * kHeadGroup; - const bool validGroup = validProcess && groupHeadBase < curHeadNum; - const bool activeGroup = validGroup && activeTileStage; - if (!activeGroup) { - continue; - } - TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * kQGroupBytes); - const int32_t blockId = LoadBlockTable(blockTablesGm, - static_cast(batchIndex) * maxBlocksPerQuery + startTile + tile); - for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { - const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; - if (baseHeadLocal >= curHeadNum) { - break; - } - const int32_t baseHead = startHead + baseHeadLocal; - const int32_t kvHead = baseHead / headsPerKv; - const int64_t kvBase = - (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; - PtoPaLoadNzHeadGroupToCa(qLeftTile, qMatTile, static_cast(headInGroupBase), - static_cast(kHeadDim / 16)); - KGlobal kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvBase); - TLOAD(kMatTile, kGlobal); - auto matmulEvent = EVENT_ID1; - set_flag(PIPE_FIX, PIPE_M, matmulEvent); - wait_flag(PIPE_FIX, PIPE_M, matmulEvent); - set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); - wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); - set_flag(PIPE_M, PIPE_MTE1, matmulEvent); - wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); - PtoPaLoadCbufToCbRaw(rightTile, kMatTile, 0, - static_cast((kHeadDim * kTileTokens) / 256), 1); - set_flag(PIPE_MTE1, PIPE_M, matmulEvent); - wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); - TMATMUL(accTile, qLeftTile, rightTile); - set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); - wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); - set_flag(PIPE_M, PIPE_FIX, matmulEvent); - wait_flag(PIPE_M, PIPE_FIX, matmulEvent); - ScoreGlobal scoreGlobal(reinterpret_cast<__gm__ float *>(scoreBase + - static_cast(slot) * scoreSlotBytes + - static_cast(headGroup) * scoreGroupBytes + - static_cast(headInGroupBase) * scoreHeadBytes)); - TSTORE(scoreGlobal, accTile); - } - } - DdrFenceBeforePtoAivReduce(); - PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!activeGroup) { + continue; + } + TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * + kQGroupBytes); + const int32_t blockId = + LoadBlockTable(blockTablesGm, static_cast(batchIndex) * + maxBlocksPerQuery + + startTile + tile); + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; + headInGroupBase += headsPerKv) { + const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; + if (baseHeadLocal >= curHeadNum) { + break; } + const int32_t baseHead = startHead + baseHeadLocal; + const int32_t kvHead = baseHead / headsPerKv; + const int64_t kvBase = + (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + + kvHead) * + ctx.headDim; + PtoPaLoadNzHeadGroupToCa(qLeftTile, qMatTile, + static_cast(headInGroupBase), + static_cast(kHeadDim / 16)); + KGlobal kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvBase); + TLOAD(kMatTile, kGlobal); + auto matmulEvent = EVENT_ID1; + set_flag(PIPE_FIX, PIPE_M, matmulEvent); + wait_flag(PIPE_FIX, PIPE_M, matmulEvent); + set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + set_flag(PIPE_M, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); + PtoPaLoadCbufToCbRaw( + rightTile, kMatTile, 0, + static_cast((kHeadDim * kTileTokens) / 256), 1); + set_flag(PIPE_MTE1, PIPE_M, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); + TMATMUL(accTile, qLeftTile, rightTile); + set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + set_flag(PIPE_M, PIPE_FIX, matmulEvent); + wait_flag(PIPE_M, PIPE_FIX, matmulEvent); + ScoreGlobal scoreGlobal(reinterpret_cast<__gm__ float *>( + scoreBase + static_cast(slot) * scoreSlotBytes + + static_cast(headGroup) * scoreGroupBytes + + static_cast(headInGroupBase) * scoreHeadBytes)); + TSTORE(scoreGlobal, accTile); + } + } + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); + } - for (uint32_t stage = 0; stage < 2; ++stage) { - if (stage == 1 && !hasStage2) { - break; - } - const int32_t tile = tilePairBase + static_cast(stage); - const uint8_t slot = static_cast(stage); - const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); - for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { - const int32_t groupHeadBase = headGroup * kHeadGroup; - const bool validGroup = validProcess && groupHeadBase < curHeadNum; - const bool activeGroup = validGroup && activeTileStage; - if (!activeGroup) { - continue; - } - const int32_t blockId = LoadBlockTable(blockTablesGm, - static_cast(batchIndex) * maxBlocksPerQuery + startTile + tile); - ProbGlobal probGlobal(reinterpret_cast<__gm__ half *>(probBase + - static_cast(slot) * probSlotBytes + - static_cast(headGroup) * probGroupBytes)); - TLOAD(pMatTile, probGlobal); - pipe_barrier(PIPE_ALL); - for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { - const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; - if (baseHeadLocal >= curHeadNum) { - break; - } - const int32_t baseHead = startHead + baseHeadLocal; - const int32_t kvHead = baseHead / headsPerKv; - const int64_t kvBase = - (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + kvHead) * ctx.headDim; - PtoPaLoadNzHeadGroupToCa(pLeftTile, pMatTile, static_cast(headInGroupBase), - static_cast(kTileTokens / 16)); - VGlobal vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvBase); - TLOAD(vMatTile, vGlobal); - auto matmulEvent = EVENT_ID1; - set_flag(PIPE_FIX, PIPE_M, matmulEvent); - wait_flag(PIPE_FIX, PIPE_M, matmulEvent); - set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); - wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); - set_flag(PIPE_M, PIPE_MTE1, matmulEvent); - wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); - PtoPaLoadCbufToCbTranspose128Raw(rightTile, vMatTile); - set_flag(PIPE_MTE1, PIPE_M, matmulEvent); - wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); - TMATMUL(accTile, pLeftTile, rightTile); - set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); - wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); - set_flag(PIPE_M, PIPE_FIX, matmulEvent); - wait_flag(PIPE_M, PIPE_FIX, matmulEvent); - OutGlobal outGlobal(reinterpret_cast<__gm__ float *>(outBase + - static_cast(slot) * outSlotBytes + - static_cast(headGroup) * outGroupBytes + - static_cast(headInGroupBase) * outHeadBytes)); - TSTORE(outGlobal, accTile); - } - } - DdrFenceBeforePtoAivReduce(); - PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); - } - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 0)); - if (hasStage2) { - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 1)); + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; + } + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!activeGroup) { + continue; + } + const int32_t blockId = + LoadBlockTable(blockTablesGm, static_cast(batchIndex) * + maxBlocksPerQuery + + startTile + tile); + ProbGlobal probGlobal(reinterpret_cast<__gm__ half *>( + probBase + static_cast(slot) * probSlotBytes + + static_cast(headGroup) * probGroupBytes)); + TLOAD(pMatTile, probGlobal); + pipe_barrier(PIPE_ALL); + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; + headInGroupBase += headsPerKv) { + const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; + if (baseHeadLocal >= curHeadNum) { + break; } + const int32_t baseHead = startHead + baseHeadLocal; + const int32_t kvHead = baseHead / headsPerKv; + const int64_t kvBase = + (static_cast(blockId) * ctx.blockSize * ctx.kvHeads + + kvHead) * + ctx.headDim; + PtoPaLoadNzHeadGroupToCa(pLeftTile, pMatTile, + static_cast(headInGroupBase), + static_cast(kTileTokens / 16)); + VGlobal vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvBase); + TLOAD(vMatTile, vGlobal); + auto matmulEvent = EVENT_ID1; + set_flag(PIPE_FIX, PIPE_M, matmulEvent); + wait_flag(PIPE_FIX, PIPE_M, matmulEvent); + set_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_MTE2, PIPE_MTE1, matmulEvent); + set_flag(PIPE_M, PIPE_MTE1, matmulEvent); + wait_flag(PIPE_M, PIPE_MTE1, matmulEvent); + PtoPaLoadCbufToCbTranspose128Raw(rightTile, vMatTile); + set_flag(PIPE_MTE1, PIPE_M, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_M, matmulEvent); + TMATMUL(accTile, pLeftTile, rightTile); + set_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + wait_flag(PIPE_MTE1, PIPE_MTE2, matmulEvent); + set_flag(PIPE_M, PIPE_FIX, matmulEvent); + wait_flag(PIPE_M, PIPE_FIX, matmulEvent); + OutGlobal outGlobal(reinterpret_cast<__gm__ float *>( + outBase + static_cast(slot) * outSlotBytes + + static_cast(headGroup) * outGroupBytes + + static_cast(headInGroupBase) * outHeadBytes)); + TSTORE(outGlobal, accTile); + } } - + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); + } + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 0)); + if (hasStage2) { + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 1)); + } } - pipe_barrier(PIPE_ALL); + } + pipe_barrier(PIPE_ALL); } #endif #ifdef __DAV_C220_VEC__ -AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV(__gm__ uint8_t *oGm, __gm__ uint8_t *sGm, - __gm__ uint8_t *pGm, __gm__ uint8_t *oTmpGm, __gm__ uint8_t *oCoreTmpGm, __gm__ uint8_t *lGm, - __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, uint32_t subBlockId) -{ - constexpr int32_t kHeadDim = PA_TILE_TOKENS; - constexpr int32_t kTileTokens = PA_TILE_TOKENS; - constexpr int32_t kHeadGroup = 16; - constexpr int32_t kMaxHeadsPerProcess = 32; - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || - ctx.blockSize != kTileTokens || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || - ctx.numHeads % ctx.kvHeads != 0) { - pipe_barrier(PIPE_ALL); - return; +AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV( + __gm__ uint8_t *oGm, __gm__ uint8_t *sGm, __gm__ uint8_t *pGm, + __gm__ uint8_t *oTmpGm, __gm__ uint8_t *oCoreTmpGm, __gm__ uint8_t *lGm, + __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, + uint32_t subBlockId) { + constexpr int32_t kHeadDim = PA_TILE_TOKENS; + constexpr int32_t kTileTokens = PA_TILE_TOKENS; + constexpr int32_t kHeadGroup = 16; + constexpr int32_t kMaxHeadsPerProcess = 32; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || + ctx.headDimV != kHeadDim || ctx.blockSize != kTileTokens || + ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const bool activeSubBlock = subBlockId < 2; + const bool combineSubBlock = subBlockId == 0; + const int32_t formerHeadSplit = + ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t corePerBatch = + (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = + static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); + __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); + + using VecFloat128 = + Tile; + using VecHalf128 = + Tile; + using VecHalf256 = + Tile; + using VecFloat8 = Tile; + using VecFloat4x128 = + Tile; + using VecFloat4x1 = Tile; + using VecFloat1x8 = Tile; + using ScoreGlobal = + GlobalTensor, + Stride>; + using ScoreRowsGlobal = GlobalTensor, + Stride<1, 1, 1, kHeadDim, 1>>; + using ProbGlobal = + GlobalTensor, Stride<256, 256, 256, 256, 1>>; + using ProbRowGlobal = + GlobalTensor, + Stride>; + using OutGlobal = + GlobalTensor, + Stride>; + using OutputGlobal = + GlobalTensor, + Stride>; + using OutRowsGlobal = GlobalTensor, + Stride<1, 1, 1, kHeadDim, 1>>; + + VecFloat128 weightedTile; + VecFloat128 scoreTile; + VecFloat128 scoreWorkTile; + VecFloat128 pvTile; + VecHalf128 probHalfTile; + VecHalf128 outHalfTile; + VecHalf256 probTile; + VecFloat8 rowMaxTile; + VecFloat8 rowSumTile; + VecFloat8 scalarMathTile; + VecFloat4x128 scoreRowsTile; + VecFloat4x128 scoreRowsWorkTile; + VecFloat128 + probRowView; // 1x128 view aliasing one row of scoreRowsWorkTile for TCVT + VecFloat4x128 pvRowsTile; + VecFloat4x1 rowMaxRowsTile; + VecFloat4x1 maxStateRowsTile; + VecFloat4x1 newMaxRowsTile; + VecFloat4x1 oldScaleRowsTile; + VecFloat4x1 rowSumRowsTile; + VecFloat4x1 sumStateRowsTile; + VecFloat1x8 rowMaxRowsView; + VecFloat1x8 maxStateRowsView; + VecFloat1x8 newMaxRowsView; + VecFloat1x8 oldScaleRowsView; + VecFloat1x8 rowSumRowsView; + VecFloat1x8 sumStateRowsView; + TASSIGN(weightedTile, 0x0000); + TASSIGN(scoreTile, 0x0800); + TASSIGN(scoreWorkTile, 0x1000); + TASSIGN(pvTile, 0x1800); + TASSIGN(probHalfTile, 0x2000); + TASSIGN(outHalfTile, 0x2000); + TASSIGN(probTile, 0x2800); + TASSIGN(rowMaxTile, 0x3000); + TASSIGN(rowSumTile, 0x3040); + TASSIGN(scalarMathTile, 0x3080); + TASSIGN(scoreRowsTile, 0x0800); + constexpr uint32_t kScoreRowsWorkUb = 0x1000; + TASSIGN(scoreRowsWorkTile, kScoreRowsWorkUb); + TASSIGN(pvRowsTile, 0x1800); + TASSIGN(rowMaxRowsTile, 0x3000); + TASSIGN(maxStateRowsTile, 0x3020); + TASSIGN(newMaxRowsTile, 0x3040); + TASSIGN(oldScaleRowsTile, 0x3060); + TASSIGN(rowSumRowsTile, 0x3080); + TASSIGN(sumStateRowsTile, 0x30a0); + TRESHAPE(rowMaxRowsView, rowMaxRowsTile); + TRESHAPE(maxStateRowsView, maxStateRowsTile); + TRESHAPE(newMaxRowsView, newMaxRowsTile); + TRESHAPE(oldScaleRowsView, oldScaleRowsTile); + TRESHAPE(rowSumRowsView, rowSumRowsTile); + TRESHAPE(sumStateRowsView, sumStateRowsTile); + + constexpr uint32_t kAccumUbBase = 0x4000; + constexpr uint32_t kAccumHeadBytes = kHeadDim * sizeof(float); + constexpr int64_t scoreHeadBytes = 4 * kTileTokens * sizeof(float); + constexpr int64_t probHeadBytes = 256 * sizeof(half); + constexpr int64_t outHeadBytes = 4 * kHeadDim * sizeof(float); + constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; + constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; + constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; + const int64_t scoreSlotBytes = + static_cast(maxHeadGroups) * scoreGroupBytes; + const int64_t probSlotBytes = + static_cast(maxHeadGroups) * probGroupBytes; + const int64_t outSlotBytes = + static_cast(maxHeadGroups) * outGroupBytes; + __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; + __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; + __gm__ uint8_t *outTmpBase = oTmpGm + workerIdx * outSlotBytes * 2; + + const int64_t processRounds = (processNum + workerNum - 1) / workerNum; + const int32_t stageTileCount = + (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; + for (int64_t processRound = 0; processRound < processRounds; ++processRound) { + const int64_t process = processRound * workerNum + workerIdx; + bool validProcess = process < processNum; + int32_t batchIndex = 0; + int32_t curHeadNum = 0; + int32_t startHead = 0; + int32_t tileCount = 0; + int32_t curKvSeqLen = 0; + int32_t curSplit = 0; + uint64_t lBase = 0; + uint64_t oFdBase = 0; + if (validProcess) { + int32_t curBatchSlot = + static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = + LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t kvSeqLenAlign = + ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = + (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + curSplit = static_cast(process % ctx.kvSplitCoreNum); + validProcess = kvSeqLen > 0 && curSplit < kvLoop; + if (validProcess) { + const int32_t curHeadBlock = + static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + startHead = curHeadBlock * formerHeadSplit; + curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; + lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + } } - const bool activeSubBlock = subBlockId < 2; - const bool combineSubBlock = subBlockId == 0; - const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; - const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; - const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; - __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); - __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); - - using VecFloat128 = Tile; - using VecHalf128 = Tile; - using VecHalf256 = Tile; - using VecFloat8 = Tile; - using VecFloat4x128 = Tile; - using VecFloat4x1 = Tile; - using VecFloat1x8 = Tile; - using ScoreGlobal = GlobalTensor, - Stride>; - using ScoreRowsGlobal = GlobalTensor, - Stride<1, 1, 1, kHeadDim, 1>>; - using ProbGlobal = GlobalTensor, Stride<256, 256, 256, 256, 1>>; - using ProbRowGlobal = GlobalTensor, - Stride>; - using OutGlobal = GlobalTensor, - Stride>; - using OutputGlobal = GlobalTensor, - Stride>; - using OutRowsGlobal = GlobalTensor, - Stride<1, 1, 1, kHeadDim, 1>>; + const int32_t subHeadBegin = subBlockId == 0 ? 0 : curHeadNum / 2; + const int32_t subHeadEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; + float maxScore[kMaxHeadsPerProcess]; + float sumExp[kMaxHeadsPerProcess]; + float oldScaleByHead[kMaxHeadsPerProcess]; + if (activeSubBlock && validProcess) { + for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; + ++headLocal) { + maxScore[headLocal] = -3.4028234663852886e38f; + sumExp[headLocal] = 0.0f; + oldScaleByHead[headLocal] = 0.0f; + TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * + kAccumHeadBytes); + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + } + } - VecFloat128 weightedTile; - VecFloat128 scoreTile; - VecFloat128 scoreWorkTile; - VecFloat128 pvTile; - VecHalf128 probHalfTile; - VecHalf128 outHalfTile; - VecHalf256 probTile; - VecFloat8 rowMaxTile; - VecFloat8 rowSumTile; - VecFloat8 scalarMathTile; - VecFloat4x128 scoreRowsTile; - VecFloat4x128 scoreRowsWorkTile; - VecFloat128 probRowView; // 1x128 view aliasing one row of scoreRowsWorkTile for TCVT - VecFloat4x128 pvRowsTile; - VecFloat4x1 rowMaxRowsTile; - VecFloat4x1 maxStateRowsTile; - VecFloat4x1 newMaxRowsTile; - VecFloat4x1 oldScaleRowsTile; - VecFloat4x1 rowSumRowsTile; - VecFloat4x1 sumStateRowsTile; - VecFloat1x8 rowMaxRowsView; - VecFloat1x8 maxStateRowsView; - VecFloat1x8 newMaxRowsView; - VecFloat1x8 oldScaleRowsView; - VecFloat1x8 rowSumRowsView; - VecFloat1x8 sumStateRowsView; - TASSIGN(weightedTile, 0x0000); - TASSIGN(scoreTile, 0x0800); - TASSIGN(scoreWorkTile, 0x1000); - TASSIGN(pvTile, 0x1800); - TASSIGN(probHalfTile, 0x2000); - TASSIGN(outHalfTile, 0x2000); - TASSIGN(probTile, 0x2800); - TASSIGN(rowMaxTile, 0x3000); - TASSIGN(rowSumTile, 0x3040); - TASSIGN(scalarMathTile, 0x3080); - TASSIGN(scoreRowsTile, 0x0800); - constexpr uint32_t kScoreRowsWorkUb = 0x1000; - TASSIGN(scoreRowsWorkTile, kScoreRowsWorkUb); - TASSIGN(pvRowsTile, 0x1800); - TASSIGN(rowMaxRowsTile, 0x3000); - TASSIGN(maxStateRowsTile, 0x3020); - TASSIGN(newMaxRowsTile, 0x3040); - TASSIGN(oldScaleRowsTile, 0x3060); - TASSIGN(rowSumRowsTile, 0x3080); - TASSIGN(sumStateRowsTile, 0x30a0); - TRESHAPE(rowMaxRowsView, rowMaxRowsTile); - TRESHAPE(maxStateRowsView, maxStateRowsTile); - TRESHAPE(newMaxRowsView, newMaxRowsTile); - TRESHAPE(oldScaleRowsView, oldScaleRowsTile); - TRESHAPE(rowSumRowsView, rowSumRowsTile); - TRESHAPE(sumStateRowsView, sumStateRowsTile); - - constexpr uint32_t kAccumUbBase = 0x4000; - constexpr uint32_t kAccumHeadBytes = kHeadDim * sizeof(float); - constexpr int64_t scoreHeadBytes = 4 * kTileTokens * sizeof(float); - constexpr int64_t probHeadBytes = 256 * sizeof(half); - constexpr int64_t outHeadBytes = 4 * kHeadDim * sizeof(float); - constexpr int64_t scoreGroupBytes = kHeadGroup * scoreHeadBytes; - constexpr int64_t probGroupBytes = kHeadGroup * probHeadBytes; - constexpr int64_t outGroupBytes = kHeadGroup * outHeadBytes; - const int64_t scoreSlotBytes = static_cast(maxHeadGroups) * scoreGroupBytes; - const int64_t probSlotBytes = static_cast(maxHeadGroups) * probGroupBytes; - const int64_t outSlotBytes = static_cast(maxHeadGroups) * outGroupBytes; - __gm__ uint8_t *scoreBase = sGm + workerIdx * scoreSlotBytes * 2; - __gm__ uint8_t *probBase = pGm + workerIdx * probSlotBytes * 2; - __gm__ uint8_t *outTmpBase = oTmpGm + workerIdx * outSlotBytes * 2; - - const int64_t processRounds = (processNum + workerNum - 1) / workerNum; - const int32_t stageTileCount = (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; - for (int64_t processRound = 0; processRound < processRounds; ++processRound) { - const int64_t process = processRound * workerNum + workerIdx; - bool validProcess = process < processNum; - int32_t batchIndex = 0; - int32_t curHeadNum = 0; - int32_t startHead = 0; - int32_t tileCount = 0; - int32_t curKvSeqLen = 0; - int32_t curSplit = 0; - uint64_t lBase = 0; - uint64_t oFdBase = 0; - if (validProcess) { - int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); - int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); - paraBase = ctx.headSize + sortedBatch * ctx.paraSize; - batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; - const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; - curSplit = static_cast(process % ctx.kvSplitCoreNum); - validProcess = kvSeqLen > 0 && curSplit < kvLoop; - if (validProcess) { - const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); - startHead = curHeadBlock * formerHeadSplit; - curHeadNum = formerHeadSplit; - if (curHeadBlock == corePerBatch - 1) { - curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; - } - const int32_t startKv = curSplit * ctx.kvSplitPerCore; - curKvSeqLen = ctx.kvSplitPerCore; - if (curSplit == kvLoop - 1) { - curKvSeqLen = kvSeqLen - startKv; - } - tileCount = (curKvSeqLen + kTileTokens - 1) / kTileTokens; - lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); - oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); - } - } + for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; + tilePairBase += 2) { + const bool hasStage2 = (tilePairBase + 1) < stageTileCount; + const bool activeTile0 = validProcess && tilePairBase < tileCount; + const bool activeTile1 = + validProcess && hasStage2 && (tilePairBase + 1) < tileCount; - const int32_t subHeadBegin = subBlockId == 0 ? 0 : curHeadNum / 2; - const int32_t subHeadEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; - float maxScore[kMaxHeadsPerProcess]; - float sumExp[kMaxHeadsPerProcess]; - float oldScaleByHead[kMaxHeadsPerProcess]; - if (activeSubBlock && validProcess) { - for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; ++headLocal) { - maxScore[headLocal] = -3.4028234663852886e38f; - sumExp[headLocal] = 0.0f; - oldScaleByHead[headLocal] = 0.0f; - TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); - TEXPANDS(weightedTile, 0.0f); - pipe_barrier(PIPE_V); - } + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; } - - for (int32_t tilePairBase = 0; tilePairBase < stageTileCount; tilePairBase += 2) { - const bool hasStage2 = (tilePairBase + 1) < stageTileCount; - const bool activeTile0 = validProcess && tilePairBase < tileCount; - const bool activeTile1 = validProcess && hasStage2 && (tilePairBase + 1) < tileCount; - - for (uint32_t stage = 0; stage < 2; ++stage) { - if (stage == 1 && !hasStage2) { - break; - } - const int32_t tile = tilePairBase + static_cast(stage); - const uint8_t slot = static_cast(stage); - const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); - for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { - const int32_t groupHeadBase = headGroup * kHeadGroup; - const bool validGroup = validProcess && groupHeadBase < curHeadNum; - const bool activeGroup = validGroup && activeTileStage; - if (!(activeSubBlock && activeGroup)) { - continue; - } - for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { - const int32_t headLocal = groupHeadBase + headInGroupBase; - if (headLocal >= curHeadNum) { - break; - } - if (headsPerKv != 4 || headLocal < subHeadBegin || headLocal + 4 > subHeadEnd) { - continue; - } - ScoreRowsGlobal scoreGlobal(reinterpret_cast<__gm__ float *>(scoreBase + - static_cast(slot) * scoreSlotBytes + - static_cast(headGroup) * scoreGroupBytes + - static_cast(headInGroupBase) * scoreHeadBytes)); - TLOAD(scoreRowsTile, scoreGlobal); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - TMULS(scoreRowsTile, scoreRowsTile, ctx.scale); - pipe_barrier(PIPE_V); - for (int32_t row = 0; row < 4; ++row) { - maxStateRowsTile.data()[row] = maxScore[headLocal + row]; - sumStateRowsTile.data()[row] = sumExp[headLocal + row]; - } - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - TROWMAX(rowMaxRowsTile, scoreRowsTile, scoreRowsWorkTile); - pipe_barrier(PIPE_V); - TMAX(newMaxRowsView, rowMaxRowsView, maxStateRowsView); - pipe_barrier(PIPE_V); - TSUB(oldScaleRowsView, maxStateRowsView, newMaxRowsView); - pipe_barrier(PIPE_V); - TEXP(oldScaleRowsView, oldScaleRowsView); - pipe_barrier(PIPE_V); - if (tile == 0) { - TEXPANDS(oldScaleRowsView, 0.0f); - pipe_barrier(PIPE_V); - } - TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile); - pipe_barrier(PIPE_V); - TEXP(scoreRowsWorkTile, scoreRowsWorkTile); - pipe_barrier(PIPE_V); - TROWSUM(rowSumRowsTile, scoreRowsWorkTile, scoreRowsTile); - pipe_barrier(PIPE_V); - TMUL(sumStateRowsView, sumStateRowsView, oldScaleRowsView); - pipe_barrier(PIPE_V); - TADD(sumStateRowsView, sumStateRowsView, rowSumRowsView); - pipe_barrier(PIPE_V); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - __gm__ half *probScratch = reinterpret_cast<__gm__ half *>(probBase + - static_cast(slot) * probSlotBytes + - static_cast(headGroup) * probGroupBytes); - for (int32_t row = 0; row < 4; ++row) { - maxScore[headLocal + row] = newMaxRowsTile.data()[row]; - sumExp[headLocal + row] = sumStateRowsTile.data()[row]; - oldScaleByHead[headLocal + row] = oldScaleRowsTile.data()[row]; - TASSIGN(probRowView, kScoreRowsWorkUb + - static_cast(row) * kTileTokens * sizeof(float)); - PtoPaConvF32ToF16(probHalfTile, probRowView, 2); - pipe_barrier(PIPE_V); - ProbRowGlobal probRowGlobal(probScratch + - static_cast(headInGroupBase + row) * kTileTokens); - TSTORE(probRowGlobal, probHalfTile); - } - } - } - DdrFenceBeforePtoAivReduce(); - PtoPaSignalFromVec(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); + const int32_t tile = tilePairBase + static_cast(stage); + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!(activeSubBlock && activeGroup)) { + continue; + } + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; + headInGroupBase += headsPerKv) { + const int32_t headLocal = groupHeadBase + headInGroupBase; + if (headLocal >= curHeadNum) { + break; } - - for (uint32_t stage = 0; stage < 2; ++stage) { - if (stage == 1 && !hasStage2) { - break; - } - const uint8_t slot = static_cast(stage); - const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); - for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { - const int32_t groupHeadBase = headGroup * kHeadGroup; - const bool validGroup = validProcess && groupHeadBase < curHeadNum; - const bool activeGroup = validGroup && activeTileStage; - if (!(activeSubBlock && activeGroup)) { - continue; - } - for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { - const int32_t headLocal = groupHeadBase + headInGroupBase; - if (headLocal >= curHeadNum) { - break; - } - if (headsPerKv != 4 || headLocal < subHeadBegin || headLocal + 4 > subHeadEnd) { - continue; - } - for (int32_t row = 0; row < 4; ++row) { - oldScaleRowsTile.data()[row] = oldScaleByHead[headLocal + row]; - } - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - OutRowsGlobal outGlobal(reinterpret_cast<__gm__ float *>(outTmpBase + - static_cast(slot) * outSlotBytes + - static_cast(headGroup) * outGroupBytes + - static_cast(headInGroupBase) * outHeadBytes)); - TLOAD(pvRowsTile, outGlobal); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - VecFloat4x128 weightedRowsTile; - TASSIGN(weightedRowsTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); - TROWEXPANDMUL(weightedRowsTile, weightedRowsTile, oldScaleRowsTile); - pipe_barrier(PIPE_V); - TADD(weightedRowsTile, weightedRowsTile, pvRowsTile); - pipe_barrier(PIPE_V); - } - } - PtoPaSignalFreeFromVec(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, slot)); + if (headsPerKv != 4 || headLocal < subHeadBegin || + headLocal + 4 > subHeadEnd) { + continue; } - } - - - if (activeSubBlock && validProcess) { - for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; ++headLocal) { - const int32_t head = startHead + headLocal; - const float invSum = sumExp[headLocal] > 0.0f ? 1.0f / sumExp[headLocal] : 0.0f; - const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + - static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + - static_cast(curSplit) * ctx.headDim; - const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; - partialL[lOffset] = maxScore[headLocal] + PtoLogScalar(scalarMathTile, sumExp[headLocal]); - TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * kAccumHeadBytes); - TMULS(weightedTile, weightedTile, invSum); - pipe_barrier(PIPE_V); - OutGlobal weightedGlobal(reinterpret_cast<__gm__ float *>(partialOut + outOffset)); - TSTORE(weightedGlobal, weightedTile); + ScoreRowsGlobal scoreGlobal(reinterpret_cast<__gm__ float *>( + scoreBase + static_cast(slot) * scoreSlotBytes + + static_cast(headGroup) * scoreGroupBytes + + static_cast(headInGroupBase) * scoreHeadBytes)); + TLOAD(scoreRowsTile, scoreGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(scoreRowsTile, scoreRowsTile, ctx.scale); + pipe_barrier(PIPE_V); + for (int32_t row = 0; row < 4; ++row) { + maxStateRowsTile.data()[row] = maxScore[headLocal + row]; + sumStateRowsTile.data()[row] = sumExp[headLocal + row]; + } + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TROWMAX(rowMaxRowsTile, scoreRowsTile, scoreRowsWorkTile); + pipe_barrier(PIPE_V); + TMAX(newMaxRowsView, rowMaxRowsView, maxStateRowsView); + pipe_barrier(PIPE_V); + TSUB(oldScaleRowsView, maxStateRowsView, newMaxRowsView); + pipe_barrier(PIPE_V); + TEXP(oldScaleRowsView, oldScaleRowsView); + pipe_barrier(PIPE_V); + if (tile == 0) { + TEXPANDS(oldScaleRowsView, 0.0f); + pipe_barrier(PIPE_V); } + TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile); + pipe_barrier(PIPE_V); + TEXP(scoreRowsWorkTile, scoreRowsWorkTile); + pipe_barrier(PIPE_V); + TROWSUM(rowSumRowsTile, scoreRowsWorkTile, scoreRowsTile); + pipe_barrier(PIPE_V); + TMUL(sumStateRowsView, sumStateRowsView, oldScaleRowsView); + pipe_barrier(PIPE_V); + TADD(sumStateRowsView, sumStateRowsView, rowSumRowsView); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + __gm__ half *probScratch = reinterpret_cast<__gm__ half *>( + probBase + static_cast(slot) * probSlotBytes + + static_cast(headGroup) * probGroupBytes); + for (int32_t row = 0; row < 4; ++row) { + maxScore[headLocal + row] = newMaxRowsTile.data()[row]; + sumExp[headLocal + row] = sumStateRowsTile.data()[row]; + oldScaleByHead[headLocal + row] = oldScaleRowsTile.data()[row]; + TASSIGN(probRowView, + kScoreRowsWorkUb + static_cast(row) * + kTileTokens * sizeof(float)); + PtoPaConvF32ToF16(probHalfTile, probRowView, 2); + pipe_barrier(PIPE_V); + ProbRowGlobal probRowGlobal( + probScratch + + static_cast(headInGroupBase + row) * kTileTokens); + TSTORE(probRowGlobal, probHalfTile); + } + } } - } + DdrFenceBeforePtoAivReduce(); + PtoPaSignalFromVec(PtoPaSlotFlag(PTO_PA_RAW_P_READY, slot)); + } - DdrBarrierBeforePtoFfts(); - ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); - wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); - if (!combineSubBlock) { - pipe_barrier(PIPE_ALL); - return; - } - const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; - const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - const int64_t combineWorkerIdx = workerIdx; - const int64_t combineWorkerNum = workerNum; - for (int64_t row = combineWorkerIdx; row < totalRows; row += combineWorkerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); - paraBase = ctx.headSize + sortedBatch * ctx.paraSize; - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - if (kvSeqLen <= 0) { - continue; + for (uint32_t stage = 0; stage < 2; ++stage) { + if (stage == 1 && !hasStage2) { + break; } - const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; - const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; - const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); - const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); - __ubuf__ float *splitScale = reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3200); - __ubuf__ float *splitReduce = reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3400); - const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum; - const uint32_t lRemain = static_cast(kvLoop % 8); - copy_gm_to_ubuf_align_b32(splitScale, partialL + lOffset, 0, 1, static_cast(kvLoop * 4), 0, - lRemain == 0 ? 0 : 8 - lRemain, 0, 0); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - PtoPaSetFloatVectorMask(static_cast(kvLoop)); - vcmax(splitReduce, splitScale, 1, 1, 1, 8, static_cast(ONLY_VALUE)); - pipe_barrier(PIPE_V); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - const float lMax = splitReduce[0]; - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - vadds(splitScale, splitScale, -lMax, 1, 1, 1, 8, 8); - pipe_barrier(PIPE_V); - vexp(splitScale, splitScale, 1, 1, 1, 8, 8); - pipe_barrier(PIPE_V); - vcadd(splitReduce, splitScale, 1, 1, 1, 8, 0); - pipe_barrier(PIPE_V); - set_vector_mask(static_cast(-1), static_cast(-1)); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - const float denom = splitReduce[0]; - set_flag(PIPE_S, PIPE_V, EVENT_ID2); - wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; - PtoPaSetFloatVectorMask(static_cast(kvLoop)); - vmuls(splitScale, splitScale, static_cast(invDenom), 1, 1, 1, 8, 8); - pipe_barrier(PIPE_V); - set_vector_mask(static_cast(-1), static_cast(-1)); - TASSIGN(weightedTile, 0x0000); - TEXPANDS(weightedTile, 0.0f); - pipe_barrier(PIPE_V); - constexpr uint32_t kCombineOutUb = 0x4000; - __ubuf__ float *splitOut = reinterpret_cast<__ubuf__ float *>((uintptr_t)kCombineOutUb); - const uint64_t firstOutOffset = oFdBase * ctx.kvSplitCoreNum + - static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum; - copy_gm_to_ubuf_align_b32(splitOut, partialOut + firstOutOffset, 0, 1, - static_cast(kvLoop * ctx.headDim * 4), 0, 0, 0, 0); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - for (int32_t split = 0; split < kvLoop; ++split) { - TASSIGN(pvTile, kCombineOutUb + static_cast(split) * kHeadDim * sizeof(float)); - set_flag(PIPE_V, PIPE_S, EVENT_ID3); - wait_flag(PIPE_V, PIPE_S, EVENT_ID3); - const float splitWeight = splitScale[split]; + const uint8_t slot = static_cast(stage); + const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); + for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { + const int32_t groupHeadBase = headGroup * kHeadGroup; + const bool validGroup = validProcess && groupHeadBase < curHeadNum; + const bool activeGroup = validGroup && activeTileStage; + if (!(activeSubBlock && activeGroup)) { + continue; + } + for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; + headInGroupBase += headsPerKv) { + const int32_t headLocal = groupHeadBase + headInGroupBase; + if (headLocal >= curHeadNum) { + break; + } + if (headsPerKv != 4 || headLocal < subHeadBegin || + headLocal + 4 > subHeadEnd) { + continue; + } + for (int32_t row = 0; row < 4; ++row) { + oldScaleRowsTile.data()[row] = oldScaleByHead[headLocal + row]; + } set_flag(PIPE_S, PIPE_V, EVENT_ID2); wait_flag(PIPE_S, PIPE_V, EVENT_ID2); - TMULS(pvTile, pvTile, splitWeight); + OutRowsGlobal outGlobal(reinterpret_cast<__gm__ float *>( + outTmpBase + static_cast(slot) * outSlotBytes + + static_cast(headGroup) * outGroupBytes + + static_cast(headInGroupBase) * outHeadBytes)); + TLOAD(pvRowsTile, outGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + VecFloat4x128 weightedRowsTile; + TASSIGN(weightedRowsTile, + kAccumUbBase + + static_cast(headLocal) * kAccumHeadBytes); + TROWEXPANDMUL(weightedRowsTile, weightedRowsTile, oldScaleRowsTile); pipe_barrier(PIPE_V); - TADD(weightedTile, weightedTile, pvTile); + TADD(weightedRowsTile, weightedRowsTile, pvRowsTile); pipe_barrier(PIPE_V); + } } - const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - OutputGlobal outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); - TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + PtoPaSignalFreeFromVec(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, slot)); + } + } + + if (activeSubBlock && validProcess) { + for (int32_t headLocal = subHeadBegin; headLocal < subHeadEnd; + ++headLocal) { + const int32_t head = startHead + headLocal; + const float invSum = + sumExp[headLocal] > 0.0f ? 1.0f / sumExp[headLocal] : 0.0f; + const uint64_t outOffset = + oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(curSplit) * ctx.headDim; + const uint64_t lOffset = + lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; + partialL[lOffset] = maxScore[headLocal] + + PtoLogScalar(scalarMathTile, sumExp[headLocal]); + TASSIGN(weightedTile, kAccumUbBase + static_cast(headLocal) * + kAccumHeadBytes); + TMULS(weightedTile, weightedTile, invSum); pipe_barrier(PIPE_V); - TSTORE(outGlobal, outHalfTile); - pipe_barrier(PIPE_ALL); + OutGlobal weightedGlobal( + reinterpret_cast<__gm__ float *>(partialOut + outOffset)); + TSTORE(weightedGlobal, weightedTile); + } + } + } + + DdrBarrierBeforePtoFfts(); + ffts_cross_core_sync(PIPE_MTE3, + PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); + wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); + if (!combineSubBlock) { + pipe_barrier(PIPE_ALL); + return; + } + const int32_t effectiveBatch = + ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + const int64_t combineWorkerIdx = workerIdx; + const int64_t combineWorkerNum = workerNum; + for (int64_t row = combineWorkerIdx; row < totalRows; + row += combineWorkerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = + LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; } + const int32_t kvSeqLenAlign = + ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = + (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + __ubuf__ float *splitScale = + reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3200); + __ubuf__ float *splitReduce = + reinterpret_cast<__ubuf__ float *>((uintptr_t)0x3400); + const uint64_t lOffset = + lBase + static_cast(head) * ctx.kvSplitCoreNum; + const uint32_t lRemain = static_cast(kvLoop % 8); + copy_gm_to_ubuf_align_b32(splitScale, partialL + lOffset, 0, 1, + static_cast(kvLoop * 4), 0, + lRemain == 0 ? 0 : 8 - lRemain, 0, 0); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + PtoPaSetFloatVectorMask(static_cast(kvLoop)); + vcmax(splitReduce, splitScale, 1, 1, 1, 8, + static_cast(ONLY_VALUE)); + pipe_barrier(PIPE_V); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float lMax = splitReduce[0]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + vadds(splitScale, splitScale, -lMax, 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + vexp(splitScale, splitScale, 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + vcadd(splitReduce, splitScale, 1, 1, 1, 8, 0); + pipe_barrier(PIPE_V); + set_vector_mask(static_cast(-1), static_cast(-1)); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float denom = splitReduce[0]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; + PtoPaSetFloatVectorMask(static_cast(kvLoop)); + vmuls(splitScale, splitScale, static_cast(invDenom), 1, 1, 1, 8, 8); + pipe_barrier(PIPE_V); + set_vector_mask(static_cast(-1), static_cast(-1)); + TASSIGN(weightedTile, 0x0000); + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + constexpr uint32_t kCombineOutUb = 0x4000; + __ubuf__ float *splitOut = + reinterpret_cast<__ubuf__ float *>((uintptr_t)kCombineOutUb); + const uint64_t firstOutOffset = + oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum; + copy_gm_to_ubuf_align_b32(splitOut, partialOut + firstOutOffset, 0, 1, + static_cast(kvLoop * ctx.headDim * 4), + 0, 0, 0, 0); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + for (int32_t split = 0; split < kvLoop; ++split) { + TASSIGN(pvTile, kCombineOutUb + static_cast(split) * kHeadDim * + sizeof(float)); + set_flag(PIPE_V, PIPE_S, EVENT_ID3); + wait_flag(PIPE_V, PIPE_S, EVENT_ID3); + const float splitWeight = splitScale[split]; + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); + TMULS(pvTile, pvTile, splitWeight); + pipe_barrier(PIPE_V); + TADD(weightedTile, weightedTile, pvTile); + pipe_barrier(PIPE_V); + } + const int64_t outBase = + (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + OutputGlobal outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); + TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + pipe_barrier(PIPE_V); + TSTORE(outGlobal, outHalfTile); pipe_barrier(PIPE_ALL); + } + pipe_barrier(PIPE_ALL); } #endif #ifdef __DAV_C220_VEC__ AICORE inline void RunPtoPagedAttentionDecodeSplitKV( - __gm__ uint8_t *qGm, - __gm__ uint8_t *kGm, - __gm__ uint8_t *vGm, - __gm__ uint8_t *blockTablesGm, - __gm__ uint8_t *oGm, - __gm__ uint8_t *oCoreTmpGm, - __gm__ uint8_t *lGm, - __gm__ uint8_t *tilingParaGm, - int64_t workerIdx, - int64_t workerNum, - uint32_t subBlockId) -{ - constexpr int32_t kHeadDim = 128; - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || ctx.headDimV != kHeadDim || - ctx.blockSize != PA_TILE_TOKENS || ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || - ctx.numHeads % ctx.kvHeads != 0) { + __gm__ uint8_t *qGm, __gm__ uint8_t *kGm, __gm__ uint8_t *vGm, + __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *oGm, + __gm__ uint8_t *oCoreTmpGm, __gm__ uint8_t *lGm, + __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum, + uint32_t subBlockId) { + constexpr int32_t kHeadDim = 128; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (workerIdx < 0 || workerNum <= 0 || ctx.headDim != kHeadDim || + ctx.headDimV != kHeadDim || ctx.blockSize != PA_TILE_TOKENS || + ctx.kvSplitCoreNum <= 1 || ctx.numHeads <= 0 || ctx.kvHeads <= 0 || + ctx.numHeads % ctx.kvHeads != 0) { + pipe_barrier(PIPE_ALL); + return; + } + + const int32_t maxBlocksPerQuery = + ctx.maxBlocksPerQuery > 0 + ? ctx.maxBlocksPerQuery + : (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int32_t formerHeadSplit = + ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; + const int32_t corePerBatch = + (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; + const int64_t processNum = + static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; + __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); + __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); + + using VecHalf128 = + Tile; + using VecFloat128 = + Tile; + using VecFloat8 = Tile; + using GlobalHalf128 = GlobalTensor, + Stride<1, 1, 1, kHeadDim, 1>>; + + VecHalf128 qHalfTile; + VecFloat128 qFloatTile; + VecHalf128 kHalfTile; + VecFloat128 kFloatTile; + VecFloat128 qkProductTile; + VecFloat8 scoreTile; + VecFloat128 reduceTmpTile; + VecHalf128 vHalfTile; + VecFloat128 vFloatTile; + VecFloat128 weightedTile; + VecFloat8 scalarMathTile; + TASSIGN(qHalfTile, 0x0800); + TASSIGN(qFloatTile, 0x1000); + TASSIGN(kHalfTile, 0x1800); + TASSIGN(kFloatTile, 0x2000); + TASSIGN(qkProductTile, 0x2800); + TASSIGN(scoreTile, 0x3000); + TASSIGN(reduceTmpTile, 0x3800); + TASSIGN(vHalfTile, 0x4000); + TASSIGN(vFloatTile, 0x4800); + TASSIGN(weightedTile, 0x5000); + TASSIGN(scalarMathTile, 0x5800); + + for (int64_t process = workerIdx; process < processNum; + process += workerNum) { + int32_t curBatchSlot = + static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); + int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; + const int32_t sortedBatch = + LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; + } + + const int32_t kvSeqLenAlign = + ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = + (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const int32_t curSplit = static_cast(process % ctx.kvSplitCoreNum); + if (curSplit >= kvLoop) { + continue; + } + + const int32_t curHeadBlock = + static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); + const int32_t startHead = curHeadBlock * formerHeadSplit; + int32_t curHeadNum = formerHeadSplit; + if (curHeadBlock == corePerBatch - 1) { + curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; + } + const int32_t startKv = curSplit * ctx.kvSplitPerCore; + int32_t curKvSeqLen = ctx.kvSplitPerCore; + if (curSplit == kvLoop - 1) { + curKvSeqLen = kvSeqLen - startKv; + } + + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + const int32_t headBegin = subBlockId == 0 ? 0 : curHeadNum / 2; + const int32_t headEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; + for (int32_t headLocal = headBegin; headLocal < headEnd; ++headLocal) { + const int32_t head = startHead + headLocal; + const int32_t kvHead = head / headsPerKv; + const int64_t qBase = + (static_cast(batchIndex) * ctx.numHeads + head) * + ctx.headDim; + GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qHalfTile, qGlobal); + pipe_barrier(PIPE_ALL); + TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_V); + for (int32_t relPos = 0; relPos < curKvSeqLen; ++relPos) { + const int32_t pos = startKv + relPos; + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, + ctx.blockSize, blockId, offsetInBlock); + const int64_t kvOffset = + (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * + ctx.kvHeads + + kvHead) * + ctx.headDim); + + GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); + TLOAD(kHalfTile, kGlobal); + pipe_barrier(PIPE_ALL); + TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + TMUL(qkProductTile, qFloatTile, kFloatTile); + pipe_barrier(PIPE_V); + TROWSUM(scoreTile, qkProductTile, reduceTmpTile); + pipe_barrier(PIPE_V); + const float score = scoreTile.data()[0] * ctx.scale; + const float newMax = score > maxScore ? score : maxScore; + float oldScale = 0.0f; + if (relPos != 0) { + scalarMathTile.data()[0] = maxScore - newMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + oldScale = scalarMathTile.data()[0]; + } + scalarMathTile.data()[0] = score - newMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + const float probUnnorm = scalarMathTile.data()[0]; + sumExp = sumExp * oldScale + probUnnorm; + + GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); + TMULS(weightedTile, weightedTile, oldScale); + pipe_barrier(PIPE_V); + TLOAD(vHalfTile, vGlobal); pipe_barrier(PIPE_ALL); - return; + TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_V); + TAXPY(weightedTile, vFloatTile, probUnnorm); + pipe_barrier(PIPE_V); + maxScore = newMax; + } + + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const uint64_t outOffset = + oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(curSplit) * ctx.headDim; + const uint64_t lOffset = + lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; + float logSumExp = -3.4028234663852886e38f; + if (sumExp > 0.0f) { + scalarMathTile.data()[0] = sumExp; + TLOG(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + logSumExp = scalarMathTile.data()[0]; + } + partialL[lOffset] = maxScore + logSumExp; + for (int32_t dim = 0; dim < kHeadDim; ++dim) { + partialOut[outOffset + dim] = weightedTile.data()[dim] * invSum; + } + } + } + + DdrFenceBeforePtoAivReduce(); + ffts_cross_core_sync(PIPE_MTE3, + PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); + wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); + + const int32_t effectiveBatch = + ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + const int64_t combineWorkerIdx = + workerIdx * 2 + static_cast(subBlockId); + const int64_t combineWorkerNum = workerNum * 2; + for (int64_t row = combineWorkerIdx; row < totalRows; + row += combineWorkerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; + const int32_t sortedBatch = + LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); + paraBase = ctx.headSize + sortedBatch * ctx.paraSize; + const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + if (kvSeqLen <= 0) { + continue; + } + + const int32_t kvSeqLenAlign = + ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; + const int32_t kvLoop = + (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; + const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); + const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); + float lMax = -3.4028234663852886e38f; + for (int32_t split = 0; split < kvLoop; ++split) { + const float lValue = + partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + + split]; + lMax = lValue > lMax ? lValue : lMax; + } + float denom = 0.0f; + float splitScale[64]; + for (int32_t split = 0; split < kvLoop; ++split) { + const float lValue = + partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + + split]; + scalarMathTile.data()[0] = lValue - lMax; + TEXP(scalarMathTile, scalarMathTile); + pipe_barrier(PIPE_V); + const float scale = scalarMathTile.data()[0]; + splitScale[split] = scale; + denom += scale; } + const float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; + for (int32_t split = 0; split < kvLoop; ++split) { + splitScale[split] *= invDenom; + } + const int64_t outBase = + (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < kHeadDim; ++dim) { + float value = 0.0f; + for (int32_t split = 0; split < kvLoop; ++split) { + const uint64_t outOffset = + oFdBase * ctx.kvSplitCoreNum + + static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + + static_cast(split) * ctx.headDim; + value += partialOut[outOffset + dim] * splitScale[split]; + } + StoreOutputFp16(oGm, outBase + dim, value); + } + } + pipe_barrier(PIPE_ALL); +} +#endif + +AICORE inline void RunPtoPagedAttentionDecode( + __gm__ uint8_t *qGm, __gm__ uint8_t *kGm, __gm__ uint8_t *vGm, + __gm__ uint8_t *blockTablesGm, __gm__ uint8_t *oGm, + __gm__ uint8_t *tilingParaGm, int64_t workerIdx, int64_t workerNum) { + constexpr int32_t kMaxHeadDim = 128; + PtoPaInitCoreState(); + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + const int32_t effectiveBatch = + ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; + const int32_t maxBlocksPerQuery = + ctx.maxBlocksPerQuery > 0 + ? ctx.maxBlocksPerQuery + : (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; + const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; + const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; + if (workerIdx < 0 || workerNum <= 0) { + pipe_barrier(PIPE_ALL); + return; + } - const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : - (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; - const int32_t corePerBatch = (ctx.numHeads + formerHeadSplit - 1) / formerHeadSplit; - const int64_t processNum = static_cast(ctx.formerBatch) * corePerBatch * ctx.kvSplitCoreNum; - __gm__ float *partialOut = reinterpret_cast<__gm__ float *>(oCoreTmpGm); - __gm__ float *partialL = reinterpret_cast<__gm__ float *>(lGm); - - using VecHalf128 = Tile; - using VecFloat128 = Tile; + if (ctx.headDim > kMaxHeadDim) { + pipe_barrier(PIPE_ALL); + return; + } + + using DecodeScalarTile = + Tile; + DecodeScalarTile decodeScalarMathTile; + TASSIGN(decodeScalarMathTile, 0x5800); + + if (ctx.headDim == kMaxHeadDim) { + constexpr uint64_t kWeightedUb = 0x0000; + constexpr uint64_t kQHalfUb = 0x0800; + constexpr uint64_t kQFloatUb = 0x1000; + constexpr uint64_t kKHalfUb = 0x1800; + constexpr uint64_t kKFloatUb = 0x2000; + constexpr uint64_t kQKProductUb = 0x2800; + constexpr uint64_t kScoreUb = 0x3000; + constexpr uint64_t kReduceTmpUb = 0x3800; + constexpr uint64_t kVHalfUb = 0x4000; + constexpr uint64_t kVFloatUb = 0x4800; + constexpr uint64_t kOutHalfUb = 0x5000; + + using VecHalf128 = Tile; + using VecFloat128 = Tile; using VecFloat8 = Tile; - using GlobalHalf128 = - GlobalTensor, Stride<1, 1, 1, kHeadDim, 1>>; + using GlobalHalf128 = GlobalTensor, + Stride<1, 1, 1, kMaxHeadDim, 1>>; + VecFloat128 weightedTile; VecHalf128 qHalfTile; VecFloat128 qFloatTile; VecHalf128 kHalfTile; @@ -1174,385 +1560,163 @@ AICORE inline void RunPtoPagedAttentionDecodeSplitKV( VecFloat128 reduceTmpTile; VecHalf128 vHalfTile; VecFloat128 vFloatTile; - VecFloat128 weightedTile; - VecFloat8 scalarMathTile; - TASSIGN(qHalfTile, 0x0800); - TASSIGN(qFloatTile, 0x1000); - TASSIGN(kHalfTile, 0x1800); - TASSIGN(kFloatTile, 0x2000); - TASSIGN(qkProductTile, 0x2800); - TASSIGN(scoreTile, 0x3000); - TASSIGN(reduceTmpTile, 0x3800); - TASSIGN(vHalfTile, 0x4000); - TASSIGN(vFloatTile, 0x4800); - TASSIGN(weightedTile, 0x5000); - TASSIGN(scalarMathTile, 0x5800); - - for (int64_t process = workerIdx; process < processNum; process += workerNum) { - int32_t curBatchSlot = static_cast(process / (corePerBatch * ctx.kvSplitCoreNum)); - int32_t paraBase = ctx.headSize + curBatchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); - paraBase = ctx.headSize + sortedBatch * ctx.paraSize; - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - if (kvSeqLen <= 0) { - continue; - } - - const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; - const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; - const int32_t curSplit = static_cast(process % ctx.kvSplitCoreNum); - if (curSplit >= kvLoop) { - continue; - } - - const int32_t curHeadBlock = static_cast((process / ctx.kvSplitCoreNum) % corePerBatch); - const int32_t startHead = curHeadBlock * formerHeadSplit; - int32_t curHeadNum = formerHeadSplit; - if (curHeadBlock == corePerBatch - 1) { - curHeadNum = ctx.numHeads - curHeadBlock * formerHeadSplit; - } - const int32_t startKv = curSplit * ctx.kvSplitPerCore; - int32_t curKvSeqLen = ctx.kvSplitPerCore; - if (curSplit == kvLoop - 1) { - curKvSeqLen = kvSeqLen - startKv; - } + VecHalf128 outHalfTile; - const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); - const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); - const int32_t headBegin = subBlockId == 0 ? 0 : curHeadNum / 2; - const int32_t headEnd = subBlockId == 0 ? curHeadNum / 2 : curHeadNum; - for (int32_t headLocal = headBegin; headLocal < headEnd; ++headLocal) { - const int32_t head = startHead + headLocal; - const int32_t kvHead = head / headsPerKv; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); - TLOAD(qHalfTile, qGlobal); - pipe_barrier(PIPE_ALL); - TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_V); + TASSIGN(weightedTile, kWeightedUb); + TASSIGN(qHalfTile, kQHalfUb); + TASSIGN(qFloatTile, kQFloatUb); + TASSIGN(kHalfTile, kKHalfUb); + TASSIGN(kFloatTile, kKFloatUb); + TASSIGN(qkProductTile, kQKProductUb); + TASSIGN(scoreTile, kScoreUb); + TASSIGN(reduceTmpTile, kReduceTmpUb); + TASSIGN(vHalfTile, kVHalfUb); + TASSIGN(vFloatTile, kVFloatUb); + TASSIGN(outHalfTile, kOutHalfUb); - float maxScore = -3.4028234663852886e38f; - float sumExp = 0.0f; - TEXPANDS(weightedTile, 0.0f); - pipe_barrier(PIPE_V); - for (int32_t relPos = 0; relPos < curKvSeqLen; ++relPos) { - const int32_t pos = startKv + relPos; - int32_t blockId = 0; - int32_t offsetInBlock = 0; - ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, - offsetInBlock); - const int64_t kvOffset = (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * - ctx.kvHeads + kvHead) * ctx.headDim); - - GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); - TLOAD(kHalfTile, kGlobal); - pipe_barrier(PIPE_ALL); - TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_V); - TMUL(qkProductTile, qFloatTile, kFloatTile); - pipe_barrier(PIPE_V); - TROWSUM(scoreTile, qkProductTile, reduceTmpTile); - pipe_barrier(PIPE_V); - const float score = scoreTile.data()[0] * ctx.scale; - const float newMax = score > maxScore ? score : maxScore; - float oldScale = 0.0f; - if (relPos != 0) { - scalarMathTile.data()[0] = maxScore - newMax; - TEXP(scalarMathTile, scalarMathTile); - pipe_barrier(PIPE_V); - oldScale = scalarMathTile.data()[0]; - } - scalarMathTile.data()[0] = score - newMax; - TEXP(scalarMathTile, scalarMathTile); - pipe_barrier(PIPE_V); - const float probUnnorm = scalarMathTile.data()[0]; - sumExp = sumExp * oldScale + probUnnorm; - - GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); - TMULS(weightedTile, weightedTile, oldScale); - pipe_barrier(PIPE_V); - TLOAD(vHalfTile, vGlobal); - pipe_barrier(PIPE_ALL); - TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_V); - TAXPY(weightedTile, vFloatTile, probUnnorm); - pipe_barrier(PIPE_V); - maxScore = newMax; - } - - const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; - const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + - static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + - static_cast(curSplit) * ctx.headDim; - const uint64_t lOffset = lBase + static_cast(head) * ctx.kvSplitCoreNum + curSplit; - float logSumExp = -3.4028234663852886e38f; - if (sumExp > 0.0f) { - scalarMathTile.data()[0] = sumExp; - TLOG(scalarMathTile, scalarMathTile); - pipe_barrier(PIPE_V); - logSumExp = scalarMathTile.data()[0]; - } - partialL[lOffset] = maxScore + logSumExp; - for (int32_t dim = 0; dim < kHeadDim; ++dim) { - partialOut[outOffset + dim] = weightedTile.data()[dim] * invSum; - } - } + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = + ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = + LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t kvHead = head / headsPerKv; + const int64_t qBase = + (static_cast(batchIndex) * ctx.numHeads + head) * + ctx.headDim; + GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); + TLOAD(qHalfTile, qGlobal); + pipe_barrier(PIPE_ALL); + TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + TEXPANDS(weightedTile, 0.0f); + pipe_barrier(PIPE_ALL); + + for (int32_t pos = 0; pos < kvSeqLen; ++pos) { + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, + ctx.blockSize, blockId, offsetInBlock); + const int64_t kvOffset = + (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * + ctx.kvHeads + + kvHead) * + ctx.headDim); + + GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); + TLOAD(kHalfTile, kGlobal); + pipe_barrier(PIPE_ALL); + TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + TMUL(qkProductTile, qFloatTile, kFloatTile); + pipe_barrier(PIPE_ALL); + TROWSUM(scoreTile, qkProductTile, reduceTmpTile); + pipe_barrier(PIPE_ALL); + const float rawScore = scoreTile.data()[0]; + const float score = rawScore * ctx.scale; + + const float newMax = score > maxScore ? score : maxScore; + const float oldScale = + (pos == 0) ? 0.0f + : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); + const float probUnnorm = + PtoExpScalar(decodeScalarMathTile, score - newMax); + sumExp = sumExp * oldScale + probUnnorm; + + GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); + TMULS(weightedTile, weightedTile, oldScale); + pipe_barrier(PIPE_ALL); + TLOAD(vHalfTile, vGlobal); + pipe_barrier(PIPE_ALL); + TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); + pipe_barrier(PIPE_ALL); + TAXPY(weightedTile, vFloatTile, probUnnorm); + pipe_barrier(PIPE_ALL); + maxScore = newMax; + } + + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const int64_t outBase = + (static_cast(batchIndex) * ctx.numHeads + head) * + ctx.headDim; + GlobalHalf128 outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); + TMULS(weightedTile, weightedTile, invSum); + pipe_barrier(PIPE_ALL); + TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + pipe_barrier(PIPE_ALL); + TSTORE(outGlobal, outHalfTile); + pipe_barrier(PIPE_ALL); } - DdrFenceBeforePtoAivReduce(); - ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); - wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); - - const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; - const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - const int64_t combineWorkerIdx = workerIdx * 2 + static_cast(subBlockId); - const int64_t combineWorkerNum = workerNum * 2; - for (int64_t row = combineWorkerIdx; row < totalRows; row += combineWorkerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - int32_t paraBase = ctx.headSize + batchSlot * ctx.paraSize; - const int32_t sortedBatch = LoadTilingI32(tilingParaGm, paraBase + kParaBatchIndex); - paraBase = ctx.headSize + sortedBatch * ctx.paraSize; - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + 8); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - if (kvSeqLen <= 0) { - continue; - } - - const int32_t kvSeqLenAlign = ((kvSeqLen + ctx.blockSize - 1) / ctx.blockSize) * ctx.blockSize; - const int32_t kvLoop = (kvSeqLenAlign + ctx.kvSplitPerCore - 1) / ctx.kvSplitPerCore; - const uint64_t lBase = LoadTilingOffset64(tilingParaGm, paraBase, 11, 12); - const uint64_t oFdBase = LoadTilingOffset64(tilingParaGm, paraBase, 15, 16); - float lMax = -3.4028234663852886e38f; - for (int32_t split = 0; split < kvLoop; ++split) { - const float lValue = partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + split]; - lMax = lValue > lMax ? lValue : lMax; - } - float denom = 0.0f; - float splitScale[64]; - for (int32_t split = 0; split < kvLoop; ++split) { - const float lValue = partialL[lBase + static_cast(head) * ctx.kvSplitCoreNum + split]; - scalarMathTile.data()[0] = lValue - lMax; - TEXP(scalarMathTile, scalarMathTile); - pipe_barrier(PIPE_V); - const float scale = scalarMathTile.data()[0]; - splitScale[split] = scale; - denom += scale; - } - const float invDenom = denom > 0.0f ? 1.0f / denom : 0.0f; - for (int32_t split = 0; split < kvLoop; ++split) { - splitScale[split] *= invDenom; - } - const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - for (int32_t dim = 0; dim < kHeadDim; ++dim) { - float value = 0.0f; - for (int32_t split = 0; split < kvLoop; ++split) { - const uint64_t outOffset = oFdBase * ctx.kvSplitCoreNum + - static_cast(head) * ctx.headDim * ctx.kvSplitCoreNum + - static_cast(split) * ctx.headDim; - value += partialOut[outOffset + dim] * splitScale[split]; - } - StoreOutputFp16(oGm, outBase + dim, value); - } - } pipe_barrier(PIPE_ALL); -} -#endif - -AICORE inline void RunPtoPagedAttentionDecode( - __gm__ uint8_t *qGm, - __gm__ uint8_t *kGm, - __gm__ uint8_t *vGm, - __gm__ uint8_t *blockTablesGm, - __gm__ uint8_t *oGm, - __gm__ uint8_t *tilingParaGm, - int64_t workerIdx, - int64_t workerNum) -{ - constexpr int32_t kMaxHeadDim = 128; - PtoPaInitCoreState(); - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; - const int32_t maxBlocksPerQuery = ctx.maxBlocksPerQuery > 0 ? ctx.maxBlocksPerQuery : - (ctx.maxKvSeqLen + ctx.blockSize - 1) / ctx.blockSize; - const int32_t headsPerKv = ctx.numHeads / ctx.kvHeads; - const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - if (workerIdx < 0 || workerNum <= 0) { - pipe_barrier(PIPE_ALL); - return; + return; + } + + for (int64_t row = workerIdx; row < totalRows; row += workerNum) { + const int32_t head = static_cast(row % ctx.numHeads); + const int32_t batchSlot = static_cast(row / ctx.numHeads); + const int32_t paraBase = + ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); + const int32_t kvSeqLen = + LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); + const int32_t batchIndex = + LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); + const int32_t kvHead = head / headsPerKv; + + float qValues[kMaxHeadDim]; + const int64_t qBase = + (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + qValues[dim] = LoadFp16(qGm, qBase + dim); } - if (ctx.headDim > kMaxHeadDim) { - pipe_barrier(PIPE_ALL); - return; + float maxScore = -3.4028234663852886e38f; + float sumExp = 0.0f; + float weighted[kMaxHeadDim]; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + weighted[dim] = 0.0f; } - using DecodeScalarTile = Tile; - DecodeScalarTile decodeScalarMathTile; - TASSIGN(decodeScalarMathTile, 0x5800); - - if (ctx.headDim == kMaxHeadDim) { - constexpr uint64_t kWeightedUb = 0x0000; - constexpr uint64_t kQHalfUb = 0x0800; - constexpr uint64_t kQFloatUb = 0x1000; - constexpr uint64_t kKHalfUb = 0x1800; - constexpr uint64_t kKFloatUb = 0x2000; - constexpr uint64_t kQKProductUb = 0x2800; - constexpr uint64_t kScoreUb = 0x3000; - constexpr uint64_t kReduceTmpUb = 0x3800; - constexpr uint64_t kVHalfUb = 0x4000; - constexpr uint64_t kVFloatUb = 0x4800; - constexpr uint64_t kOutHalfUb = 0x5000; - - using VecHalf128 = Tile; - using VecFloat128 = Tile; - using VecFloat8 = Tile; - using GlobalHalf128 = - GlobalTensor, Stride<1, 1, 1, kMaxHeadDim, 1>>; - - VecFloat128 weightedTile; - VecHalf128 qHalfTile; - VecFloat128 qFloatTile; - VecHalf128 kHalfTile; - VecFloat128 kFloatTile; - VecFloat128 qkProductTile; - VecFloat8 scoreTile; - VecFloat128 reduceTmpTile; - VecHalf128 vHalfTile; - VecFloat128 vFloatTile; - VecHalf128 outHalfTile; - - TASSIGN(weightedTile, kWeightedUb); - TASSIGN(qHalfTile, kQHalfUb); - TASSIGN(qFloatTile, kQFloatUb); - TASSIGN(kHalfTile, kKHalfUb); - TASSIGN(kFloatTile, kKFloatUb); - TASSIGN(qkProductTile, kQKProductUb); - TASSIGN(scoreTile, kScoreUb); - TASSIGN(reduceTmpTile, kReduceTmpUb); - TASSIGN(vHalfTile, kVHalfUb); - TASSIGN(vFloatTile, kVFloatUb); - TASSIGN(outHalfTile, kOutHalfUb); - - for (int64_t row = workerIdx; row < totalRows; row += workerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); - const int32_t kvHead = head / headsPerKv; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - GlobalHalf128 qGlobal(reinterpret_cast<__gm__ half *>(qGm) + qBase); - TLOAD(qHalfTile, qGlobal); - pipe_barrier(PIPE_ALL); - TCVT(qFloatTile, qHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_ALL); - - float maxScore = -3.4028234663852886e38f; - float sumExp = 0.0f; - TEXPANDS(weightedTile, 0.0f); - pipe_barrier(PIPE_ALL); - - for (int32_t pos = 0; pos < kvSeqLen; ++pos) { - int32_t blockId = 0; - int32_t offsetInBlock = 0; - ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, - offsetInBlock); - const int64_t kvOffset = (((static_cast(blockId) * ctx.blockSize + offsetInBlock) * - ctx.kvHeads + kvHead) * ctx.headDim); - - GlobalHalf128 kGlobal(reinterpret_cast<__gm__ half *>(kGm) + kvOffset); - TLOAD(kHalfTile, kGlobal); - pipe_barrier(PIPE_ALL); - TCVT(kFloatTile, kHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_ALL); - TMUL(qkProductTile, qFloatTile, kFloatTile); - pipe_barrier(PIPE_ALL); - TROWSUM(scoreTile, qkProductTile, reduceTmpTile); - pipe_barrier(PIPE_ALL); - const float rawScore = scoreTile.data()[0]; - const float score = rawScore * ctx.scale; - - const float newMax = score > maxScore ? score : maxScore; - const float oldScale = (pos == 0) ? 0.0f : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); - const float probUnnorm = PtoExpScalar(decodeScalarMathTile, score - newMax); - sumExp = sumExp * oldScale + probUnnorm; - - GlobalHalf128 vGlobal(reinterpret_cast<__gm__ half *>(vGm) + kvOffset); - TMULS(weightedTile, weightedTile, oldScale); - pipe_barrier(PIPE_ALL); - TLOAD(vHalfTile, vGlobal); - pipe_barrier(PIPE_ALL); - TCVT(vFloatTile, vHalfTile, RoundMode::CAST_NONE); - pipe_barrier(PIPE_ALL); - TAXPY(weightedTile, vFloatTile, probUnnorm); - pipe_barrier(PIPE_ALL); - maxScore = newMax; - } - - const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; - const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - GlobalHalf128 outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); - TMULS(weightedTile, weightedTile, invSum); - pipe_barrier(PIPE_ALL); - TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); - pipe_barrier(PIPE_ALL); - TSTORE(outGlobal, outHalfTile); - pipe_barrier(PIPE_ALL); - } - - pipe_barrier(PIPE_ALL); - return; + for (int32_t pos = 0; pos < kvSeqLen; ++pos) { + int32_t blockId = 0; + int32_t offsetInBlock = 0; + ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, + ctx.blockSize, blockId, offsetInBlock); + const float score = ComputeScoreByBlock( + qValues, kGm, blockId, offsetInBlock, ctx.blockSize, kvHead, + ctx.headDim, ctx.kvHeads, ctx.scale); + const bool updateMax = score > maxScore; + const float newMax = updateMax ? score : maxScore; + const float oldScale = + (pos == 0) ? 0.0f + : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); + const float probUnnorm = + PtoExpScalar(decodeScalarMathTile, score - newMax); + sumExp = sumExp * oldScale + probUnnorm; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + const float value = + LoadPagedVByBlock(vGm, blockId, offsetInBlock, ctx.blockSize, + ctx.kvHeads, kvHead, ctx.headDim, dim); + weighted[dim] = weighted[dim] * oldScale + probUnnorm * value; + } + maxScore = newMax; } - for (int64_t row = workerIdx; row < totalRows; row += workerNum) { - const int32_t head = static_cast(row % ctx.numHeads); - const int32_t batchSlot = static_cast(row / ctx.numHeads); - const int32_t paraBase = ResolveSortedParaBase(tilingParaGm, ctx, batchSlot); - const int32_t kvSeqLen = LoadTilingI32(tilingParaGm, paraBase + kParaKvSeqLen); - const int32_t batchIndex = LoadTilingI32(tilingParaGm, paraBase + kParaRealBatchIndex); - const int32_t kvHead = head / headsPerKv; - - float qValues[kMaxHeadDim]; - const int64_t qBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - for (int32_t dim = 0; dim < ctx.headDim; ++dim) { - qValues[dim] = LoadFp16(qGm, qBase + dim); - } - - float maxScore = -3.4028234663852886e38f; - float sumExp = 0.0f; - float weighted[kMaxHeadDim]; - for (int32_t dim = 0; dim < ctx.headDim; ++dim) { - weighted[dim] = 0.0f; - } - - for (int32_t pos = 0; pos < kvSeqLen; ++pos) { - int32_t blockId = 0; - int32_t offsetInBlock = 0; - ResolvePagedPosition(blockTablesGm, batchIndex, maxBlocksPerQuery, pos, ctx.blockSize, blockId, offsetInBlock); - const float score = ComputeScoreByBlock(qValues, kGm, blockId, offsetInBlock, ctx.blockSize, kvHead, - ctx.headDim, ctx.kvHeads, ctx.scale); - const bool updateMax = score > maxScore; - const float newMax = updateMax ? score : maxScore; - const float oldScale = (pos == 0) ? 0.0f : PtoExpScalar(decodeScalarMathTile, maxScore - newMax); - const float probUnnorm = PtoExpScalar(decodeScalarMathTile, score - newMax); - sumExp = sumExp * oldScale + probUnnorm; - for (int32_t dim = 0; dim < ctx.headDim; ++dim) { - const float value = LoadPagedVByBlock(vGm, blockId, offsetInBlock, ctx.blockSize, ctx.kvHeads, kvHead, ctx.headDim, dim); - weighted[dim] = weighted[dim] * oldScale + probUnnorm * value; - } - maxScore = newMax; - } - - const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; - const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - for (int32_t dim = 0; dim < ctx.headDim; ++dim) { - StoreOutputFp16(oGm, outBase + dim, weighted[dim] * invSum); - } + const float invSum = sumExp > 0.0f ? 1.0f / sumExp : 0.0f; + const int64_t outBase = + (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; + for (int32_t dim = 0; dim < ctx.headDim; ++dim) { + StoreOutputFp16(oGm, outBase + dim, weighted[dim] * invSum); } + } - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_ALL); } #endif diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py index 884f4e2e..726d863e 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py @@ -266,6 +266,7 @@ def _split_core_bns_nd( ) +# pylint: disable-next=too-many-positional-arguments def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 batch: int, kv_seq_lens: list[int], @@ -340,14 +341,15 @@ def make_pa_nd_decode_tiling( # noqa: PLR0913, PLR0915 is_long_seq, ) - if ( + supports_head_move = ( head_dim % 16 == 0 and head_dim <= EMBEDDING_LIMIT and head_dim_v % 16 == 0 and head_dim_v <= EMBEDDING_LIMIT and kv_real == num_heads and not is_quant - ): + ) + if supports_head_move: head_num_move = 2 else: head_num_move = 1 From 8b5f6a082f58adba8d10cf65133c1c30555e090f Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 13:03:52 +0200 Subject: [PATCH 06/11] fix clang lint --- .../paged_attention_highperf/pa_entry.hpp | 177 +++++++++--------- 1 file changed, 84 insertions(+), 93 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp index f6e4f40f..9623a2a3 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_entry.hpp @@ -1,11 +1,13 @@ /** Copyright (c) 2026 Huawei Technologies Co., Ltd. -This program is free software, you can redistribute it and/or modify it under the terms and conditions of -CANN Open Software License Agreement Version 2.0 (the "License"). -Please refer to the License for details. You may not use this file except in compliance with the License. -THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -See LICENSE in the root of the software repository for the full text of the License. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN "AS +IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A +PARTICULAR PURPOSE. See LICENSE in the root of the software repository for the +full text of the License. */ #ifndef PTO_PAGED_ATTENTION_HIGHPERF_ENTRY_HPP @@ -14,12 +16,9 @@ See LICENSE in the root of the software repository for the full text of the Lice #include "pa_kernel_impl.hpp" static AICORE __attribute__((always_inline)) void paged_attention_mask_body( - __gm__ uint8_t *__restrict__ sync, - uint32_t ptoBlockIdx, - uint32_t ptoBlockNum, - uint32_t ptoSubBlockId, - __gm__ uint8_t *__restrict__ qGm, - __gm__ uint8_t *__restrict__ kGm, + __gm__ uint8_t *__restrict__ sync, uint32_t ptoBlockIdx, + uint32_t ptoBlockNum, uint32_t ptoSubBlockId, + __gm__ uint8_t *__restrict__ qGm, __gm__ uint8_t *__restrict__ kGm, __gm__ uint8_t *__restrict__ vGm, __gm__ uint8_t *__restrict__ blockTablesGm, __gm__ uint8_t *__restrict__ maskGm, @@ -28,75 +27,73 @@ static AICORE __attribute__((always_inline)) void paged_attention_mask_body( __gm__ uint8_t *__restrict__ deqScale2Gm, __gm__ uint8_t *__restrict__ offset2Gm, __gm__ uint8_t *__restrict__ razorOffset, - __gm__ uint8_t *__restrict__ scaleGm, - __gm__ uint8_t *__restrict__ logNGm, - __gm__ uint8_t *__restrict__ eyeGm, - __gm__ uint8_t *__restrict__ oGm, - __gm__ uint8_t *__restrict__ sGm, - __gm__ uint8_t *__restrict__ pGm, - __gm__ uint8_t *__restrict__ oTmpGm, - __gm__ uint8_t *__restrict__ goGm, - __gm__ uint8_t *__restrict__ oCoreTmpGm, - __gm__ uint8_t *__restrict__ lGm, - __gm__ uint8_t *__restrict__ gmK16, - __gm__ uint8_t *__restrict__ gmV16, - __gm__ uint8_t *__restrict__ tilingParaGm) -{ - (void)maskGm; - (void)deqScale1Gm; - (void)offset1Gm; - (void)deqScale2Gm; - (void)offset2Gm; - (void)razorOffset; - (void)scaleGm; - (void)logNGm; - (void)eyeGm; - (void)sGm; - (void)pGm; - (void)oTmpGm; - (void)goGm; - (void)gmK16; - (void)gmV16; + __gm__ uint8_t *__restrict__ scaleGm, __gm__ uint8_t *__restrict__ logNGm, + __gm__ uint8_t *__restrict__ eyeGm, __gm__ uint8_t *__restrict__ oGm, + __gm__ uint8_t *__restrict__ sGm, __gm__ uint8_t *__restrict__ pGm, + __gm__ uint8_t *__restrict__ oTmpGm, __gm__ uint8_t *__restrict__ goGm, + __gm__ uint8_t *__restrict__ oCoreTmpGm, __gm__ uint8_t *__restrict__ lGm, + __gm__ uint8_t *__restrict__ gmK16, __gm__ uint8_t *__restrict__ gmV16, + __gm__ uint8_t *__restrict__ tilingParaGm) { + (void)maskGm; + (void)deqScale1Gm; + (void)offset1Gm; + (void)deqScale2Gm; + (void)offset2Gm; + (void)razorOffset; + (void)scaleGm; + (void)logNGm; + (void)eyeGm; + (void)sGm; + (void)pGm; + (void)oTmpGm; + (void)goGm; + (void)gmK16; + (void)gmV16; - if (sync != nullptr) { - set_ffts_base_addr(reinterpret_cast(sync)); - } - set_atomic_none(); - set_mask_norm(); + if (sync != nullptr) { + set_ffts_base_addr(reinterpret_cast(sync)); + } + set_atomic_none(); + set_mask_norm(); #ifdef __DAV_C220_CUBE__ - const int64_t workerIdx = static_cast(ptoBlockIdx); - const int64_t workerNum = static_cast(ptoBlockNum); - if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { - RunPtoPagedAttentionCubePipelineSplitKV(qGm, kGm, vGm, blockTablesGm, sGm, pGm, oTmpGm, tilingParaGm, - workerIdx, workerNum); - } else { - pipe_barrier(PIPE_ALL); - } + const int64_t workerIdx = static_cast(ptoBlockIdx); + const int64_t workerNum = static_cast(ptoBlockNum); + if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + RunPtoPagedAttentionCubePipelineSplitKV(qGm, kGm, vGm, blockTablesGm, sGm, + pGm, oTmpGm, tilingParaGm, + workerIdx, workerNum); + } else { + pipe_barrier(PIPE_ALL); + } #elif defined(__DAV_C220_VEC__) - const int64_t workerIdx = static_cast(ptoBlockIdx) * 2 + static_cast(ptoSubBlockId); - const int64_t workerNum = static_cast(ptoBlockNum) * 2; - const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); - if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { - RunPtoPagedAttentionVecPipelineSplitKV(oGm, sGm, pGm, oTmpGm, oCoreTmpGm, lGm, tilingParaGm, - static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); - } else if (ctx.kvSplitCoreNum > 1) { - RunPtoPagedAttentionDecodeSplitKV(qGm, kGm, vGm, blockTablesGm, oGm, oCoreTmpGm, lGm, tilingParaGm, - static_cast(ptoBlockIdx), static_cast(ptoBlockNum), ptoSubBlockId); - } else { - RunPtoPagedAttentionDecode(qGm, kGm, vGm, blockTablesGm, oGm, tilingParaGm, workerIdx, workerNum); - } + const int64_t workerIdx = static_cast(ptoBlockIdx) * 2 + + static_cast(ptoSubBlockId); + const int64_t workerNum = static_cast(ptoBlockNum) * 2; + const PaTilingContext ctx = LoadPaTilingContext(tilingParaGm); + if (SupportsPtoPagedAttentionRawSplitKV(tilingParaGm)) { + RunPtoPagedAttentionVecPipelineSplitKV( + oGm, sGm, pGm, oTmpGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), + ptoSubBlockId); + } else if (ctx.kvSplitCoreNum > 1) { + RunPtoPagedAttentionDecodeSplitKV( + qGm, kGm, vGm, blockTablesGm, oGm, oCoreTmpGm, lGm, tilingParaGm, + static_cast(ptoBlockIdx), static_cast(ptoBlockNum), + ptoSubBlockId); + } else { + RunPtoPagedAttentionDecode(qGm, kGm, vGm, blockTablesGm, oGm, tilingParaGm, + workerIdx, workerNum); + } #else - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_ALL); #endif } #ifndef PTO_PA_NO_GLOBAL_ENTRY extern "C" __global__ AICORE void paged_attention_mask( - __gm__ uint8_t *__restrict__ sync, - __gm__ uint8_t *__restrict__ qGm, - __gm__ uint8_t *__restrict__ kGm, - __gm__ uint8_t *__restrict__ vGm, + __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ qGm, + __gm__ uint8_t *__restrict__ kGm, __gm__ uint8_t *__restrict__ vGm, __gm__ uint8_t *__restrict__ blockTablesGm, __gm__ uint8_t *__restrict__ maskGm, __gm__ uint8_t *__restrict__ deqScale1Gm, @@ -104,32 +101,26 @@ extern "C" __global__ AICORE void paged_attention_mask( __gm__ uint8_t *__restrict__ deqScale2Gm, __gm__ uint8_t *__restrict__ offset2Gm, __gm__ uint8_t *__restrict__ razorOffset, - __gm__ uint8_t *__restrict__ scaleGm, - __gm__ uint8_t *__restrict__ logNGm, - __gm__ uint8_t *__restrict__ eyeGm, - __gm__ uint8_t *__restrict__ oGm, - __gm__ uint8_t *__restrict__ sGm, - __gm__ uint8_t *__restrict__ pGm, - __gm__ uint8_t *__restrict__ oTmpGm, - __gm__ uint8_t *__restrict__ goGm, - __gm__ uint8_t *__restrict__ oCoreTmpGm, - __gm__ uint8_t *__restrict__ lGm, - __gm__ uint8_t *__restrict__ gmK16, - __gm__ uint8_t *__restrict__ gmV16, - __gm__ uint8_t *__restrict__ tilingParaGm) -{ - const uint32_t ptoBlockIdx = static_cast(get_block_idx()); - const uint32_t ptoBlockNum = static_cast(get_block_num()); + __gm__ uint8_t *__restrict__ scaleGm, __gm__ uint8_t *__restrict__ logNGm, + __gm__ uint8_t *__restrict__ eyeGm, __gm__ uint8_t *__restrict__ oGm, + __gm__ uint8_t *__restrict__ sGm, __gm__ uint8_t *__restrict__ pGm, + __gm__ uint8_t *__restrict__ oTmpGm, __gm__ uint8_t *__restrict__ goGm, + __gm__ uint8_t *__restrict__ oCoreTmpGm, __gm__ uint8_t *__restrict__ lGm, + __gm__ uint8_t *__restrict__ gmK16, __gm__ uint8_t *__restrict__ gmV16, + __gm__ uint8_t *__restrict__ tilingParaGm) { + const uint32_t ptoBlockIdx = static_cast(get_block_idx()); + const uint32_t ptoBlockNum = static_cast(get_block_num()); #ifdef __DAV_C220_VEC__ - const uint32_t ptoSubBlockId = static_cast(get_subblockid()); + const uint32_t ptoSubBlockId = static_cast(get_subblockid()); #else - const uint32_t ptoSubBlockId = 0; + const uint32_t ptoSubBlockId = 0; #endif - paged_attention_mask_body( - sync, ptoBlockIdx, ptoBlockNum, ptoSubBlockId, qGm, kGm, vGm, blockTablesGm, maskGm, deqScale1Gm, offset1Gm, - deqScale2Gm, offset2Gm, razorOffset, scaleGm, logNGm, eyeGm, oGm, sGm, pGm, oTmpGm, goGm, oCoreTmpGm, - lGm, gmK16, gmV16, tilingParaGm); + paged_attention_mask_body(sync, ptoBlockIdx, ptoBlockNum, ptoSubBlockId, qGm, + kGm, vGm, blockTablesGm, maskGm, deqScale1Gm, + offset1Gm, deqScale2Gm, offset2Gm, razorOffset, + scaleGm, logNGm, eyeGm, oGm, sGm, pGm, oTmpGm, goGm, + oCoreTmpGm, lGm, gmK16, gmV16, tilingParaGm); } #endif From 55899eba57e3370c2242bcf6493aaa3ca3f3fe46 Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 13:10:57 +0200 Subject: [PATCH 07/11] fix clang of remaining files --- .../paged_attention_highperf/pa_kernel.cpp | 78 +++++---------- .../pa_tiling_struct.hpp | 97 ++++++++++--------- 2 files changed, 76 insertions(+), 99 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp index 65b905ce..80d41a15 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel.cpp @@ -1,66 +1,34 @@ /** Copyright (c) 2026 Huawei Technologies Co., Ltd. -This program is free software, you can redistribute it and/or modify it under the terms and conditions of -CANN Open Software License Agreement Version 2.0 (the "License"). -Please refer to the License for details. You may not use this file except in compliance with the License. -THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -See LICENSE in the root of the software repository for the full text of the License. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN "AS +IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A +PARTICULAR PURPOSE. See LICENSE in the root of the software repository for the +full text of the License. */ #include - #include -#include "runtime/rt.h" #include "pa_entry.hpp" +#include "runtime/rt.h" -extern "C" void call_kernel( - void *stream, - uint8_t *qGm, - uint8_t *kGm, - uint8_t *vGm, - uint8_t *blockTablesGm, - uint8_t *oGm, - uint8_t *sGm, - uint8_t *pGm, - uint8_t *oTmpGm, - uint8_t *goGm, - uint8_t *oCoreTmpGm, - uint8_t *lGm, - uint8_t *gmK16, - uint8_t *gmV16, - uint8_t *tilingParaGm, - uint8_t *nullGm, - uint32_t blockDim) -{ - uint64_t ffts = 0; - uint32_t fftsLen = 0; - rtGetC2cCtrlAddr(&ffts, &fftsLen); +extern "C" void call_kernel(void *stream, uint8_t *qGm, uint8_t *kGm, + uint8_t *vGm, uint8_t *blockTablesGm, uint8_t *oGm, + uint8_t *sGm, uint8_t *pGm, uint8_t *oTmpGm, + uint8_t *goGm, uint8_t *oCoreTmpGm, uint8_t *lGm, + uint8_t *gmK16, uint8_t *gmV16, + uint8_t *tilingParaGm, uint8_t *nullGm, + uint32_t blockDim) { + uint64_t ffts = 0; + uint32_t fftsLen = 0; + rtGetC2cCtrlAddr(&ffts, &fftsLen); - paged_attention_mask<<>>( - reinterpret_cast<__gm__ uint8_t *>(ffts), - qGm, - kGm, - vGm, - blockTablesGm, - nullGm, - nullGm, - nullGm, - nullGm, - nullGm, - nullGm, - nullGm, - nullGm, - nullGm, - oGm, - sGm, - pGm, - oTmpGm, - goGm, - oCoreTmpGm, - lGm, - gmK16, - gmV16, - tilingParaGm); + paged_attention_mask<<>>( + reinterpret_cast<__gm__ uint8_t *>(ffts), qGm, kGm, vGm, blockTablesGm, + nullGm, nullGm, nullGm, nullGm, nullGm, nullGm, nullGm, nullGm, nullGm, + oGm, sGm, pGm, oTmpGm, goGm, oCoreTmpGm, lGm, gmK16, gmV16, tilingParaGm); } diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp b/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp index 4938fafb..eccbf56e 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling_struct.hpp @@ -1,11 +1,13 @@ /* * Copyright (c) PyPTO Contributors. - * This program is free software, you can redistribute it and/or modify it under the terms and conditions of - * CANN Open Software License Agreement Version 2.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. + * This program is free software, you can redistribute it and/or modify it under + * the terms and conditions of CANN Open Software License Agreement Version 2.0 + * (the "License"). Please refer to the License for details. You may not use + * this file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON + * AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. * ----------------------------------------------------------------------------------------------------------- */ #ifndef PAGED_HATTENTION_H @@ -39,56 +41,63 @@ constexpr int32_t PARALLEL_MAX_BATCH = 2000; constexpr int32_t WORKSPACE_BLOCK_SIZE_DB = 65536; // 128 * 256 * 2 enum class TilingKeyType { - TILING_HALF_DATA = 0, - TILING_BF16_DATA = 1, - TILING_INT8_DATA = 2, - TILING_INT8_CUBE_QUANT = 4, - TILING_INT8_VEC_QUANT = 8, - TILING_INT8_VEC_QUANTBF16 = 9, - TILING_QUANT_FP16OUT = 12, - TILING_QUANT_BF16OUT = 14 + TILING_HALF_DATA = 0, + TILING_BF16_DATA = 1, + TILING_INT8_DATA = 2, + TILING_INT8_CUBE_QUANT = 4, + TILING_INT8_VEC_QUANT = 8, + TILING_INT8_VEC_QUANTBF16 = 9, + TILING_QUANT_FP16OUT = 12, + TILING_QUANT_BF16OUT = 14 }; -enum class CalcType { CALC_TYPE_DEFAULT = 0, CALC_TYPE_MIX = 1, CALC_TYPE_PREFILL = 2 }; +enum class CalcType { + CALC_TYPE_DEFAULT = 0, + CALC_TYPE_MIX = 1, + CALC_TYPE_PREFILL = 2 +}; enum class DataShapeType { BSND = 0, BNSD = 1 }; -enum class CompressType { COMPRESS_TYPE_UNDEFINED = 0, COMPRESS_TYPE_KVHEAD = 1 }; +enum class CompressType { + COMPRESS_TYPE_UNDEFINED = 0, + COMPRESS_TYPE_KVHEAD = 1 +}; enum class PagedAttnVariant { DEFAULT = 0, MULTI_LATENT = 1 }; using PagedAttentionInfo = struct PagedAttentionTilingParams { - int32_t numTokens = 0; - int32_t numHeads = 0; - int32_t embeddingSize = 0; - int32_t embeddingSizeV = 0; - int32_t numBlocks = 0; - int32_t blockSize = 0; - int32_t maxNumBlocksPerQuery = 0; - float tor = 0; - int32_t kvHeads = 0; - int32_t maxPromptLen = 0; - int32_t batchStride = 0; - int32_t headStride = 0; - TilingKeyType type = TilingKeyType::TILING_HALF_DATA; - int32_t batch = 0; - int32_t isMaskSquare = 0; - int32_t *batchRunStatus{nullptr}; - int32_t *kvSeqLen{nullptr}; - int32_t modCoef{-1}; - int32_t divCoef{1}; - int32_t *qSeqLen{nullptr}; - int32_t qHeadOriginal = 0; - int32_t compressHead = 0; - int32_t tBlockAlign = 16; // L1 tile alignment: 16 for fp16, 32 for int8 - int32_t dataShapeType = 0; + int32_t numTokens = 0; + int32_t numHeads = 0; + int32_t embeddingSize = 0; + int32_t embeddingSizeV = 0; + int32_t numBlocks = 0; + int32_t blockSize = 0; + int32_t maxNumBlocksPerQuery = 0; + float tor = 0; + int32_t kvHeads = 0; + int32_t maxPromptLen = 0; + int32_t batchStride = 0; + int32_t headStride = 0; + TilingKeyType type = TilingKeyType::TILING_HALF_DATA; + int32_t batch = 0; + int32_t isMaskSquare = 0; + int32_t *batchRunStatus{nullptr}; + int32_t *kvSeqLen{nullptr}; + int32_t modCoef{-1}; + int32_t divCoef{1}; + int32_t *qSeqLen{nullptr}; + int32_t qHeadOriginal = 0; + int32_t compressHead = 0; + int32_t tBlockAlign = 16; // L1 tile alignment: 16 for fp16, 32 for int8 + int32_t dataShapeType = 0; }; using AddrOffsets = struct AddressOffsetInfo { - uint64_t addrQSeqOffset = 0; - uint64_t addrOSeqOffset = 0; - uint64_t addrOFdSeqOffset = 0; - uint64_t addrLSeqOffset = 0; + uint64_t addrQSeqOffset = 0; + uint64_t addrOSeqOffset = 0; + uint64_t addrOFdSeqOffset = 0; + uint64_t addrLSeqOffset = 0; }; } // namespace AtbOps From cff9da4d909974be009050de2a563c28b309706b Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 15:11:18 +0200 Subject: [PATCH 08/11] improved sync now 950 GB/s --- .../jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp | 6 ++++-- examples/jit_cpp/paged_attention_highperf/pa_tiling.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index b7e00b9a..df74cbe3 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -603,7 +603,8 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( TASSIGN(qMatTile, kQCacheBase + static_cast(headGroup) * kQGroupBytes); TLOAD(qMatTile, qGlobal); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); } } @@ -702,7 +703,8 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( probBase + static_cast(slot) * probSlotBytes + static_cast(headGroup) * probGroupBytes)); TLOAD(pMatTile, probGlobal); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); for (int32_t headInGroupBase = 0; headInGroupBase < kHeadGroup; headInGroupBase += headsPerKv) { const int32_t baseHeadLocal = groupHeadBase + headInGroupBase; diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py index 726d863e..433df17f 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py @@ -226,6 +226,10 @@ def _split_core_bns_nd( if is_long_seq: kv_block_per_core = _ceil_div(kv_seq_block_num, block_dim) + if block_size == KV_SEQLEN_SLICE and kv_seq_block_num <= 64: + kv_block_per_core = max(kv_block_per_core, 4) + elif block_size == KV_SEQLEN_SLICE and kv_seq_block_num <= 128: + kv_block_per_core = max(kv_block_per_core, 16) else: core_per_batch = _ceil_div(block_dim, decoder_batch) kv_block_per_core = _ceil_div(kv_seq_block_num, core_per_batch) From 4477d99b37938a010bd78a8b369172e86443c3c2 Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 15:56:04 +0200 Subject: [PATCH 09/11] optimized combine path --- .../jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index df74cbe3..5dd8279d 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -778,7 +778,6 @@ AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV( } const bool activeSubBlock = subBlockId < 2; - const bool combineSubBlock = subBlockId == 0; const int32_t formerHeadSplit = ctx.formerHeadSplit > 0 ? ctx.formerHeadSplit : 1; const int32_t maxHeadGroups = (formerHeadSplit + kHeadGroup - 1) / kHeadGroup; @@ -1132,15 +1131,16 @@ AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV( ffts_cross_core_sync(PIPE_MTE3, PtoPaGetFftsMsg(0x0, PTO_PA_REDUCE_READY_DECODER)); wait_flag_dev(PTO_PA_REDUCE_READY_DECODER); - if (!combineSubBlock) { + if (!activeSubBlock) { pipe_barrier(PIPE_ALL); return; } const int32_t effectiveBatch = ctx.decoderBatch > 0 ? ctx.decoderBatch : ctx.batch; const int64_t totalRows = static_cast(effectiveBatch) * ctx.numHeads; - const int64_t combineWorkerIdx = workerIdx; - const int64_t combineWorkerNum = workerNum; + const int64_t combineWorkerIdx = + workerIdx * 2 + static_cast(subBlockId); + const int64_t combineWorkerNum = workerNum * 2; for (int64_t row = combineWorkerIdx; row < totalRows; row += combineWorkerNum) { const int32_t head = static_cast(row % ctx.numHeads); From 2edb8cff382cf360cfcadf03d963a1e65111802b Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 16:05:36 +0200 Subject: [PATCH 10/11] vector convert in UB, then aligned GM copy --- .../paged_attention_highperf/pa_kernel_impl.hpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index 5dd8279d..db1fec66 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -1228,11 +1228,15 @@ AICORE inline void RunPtoPagedAttentionVecPipelineSplitKV( } const int64_t outBase = (static_cast(batchIndex) * ctx.numHeads + head) * ctx.headDim; - OutputGlobal outGlobal(reinterpret_cast<__gm__ half *>(oGm) + outBase); - TCVT(outHalfTile, weightedTile, RoundMode::CAST_RINT); + PtoPaConvF32ToF16(outHalfTile, weightedTile, 2); pipe_barrier(PIPE_V); - TSTORE(outGlobal, outHalfTile); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + copy_ubuf_to_gm_align_b16( + reinterpret_cast<__gm__ half *>(oGm) + outBase, + reinterpret_cast<__ubuf__ half *>(outHalfTile.data()), 0, 1, + static_cast(ctx.headDim * sizeof(half)), 0, 0, 0, 0); + pipe_barrier(PIPE_MTE3); } pipe_barrier(PIPE_ALL); } From a8da83c7edae24253a4972cd70a7b21d2bb9a48a Mon Sep 17 00:00:00 2001 From: MirkoDeVita98 Date: Mon, 29 Jun 2026 16:39:56 +0200 Subject: [PATCH 11/11] 980GB/s version --- .../paged_attention_highperf/pa_kernel_impl.hpp | 15 +++++++++++---- .../jit_cpp/paged_attention_highperf/pa_tiling.py | 4 ---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp index db1fec66..acb8351c 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp +++ b/examples/jit_cpp/paged_attention_highperf/pa_kernel_impl.hpp @@ -545,6 +545,7 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( const int64_t processRounds = (processNum + workerNum - 1) / workerNum; const int32_t stageTileCount = (ctx.kvSplitPerCore + kTileTokens - 1) / kTileTokens; + bool qkSlotNeedsFree[2] = {false, false}; for (int64_t processRound = 0; processRound < processRounds; ++processRound) { const int64_t process = processRound * workerNum + workerIdx; bool validProcess = process < processNum; @@ -621,6 +622,10 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( } const int32_t tile = tilePairBase + static_cast(stage); const uint8_t slot = static_cast(stage); + if (qkSlotNeedsFree[slot]) { + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, slot)); + qkSlotNeedsFree[slot] = false; + } const bool activeTileStage = (stage == 0) ? activeTile0 : activeTile1; for (int32_t headGroup = 0; headGroup < maxHeadGroups; ++headGroup) { const int32_t groupHeadBase = headGroup * kHeadGroup; @@ -746,11 +751,13 @@ AICORE inline void RunPtoPagedAttentionCubePipelineSplitKV( } DdrFenceBeforePtoAivReduce(); PtoPaSignalFromCube(PtoPaSlotFlag(PTO_PA_RAW_PV_READY, slot)); + qkSlotNeedsFree[slot] = true; } - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 0)); - if (hasStage2) { - wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, 1)); - } + } + } + for (uint8_t slot = 0; slot < 2; ++slot) { + if (qkSlotNeedsFree[slot]) { + wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_PV_FREE, slot)); } } pipe_barrier(PIPE_ALL); diff --git a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py index 433df17f..726d863e 100644 --- a/examples/jit_cpp/paged_attention_highperf/pa_tiling.py +++ b/examples/jit_cpp/paged_attention_highperf/pa_tiling.py @@ -226,10 +226,6 @@ def _split_core_bns_nd( if is_long_seq: kv_block_per_core = _ceil_div(kv_seq_block_num, block_dim) - if block_size == KV_SEQLEN_SLICE and kv_seq_block_num <= 64: - kv_block_per_core = max(kv_block_per_core, 4) - elif block_size == KV_SEQLEN_SLICE and kv_seq_block_num <= 128: - kv_block_per_core = max(kv_block_per_core, 16) else: core_per_batch = _ceil_div(block_dim, decoder_batch) kv_block_per_core = _ceil_div(kv_seq_block_num, core_per_batch)