Skip to content

[Feature] Let L3 own/pool the 4 per-step decode input tensors (skip per-token rtMalloc/rtFree) #44

Description

@ChaoWao

Summary

The fused decode dispatch re-allocates (rtMalloc) and frees (rtFree) the
four per-step decode input tensors — hidden, seq_lens, block_table,
slot_mapping — on every decode step, even though their shapes are
constant for the entire generation. Request: have L3 own and pool these
device buffers once and decide per-step whether a re-copy is needed, so
each generated token no longer pays raw CANN rtMalloc/rtFree driver calls
(and redundant H2D copies).

Area: Executor or runtime

Motivation / Use Case

Measured on Qwen3-14B decode (a2a3, batch-16, 256-token context), per-token
TPOT ≈ 52 ms. Layer breakdown (chip-timing style: host_wall / runner_run
/ device_wall from run_prepared, plus the per-bind sub-split):

layer ms
① kernel run time (device makespan) ~32.0
② device init/finalize ~1.2
③ host↔device handshake ~1.0
④ attach + bind + validate ~10.1
  └ bind args_malloc_copy ~2.4
  └ bind prebuilt_arena ~5.7

The args_malloc_copy cost is not the data transfer. Each of the four
input tensors goes through device_malloc → MemoryAllocator::alloc → rtMalloc
(the onboard allocator wraps raw CANN rtMalloc/rtFree with no reuse pool —
it only tracks live pointers in a ptr_set_), and is device_free → rtFree'd
again in the per-run copy-back/cleanup loop. So every decode token pays
rtMalloc + 4× rtFree raw driver calls
; the actual H2D is tiny
(hidden 160 KB + the rest ≈ tens of µs). The shapes never change during
decode, so this alloc/free churn is pure per-token overhead that should be
one-time.

Proposed API / Behavior

Have L3 manage the lifetime and the copy policy of the four decode inputs:

  1. Pool the device buffers: allocate once per callable (keyed by
    callable id + arg slot + size), reuse across all decode steps, free at
    finalize — no per-step rtMalloc/rtFree.
  2. Let L3 decide whether to re-copy (H2D) each step, instead of always
    copying, based on who actually produces the value:
    • hidden — changes every step (embedding of the freshly sampled token)
      → re-copy.
    • seq_lens — is just prompt_len + step (a per-row counter); L3/device
      can maintain it → ideally no H2D.
    • slot_mapping — derivable from seq_len + the page table
      (block_table); L3/device can compute it → ideally no H2D.
    • block_table — changes only when a sequence crosses a page boundary
      (~every page_size = 128 tokens; signal = the KV manager allocated a new
      page) → re-copy only at boundaries.

Net effect: per token goes from "8 driver calls + 4 H2D" to "(at most) the
hidden H2D + an occasional block_table update", with pooled device buffers.

Measurement basis (how the numbers above were obtained)

To avoid a misread: the per-token timing lives at L2, not L3.

  • The L3 DistributedWorker.run forks a chip-child that runs run_prepared
    once per decode step. That L2 run_prepared already computes both
    host_wall and device_wall (the latter is nonzero ~33 ms — the fused
    decode's on-NPU orchestrator wall), and the on-device orch/sched markers
    (Thread N: orch_start/orch_end ... sched_*, emitted at LOG_INFO_V9 under
    the default PTO2_PROFILING, parsed by
    simpler_setup.tools.device_log_timing) give the kernel makespan.
  • The L3 parent's Worker.run returns RunTiming(python_wall, 0). That
    device_wall=0 is expected — L3 is a host dispatcher with no device wall
    of its own; it is not "missing data". The real numbers are at L2; the
    parent simply drops the child's RunTiming.
  • Surfacing the host-side split needs one LOG_INFO_V9 line in
    run_prepared
    (c_api_shared.cpp) printing host_wall / runner_run / device_wall — the same V9 tier the orch/sched markers already use.
    host_wall and device_wall are already computed; runner_run (a
    steady_clock around runner->run()) is the only net-new measurement.

This pooling request and that one-line instrumentation are independent; the
instrumentation just makes the win measurable.

Alternatives Considered

  • Pool only at the runtime allocator level (make MemoryAllocator reuse
    same-size frees): removes the rtMalloc/rtFree churn but still re-copies
    everything and does not let L3 skip redundant copies.
  • Move the bookkeeping fully on-device (device-side seq_len counter +
    page-table lookup for slot_mapping, and on-device sampling + embedding so
    hidden never leaves the device): the ideal end-state — it eliminates the
    per-token host roundtrip entirely — but a larger architectural change. The
    L3-managed pooling requested here is the incremental first step.

Additional Context

Relevant code paths:

  • Host build of the inputs (pypto-serving): npu_runner.py
    _prepare_decode_inputs builds fresh temps, _pad_decode_inputs copies them
    into the persistent compiled.decode_*_buffer (already .share_memory_()'d
    once, so parent↔chip-child is already zero-copy).
  • Per-step device alloc + H2D (simpler runtime): runtime_maker.cpp
    bind_callable_to_runtime_impl (per-tensor device_malloc + copy_to_device),
    with tensor_pairs_ device_free'd in the copy-back/cleanup loop every run.
  • Allocator (simpler): memory_allocator.h/.cpp onboard = raw
    rtMalloc/rtFree, ptr_set_ only (no reuse pool).

Because the persistent host buffers are already shared, the only remaining
per-token waste is on the device side: the rtMalloc/rtFree churn and the
redundant copies. This is a serving/runtime cross-cut — the ask is for L3
(the distributed runtime layer) to own these device buffers and the copy
decision across decode steps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions