Sharded Broadcast with Dual-Stream Pipeline for Reduced Peak NPU Memory#257
Sharded Broadcast with Dual-Stream Pipeline for Reduced Peak NPU Memory#257larksudo wants to merge 6 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a sharded, pipelined broadcast and load mechanism (_pipelined_sharded_broadcast_and_load) in the Ascend cache engine to reduce peak NPU memory usage by interleaving broadcast and GPU loading operations on separate streams. The review feedback identifies several critical issues: a race condition where CPU memory backing reordered_chunks is prematurely freed while the NPU is still reading from it; a runtime AttributeError due to a missing get_size() method on MemoryObjMetadata; a potential memory leak if the metadata broadcast returns an empty plan; and a crash/hang risk in the CPU fallback mechanism because HCCL collectives do not support CPU tensors.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _pipeline_sender( | ||
| self, | ||
| plan: Dict[str, Any], | ||
| reordered_chunks: list, | ||
| load_stream: "torch.npu.Stream", | ||
| **kwargs, | ||
| ) -> None: | ||
| """Rank 0 side of the sharded broadcast pipeline (Phase 3). | ||
|
|
||
| Iterates ``plan["shard_plan"]`` and for each shard submits: | ||
|
|
||
| * On ``broadcast_stream``: pinned CPU→NPU H2D via | ||
| :meth:`to` with ``non_blocking=True``, then ``broadcast_fn`` (HCCL). One | ||
| ``Event`` is recorded after the shard's last broadcast. | ||
| * On ``load_stream``: ``gpu_connector.to_gpu`` per chunk, gated | ||
| on the broadcast ``Event`` via ``wait_event``. | ||
|
|
||
| The loop body contains no collective barriers — the only | ||
| barrier happened in :meth:`_broadcast_metadata_table`. CPU | ||
| returns to the next shard immediately after submitting to_gpu. | ||
| """ | ||
| meta_table: List[Tuple[int, int, Dict[str, Any]]] = plan["meta"] | ||
| shard_plan: List[Tuple[int, int]] = plan["shard_plan"] | ||
|
|
||
| pending: List[Tuple[List[MemoryObj], torch.npu.Event]] = [] | ||
|
|
||
| try: | ||
| for shard_idx, (offset, count) in enumerate(shard_plan): | ||
| t_bc_start = time.perf_counter() | ||
|
|
||
| sub_objs: List[TensorMemoryObj] = [] | ||
| sub_starts: List[int] = [] | ||
| sub_ends: List[int] = [] | ||
| t_h2d_total = 0.0 | ||
|
|
||
| with torch.npu.stream(self.broadcast_stream): | ||
| for i in range(offset, offset + count): | ||
| _, mem_obj, _, _ = reordered_chunks[i] | ||
| start, end_pos, _ = meta_table[i] | ||
|
|
||
| raw = mem_obj.raw_tensor | ||
| if raw is None: | ||
| raise ValueError( | ||
| "rank=0 _pipeline_sender: chunk [%d:%d] " | ||
| "raw_tensor is None " | ||
| "(is_valid=%s, ref_count=%d)." % | ||
| ( | ||
| start, | ||
| end_pos, | ||
| mem_obj.is_valid(), | ||
| mem_obj.get_ref_count(), | ||
| ) | ||
| ) | ||
| t_h2d = time.perf_counter() | ||
| gpu_tensor = raw.to(f"npu:{self.metadata.worker_id}", non_blocking=True) | ||
| t_h2d_total += time.perf_counter() - t_h2d | ||
| self.broadcast_fn(gpu_tensor, self.metadata.first_rank) | ||
|
|
||
| meta = mem_obj.metadata | ||
| meta_copy = MemoryObjMetadata( | ||
| shape=meta.shape, | ||
| dtype=meta.dtype, | ||
| address=meta.address, | ||
| phy_size=meta.phy_size, | ||
| ref_count=1, | ||
| fmt=meta.fmt, | ||
| shapes=meta.shapes, | ||
| dtypes=meta.dtypes, | ||
| ) | ||
| sub_objs.append( | ||
| TensorMemoryObj( | ||
| raw_data=gpu_tensor, | ||
| metadata=meta_copy, | ||
| parent_allocator=None, | ||
| ) | ||
| ) | ||
| sub_starts.append(start) | ||
| sub_ends.append(end_pos) | ||
|
|
||
| ev_bc = torch.npu.Event() | ||
| ev_bc.record() | ||
|
|
||
| t_bc_end = time.perf_counter() | ||
|
|
||
| load_stream.wait_event(ev_bc) | ||
| with torch.npu.stream(load_stream): | ||
| for obj, start, end_pos in zip( | ||
| sub_objs, sub_starts, sub_ends | ||
| ): | ||
| self.gpu_connector.to_gpu( | ||
| obj, start, end_pos, **kwargs | ||
| ) | ||
| ev_togpu = torch.npu.Event() | ||
| ev_togpu.record() | ||
|
|
||
| t_togpu_submit = time.perf_counter() | ||
|
|
||
| logger.debug( | ||
| "rank=0 shard[%d] cnt=%d bc=%.2fms h2d=%.2fms enqueue=%.2fms", | ||
| shard_idx, | ||
| count, | ||
| (t_bc_end - t_bc_start) * 1000, | ||
| t_h2d_total * 1000, | ||
| (t_togpu_submit - t_bc_end) * 1000, | ||
| ) | ||
|
|
||
| pending.append((sub_objs, ev_togpu)) | ||
| self._try_release_pending(pending) | ||
|
|
||
| finally: | ||
| for objs, ev in pending: | ||
| try: | ||
| ev.synchronize() | ||
| except Exception: | ||
| pass | ||
| for obj in objs: | ||
| try: | ||
| obj.ref_count_down() | ||
| except Exception: | ||
| pass | ||
| pending.clear() |
There was a problem hiding this comment.
There is a critical race condition regarding the lifetime of the CPU memory backing reordered_chunks. Since _pipeline_sender submits H2D copies asynchronously and returns immediately, calling memory_obj.ref_count_down() in retrieve right after the pipeline call will prematurely free the CPU memory while the NPU is still reading from it. To prevent data corruption, we should delegate the lifetime management of the original CPU MemoryObjs to the pipeline, adding them to pending so they are only released after the ev_togpu event completes. We also track last_processed_chunk_idx to prevent memory leaks if an exception occurs mid-loop. Note that we do not wrap the broadcast or transfer operations in stream contexts here, as stream context wrapping should be localized inside the specific transfer functions themselves.
def _pipeline_sender(
self,
plan: Dict[str, Any],
reordered_chunks: list,
load_stream: "torch.npu.Stream",
**kwargs,
) -> None:
meta_table: List[Tuple[int, int, Dict[str, Any]]] = plan["meta"]
shard_plan: List[Tuple[int, int]] = plan["shard_plan"]
pending: List[Tuple[List[MemoryObj], torch.npu.Event]] = []
last_processed_chunk_idx = 0
try:
for shard_idx, (offset, count) in enumerate(shard_plan):
t_bc_start = time.perf_counter()
sub_objs: List[TensorMemoryObj] = []
orig_objs: List[MemoryObj] = []
sub_starts: List[int] = []
sub_ends: List[int] = []
t_h2d_total = 0.0
for i in range(offset, offset + count):
_, mem_obj, _, _ = reordered_chunks[i]
orig_objs.append(mem_obj)
start, end_pos, _ = meta_table[i]
raw = mem_obj.raw_tensor
if raw is None:
raise ValueError(
"rank=0 _pipeline_sender: chunk [%d:%d] "
"raw_tensor is None "
"(is_valid=%s, ref_count=%d)." %
(
start,
end_pos,
mem_obj.is_valid(),
mem_obj.get_ref_count(),
)
)
t_h2d = time.perf_counter()
gpu_tensor = raw.to(f"npu:{self.metadata.worker_id}", non_blocking=True)
t_h2d_total += time.perf_counter() - t_h2d
self.broadcast_fn(gpu_tensor, self.metadata.first_rank)
meta = mem_obj.metadata
meta_copy = MemoryObjMetadata(
shape=meta.shape,
dtype=meta.dtype,
address=meta.address,
phy_size=meta.phy_size,
ref_count=1,
fmt=meta.fmt,
shapes=meta.shapes,
dtypes=meta.dtypes,
)
sub_objs.append(
TensorMemoryObj(
raw_data=gpu_tensor,
metadata=meta_copy,
parent_allocator=None,
)
)
sub_starts.append(start)
sub_ends.append(end_pos)
ev_bc = torch.npu.Event()
ev_bc.record()
t_bc_end = time.perf_counter()
load_stream.wait_event(ev_bc)
for obj, start, end_pos in zip(
sub_objs, sub_starts, sub_ends
):
self.gpu_connector.to_gpu(
obj, start, end_pos, **kwargs
)
ev_togpu = torch.npu.Event()
ev_togpu.record()
t_togpu_submit = time.perf_counter()
logger.debug(
"rank=0 shard[%d] cnt=%d bc=%.2fms h2d=%.2fms enqueue=%.2fms",
shard_idx,
count,
(t_bc_end - t_bc_start) * 1000,
t_h2d_total * 1000,
(t_togpu_submit - t_bc_end) * 1000,
)
pending.append((sub_objs + orig_objs, ev_togpu))
last_processed_chunk_idx = offset + count
self._try_release_pending(pending)
finally:
for objs, ev in pending:
try:
ev.synchronize()
except Exception:
pass
for obj in objs:
try:
obj.ref_count_down()
except Exception:
pass
pending.clear()
for i in range(last_processed_chunk_idx, len(reordered_chunks)):
try:
reordered_chunks[i][1].ref_count_down()
except Exception:
passReferences
- When fixing stream synchronization or stream-wrapping issues for specific hardware/device transfer functions, localize the stream context (e.g., wrapping operations in
with torch.npu.stream(...)) inside the specific function itself rather than wrapping it at the caller level.
| for key, memory_obj, _, _ in reordered_chunks: | ||
| if self.remove_after_retrieve and not self._is_passive(): | ||
| if self.storage_manager is None: | ||
| raise ValueError("storage_manager is required for remove") | ||
| self.storage_manager.remove(key, self.retrieve_locations) | ||
| if self._is_sync_pd_backend(): | ||
| memory_obj.ref_count_down() | ||
| elif not self.async_loading: | ||
| memory_obj.ref_count_down() |
There was a problem hiding this comment.
Skip calling ref_count_down() immediately in retrieve when self.save_only_first_rank is True. This delegates the lifetime management of the CPU memory objects to the sharded broadcast pipeline, preventing a race condition where the CPU memory is freed while the NPU is still reading from it.
| for key, memory_obj, _, _ in reordered_chunks: | |
| if self.remove_after_retrieve and not self._is_passive(): | |
| if self.storage_manager is None: | |
| raise ValueError("storage_manager is required for remove") | |
| self.storage_manager.remove(key, self.retrieve_locations) | |
| if self._is_sync_pd_backend(): | |
| memory_obj.ref_count_down() | |
| elif not self.async_loading: | |
| memory_obj.ref_count_down() | |
| for key, memory_obj, _, _ in reordered_chunks: | |
| if self.remove_after_retrieve and not self._is_passive(): | |
| if self.storage_manager is None: | |
| raise ValueError("storage_manager is required for remove") | |
| self.storage_manager.remove(key, self.retrieve_locations) | |
| if self._is_sync_pd_backend() and not self.save_only_first_rank: | |
| memory_obj.ref_count_down() | |
| elif not self.async_loading and not self.save_only_first_rank: | |
| memory_obj.ref_count_down() |
| try: | ||
| return torch.empty( | ||
| torch.Size([metadata.get_size()]), | ||
| dtype=torch.uint8, | ||
| device=f"npu:{local_rank}", | ||
| ) | ||
| except RuntimeError as e: | ||
| logger.warning( | ||
| "rank=%d NPU OOM allocating broadcast buffer for chunk " | ||
| "[%d:%d] (size=%d bytes), falling back to CPU pinned memory. " | ||
| "This chunk will be treated as cache miss. Error: %s", | ||
| self.metadata.worker_id, | ||
| start, | ||
| end, | ||
| metadata.get_size(), | ||
| e, | ||
| ) | ||
| return torch.empty( | ||
| torch.Size([metadata.get_size()]), | ||
| dtype=torch.uint8, | ||
| device="cpu", | ||
| pin_memory=True, | ||
| ) |
There was a problem hiding this comment.
MemoryObjMetadata does not implement a get_size() method. Calling metadata.get_size() will raise an AttributeError at runtime on non-rank-0 workers. Use metadata.phy_size instead, which correctly represents the physical size of the memory object in bytes.
try:
return torch.empty(
torch.Size([metadata.phy_size]),
dtype=torch.uint8,
device=f"npu:{local_rank}",
)
except RuntimeError as e:
logger.warning(
"rank=%d NPU OOM allocating broadcast buffer for chunk "
"[%d:%d] (size=%d bytes), falling back to CPU pinned memory. "
"This chunk will be treated as cache miss. Error: %s",
self.metadata.worker_id,
start,
end,
metadata.phy_size,
e,
)
return torch.empty(
torch.Size([metadata.phy_size]),
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)| plan = self._broadcast_metadata_table( | ||
| reordered_chunks, shard_size, first_rank | ||
| ) | ||
| if plan is None or plan.get("total", 0) == 0: | ||
| return |
There was a problem hiding this comment.
If the metadata table broadcast returns None or an empty plan, we must release the reordered_chunks immediately to prevent memory leaks, since they are no longer handed over to the pipeline.
plan = self._broadcast_metadata_table(
reordered_chunks, shard_size, first_rank
)
if plan is None or plan.get("total", 0) == 0:
if self.metadata.is_first_rank():
for _, mem_obj, _, _ in reordered_chunks:
mem_obj.ref_count_down()
return| raw = self._alloc_broadcast_buffer( | ||
| metadata, local_rank, self.metadata.first_rank, | ||
| start, end_pos, | ||
| ) | ||
| self.broadcast_fn(raw, self.metadata.first_rank) | ||
|
|
||
| if raw.device.type == "cpu": | ||
| logger.warning( | ||
| "rank=%d chunk [%d:%d] received on CPU due to " | ||
| "NPU OOM, skipping to_gpu (cache miss)", | ||
| self.metadata.worker_id, | ||
| start, | ||
| end_pos, | ||
| ) | ||
| continue |
There was a problem hiding this comment.
The CPU fallback mechanism in _alloc_broadcast_buffer returns a CPU tensor on NPU OOM. However, calling self.broadcast_fn (which is an HCCL collective) on a CPU tensor on line 509 will crash, as HCCL only supports NPU tensors. Furthermore, if any rank skips the collective broadcast, it will cause a desynchronization and hang the other ranks. It is better to let the RuntimeError propagate or handle the OOM as a fatal error rather than attempting a CPU fallback that crashes or hangs.
Summary
This PR introduces sharded broadcast for KV cache retrieval in multi-rank (TP) deployments. Instead of broadcasting all KV cache chunks in a single collective and then transferring them to GPU, the broadcast is split into smaller shards that are pipelined through two NPU streams —
broadcast_streamandload_stream— enabling overlap between communication and GPU-side data transfer.Motivation
The upstream implementation broadcasts all KV cache chunks at once, which causes two problems:
to_gpuexecute sequentially — the GPU is idle during broadcast, and the NPU is idle during data transfer.This PR addresses both by sharding the broadcast and pipelining it with
to_gpuon separate streams.Key Changes
cache_engine.py_pipelined_sharded_broadcast_and_load()as the core entry point for sharded broadcast._pipeline_sender()(rank 0) and_pipeline_receiver()(non-rank-0) handle the per-shard loop.broadcast_streamis created whensave_only_first_rank=Trueand aload_streamis available on the GPU connector.Dual-stream pipeline
Each shard records an
Eventonbroadcast_streamafter the broadcast completes. Theload_streamwaits on this event before executingto_gpu. The CPU dispatcher immediately moves to the next shard after submittingto_gpu, allowing the NPU to overlap the next shard's broadcast with the current shard'sto_gpu.Per-shard memory allocation
NPU receive buffers are allocated per shard rather than for the entire KV cache at once. A
_try_release_pending()mechanism reclaims each shard's buffers once itsto_gpuevent fires, keeping peak NPU memory bounded at approximately2 × shard_size × chunk_size.Rank 0 H2D optimization
CPU-to-NPU data transfer on rank 0 uses
to(non_blocking=True)with pinned memory, avoiding synchronous H2D copies that previously dominated broadcast latency.Configuration
Add the following to your LMCache config YAML:
Parameters
save_only_first_rankFalsebroadcast_shard_size0or a value larger than the total chunk count to fall back to the original single-broadcast behavior.16Trade-offs
broadcast_shard_sizeA value of
8or16is recommended for most workloads.