Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/jit_cpp/paged_attention_highperf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
outputs/
*.so
pa_highperf_jit.so
pa_highperf_jit_bench.csv
__pycache__/
.pytest_cache/
35 changes: 35 additions & 0 deletions examples/jit_cpp/paged_attention_highperf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Paged Attention HighPerf

JIT demo for a PTO-ISA paged-attention decode kernel using PyTorch/NPU tensors.

The example includes:

| File | Purpose |
|---|---|
| pa_kernel.cpp | host-callable JIT entry point |
| pa_entry.hpp | AIC/AIV dispatch wrapper |
| pa_kernel_impl.hpp | PTO kernel implementation |
| pa_tiling_struct.hpp | tiling type definitions |
| pa_tiling.py | Python tiling/workspace construction |
| pa_compile_and_run.py | correctness smoke test |
| pa_benchmark.py | benchmark driver |
| jit_util_pa.py | JIT compile and ctypes wrapper |

## Requirements

Set the PTO-ISA include root and CANN toolkit path if they are not already in the environment:

export PTO_LIB_PATH=/path/to/pto-isa
export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/cann-9.0.0

## Run

cd examples/jit_cpp/paged_attention_highperf
python3 pa_compile_and_run.py
python3 pa_benchmark.py --device npu:0 --shape b=8,s=4096 --check --warmup 1 --iters 1

The full benchmark sweep can be run with:

python3 pa_benchmark.py --device npu:0

Use --check for correctness validation against the Python/PyTorch reference. For very large shapes, the reference can dominate runtime and memory use.
158 changes: 158 additions & 0 deletions examples/jit_cpp/paged_attention_highperf/jit_util_pa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/python3
# coding=utf-8
# --------------------------------------------------------------------------------
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the
# terms and conditions of CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance
# with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY,
# OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# --------------------------------------------------------------------------------

import ctypes
import os
import subprocess
from pathlib import Path

import torch

ASCEND_TOOLKIT_HOME = os.environ.get(
"ASCEND_TOOLKIT_HOME", "/usr/local/Ascend/cann-9.0.0"
)
PTO_LIB_PATH = os.environ.get("PTO_LIB_PATH", str(Path(__file__).resolve().parents[3]))


def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p:
return ctypes.c_void_p(t.data_ptr())


def _npu_arch_flag() -> str:
return os.environ.get("NPU_ARCH", "dav-2201").strip()


def compile_paged_attention(
kernel_cpp: str, verbose: bool = False, timeout: int = 300
) -> str:
lib_path = os.path.join(os.path.dirname(kernel_cpp), "pa_highperf_jit.so")
example_dir = os.path.dirname(kernel_cpp)
flags = [
"-fPIC",
"-shared",
"-xcce",
f"--npu-arch={_npu_arch_flag()}",
"-O2",
"-std=c++17",
"-Wno-ignored-attributes",
"-Wno-macro-redefined",
f"-I{PTO_LIB_PATH}/include",
f"-I{example_dir}",
f"-I{ASCEND_TOOLKIT_HOME}/include",
f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/runtime",
f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc",
f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/profiling",
]
cmd = ["bisheng", *flags, kernel_cpp, "-o", lib_path]
if verbose:
print("compile command:\n", " ".join(cmd))
subprocess.run(cmd, check=True, timeout=timeout)
return lib_path


def load_paged_attention_lib(lib_path: str, check_type: bool = True):
lib = ctypes.CDLL(os.path.abspath(lib_path))
if check_type:
lib.call_kernel.argtypes = [
ctypes.c_void_p, # stream
ctypes.c_void_p, # q
ctypes.c_void_p, # k
ctypes.c_void_p, # v
ctypes.c_void_p, # block table
ctypes.c_void_p, # out
ctypes.c_void_p, # s
ctypes.c_void_p, # p
ctypes.c_void_p, # o_tmp
ctypes.c_void_p, # go
ctypes.c_void_p, # o_core_tmp
ctypes.c_void_p, # l
ctypes.c_void_p, # gm_k16
ctypes.c_void_p, # gm_v16
ctypes.c_void_p, # tiling
ctypes.c_void_p, # null
ctypes.c_uint32, # block dim
]
lib.call_kernel.restype = None

workspace = {}
# pylint: disable-next=protected-access
default_stream_ptr = torch.npu.current_stream()._as_parameter_

def _alloc(device, workspace_sizes, tiling):
tiling_cpu = tuple(int(x) for x in tiling.detach().cpu().tolist())
sizes_key = tuple(
sorted((name, int(size)) for name, size in workspace_sizes.items())
)
key = (str(device), sizes_key, tiling_cpu)
if workspace.get("key") == key:
return
workspace.clear()
workspace["key"] = key
for name, size in workspace_sizes.items():
workspace[name] = torch.empty(
(int(size),), device=device, dtype=torch.uint8
)
if name in {"s", "p", "o_tmp", "o_core_tmp", "l"}:
workspace[name].zero_()
workspace["null"] = torch.zeros((1,), device=device, dtype=torch.uint8)
workspace["tiling"] = tiling.to(device=device, dtype=torch.int32)

def paged_attention(
q,
k,
v,
block_table,
workspace_sizes,
tiling,
stream_ptr=default_stream_ptr,
block_dim: int = 24,
):
_alloc(q.device, workspace_sizes, tiling)
out = torch.empty_like(q)
lib.call_kernel(
stream_ptr,
torch_to_ctypes(q),
torch_to_ctypes(k),
torch_to_ctypes(v),
torch_to_ctypes(block_table),
torch_to_ctypes(out),
torch_to_ctypes(workspace["s"]),
torch_to_ctypes(workspace["p"]),
torch_to_ctypes(workspace["o_tmp"]),
torch_to_ctypes(workspace["go"]),
torch_to_ctypes(workspace["o_core_tmp"]),
torch_to_ctypes(workspace["l"]),
torch_to_ctypes(workspace["k16"]),
torch_to_ctypes(workspace["v16"]),
torch_to_ctypes(workspace["tiling"]),
torch_to_ctypes(workspace["null"]),
block_dim,
)
return out

return paged_attention


def jit_compile_paged_attention(
verbose: bool = False, clean_up: bool = True, kernel_cpp: str = "pa_kernel.cpp"
):
kernel_path = str((Path(__file__).resolve().parent / kernel_cpp).resolve())
lib_path = compile_paged_attention(kernel_path, verbose=verbose)
fn = load_paged_attention_lib(lib_path)
if clean_up:
try:
os.remove(lib_path)
except OSError:
pass
return fn
197 changes: 197 additions & 0 deletions examples/jit_cpp/paged_attention_highperf/pa_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#!/usr/bin/python3
# coding=utf-8
import argparse
import csv
import gc
import time

import torch

from jit_util_pa import jit_compile_paged_attention
from pa_compile_and_run import (
PaShape,
golden_attention,
make_inputs,
make_launch_config,
)

NUM_ITERATIONS = 50
WARMUP = 10
RUN_DELAY_SECONDS = 2.0
BATCHES = [1, 2, 4, 8, 32, 64]
SEQ_LENS = [128, 512, 4096, 8192, 16384, 32768, 65536, 131072]
DEFAULT_SHAPES = [
PaShape(batch=batch, seq_len=seq_len) for batch in BATCHES for seq_len in SEQ_LENS
]


def paged_attention_flops(shape: PaShape):
qk_and_pv = 4 * shape.batch * shape.num_heads * shape.seq_len * shape.head_dim
scale = shape.batch * shape.num_heads * shape.seq_len
rows = shape.batch * shape.num_heads
softmax = rows * (
(shape.seq_len - 1)
+ shape.seq_len
+ shape.seq_len
+ (shape.seq_len - 1)
+ shape.seq_len
)
return qk_and_pv + scale + softmax


def tensor_bytes(shape: PaShape):
dtype_bytes = 2
q_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes
k_bytes = (
shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes
)
v_bytes = (
shape.batch * shape.seq_len * shape.num_kv_heads * shape.head_dim * dtype_bytes
)
out_bytes = shape.batch * shape.num_heads * shape.head_dim * dtype_bytes
blocks_per_batch = (shape.seq_len + shape.block_size - 1) // shape.block_size
block_table_bytes = shape.batch * blocks_per_batch * 4
return q_bytes + k_bytes + v_bytes + out_bytes + block_table_bytes


def tflops(flops, ms):
return flops / (ms * 1e-3) / 1e12


def tb_per_second(num_bytes, ms):
return num_bytes / (ms * 1e-3) / 1e12


def time_npu(fn, iters, warmup):
for _ in range(warmup):
_ = fn()
torch.npu.synchronize()
start = torch.npu.Event(enable_timing=True)
end = torch.npu.Event(enable_timing=True)
start.record()
for _ in range(iters):
_ = fn()
end.record()
torch.npu.synchronize()
return start.elapsed_time(end) / iters


def parse_shape(text):
values = {}
for item in text.split(","):
key, value = item.split("=", 1)
values[key.strip()] = int(value)
return PaShape(
batch=values.get("b", values.get("batch", 1)),
seq_len=values.get("s", values.get("seq", 128)),
num_heads=values.get("h", values.get("heads", 32)),
num_kv_heads=values.get("kv", values.get("kv_heads", 8)),
head_dim=values.get("d", values.get("head_dim", 128)),
block_size=values.get("bs", values.get("block_size", 128)),
block_dim=values.get("bd", values.get("block_dim", 24)),
)


def run_shape(pa, shape, device, iters, warmup, check):
q, k, v, block_table = make_inputs(shape, device, deterministic=check)
ws, tiling, _ = make_launch_config(shape)
if check:
out = pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim)
torch.npu.synchronize()
torch.testing.assert_close(
out.float(),
golden_attention(q, k, v, block_table, shape),
rtol=5e-3,
atol=2e-2,
)
ms = time_npu(
lambda: pa(q, k, v, block_table, ws, tiling, block_dim=shape.block_dim),
iters,
warmup,
)
flops = paged_attention_flops(shape)
bytes_total = tensor_bytes(shape)
perf = tflops(flops, ms)
norm_perf = perf * shape.block_dim
return {
"shape": shape.name,
"batch": shape.batch,
"seq_len": shape.seq_len,
"block_dim": shape.block_dim,
"jit_time_us": f"{ms * 1000:.3f}",
"jit_tflops": f"{perf:.6f}",
"jit_tflops_normalized": f"{norm_perf:.6f}",
"jit_bandwidth_tb_s": f"{tb_per_second(bytes_total, ms):.6f}",
}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--csv", default="pa_highperf_jit_bench.csv")
parser.add_argument("--iters", type=int, default=NUM_ITERATIONS)
parser.add_argument("--warmup", type=int, default=WARMUP)
parser.add_argument(
"--run-delay",
type=float,
default=RUN_DELAY_SECONDS,
help="Seconds to wait between benchmark shapes; set to 0 to disable.",
)
parser.add_argument("--device", default="npu:0")
parser.add_argument(
"--shape",
action="append",
help="Shape override, e.g. b=2,s=8192 or batch=4,seq=512",
)
parser.add_argument(
"--check",
action="store_true",
help="Run correctness check before timing each shape.",
)
parser.add_argument("--no-check", action="store_true", help=argparse.SUPPRESS)
args = parser.parse_args()

torch.npu.set_device(args.device)
shapes = (
[parse_shape(item) for item in args.shape] if args.shape else DEFAULT_SHAPES
)
pa = jit_compile_paged_attention(verbose=False)
rows = []
for idx, shape in enumerate(shapes):
row = run_shape(
pa,
shape,
args.device,
args.iters,
args.warmup,
args.check and not args.no_check,
)
rows.append(row)
print(
f"paged_attention_highperf_jit {row['shape']}: {row['jit_time_us']} us/iter, "
f"{row['jit_tflops']} TFLOPS logical, {row['jit_tflops_normalized']} TFLOPS normalized, "
f"{row['jit_bandwidth_tb_s']} TB/s, block_dim={row['block_dim']}"
)
torch.npu.synchronize()
gc.collect()
torch.npu.empty_cache()
if args.run_delay > 0 and idx + 1 < len(shapes):
time.sleep(args.run_delay)

fieldnames = [
"shape",
"batch",
"seq_len",
"block_dim",
"jit_time_us",
"jit_tflops",
"jit_tflops_normalized",
"jit_bandwidth_tb_s",
]
with open(args.csv, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)


if __name__ == "__main__":
main()
Loading
Loading