// PTO paged-attention highperf expanded-to-CCE audit
//
// Audit artifact, not a compiling kernel. This file inlines the current PTO
// implementation into the CCE primitives that the PTO API maps to on a2a3.
// Use it next to simpler's pa_kernel.cce to compare semantic differences.
//
// Sources used for this expansion:
// - kernels/manual/a2a3/paged_attention_highperf/pa_kernel_impl.hpp
// - include/pto/common/pto_instr.hpp
// - include/pto/common/arch/memory/tload_common.hpp
// - include/pto/npu/a2a3/TLoad.hpp
// - include/pto/npu/a2a3/TStore.hpp
// - include/pto/npu/a2a3/TMatmul.hpp
// ============================================================================
// 0. PTO API Lowering Cheatsheet
// ============================================================================
void PTO_API_LOWERING_CHEATSHEET()
{
// TLOAD(MatTile NZ, GlobalTensor ND half) -> TLoadGm2L1Nd2nz.
copy_gm_to_cbuf_multi_nd2nz_b16(
dst_cbuf, src_gm,
0, // sid
1, // ndNum in this kernel's Q/K/V shapes
nValue,
dValue,
0, // srcNdMatrixStride
srcDValue,
dstNzC0Stride,
1, // dstNzMatrixStride
0); // dstNzC0StrideTail
// TLOAD(VecTile ND, GlobalTensor ND float/half) -> TLoadGm2ubNd2nd.
copy_gm_to_ubuf_align_b32(dst_ub, src_gm, 0, nBurst,
lenBurstBytes, 0, ubPad, gmGapBytes, ubGapBlocks);
copy_gm_to_ubuf_align_b16(dst_ub, src_gm, 0, nBurst,
lenBurstBytes, 0, ubPad, gmGapBytes, ubGapBlocks);
// PtoPaLoadCbufToCaRaw.
load_cbuf_to_ca(dst_ca, src_cbuf + srcElementOffset,
0, repeatTimes, srcStride, 0, 0, false, false,
addr_cal_mode_t(0));
// PtoPaLoadCbufToCbRaw.
load_cbuf_to_cb(dst_cb, src_cbuf + srcElementOffset,
0, repeatTimes, srcStride, 0, 0, false,
addr_cal_mode_t(0));
// PtoPaLoadCbufToCbTranspose128Raw.
for (uint32_t idx = 0; idx < 8; ++idx) {
load_cbuf_to_cb_transpose(
dst_cb + idx * 128 * 16,
src_cbuf + idx * 16 * 16,
0, 8, 8, 0, addr_cal_mode_t(0), 0);
}
// TMATMUL(TileAcc, TileLeft, TileRight) -> TMatmul -> mad.
// PTO TMatmul forces m=16 when m==1 for non-GEMV, but this kernel uses m=16.
mad(dst_cc, a_ca, b_cb, m, k, n,
0, // AccPhase::Unspecified
false, // kDirectionAlign for half x half
false, // cmatrixSource
true); // cmatrixInitVal
// TSTORE(GlobalTensor ND float, TileAcc float) -> TStoreAccNz2nd.
// The actual implementation packs xm/xt registers then calls the raw form.
copy_matrix_cc_to_gm(dst_gm, src_cc, xmReg, xtReg);
// TSTORE(GlobalTensor ND half/float, VecTile) -> TStoreUb2gmNd2nd.
copy_ubuf_to_gm_align_b16(dst_gm, src_ub, 0, nBurst,
lenBurstBytes, 0, 0, ubGapBlocks, gmGapBytes);
copy_ubuf_to_gm_align_b32(dst_gm, src_ub, 0, nBurst,
lenBurstBytes, 0, 0, ubGapBlocks, gmGapBytes);
}
// ============================================================================
// 1. CUBE QK Hot Loop Expanded from PTO
// ============================================================================
void PTO_CUBE_QK_EXPANDED_TO_CCE()
{
// Original PTO shape in pa_kernel_impl.hpp:
// QGlobal qGlobal(qGm + qBase);
// TLOAD(qMatTile, qGlobal);
// PtoPaLoadCbufToCaRaw(qLeftTile, qMatTile,
// headInGroupBase * 16, kHeadDim / 16, 1);
// KGlobal kGlobal(kGm + kvBase);
// TLOAD(kMatTile, kGlobal);
// PtoPaLoadCbufToCbRaw(rightTile, kMatTile, 0,
// (kHeadDim * kTileTokens) / 256, 1);
// TMATMUL(accTile, qLeftTile, rightTile);
// TSTORE(scoreGlobal, accTile);
// Current split-KV PTO QGlobal is ND half [1,1,1,16,128].
// qMatTile is an NZ MatTile holding the 16-head group.
copy_gm_to_cbuf_multi_nd2nz_b16(
qMatTile_cbuf,
reinterpret_cast<__gm__ half *>(qGm) + qBase,
0, 1,
16, // nValue: full 16-head group
128, // dValue
0,
128, // srcDValue
16, // dstNzC0Stride after RoundUp<16>(16)
1,
0);
pipe_barrier(PIPE_ALL);
load_cbuf_to_ca(
qLeftTile_ca,
qMatTile_cbuf + headInGroupBase * 16,
0,
8, // kHeadDim / 16
1,
0,
0,
false,
false,
addr_cal_mode_t(0));
// KGlobal is DN half [1,1,1,128,128] with stride4 = 8 * headDim.
// For the current fixed shape this reaches TLoadGm2L1Dn2zn.
copy_gm_to_cbuf_multi_nd2nz_b16(
kMatTile_cbuf,
reinterpret_cast<__gm__ half *>(kGm) + kvBase,
0,
1,
128, // nValue: tile tokens
128, // dValue: head dim
0,
8 * 128, // srcDValue / physical kv-head stride
128, // dstNzC0Stride
1,
0);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
load_cbuf_to_cb(
rightTile_cb,
kMatTile_cbuf,
0,
64, // (128 * 128) / 256
1,
0,
0,
false,
addr_cal_mode_t(0));
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
mad(accTile_cc, qLeftTile_ca, rightTile_cb,
16, 128, 128, 0, false, false, true);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
// ScoreGlobal is ND float; TSTORE(Acc2GM) lowers to copy_matrix_cc_to_gm.
copy_matrix_cc_to_gm(
reinterpret_cast<__gm__ float *>(scoreBase + slot * scoreGroupBytes
+ headInGroupBase * scoreHeadBytes),
accTile_cc,
xmReg_for_ND_float_16x128,
xtReg_for_ND_float_16x128);
dsb(DSB_DDR);
pipe_barrier(PIPE_ALL);
ffts_cross_core_sync(PIPE_FIX,
1 | (2 << 4) | ((PTO_PA_RAW_QK_READY + slot) << 8));
// Simpler CCE difference to check:
// - It uses ProcessQK with explicit l0/l1 ping-pong state.
// - It signals QK_READY_DECODER and QK_READY_STAGE2, not slot flags.
// - It calls DdrBarrierBeforeFfts(), which includes pipe_barrier(PIPE_ALL).
}
// ============================================================================
// 2. AIV Softmax Stage1 Expanded from PTO
// ============================================================================
void PTO_AIV_SOFTMAX_STAGE1_EXPANDED_TO_CCE()
{
// Original PTO shape:
// wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot));
// TLOAD(scoreRowsTile, scoreGlobal);
// TMULS(scoreRowsTile, scoreRowsTile, ctx.scale);
// TROWMAX(rowMaxRowsTile, scoreRowsTile, scoreRowsWorkTile);
// TMAX(newMaxRowsView, rowMaxRowsView, maxStateRowsView);
// TSUB(oldScaleRowsView, maxStateRowsView, newMaxRowsView);
// TEXP(oldScaleRowsView, oldScaleRowsView);
// TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile);
// TEXP(scoreRowsWorkTile, scoreRowsWorkTile);
// TROWSUM(rowSumRowsTile, scoreRowsWorkTile, scoreRowsTile);
// TSTORE(probRowGlobal, probHalfTile) for each row.
wait_flag_dev(PTO_PA_RAW_QK_READY + slot);
copy_gm_to_ubuf_align_b32(
scoreRowsTile_ub,
reinterpret_cast<__gm__ float *>(scoreBase + slot * scoreGroupBytes
+ headInGroupBase * scoreHeadBytes),
0,
4, // 4 rows handled by this AIV sub-block
128 * sizeof(float),
0,
0,
0,
0);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
vmuls(scoreRowsTile_ub, scoreRowsTile_ub, ctx_scale,
4, 1, 1, 128 / 8, 128 / 8);
pipe_barrier(PIPE_V);
// PTO TROWMAX is equivalent to CCE ReduceMaxRepeatM for 4 rows.
vcmax(rowMaxRowsTile_ub, scoreRowsTile_ub,
4, 1, 1, 128 / 8, ONLY_VALUE);
pipe_barrier(PIPE_V);
vmax(newMaxRowsView_ub, rowMaxRowsView_ub, maxStateRowsView_ub,
1, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vsub(oldScaleRowsView_ub, maxStateRowsView_ub, newMaxRowsView_ub,
1, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vexp(oldScaleRowsView_ub, oldScaleRowsView_ub,
1, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
if (tile == 0) {
// TEXPANDS(oldScaleRowsView, 0.0f)
vector_dup(oldScaleRowsView_ub, 0.0f, 1, 1, 8);
pipe_barrier(PIPE_V);
}
// TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile).
vbrcb(reinterpret_cast<__ubuf__ uint32_t *>(rowBroadcast_ub),
reinterpret_cast<__ubuf__ uint32_t *>(newMaxRowsTile_ub),
1, 8, 16 / 8);
pipe_barrier(PIPE_V);
for (int colBlock = 0; colBlock < 2; ++colBlock) {
vsub(scoreRowsWorkTile_ub + colBlock * 64,
scoreRowsTile_ub + colBlock * 64,
rowBroadcast_ub,
4, 1, 1, 0, 128 / 8, 128 / 8, 1);
}
pipe_barrier(PIPE_V);
vexp(scoreRowsWorkTile_ub, scoreRowsWorkTile_ub,
(4 * 128 + 63) / 64, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
vcadd(rowSumRowsTile_ub, scoreRowsWorkTile_ub,
4, 1, 1, 128 / 8, 0);
pipe_barrier(PIPE_V);
vmul(sumStateRowsView_ub, sumStateRowsView_ub, oldScaleRowsView_ub,
1, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
vadd(sumStateRowsView_ub, sumStateRowsView_ub, rowSumRowsView_ub,
1, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
// Current PTO probability store writes two rows at a time using the
// existing 256-half scratch tile. CCE still stores the whole sub_m tile.
for (int row = 0; row < 4; row += 2) {
vconv_f322f16a(probHalf256_ub,
scoreRowsWorkTile_ub + row * 128,
4, 1, 1, 4, 8);
pipe_barrier(PIPE_V);
copy_ubuf_to_gm_align_b16(
reinterpret_cast<__gm__ half *>(probBase + slot * probGroupBytes)
+ (headInGroupBase + row) * 128,
probHalf256_ub,
0,
1,
256 * sizeof(half),
0,
0,
0,
0);
}
dsb(DSB_DDR);
pipe_barrier(PIPE_ALL);
ffts_cross_core_sync(PIPE_MTE3,
1 | (2 << 4) | ((PTO_PA_RAW_P_READY + slot) << 8));
// Simpler CCE difference to check:
// - SoftmaxStage1 stores p_gm_tensor in one bulk copy_ubuf_to_gm call.
// - It has stage1/stage2 scratch halves instead of slot flags.
// - It keeps mask/logN/general shape handling; PTO highperf path omits them.
}
// ============================================================================
// 3. CUBE PV Hot Loop Expanded from PTO
// ============================================================================
void PTO_CUBE_PV_EXPANDED_TO_CCE()
{
wait_flag_dev(PTO_PA_RAW_P_READY + slot);
if (tile >= 2) {
wait_flag_dev(PTO_PA_RAW_PV_FREE + slot);
}
// TLOAD(pMatTile, probGlobal), Global ND half -> Mat NZ.
copy_gm_to_cbuf_multi_nd2nz_b16(
pMatTile_cbuf,
reinterpret_cast<__gm__ half *>(probBase + slot * probGroupBytes),
0,
1,
16, // padded probability rows in the group
128,
0,
128,
16,
1,
0);
pipe_barrier(PIPE_ALL);
load_cbuf_to_ca(
pLeftTile_ca,
pMatTile_cbuf + headInGroupBase * 16,
0,
8, // kTileTokens / 16
1,
0,
0,
false,
false,
addr_cal_mode_t(0));
// TLOAD(vMatTile, vGlobal), VGlobal is ND half [1,1,1,128,128].
copy_gm_to_cbuf_multi_nd2nz_b16(
vMatTile_cbuf,
reinterpret_cast<__gm__ half *>(vGm) + kvBase,
0,
1,
128,
128,
0,
8 * 128,
128,
1,
0);
set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
for (uint32_t idx = 0; idx < 8; ++idx) {
load_cbuf_to_cb_transpose(
rightTile_cb + idx * 128 * 16,
vMatTile_cbuf + idx * 16 * 16,
0,
8,
8,
0,
addr_cal_mode_t(0),
0);
}
set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
mad(accTile_cc, pLeftTile_ca, rightTile_cb,
16, 128, 128, 0, false, false, true);
set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
copy_matrix_cc_to_gm(
reinterpret_cast<__gm__ float *>(outBase + slot * outGroupBytes
+ headInGroupBase * outHeadBytes),
accTile_cc,
xmReg_for_ND_float_16x128,
xtReg_for_ND_float_16x128);
dsb(DSB_DDR);
pipe_barrier(PIPE_ALL);
ffts_cross_core_sync(PIPE_FIX,
1 | (2 << 4) | ((PTO_PA_RAW_PV_READY + slot) << 8));
// Simpler CCE difference to check:
// - ProcessPV loads p_gm_tensor from a bulk probability tile produced by AIV.
// - It uses l1/l0 ping-pong flags and sets UPDATE_READY_DECODER/STAGE2.
}
// ============================================================================
// 4. AIV Stage2 Accumulate Expanded from PTO
// ============================================================================
void PTO_AIV_STAGE2_EXPANDED_TO_CCE()
{
wait_flag_dev(PTO_PA_RAW_PV_READY + slot);
copy_gm_to_ubuf_align_b32(
pvRowsTile_ub,
reinterpret_cast<__gm__ float *>(outTmpBase + slot * outGroupBytes
+ headInGroupBase * outHeadBytes),
0,
4,
128 * sizeof(float),
0,
0,
0,
0);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
// TROWEXPANDMUL(weightedRowsTile, weightedRowsTile, oldScaleRowsTile).
vbrcb(reinterpret_cast<__ubuf__ uint32_t *>(rowBroadcast_ub),
reinterpret_cast<__ubuf__ uint32_t *>(oldScaleRowsTile_ub),
1, 8, 16 / 8);
pipe_barrier(PIPE_V);
for (int colBlock = 0; colBlock < 2; ++colBlock) {
vmul(weightedRowsTile_ub + colBlock * 64,
weightedRowsTile_ub + colBlock * 64,
rowBroadcast_ub,
4, 1, 1, 0, 128 / 8, 128 / 8, 1);
}
pipe_barrier(PIPE_V);
vadd(weightedRowsTile_ub, weightedRowsTile_ub, pvRowsTile_ub,
(4 * 128 + 63) / 64, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
dsb(DSB_DDR);
pipe_barrier(PIPE_ALL);
ffts_cross_core_sync(PIPE_MTE3,
1 | (2 << 4) | ((PTO_PA_RAW_PV_FREE + slot) << 8));
// Per split output before final combine:
// partialL[lOffset] = maxScore + log(sumExp)
// TMULS(weightedTile, weightedTile, invSum)
// TSTORE(weightedGlobal, weightedTile)
vln(scalarMathTile_ub, scalarMathTile_ub, 1, 1, 1, 8, 8);
vmuls(weightedTile_ub, weightedTile_ub, invSum, 2, 1, 1, 8, 8);
copy_ubuf_to_gm_align_b32(
reinterpret_cast<__gm__ float *>(partialOut + outOffset),
weightedTile_ub,
0,
1,
128 * sizeof(float),
0,
0,
0,
0);
}
// ============================================================================
// 5. Current PTO Split-KV Final Combine Expanded to CCE
// ============================================================================
void PTO_FINAL_COMBINE_EXPANDED_TO_CCE()
{
// Current PTO source now uses vector exp/sum for split scales, but still
// scalar-loads L values and scalar-stores final fp16 output.
__ubuf__ float *splitScale = reinterpret_cast<__ubuf__ float *>(0x3200);
__ubuf__ float *splitReduce = reinterpret_cast<__ubuf__ float *>(0x3400);
float lMax = -3.4028234663852886e38f;
for (int32_t split = 0; split < kvLoop; ++split) {
float lValue = partialL[lBase + head * ctx.kvSplitCoreNum + split];
splitScale[split] = lValue; // scalar GM load + scalar UB store
lMax = lValue > lMax ? lValue : lMax;
}
set_flag(PIPE_S, PIPE_V, EVENT_ID2);
wait_flag(PIPE_S, PIPE_V, EVENT_ID2);
set_mask_norm();
set_vector_mask(0, (1ULL << kvLoop) - 1);
vadds(splitScale, splitScale, -lMax, 1, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
vexp(splitScale, splitScale, 1, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
vcadd(splitReduce, splitScale, 1, 1, 1, 8, 0);
pipe_barrier(PIPE_V);
set_vector_mask(-1ULL, -1ULL);
float invDenom = splitReduce[0] > 0.0f ? 1.0f / splitReduce[0] : 0.0f;
set_mask_norm();
set_vector_mask(0, (1ULL << kvLoop) - 1);
vmuls(splitScale, splitScale, invDenom, 1, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
set_vector_mask(-1ULL, -1ULL);
vector_dup(weightedTile_ub, 0.0f, 2, 1, 8);
pipe_barrier(PIPE_V);
for (int32_t split = 0; split < kvLoop; ++split) {
copy_gm_to_ubuf_align_b32(
pvTile_ub,
partialOut + oFdBase * ctx.kvSplitCoreNum
+ head * ctx.headDim * ctx.kvSplitCoreNum
+ split * ctx.headDim,
0,
1,
128 * sizeof(float),
0,
0,
0,
0);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
vmuls(pvTile_ub, pvTile_ub, splitScale[split], 2, 1, 1, 8, 8);
pipe_barrier(PIPE_V);
vadd(weightedTile_ub, weightedTile_ub, pvTile_ub,
2, 1, 1, 1, 8, 8, 8);
pipe_barrier(PIPE_V);
}
// Current final store is scalar in PTO, not TSTORE/conv bulk.
for (int32_t dim = 0; dim < ctx.headDim; ++dim) {
reinterpret_cast<__gm__ half *>(oGm)[outBase + dim] =
static_cast<half>(weightedTile_ub[dim]);
}
// Simpler CCE CombineScale differences to check line-by-line:
// - CCE bulk-loads all L values with copy_gm_to_ubuf_align_b32.
// - CCE reduces with vcmax/vadds/vexp/vcadd/vln over UB.
// - CCE bulk-loads O partials as m_split x head_dim.
// - CCE broadcasts each split scale and accumulates full rows.
// - CCE converts final fp32 tile with vconv_f322f16 and bulk stores with
// copy_ubuf_to_gm_align_b16.
}
The following file inlines the PTO translation to CCE: