Summary
The fused decode dispatch re-allocates (rtMalloc) and frees (rtFree) the
four per-step decode input tensors — hidden, seq_lens, block_table,
slot_mapping — on every decode step, even though their shapes are
constant for the entire generation. Request: have L3 own and pool these
device buffers once and decide per-step whether a re-copy is needed, so
each generated token no longer pays raw CANN rtMalloc/rtFree driver calls
(and redundant H2D copies).
Area: Executor or runtime
Motivation / Use Case
Measured on Qwen3-14B decode (a2a3, batch-16, 256-token context), per-token
TPOT ≈ 52 ms. Layer breakdown (chip-timing style: host_wall / runner_run
/ device_wall from run_prepared, plus the per-bind sub-split):
| layer |
ms |
| ① kernel run time (device makespan) |
~32.0 |
| ② device init/finalize |
~1.2 |
| ③ host↔device handshake |
~1.0 |
| ④ attach + bind + validate |
~10.1 |
└ bind args_malloc_copy |
~2.4 |
└ bind prebuilt_arena |
~5.7 |
The args_malloc_copy cost is not the data transfer. Each of the four
input tensors goes through device_malloc → MemoryAllocator::alloc → rtMalloc
(the onboard allocator wraps raw CANN rtMalloc/rtFree with no reuse pool —
it only tracks live pointers in a ptr_set_), and is device_free → rtFree'd
again in the per-run copy-back/cleanup loop. So every decode token pays
4× rtMalloc + 4× rtFree raw driver calls; the actual H2D is tiny
(hidden 160 KB + the rest ≈ tens of µs). The shapes never change during
decode, so this alloc/free churn is pure per-token overhead that should be
one-time.
Proposed API / Behavior
Have L3 manage the lifetime and the copy policy of the four decode inputs:
- Pool the device buffers: allocate once per callable (keyed by
callable id + arg slot + size), reuse across all decode steps, free at
finalize — no per-step rtMalloc/rtFree.
- Let L3 decide whether to re-copy (H2D) each step, instead of always
copying, based on who actually produces the value:
hidden — changes every step (embedding of the freshly sampled token)
→ re-copy.
seq_lens — is just prompt_len + step (a per-row counter); L3/device
can maintain it → ideally no H2D.
slot_mapping — derivable from seq_len + the page table
(block_table); L3/device can compute it → ideally no H2D.
block_table — changes only when a sequence crosses a page boundary
(~every page_size = 128 tokens; signal = the KV manager allocated a new
page) → re-copy only at boundaries.
Net effect: per token goes from "8 driver calls + 4 H2D" to "(at most) the
hidden H2D + an occasional block_table update", with pooled device buffers.
Measurement basis (how the numbers above were obtained)
To avoid a misread: the per-token timing lives at L2, not L3.
- The L3
DistributedWorker.run forks a chip-child that runs run_prepared
once per decode step. That L2 run_prepared already computes both
host_wall and device_wall (the latter is nonzero ~33 ms — the fused
decode's on-NPU orchestrator wall), and the on-device orch/sched markers
(Thread N: orch_start/orch_end ... sched_*, emitted at LOG_INFO_V9 under
the default PTO2_PROFILING, parsed by
simpler_setup.tools.device_log_timing) give the kernel makespan.
- The L3 parent's
Worker.run returns RunTiming(python_wall, 0). That
device_wall=0 is expected — L3 is a host dispatcher with no device wall
of its own; it is not "missing data". The real numbers are at L2; the
parent simply drops the child's RunTiming.
- Surfacing the host-side split needs one
LOG_INFO_V9 line in
run_prepared (c_api_shared.cpp) printing host_wall / runner_run / device_wall — the same V9 tier the orch/sched markers already use.
host_wall and device_wall are already computed; runner_run (a
steady_clock around runner->run()) is the only net-new measurement.
This pooling request and that one-line instrumentation are independent; the
instrumentation just makes the win measurable.
Alternatives Considered
- Pool only at the runtime allocator level (make
MemoryAllocator reuse
same-size frees): removes the rtMalloc/rtFree churn but still re-copies
everything and does not let L3 skip redundant copies.
- Move the bookkeeping fully on-device (device-side
seq_len counter +
page-table lookup for slot_mapping, and on-device sampling + embedding so
hidden never leaves the device): the ideal end-state — it eliminates the
per-token host roundtrip entirely — but a larger architectural change. The
L3-managed pooling requested here is the incremental first step.
Additional Context
Relevant code paths:
- Host build of the inputs (pypto-serving):
npu_runner.py
_prepare_decode_inputs builds fresh temps, _pad_decode_inputs copies them
into the persistent compiled.decode_*_buffer (already .share_memory_()'d
once, so parent↔chip-child is already zero-copy).
- Per-step device alloc + H2D (simpler runtime):
runtime_maker.cpp
bind_callable_to_runtime_impl (per-tensor device_malloc + copy_to_device),
with tensor_pairs_ device_free'd in the copy-back/cleanup loop every run.
- Allocator (simpler):
memory_allocator.h/.cpp onboard = raw
rtMalloc/rtFree, ptr_set_ only (no reuse pool).
Because the persistent host buffers are already shared, the only remaining
per-token waste is on the device side: the rtMalloc/rtFree churn and the
redundant copies. This is a serving/runtime cross-cut — the ask is for L3
(the distributed runtime layer) to own these device buffers and the copy
decision across decode steps.
Summary
The fused decode dispatch re-allocates (
rtMalloc) and frees (rtFree) thefour per-step decode input tensors —
hidden,seq_lens,block_table,slot_mapping— on every decode step, even though their shapes areconstant for the entire generation. Request: have L3 own and pool these
device buffers once and decide per-step whether a re-copy is needed, so
each generated token no longer pays raw CANN
rtMalloc/rtFreedriver calls(and redundant H2D copies).
Area: Executor or runtime
Motivation / Use Case
Measured on Qwen3-14B decode (a2a3, batch-16, 256-token context), per-token
TPOT ≈ 52 ms. Layer breakdown (chip-timing style:
host_wall/runner_run/
device_wallfromrun_prepared, plus the per-bind sub-split):args_malloc_copyprebuilt_arenaThe
args_malloc_copycost is not the data transfer. Each of the fourinput tensors goes through
device_malloc → MemoryAllocator::alloc → rtMalloc(the onboard allocator wraps raw CANN
rtMalloc/rtFreewith no reuse pool —it only tracks live pointers in a
ptr_set_), and isdevice_free → rtFree'dagain in the per-run copy-back/cleanup loop. So every decode token pays
4×
rtMalloc+ 4×rtFreeraw driver calls; the actual H2D is tiny(
hidden160 KB + the rest ≈ tens of µs). The shapes never change duringdecode, so this alloc/free churn is pure per-token overhead that should be
one-time.
Proposed API / Behavior
Have L3 manage the lifetime and the copy policy of the four decode inputs:
callable id + arg slot + size), reuse across all decode steps, free at
finalize— no per-steprtMalloc/rtFree.copying, based on who actually produces the value:
hidden— changes every step (embedding of the freshly sampled token)→ re-copy.
seq_lens— is justprompt_len + step(a per-row counter); L3/devicecan maintain it → ideally no H2D.
slot_mapping— derivable fromseq_len+ the page table(
block_table); L3/device can compute it → ideally no H2D.block_table— changes only when a sequence crosses a page boundary(~every
page_size= 128 tokens; signal = the KV manager allocated a newpage) → re-copy only at boundaries.
Net effect: per token goes from "8 driver calls + 4 H2D" to "(at most) the
hiddenH2D + an occasionalblock_tableupdate", with pooled device buffers.Measurement basis (how the numbers above were obtained)
To avoid a misread: the per-token timing lives at L2, not L3.
DistributedWorker.runforks a chip-child that runsrun_preparedonce per decode step. That L2
run_preparedalready computes bothhost_wallanddevice_wall(the latter is nonzero ~33 ms — the fuseddecode's on-NPU orchestrator wall), and the on-device orch/sched markers
(
Thread N: orch_start/orch_end ... sched_*, emitted atLOG_INFO_V9underthe default
PTO2_PROFILING, parsed bysimpler_setup.tools.device_log_timing) give the kernel makespan.Worker.runreturnsRunTiming(python_wall, 0). Thatdevice_wall=0is expected — L3 is a host dispatcher with no device wallof its own; it is not "missing data". The real numbers are at L2; the
parent simply drops the child's
RunTiming.LOG_INFO_V9line inrun_prepared(c_api_shared.cpp) printinghost_wall / runner_run / device_wall— the same V9 tier the orch/sched markers already use.host_wallanddevice_wallare already computed;runner_run(asteady_clockaroundrunner->run()) is the only net-new measurement.This pooling request and that one-line instrumentation are independent; the
instrumentation just makes the win measurable.
Alternatives Considered
MemoryAllocatorreusesame-size frees): removes the
rtMalloc/rtFreechurn but still re-copieseverything and does not let L3 skip redundant copies.
seq_lencounter +page-table lookup for
slot_mapping, and on-device sampling + embedding sohiddennever leaves the device): the ideal end-state — it eliminates theper-token host roundtrip entirely — but a larger architectural change. The
L3-managed pooling requested here is the incremental first step.
Additional Context
Relevant code paths:
npu_runner.py_prepare_decode_inputsbuilds fresh temps,_pad_decode_inputscopies theminto the persistent
compiled.decode_*_buffer(already.share_memory_()'donce, so parent↔chip-child is already zero-copy).
runtime_maker.cppbind_callable_to_runtime_impl(per-tensordevice_malloc+copy_to_device),with
tensor_pairs_device_free'd in the copy-back/cleanup loop every run.memory_allocator.h/.cpponboard = rawrtMalloc/rtFree,ptr_set_only (no reuse pool).Because the persistent host buffers are already shared, the only remaining
per-token waste is on the device side: the
rtMalloc/rtFreechurn and theredundant copies. This is a serving/runtime cross-cut — the ask is for L3
(the distributed runtime layer) to own these device buffers and the copy
decision across decode steps.