Skip to content

Translation of paged attention highperf from CCE to PTO-ISA #192

Description

@MirkoDeVita98

The following file inlines the PTO translation to CCE:

// PTO paged-attention highperf expanded-to-CCE audit
//
// Audit artifact, not a compiling kernel. This file inlines the current PTO
// implementation into the CCE primitives that the PTO API maps to on a2a3.
// Use it next to simpler's pa_kernel.cce to compare semantic differences.
//
// Sources used for this expansion:
// - kernels/manual/a2a3/paged_attention_highperf/pa_kernel_impl.hpp
// - include/pto/common/pto_instr.hpp
// - include/pto/common/arch/memory/tload_common.hpp
// - include/pto/npu/a2a3/TLoad.hpp
// - include/pto/npu/a2a3/TStore.hpp
// - include/pto/npu/a2a3/TMatmul.hpp

// ============================================================================
// 0. PTO API Lowering Cheatsheet
// ============================================================================

void PTO_API_LOWERING_CHEATSHEET()
{
    // TLOAD(MatTile NZ, GlobalTensor ND half) -> TLoadGm2L1Nd2nz.
    copy_gm_to_cbuf_multi_nd2nz_b16(
        dst_cbuf, src_gm,
        0,              // sid
        1,              // ndNum in this kernel's Q/K/V shapes
        nValue,
        dValue,
        0,              // srcNdMatrixStride
        srcDValue,
        dstNzC0Stride,
        1,              // dstNzMatrixStride
        0);             // dstNzC0StrideTail

    // TLOAD(VecTile ND, GlobalTensor ND float/half) -> TLoadGm2ubNd2nd.
    copy_gm_to_ubuf_align_b32(dst_ub, src_gm, 0, nBurst,
        lenBurstBytes, 0, ubPad, gmGapBytes, ubGapBlocks);
    copy_gm_to_ubuf_align_b16(dst_ub, src_gm, 0, nBurst,
        lenBurstBytes, 0, ubPad, gmGapBytes, ubGapBlocks);

    // PtoPaLoadCbufToCaRaw.
    load_cbuf_to_ca(dst_ca, src_cbuf + srcElementOffset,
        0, repeatTimes, srcStride, 0, 0, false, false,
        addr_cal_mode_t(0));

    // PtoPaLoadCbufToCbRaw.
    load_cbuf_to_cb(dst_cb, src_cbuf + srcElementOffset,
        0, repeatTimes, srcStride, 0, 0, false,
        addr_cal_mode_t(0));

    // PtoPaLoadCbufToCbTranspose128Raw.
    for (uint32_t idx = 0; idx < 8; ++idx) {
        load_cbuf_to_cb_transpose(
            dst_cb + idx * 128 * 16,
            src_cbuf + idx * 16 * 16,
            0, 8, 8, 0, addr_cal_mode_t(0), 0);
    }

    // TMATMUL(TileAcc, TileLeft, TileRight) -> TMatmul -> mad.
    // PTO TMatmul forces m=16 when m==1 for non-GEMV, but this kernel uses m=16.
    mad(dst_cc, a_ca, b_cb, m, k, n,
        0,              // AccPhase::Unspecified
        false,          // kDirectionAlign for half x half
        false,          // cmatrixSource
        true);          // cmatrixInitVal

    // TSTORE(GlobalTensor ND float, TileAcc float) -> TStoreAccNz2nd.
    // The actual implementation packs xm/xt registers then calls the raw form.
    copy_matrix_cc_to_gm(dst_gm, src_cc, xmReg, xtReg);

    // TSTORE(GlobalTensor ND half/float, VecTile) -> TStoreUb2gmNd2nd.
    copy_ubuf_to_gm_align_b16(dst_gm, src_ub, 0, nBurst,
        lenBurstBytes, 0, 0, ubGapBlocks, gmGapBytes);
    copy_ubuf_to_gm_align_b32(dst_gm, src_ub, 0, nBurst,
        lenBurstBytes, 0, 0, ubGapBlocks, gmGapBytes);
}

// ============================================================================
// 1. CUBE QK Hot Loop Expanded from PTO
// ============================================================================

void PTO_CUBE_QK_EXPANDED_TO_CCE()
{
    // Original PTO shape in pa_kernel_impl.hpp:
    //   QGlobal qGlobal(qGm + qBase);
    //   TLOAD(qMatTile, qGlobal);
    //   PtoPaLoadCbufToCaRaw(qLeftTile, qMatTile,
    //       headInGroupBase * 16, kHeadDim / 16, 1);
    //   KGlobal kGlobal(kGm + kvBase);
    //   TLOAD(kMatTile, kGlobal);
    //   PtoPaLoadCbufToCbRaw(rightTile, kMatTile, 0,
    //       (kHeadDim * kTileTokens) / 256, 1);
    //   TMATMUL(accTile, qLeftTile, rightTile);
    //   TSTORE(scoreGlobal, accTile);

    // Current split-KV PTO QGlobal is ND half [1,1,1,16,128].
    // qMatTile is an NZ MatTile holding the 16-head group.
    copy_gm_to_cbuf_multi_nd2nz_b16(
        qMatTile_cbuf,
        reinterpret_cast<__gm__ half *>(qGm) + qBase,
        0, 1,
        16,             // nValue: full 16-head group
        128,            // dValue
        0,
        128,            // srcDValue
        16,             // dstNzC0Stride after RoundUp<16>(16)
        1,
        0);
    pipe_barrier(PIPE_ALL);

    load_cbuf_to_ca(
        qLeftTile_ca,
        qMatTile_cbuf + headInGroupBase * 16,
        0,
        8,              // kHeadDim / 16
        1,
        0,
        0,
        false,
        false,
        addr_cal_mode_t(0));

    // KGlobal is DN half [1,1,1,128,128] with stride4 = 8 * headDim.
    // For the current fixed shape this reaches TLoadGm2L1Dn2zn.
    copy_gm_to_cbuf_multi_nd2nz_b16(
        kMatTile_cbuf,
        reinterpret_cast<__gm__ half *>(kGm) + kvBase,
        0,
        1,
        128,            // nValue: tile tokens
        128,            // dValue: head dim
        0,
        8 * 128,        // srcDValue / physical kv-head stride
        128,            // dstNzC0Stride
        1,
        0);

    set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
    wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
    set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);

    load_cbuf_to_cb(
        rightTile_cb,
        kMatTile_cbuf,
        0,
        64,             // (128 * 128) / 256
        1,
        0,
        0,
        false,
        addr_cal_mode_t(0));

    set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    mad(accTile_cc, qLeftTile_ca, rightTile_cb,
        16, 128, 128, 0, false, false, true);
    set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
    wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
    set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);

    // ScoreGlobal is ND float; TSTORE(Acc2GM) lowers to copy_matrix_cc_to_gm.
    copy_matrix_cc_to_gm(
        reinterpret_cast<__gm__ float *>(scoreBase + slot * scoreGroupBytes
            + headInGroupBase * scoreHeadBytes),
        accTile_cc,
        xmReg_for_ND_float_16x128,
        xtReg_for_ND_float_16x128);

    dsb(DSB_DDR);
    pipe_barrier(PIPE_ALL);
    ffts_cross_core_sync(PIPE_FIX,
        1 | (2 << 4) | ((PTO_PA_RAW_QK_READY + slot) << 8));

    // Simpler CCE difference to check:
    // - It uses ProcessQK with explicit l0/l1 ping-pong state.
    // - It signals QK_READY_DECODER and QK_READY_STAGE2, not slot flags.
    // - It calls DdrBarrierBeforeFfts(), which includes pipe_barrier(PIPE_ALL).
}

// ============================================================================
// 2. AIV Softmax Stage1 Expanded from PTO
// ============================================================================

void PTO_AIV_SOFTMAX_STAGE1_EXPANDED_TO_CCE()
{
    // Original PTO shape:
    //   wait_flag_dev(PtoPaSlotFlag(PTO_PA_RAW_QK_READY, slot));
    //   TLOAD(scoreRowsTile, scoreGlobal);
    //   TMULS(scoreRowsTile, scoreRowsTile, ctx.scale);
    //   TROWMAX(rowMaxRowsTile, scoreRowsTile, scoreRowsWorkTile);
    //   TMAX(newMaxRowsView, rowMaxRowsView, maxStateRowsView);
    //   TSUB(oldScaleRowsView, maxStateRowsView, newMaxRowsView);
    //   TEXP(oldScaleRowsView, oldScaleRowsView);
    //   TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile);
    //   TEXP(scoreRowsWorkTile, scoreRowsWorkTile);
    //   TROWSUM(rowSumRowsTile, scoreRowsWorkTile, scoreRowsTile);
    //   TSTORE(probRowGlobal, probHalfTile) for each row.

    wait_flag_dev(PTO_PA_RAW_QK_READY + slot);

    copy_gm_to_ubuf_align_b32(
        scoreRowsTile_ub,
        reinterpret_cast<__gm__ float *>(scoreBase + slot * scoreGroupBytes
            + headInGroupBase * scoreHeadBytes),
        0,
        4,              // 4 rows handled by this AIV sub-block
        128 * sizeof(float),
        0,
        0,
        0,
        0);
    set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
    wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);

    vmuls(scoreRowsTile_ub, scoreRowsTile_ub, ctx_scale,
        4, 1, 1, 128 / 8, 128 / 8);
    pipe_barrier(PIPE_V);

    // PTO TROWMAX is equivalent to CCE ReduceMaxRepeatM for 4 rows.
    vcmax(rowMaxRowsTile_ub, scoreRowsTile_ub,
        4, 1, 1, 128 / 8, ONLY_VALUE);
    pipe_barrier(PIPE_V);

    vmax(newMaxRowsView_ub, rowMaxRowsView_ub, maxStateRowsView_ub,
        1, 1, 1, 1, 8, 8, 8);
    pipe_barrier(PIPE_V);

    vsub(oldScaleRowsView_ub, maxStateRowsView_ub, newMaxRowsView_ub,
        1, 1, 1, 1, 8, 8, 8);
    pipe_barrier(PIPE_V);
    vexp(oldScaleRowsView_ub, oldScaleRowsView_ub,
        1, 1, 1, 8, 8);
    pipe_barrier(PIPE_V);

    if (tile == 0) {
        // TEXPANDS(oldScaleRowsView, 0.0f)
        vector_dup(oldScaleRowsView_ub, 0.0f, 1, 1, 8);
        pipe_barrier(PIPE_V);
    }

    // TROWEXPANDSUB(scoreRowsWorkTile, scoreRowsTile, newMaxRowsTile).
    vbrcb(reinterpret_cast<__ubuf__ uint32_t *>(rowBroadcast_ub),
        reinterpret_cast<__ubuf__ uint32_t *>(newMaxRowsTile_ub),
        1, 8, 16 / 8);
    pipe_barrier(PIPE_V);
    for (int colBlock = 0; colBlock < 2; ++colBlock) {
        vsub(scoreRowsWorkTile_ub + colBlock * 64,
            scoreRowsTile_ub + colBlock * 64,
            rowBroadcast_ub,
            4, 1, 1, 0, 128 / 8, 128 / 8, 1);
    }
    pipe_barrier(PIPE_V);

    vexp(scoreRowsWorkTile_ub, scoreRowsWorkTile_ub,
        (4 * 128 + 63) / 64, 1, 1, 8, 8);
    pipe_barrier(PIPE_V);

    vcadd(rowSumRowsTile_ub, scoreRowsWorkTile_ub,
        4, 1, 1, 128 / 8, 0);
    pipe_barrier(PIPE_V);

    vmul(sumStateRowsView_ub, sumStateRowsView_ub, oldScaleRowsView_ub,
        1, 1, 1, 1, 8, 8, 8);
    pipe_barrier(PIPE_V);
    vadd(sumStateRowsView_ub, sumStateRowsView_ub, rowSumRowsView_ub,
        1, 1, 1, 1, 8, 8, 8);
    pipe_barrier(PIPE_V);

    // Current PTO probability store writes two rows at a time using the
    // existing 256-half scratch tile. CCE still stores the whole sub_m tile.
    for (int row = 0; row < 4; row += 2) {
        vconv_f322f16a(probHalf256_ub,
            scoreRowsWorkTile_ub + row * 128,
            4, 1, 1, 4, 8);
        pipe_barrier(PIPE_V);
        copy_ubuf_to_gm_align_b16(
            reinterpret_cast<__gm__ half *>(probBase + slot * probGroupBytes)
                + (headInGroupBase + row) * 128,
            probHalf256_ub,
            0,
            1,
            256 * sizeof(half),
            0,
            0,
            0,
            0);
    }

    dsb(DSB_DDR);
    pipe_barrier(PIPE_ALL);
    ffts_cross_core_sync(PIPE_MTE3,
        1 | (2 << 4) | ((PTO_PA_RAW_P_READY + slot) << 8));

    // Simpler CCE difference to check:
    // - SoftmaxStage1 stores p_gm_tensor in one bulk copy_ubuf_to_gm call.
    // - It has stage1/stage2 scratch halves instead of slot flags.
    // - It keeps mask/logN/general shape handling; PTO highperf path omits them.
}

// ============================================================================
// 3. CUBE PV Hot Loop Expanded from PTO
// ============================================================================

void PTO_CUBE_PV_EXPANDED_TO_CCE()
{
    wait_flag_dev(PTO_PA_RAW_P_READY + slot);
    if (tile >= 2) {
        wait_flag_dev(PTO_PA_RAW_PV_FREE + slot);
    }

    // TLOAD(pMatTile, probGlobal), Global ND half -> Mat NZ.
    copy_gm_to_cbuf_multi_nd2nz_b16(
        pMatTile_cbuf,
        reinterpret_cast<__gm__ half *>(probBase + slot * probGroupBytes),
        0,
        1,
        16,             // padded probability rows in the group
        128,
        0,
        128,
        16,
        1,
        0);
    pipe_barrier(PIPE_ALL);

    load_cbuf_to_ca(
        pLeftTile_ca,
        pMatTile_cbuf + headInGroupBase * 16,
        0,
        8,              // kTileTokens / 16
        1,
        0,
        0,
        false,
        false,
        addr_cal_mode_t(0));

    // TLOAD(vMatTile, vGlobal), VGlobal is ND half [1,1,1,128,128].
    copy_gm_to_cbuf_multi_nd2nz_b16(
        vMatTile_cbuf,
        reinterpret_cast<__gm__ half *>(vGm) + kvBase,
        0,
        1,
        128,
        128,
        0,
        8 * 128,
        128,
        1,
        0);

    set_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
    wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1);
    set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1);

    for (uint32_t idx = 0; idx < 8; ++idx) {
        load_cbuf_to_cb_transpose(
            rightTile_cb + idx * 128 * 16,
            vMatTile_cbuf + idx * 16 * 16,
            0,
            8,
            8,
            0,
            addr_cal_mode_t(0),
            0);
    }

    set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    mad(accTile_cc, pLeftTile_ca, rightTile_cb,
        16, 128, 128, 0, false, false, true);
    set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
    wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1);
    set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);

    copy_matrix_cc_to_gm(
        reinterpret_cast<__gm__ float *>(outBase + slot * outGroupBytes
            + headInGroupBase * outHeadBytes),
        accTile_cc,
        xmReg_for_ND_float_16x128,
        xtReg_for_ND_float_16x128);

    dsb(DSB_DDR);
    pipe_barrier(PIPE_ALL);
    ffts_cross_core_sync(PIPE_FIX,
        1 | (2 << 4) | ((PTO_PA_RAW_PV_READY + slot) << 8));

    // Simpler CCE difference to check:
    // - ProcessPV loads p_gm_tensor from a bulk probability tile produced by AIV.
    // - It uses l1/l0 ping-pong flags and sets UPDATE_READY_DECODER/STAGE2.
}

// ============================================================================
// 4. AIV Stage2 Accumulate Expanded from PTO
// ============================================================================

void PTO_AIV_STAGE2_EXPANDED_TO_CCE()
{
    wait_flag_dev(PTO_PA_RAW_PV_READY + slot);

    copy_gm_to_ubuf_align_b32(
        pvRowsTile_ub,
        reinterpret_cast<__gm__ float *>(outTmpBase + slot * outGroupBytes
            + headInGroupBase * outHeadBytes),
        0,
        4,
        128 * sizeof(float),
        0,
        0,
        0,
        0);
    set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
    wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);

    // TROWEXPANDMUL(weightedRowsTile, weightedRowsTile, oldScaleRowsTile).
    vbrcb(reinterpret_cast<__ubuf__ uint32_t *>(rowBroadcast_ub),
        reinterpret_cast<__ubuf__ uint32_t *>(oldScaleRowsTile_ub),
        1, 8, 16 / 8);
    pipe_barrier(PIPE_V);
    for (int colBlock = 0; colBlock < 2; ++colBlock) {
        vmul(weightedRowsTile_ub + colBlock * 64,
            weightedRowsTile_ub + colBlock * 64,
            rowBroadcast_ub,
            4, 1, 1, 0, 128 / 8, 128 / 8, 1);
    }
    pipe_barrier(PIPE_V);

    vadd(weightedRowsTile_ub, weightedRowsTile_ub, pvRowsTile_ub,
        (4 * 128 + 63) / 64, 1, 1, 1, 8, 8, 8);
    pipe_barrier(PIPE_V);

    dsb(DSB_DDR);
    pipe_barrier(PIPE_ALL);
    ffts_cross_core_sync(PIPE_MTE3,
        1 | (2 << 4) | ((PTO_PA_RAW_PV_FREE + slot) << 8));

    // Per split output before final combine:
    //   partialL[lOffset] = maxScore + log(sumExp)
    //   TMULS(weightedTile, weightedTile, invSum)
    //   TSTORE(weightedGlobal, weightedTile)
    vln(scalarMathTile_ub, scalarMathTile_ub, 1, 1, 1, 8, 8);
    vmuls(weightedTile_ub, weightedTile_ub, invSum, 2, 1, 1, 8, 8);
    copy_ubuf_to_gm_align_b32(
        reinterpret_cast<__gm__ float *>(partialOut + outOffset),
        weightedTile_ub,
        0,
        1,
        128 * sizeof(float),
        0,
        0,
        0,
        0);
}

// ============================================================================
// 5. Current PTO Split-KV Final Combine Expanded to CCE
// ============================================================================

void PTO_FINAL_COMBINE_EXPANDED_TO_CCE()
{
    // Current PTO source now uses vector exp/sum for split scales, but still
    // scalar-loads L values and scalar-stores final fp16 output.

    __ubuf__ float *splitScale = reinterpret_cast<__ubuf__ float *>(0x3200);
    __ubuf__ float *splitReduce = reinterpret_cast<__ubuf__ float *>(0x3400);

    float lMax = -3.4028234663852886e38f;
    for (int32_t split = 0; split < kvLoop; ++split) {
        float lValue = partialL[lBase + head * ctx.kvSplitCoreNum + split];
        splitScale[split] = lValue;          // scalar GM load + scalar UB store
        lMax = lValue > lMax ? lValue : lMax;
    }

    set_flag(PIPE_S, PIPE_V, EVENT_ID2);
    wait_flag(PIPE_S, PIPE_V, EVENT_ID2);
    set_mask_norm();
    set_vector_mask(0, (1ULL << kvLoop) - 1);

    vadds(splitScale, splitScale, -lMax, 1, 1, 1, 8, 8);
    pipe_barrier(PIPE_V);
    vexp(splitScale, splitScale, 1, 1, 1, 8, 8);
    pipe_barrier(PIPE_V);
    vcadd(splitReduce, splitScale, 1, 1, 1, 8, 0);
    pipe_barrier(PIPE_V);
    set_vector_mask(-1ULL, -1ULL);

    float invDenom = splitReduce[0] > 0.0f ? 1.0f / splitReduce[0] : 0.0f;
    set_mask_norm();
    set_vector_mask(0, (1ULL << kvLoop) - 1);
    vmuls(splitScale, splitScale, invDenom, 1, 1, 1, 8, 8);
    pipe_barrier(PIPE_V);
    set_vector_mask(-1ULL, -1ULL);

    vector_dup(weightedTile_ub, 0.0f, 2, 1, 8);
    pipe_barrier(PIPE_V);

    for (int32_t split = 0; split < kvLoop; ++split) {
        copy_gm_to_ubuf_align_b32(
            pvTile_ub,
            partialOut + oFdBase * ctx.kvSplitCoreNum
                + head * ctx.headDim * ctx.kvSplitCoreNum
                + split * ctx.headDim,
            0,
            1,
            128 * sizeof(float),
            0,
            0,
            0,
            0);
        set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
        wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
        vmuls(pvTile_ub, pvTile_ub, splitScale[split], 2, 1, 1, 8, 8);
        pipe_barrier(PIPE_V);
        vadd(weightedTile_ub, weightedTile_ub, pvTile_ub,
            2, 1, 1, 1, 8, 8, 8);
        pipe_barrier(PIPE_V);
    }

    // Current final store is scalar in PTO, not TSTORE/conv bulk.
    for (int32_t dim = 0; dim < ctx.headDim; ++dim) {
        reinterpret_cast<__gm__ half *>(oGm)[outBase + dim] =
            static_cast<half>(weightedTile_ub[dim]);
    }

    // Simpler CCE CombineScale differences to check line-by-line:
    // - CCE bulk-loads all L values with copy_gm_to_ubuf_align_b32.
    // - CCE reduces with vcmax/vadds/vexp/vcadd/vln over UB.
    // - CCE bulk-loads O partials as m_split x head_dim.
    // - CCE broadcasts each split scale and accumulates full rows.
    // - CCE converts final fp32 tile with vconv_f322f16 and bulk stores with
    //   copy_ubuf_to_gm_align_b16.
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions