Skip to content

[RFC] Explicit AIV split: pl.split_aiv + aiv_shard/aic_gather replacing SplitVectorKernel auto-analysis #1820

Description

@Hzfengsy

Summary

Introduce an explicit, user-driven AIV-split DSL for mixed cube+vector (AIC/AIV) kernels, replacing the compiler-inferred splitting in the SplitVectorKernel pass. The surface is three constructs: a for aiv_id in pl.split_aiv(2): SPMD loop that binds the sub-core index, and two boundary ops that mark every full↔half transition across the cube↔vector boundary — pl.aiv_shard(tile, mode) (full→half, AIC→AIV) and pl.aic_gather(tile, mode) (half→full, AIV→AIC). The split mode (UP_DOWN/LEFT_RIGHT) is declared on the transfer ops in 2D-tile vocabulary and validated, not inferred. This deletes the brittle index-space analysis that currently miscompiles split-axis sub-slice stores, and lands on existing, already-correct codegen (tpop_from_aic halves, tpop_from_aiv keeps full) so codegen/PTOAS need no changes.

Related: #1034 (split-axis reduction / GM-mediated semantics), #1447 (SplitVectorKernel odd-N validation), #1789 (cross-core split dual-AIV deadlock).

Motivation

The concrete blocker. A distilled flash-attention inner step (qk_pv: QK^T → online-softmax → PV) at M_TILE=64, HEAD_DIM=512, K_TILE=128 hits a progression of compiler blockers:

  1. L0C (Acc) overflow — two co-resident FP32 matmul accumulators, QK [64,128] (32 KB) + PV [64,512] (128 KB) = 160 KB > 128 KB. Fixed by feat(ir): global largest-first packing for MemoryReuse buffer reuse #1805 (global largest-first packing now reuses the dead QK slot).
  2. Vec overflow — the [64,128] softmax intermediates hit 197632 B > 188416 B. Worked around by adding optimizations=[pl.split(pl.SplitMode.UP_DOWN)], which halves the row axis across the two AIV sub-cores.
  3. [CURRENT] ptoas codegen failure — with the UP_DOWN split active, the strided H_TILE-row stores (mi[r0:r0+16] = qk_mi[r0:r0+16]) produce:
    'pto.tstore' op expects dst static element count (8) to match src valid_shape static element count (16)
    

Root cause. SplitVectorKernel is a syntactic, per-statement shape-halver with no model of the index space. It halves a tile.slice's result type ([16,1][8,1]) but leaves its shape/offset args ([16,1]), manufacturing the mismatched tstore(subview<16>, partition<8>). The failing pattern (strided sub-slice stores along the split axis) is structural to online-softmax, so the failure is endemic under buffer pressure, not incidental. The pass is a per-op whitelist (load/store/full/create/reshape get arg rewrites; tile.slice and everything else fall through to result-type-only halving), which is fragile and unbounded as new ops appear.

Why explicit beats fixing the auto-pass. When the user authors the partition, the slice shapes and result types are self-consistent by construction, so the entire mismatch class disappears. The compiler stops doing index-space inference and instead just validates declared intent and stamps the per-transfer split. Kernel authors already write explicit tile offsets, get_subblock_idx, and cross-core handoffs — this is idiomatic, not new burden.

Design

API (three constructs)

def split_aiv(n: int = 2) -> SplitAivIterator: ...
    # for aiv_id in pl.split_aiv(2):  -> body runs once per AIV sub-core.
    #   aiv_id = pl.tile.get_subblock_idx(); n is HW-fixed at 2 (validated).

def aiv_shard(tile: T, mode: SplitMode) -> T: ...
    # AIC->AIV scatter: this lane's half of a full cube tile.
    # result type = input with the split axis HALVED.

def aic_gather(tile: T, mode: SplitMode) -> T: ...
    # AIV->AIC gather: reassemble per-lane halves into the full tile the cube consumes.
    # result type = input with the split axis DOUBLED.  (inverse of aiv_shard)

mode is SplitMode.UP_DOWN (split rows / axis 0) or SplitMode.LEFT_RIGHT (split cols / axis 1), expressed in 2D-tile vocabulary.

Worked example (the qk_pv kernel)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_pv"):
    qk_raw  = pl.matmul(q, k, b_trans=True, out_dtype=pl.FP32)         # AIC, full [M, K]
    for aiv_id in pl.split_aiv(2):
        qk_h    = pl.aiv_shard(qk_raw, pl.SplitMode.UP_DOWN)           # [M//2, K]  full->half
        qk_sc   = pl.mul(qk_h, SOFTMAX_SCALE)
        qk_mi   = pl.row_max(qk_sc)                                    # [M//2, 1]
        qk_exp  = pl.exp(pl.row_expand_sub(qk_sc, qk_mi))
        qk_li   = pl.row_sum(qk_exp)
        qk_bf16 = pl.cast(qk_exp, target_type=pl.BF16)                 # [M//2, K]
        qk_full = pl.aic_gather(qk_bf16, pl.SplitMode.UP_DOWN)         # [M, K]    half->full
        qk_oi_f = pl.matmul(qk_full, v, out_dtype=pl.FP32)            # [M, HEAD] cube sees FULL, deduce ok
        qk_oi   = pl.aiv_shard(qk_oi_f, pl.SplitMode.UP_DOWN)          # [M//2, HEAD] full->half
        base = aiv_id * (M_TILE // 2)
        for sub in pl.unroll((M_TILE // 2) // H_TILE):
            r0 = sub * H_TILE
            oi[base + r0 : base + r0 + H_TILE, :]  = qk_oi[r0:r0+H_TILE, :]   # GM store: lane's owned rows
            mi[base + r0 : base + r0 + H_TILE, 0:1] = qk_mi[r0:r0+H_TILE, 0:1]
            li[base + r0 : base + r0 + H_TILE, 0:1] = qk_li[r0:r0+H_TILE, 0:1]

The user makes one annotation per full↔half transition. Everything else flows by type.

Semantic model — closed under two boundary ops

A tile is either full (AIC vocabulary, [M,…]) or half (AIV-lane vocabulary, [M//2,…]). aiv_shard is the only full→half edge; aic_gather is the only half→full edge. The type itself ([M] vs [M//2]) encodes which — no split-state analysis or propagation is needed, and cube ops always see full operands so their shape deduction is unchanged. This is what makes qk_oi_f = matmul(qk_full, v) type-check: qk_full is full [M,K] (post-gather), so the matmul deduces [M,HEAD] normally. Feeding a half into a cube op is rejected by the verifier (see below).

Mode is declared, not inferred

The split direction is a property of the 2D tile at the transfertile.tpush_to_aiv/tpop_from_aic carry a split int attr (None=0/UpDown=1/LeftRight=2), and tiles are already 2D by ExpandMixedKernel time (FlattenTileNdTo2D is pass 15, ExpandMixedKernel 21, SplitVectorKernel 23). A generic high-dim slice cannot reliably name "row" vs "col" because the high-dim→2D-tile axis mapping (flatten/layout/transpose) is decided by passes that run after the user's source. So the user declares the mode in 2D-tile vocab on the shard/gather op, and the compiler validates it at the 2D-tile boundary rather than guessing.

Lowering — onto existing, already-correct paths

aiv_shard(x, mode)  ->  cube  tile.tpush_to_aiv(x, split=N)  +  AIV  tile.tpop_from_aic(split=N)   # HALVES (existing)
aic_gather(x, mode) ->  AIV   tile.tpush_to_aic(x, split=N)  +  cube tile.tpop_from_aiv(split=N)    # KEEPS FULL (existing)

From split_vector_kernel_pass.cpp:400-406: "tpop_from_aic: AIV consumes from cube — halve the popped tile … tpop_from_aiv: AIC consumes from vector — keep full tile shape." The two new ops are exactly the two existing pop behaviors surfaced to the user; orchestration codegen (RequiresDualAivDispatch is already boolean) and the .pto/PTOAS layer (per-transfer split) are unchanged.

Compiler-side changes, by pass

  • Frontend (language/dsl_api.py, parser/ast_parser.py): add split_aiv to _VALID_ITERATORS; _parse_split_aiv_for_loop (mirror _parse_spmd_for_loop) opens an InCore scope with attr {"split_aiv": True} and binds aiv_id = tile.get_subblock_idx(). Reject n != 2, loop kwargs (init_values/chunk/stage/deps), and a scope that also carries optimizations=[pl.split(...)].
  • New IR ops tile.aiv_shard / tile.aic_gather (src/ir/op/tile_ops/): one tile arg + split int attr; type inference halves / doubles the mode axis (CHECK_SPAN even extent).
  • ExpandMixedKernel (pass 21): lower aiv_shard/aic_gather to their tpush/tpop pairs with split = int(mode) from the op (today MakeSplitKwargs() stamps the split=0 placeholder; now it carries the declared mode).
  • SplitVectorKernel (pass 23): for split_aiv functions, skip the per-op halving entirely (tiles already half via the shard type). Retain only dual_aiv_dispatch stamping.
  • New AivSplitVerifier (registered in PropertyVerifierRegistry): the validation below.

Validation (pass-time, span'd)

Check Rule
even extent aiv_shard: split-axis even; aic_gather: result extent even
full-operand pairing a cube op may only consume full tiles; a half must pass through aic_gather first (this is the check that catches the matmul(half, …) mistake)
mode consistency a half from aiv_shard(_, a) later aic_gather'd must use axis a (def-use link)
get_subblock_idx scope only valid inside a split_aiv function (else it silently returns 0 and both lanes write the same rows)
GM offset ↔ mode axis a Split(a) tile's GM store must address axis a by aiv_id
slot-size GetCommonSlotSizeBytes succeeds across all transfers
mode policy conservative single-mode gate (see Open Questions) — one deletable CHECK

Multi-mode finding (ISA fact, corrects an earlier assumption)

We verified that per-transfer split modes are mechanically supported by the runtime/PTOAS, not limited to one mode per kernel: the pipe pipe_key (FlagID|DirType|SlotNum|LocalSlotNum|SlotSize) does not encode the split mode; CreateInitializePipe fixes only slot size/direction/depth (GetCommonSlotSizeBytes = max); the split is realized per-tpop ("AIC pushes one full tile, both AIV0/AIV1 consume their halves from the same slot", runtime/.../tpush-tpop-sim.md); and RequiresDualAivDispatch is a boolean. The existing "conflicting cross-core split modes" error in SplitVectorKernel is therefore a software simplification tied to its one-split_dim-per-function rewrite, not a hardware law. The explicit model removes that coupling, so mixed modes become structurally possible (subject to a PTOAS confirmation probe — see Open Questions).

Affected Files

  • python/pypto/language/dsl_api.pyModifysplit_aiv, aiv_shard, aic_gather
  • python/pypto/language/parser/ast_parser.pyModify_VALID_ITERATORS + _parse_split_aiv_for_loop + double-decl reject
  • python/pypto/language/__init__.py, op/__init__.py, pypto_core/*.pyiModify — exports + stubs
  • src/ir/op/tile_ops/aiv_shard.cpp (+ aic_gather) — Newtile.aiv_shard / tile.aic_gather ops + type inference
  • python/bindings/modules/*.cppModify — bindings for the new ops
  • src/ir/transforms/expand_mixed_kernel_pass.cppModify — lower shard/gather to tpush/tpop with split=N
  • src/ir/transforms/split_vector_kernel_pass.cppModify — bypass per-op halving for split_aiv functions
  • src/ir/transforms/aiv_split_verifier.cppNew — validation + registry entry
  • docs/{en,zh-cn}/dev/passes/ + DSL docs — New/Modify — pass doc + pl.split_aiv/aiv_shard/aic_gather

Testing Plan

tests/ut/ir/transforms/test_aiv_shard.py (before/after, ir.assert_structural_equal):

  1. The qk_pv case: aiv_shard(UP_DOWN) + H_TILE sub-slice store compiles through ptoas (the original regression).
  2. LEFT_RIGHT column shard → split=2, even-N validated.
  3. Type check: aiv_shard([M,K], UP_DOWN)[M//2,K]; aic_gather([M//2,K], UP_DOWN)[M,K].
  4. Full-operand pairing: matmul(half, …) without aic_gather raises a span'd error.
  5. Error cases: odd extent; reduce-on-split-axis; get_subblock_idx outside split_aiv; GM offset on the wrong axis; double-declaration with pl.split.
  6. Single-mode gate fires on two differing modes (until lifted).
  7. Codegen parity: post-lowering IR matches the legacy pl.split path for an equivalent whole-tile kernel.

Alternatives Considered

  • Fix the auto-pass (SplitVectorKernel) systematically — add a SplitState (FullAxis/SubRegion) model + coordinate-transform rewrite + keep split-axis loops rolled. Rejected as the primary path: it re-implements the index-space reasoning the pass keeps getting wrong; the unrolled strided store leaves the iteration structure unrecoverable (offsets 32,48 index a 32-row tile). Worth retaining as the legacy path for simple whole-tile kernels.
  • Generic slice as the transfer marker (qk_raw[aiv_id*half:…]) — rejected: a high-dim source slice can't be mapped to the 2D-tile row/col mode (flatten/layout/transpose intervene), so the compiler can't reliably check UP_DOWN vs LEFT_RIGHT.
  • mode= on the split_aiv loop — rejected: the mode is a property of each AIC↔AIV transfer, not of the SPMD loop; multiple transfers may differ.
  • Remove the loop entirely (infer dual-AIV from shard ops, use get_subblock_idx() directly) — rejected: the GM output store needs the lane index, so a bound aiv_id + a scoped per-lane region is cleaner than scattered get_subblock_idx() calls.

Rollout / Migration

  • pl.split_aiv (explicit) and pl.split(mode) (auto) coexist in parallel, converging on byte-identical post-detection IR (split Function attr + dual_aiv_dispatch, tpush/tpop carrying split=N), so codegen and PTOAS need zero changes.
  • Explicit is the recommended path for index-space-aware kernels; auto is retained as the simple on-ramp. Do not make auto lower to explicit (synthesizing correct aiv_id-indexed slices for arbitrary kernels is exactly the inference the auto path lacks).
  • No backward-incompatible change: existing pl.split users are untouched; new ops are additive.

Security Considerations

None. This is a compiler-internal DSL/IR change with no new external input surface, network, or secret handling. The added verifier strictly tightens validation (turns silent miscompiles into pass-time errors).

Open Questions

  1. Region boundary / cube roundtrip. How the cube assembles two lanes' halves for an intermediate matmul (via aic_gathertpop_from_aiv full path) needs a prototype to confirm against PTOAS on a small two-lane example before the pass is written.
  2. PTOAS mixed-mode probe. Hand-author a .pto with two tpops (split=1, split=2) and run ptoas to definitively decide whether the conservative single-mode gate can be deleted.
  3. Naming. aic_gather (scatter/gather symmetry) vs aic_merge / aic_assemble. Feedback welcome.
  4. Pure-AIV split (no cube transfer). Has no shard op to trigger dual-dispatch; out of scope here, or add a thin with pl.dual_aiv() marker only if a real case appears.

Metadata

Metadata

Assignees

No one assigned

    Labels

    rfcDesign proposal / request for comments

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    Status
    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions