Skip to content

Sharded Broadcast with Dual-Stream Pipeline for Reduced Peak NPU Memory#257

Open
larksudo wants to merge 6 commits into
LMCache:mainfrom
larksudo:broadcast3
Open

Sharded Broadcast with Dual-Stream Pipeline for Reduced Peak NPU Memory#257
larksudo wants to merge 6 commits into
LMCache:mainfrom
larksudo:broadcast3

Conversation

@larksudo

Copy link
Copy Markdown
Collaborator

Summary

This PR introduces sharded broadcast for KV cache retrieval in multi-rank (TP) deployments. Instead of broadcasting all KV cache chunks in a single collective and then transferring them to GPU, the broadcast is split into smaller shards that are pipelined through two NPU streams — broadcast_stream and load_stream — enabling overlap between communication and GPU-side data transfer.

Motivation

The upstream implementation broadcasts all KV cache chunks at once, which causes two problems:

  1. Peak NPU memory pressure: all receive buffers for the entire KV cache are allocated simultaneously, driving up peak memory usage linearly with context length.
  2. No overlap: broadcast and to_gpu execute sequentially — the GPU is idle during broadcast, and the NPU is idle during data transfer.

This PR addresses both by sharding the broadcast and pipelining it with to_gpu on separate streams.

Key Changes

cache_engine.py

  • Added _pipelined_sharded_broadcast_and_load() as the core entry point for sharded broadcast.
  • _pipeline_sender() (rank 0) and _pipeline_receiver() (non-rank-0) handle the per-shard loop.
  • A dedicated broadcast_stream is created when save_only_first_rank=True and a load_stream is available on the GPU connector.

Dual-stream pipeline

Each shard records an Event on broadcast_stream after the broadcast completes. The load_stream waits on this event before executing to_gpu. The CPU dispatcher immediately moves to the next shard after submitting to_gpu, allowing the NPU to overlap the next shard's broadcast with the current shard's to_gpu.

Time →

broadcast_stream:  |--bc shard[0]--|--bc shard[1]--|--bc shard[2]--|...
load_stream:                        |--to_gpu shard[0]--|--to_gpu shard[1]--|...

Per-shard memory allocation

NPU receive buffers are allocated per shard rather than for the entire KV cache at once. A _try_release_pending() mechanism reclaims each shard's buffers once its to_gpu event fires, keeping peak NPU memory bounded at approximately 2 × shard_size × chunk_size.

Rank 0 H2D optimization

CPU-to-NPU data transfer on rank 0 uses to(non_blocking=True) with pinned memory, avoiding synchronous H2D copies that previously dominated broadcast latency.

Configuration

Add the following to your LMCache config YAML:

chunk_size: 256
local_cpu: True
max_local_cpu_size: 2

extra_config:
  save_only_first_rank: True
  first_rank_max_local_cpu_size: 50
  broadcast_shard_size: 8   # number of chunks per broadcast shard (default: 16)

Parameters

Parameter Description Default
save_only_first_rank Only rank 0 persists KV cache to storage. Enables the dual-stream pipeline; non-rank-0 workers skip store and lookup operations, reducing CPU-side overhead. False
broadcast_shard_size Number of chunks per broadcast shard. Smaller values reduce peak NPU memory but increase the number of shard iterations. Set to 0 or a value larger than the total chunk count to fall back to the original single-broadcast behavior. 16

Trade-offs

broadcast_shard_size Peak NPU Memory Pipeline Overhead
Small (e.g., 4) Lower More shard iterations
Large (e.g., 32) Higher Fewer shard iterations
0 / max Same as upstream No sharding (original behavior)

A value of 8 or 16 is recommended for most workloads.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a sharded, pipelined broadcast and load mechanism (_pipelined_sharded_broadcast_and_load) in the Ascend cache engine to reduce peak NPU memory usage by interleaving broadcast and GPU loading operations on separate streams. The review feedback identifies several critical issues: a race condition where CPU memory backing reordered_chunks is prematurely freed while the NPU is still reading from it; a runtime AttributeError due to a missing get_size() method on MemoryObjMetadata; a potential memory leak if the metadata broadcast returns an empty plan; and a crash/hang risk in the CPU fallback mechanism because HCCL collectives do not support CPU tensors.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +337 to +457
def _pipeline_sender(
self,
plan: Dict[str, Any],
reordered_chunks: list,
load_stream: "torch.npu.Stream",
**kwargs,
) -> None:
"""Rank 0 side of the sharded broadcast pipeline (Phase 3).

Iterates ``plan["shard_plan"]`` and for each shard submits:

* On ``broadcast_stream``: pinned CPU→NPU H2D via
:meth:`to` with ``non_blocking=True``, then ``broadcast_fn`` (HCCL). One
``Event`` is recorded after the shard's last broadcast.
* On ``load_stream``: ``gpu_connector.to_gpu`` per chunk, gated
on the broadcast ``Event`` via ``wait_event``.

The loop body contains no collective barriers — the only
barrier happened in :meth:`_broadcast_metadata_table`. CPU
returns to the next shard immediately after submitting to_gpu.
"""
meta_table: List[Tuple[int, int, Dict[str, Any]]] = plan["meta"]
shard_plan: List[Tuple[int, int]] = plan["shard_plan"]

pending: List[Tuple[List[MemoryObj], torch.npu.Event]] = []

try:
for shard_idx, (offset, count) in enumerate(shard_plan):
t_bc_start = time.perf_counter()

sub_objs: List[TensorMemoryObj] = []
sub_starts: List[int] = []
sub_ends: List[int] = []
t_h2d_total = 0.0

with torch.npu.stream(self.broadcast_stream):
for i in range(offset, offset + count):
_, mem_obj, _, _ = reordered_chunks[i]
start, end_pos, _ = meta_table[i]

raw = mem_obj.raw_tensor
if raw is None:
raise ValueError(
"rank=0 _pipeline_sender: chunk [%d:%d] "
"raw_tensor is None "
"(is_valid=%s, ref_count=%d)." %
(
start,
end_pos,
mem_obj.is_valid(),
mem_obj.get_ref_count(),
)
)
t_h2d = time.perf_counter()
gpu_tensor = raw.to(f"npu:{self.metadata.worker_id}", non_blocking=True)
t_h2d_total += time.perf_counter() - t_h2d
self.broadcast_fn(gpu_tensor, self.metadata.first_rank)

meta = mem_obj.metadata
meta_copy = MemoryObjMetadata(
shape=meta.shape,
dtype=meta.dtype,
address=meta.address,
phy_size=meta.phy_size,
ref_count=1,
fmt=meta.fmt,
shapes=meta.shapes,
dtypes=meta.dtypes,
)
sub_objs.append(
TensorMemoryObj(
raw_data=gpu_tensor,
metadata=meta_copy,
parent_allocator=None,
)
)
sub_starts.append(start)
sub_ends.append(end_pos)

ev_bc = torch.npu.Event()
ev_bc.record()

t_bc_end = time.perf_counter()

load_stream.wait_event(ev_bc)
with torch.npu.stream(load_stream):
for obj, start, end_pos in zip(
sub_objs, sub_starts, sub_ends
):
self.gpu_connector.to_gpu(
obj, start, end_pos, **kwargs
)
ev_togpu = torch.npu.Event()
ev_togpu.record()

t_togpu_submit = time.perf_counter()

logger.debug(
"rank=0 shard[%d] cnt=%d bc=%.2fms h2d=%.2fms enqueue=%.2fms",
shard_idx,
count,
(t_bc_end - t_bc_start) * 1000,
t_h2d_total * 1000,
(t_togpu_submit - t_bc_end) * 1000,
)

pending.append((sub_objs, ev_togpu))
self._try_release_pending(pending)

finally:
for objs, ev in pending:
try:
ev.synchronize()
except Exception:
pass
for obj in objs:
try:
obj.ref_count_down()
except Exception:
pass
pending.clear()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical race condition regarding the lifetime of the CPU memory backing reordered_chunks. Since _pipeline_sender submits H2D copies asynchronously and returns immediately, calling memory_obj.ref_count_down() in retrieve right after the pipeline call will prematurely free the CPU memory while the NPU is still reading from it. To prevent data corruption, we should delegate the lifetime management of the original CPU MemoryObjs to the pipeline, adding them to pending so they are only released after the ev_togpu event completes. We also track last_processed_chunk_idx to prevent memory leaks if an exception occurs mid-loop. Note that we do not wrap the broadcast or transfer operations in stream contexts here, as stream context wrapping should be localized inside the specific transfer functions themselves.

    def _pipeline_sender(
        self,
        plan: Dict[str, Any],
        reordered_chunks: list,
        load_stream: "torch.npu.Stream",
        **kwargs,
    ) -> None:
        meta_table: List[Tuple[int, int, Dict[str, Any]]] = plan["meta"]
        shard_plan: List[Tuple[int, int]] = plan["shard_plan"]

        pending: List[Tuple[List[MemoryObj], torch.npu.Event]] = []
        last_processed_chunk_idx = 0

        try:
            for shard_idx, (offset, count) in enumerate(shard_plan):
                t_bc_start = time.perf_counter()

                sub_objs: List[TensorMemoryObj] = []
                orig_objs: List[MemoryObj] = []
                sub_starts: List[int] = []
                sub_ends: List[int] = []
                t_h2d_total = 0.0

                for i in range(offset, offset + count):
                    _, mem_obj, _, _ = reordered_chunks[i]
                    orig_objs.append(mem_obj)
                    start, end_pos, _ = meta_table[i]

                    raw = mem_obj.raw_tensor
                    if raw is None:
                        raise ValueError(
                            "rank=0 _pipeline_sender: chunk [%d:%d] "
                            "raw_tensor is None "
                            "(is_valid=%s, ref_count=%d)." %
                            (
                                start,
                                end_pos,
                                mem_obj.is_valid(),
                                mem_obj.get_ref_count(),
                            )
                        )
                    t_h2d = time.perf_counter()
                    gpu_tensor = raw.to(f"npu:{self.metadata.worker_id}", non_blocking=True)
                    t_h2d_total += time.perf_counter() - t_h2d
                    self.broadcast_fn(gpu_tensor, self.metadata.first_rank)

                    meta = mem_obj.metadata
                    meta_copy = MemoryObjMetadata(
                        shape=meta.shape,
                        dtype=meta.dtype,
                        address=meta.address,
                        phy_size=meta.phy_size,
                        ref_count=1,
                        fmt=meta.fmt,
                        shapes=meta.shapes,
                        dtypes=meta.dtypes,
                    )
                    sub_objs.append(
                        TensorMemoryObj(
                            raw_data=gpu_tensor,
                            metadata=meta_copy,
                            parent_allocator=None,
                        )
                    )
                    sub_starts.append(start)
                    sub_ends.append(end_pos)

                ev_bc = torch.npu.Event()
                ev_bc.record()

                t_bc_end = time.perf_counter()

                load_stream.wait_event(ev_bc)
                for obj, start, end_pos in zip(
                    sub_objs, sub_starts, sub_ends
                ):
                    self.gpu_connector.to_gpu(
                        obj, start, end_pos, **kwargs
                    )
                ev_togpu = torch.npu.Event()
                ev_togpu.record()

                t_togpu_submit = time.perf_counter()

                logger.debug(
                    "rank=0 shard[%d] cnt=%d bc=%.2fms h2d=%.2fms enqueue=%.2fms",
                    shard_idx,
                    count,
                    (t_bc_end - t_bc_start) * 1000,
                    t_h2d_total * 1000,
                    (t_togpu_submit - t_bc_end) * 1000,
                )

                pending.append((sub_objs + orig_objs, ev_togpu))
                last_processed_chunk_idx = offset + count
                self._try_release_pending(pending)

        finally:
            for objs, ev in pending:
                try:
                    ev.synchronize()
                except Exception:
                    pass
                for obj in objs:
                    try:
                        obj.ref_count_down()
                    except Exception:
                        pass
            pending.clear()

            for i in range(last_processed_chunk_idx, len(reordered_chunks)):
                try:
                    reordered_chunks[i][1].ref_count_down()
                except Exception:
                    pass
References
  1. When fixing stream synchronization or stream-wrapping issues for specific hardware/device transfer functions, localize the stream context (e.g., wrapping operations in with torch.npu.stream(...)) inside the specific function itself rather than wrapping it at the caller level.

Comment on lines +668 to +676
for key, memory_obj, _, _ in reordered_chunks:
if self.remove_after_retrieve and not self._is_passive():
if self.storage_manager is None:
raise ValueError("storage_manager is required for remove")
self.storage_manager.remove(key, self.retrieve_locations)
if self._is_sync_pd_backend():
memory_obj.ref_count_down()
elif not self.async_loading:
memory_obj.ref_count_down()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Skip calling ref_count_down() immediately in retrieve when self.save_only_first_rank is True. This delegates the lifetime management of the CPU memory objects to the sharded broadcast pipeline, preventing a race condition where the CPU memory is freed while the NPU is still reading from it.

Suggested change
for key, memory_obj, _, _ in reordered_chunks:
if self.remove_after_retrieve and not self._is_passive():
if self.storage_manager is None:
raise ValueError("storage_manager is required for remove")
self.storage_manager.remove(key, self.retrieve_locations)
if self._is_sync_pd_backend():
memory_obj.ref_count_down()
elif not self.async_loading:
memory_obj.ref_count_down()
for key, memory_obj, _, _ in reordered_chunks:
if self.remove_after_retrieve and not self._is_passive():
if self.storage_manager is None:
raise ValueError("storage_manager is required for remove")
self.storage_manager.remove(key, self.retrieve_locations)
if self._is_sync_pd_backend() and not self.save_only_first_rank:
memory_obj.ref_count_down()
elif not self.async_loading and not self.save_only_first_rank:
memory_obj.ref_count_down()

Comment on lines +182 to +204
try:
return torch.empty(
torch.Size([metadata.get_size()]),
dtype=torch.uint8,
device=f"npu:{local_rank}",
)
except RuntimeError as e:
logger.warning(
"rank=%d NPU OOM allocating broadcast buffer for chunk "
"[%d:%d] (size=%d bytes), falling back to CPU pinned memory. "
"This chunk will be treated as cache miss. Error: %s",
self.metadata.worker_id,
start,
end,
metadata.get_size(),
e,
)
return torch.empty(
torch.Size([metadata.get_size()]),
dtype=torch.uint8,
device="cpu",
pin_memory=True,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

MemoryObjMetadata does not implement a get_size() method. Calling metadata.get_size() will raise an AttributeError at runtime on non-rank-0 workers. Use metadata.phy_size instead, which correctly represents the physical size of the memory object in bytes.

        try:
            return torch.empty( 
                torch.Size([metadata.phy_size]),
                dtype=torch.uint8,
                device=f"npu:{local_rank}",
            )
        except RuntimeError as e:
            logger.warning(
                "rank=%d NPU OOM allocating broadcast buffer for chunk "
                "[%d:%d] (size=%d bytes), falling back to CPU pinned memory. "
                "This chunk will be treated as cache miss. Error: %s",
                self.metadata.worker_id,
                start,
                end,
                metadata.phy_size,
                e,
            )
            return torch.empty(
                torch.Size([metadata.phy_size]),
                dtype=torch.uint8,
                device="cpu",
                pin_memory=True,
            )

Comment on lines +316 to +320
plan = self._broadcast_metadata_table(
reordered_chunks, shard_size, first_rank
)
if plan is None or plan.get("total", 0) == 0:
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the metadata table broadcast returns None or an empty plan, we must release the reordered_chunks immediately to prevent memory leaks, since they are no longer handed over to the pipeline.

        plan = self._broadcast_metadata_table(
            reordered_chunks, shard_size, first_rank
        )
        if plan is None or plan.get("total", 0) == 0:
            if self.metadata.is_first_rank():
                for _, mem_obj, _, _ in reordered_chunks:
                    mem_obj.ref_count_down()
            return

Comment on lines +505 to +519
raw = self._alloc_broadcast_buffer(
metadata, local_rank, self.metadata.first_rank,
start, end_pos,
)
self.broadcast_fn(raw, self.metadata.first_rank)

if raw.device.type == "cpu":
logger.warning(
"rank=%d chunk [%d:%d] received on CPU due to "
"NPU OOM, skipping to_gpu (cache miss)",
self.metadata.worker_id,
start,
end_pos,
)
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CPU fallback mechanism in _alloc_broadcast_buffer returns a CPU tensor on NPU OOM. However, calling self.broadcast_fn (which is an HCCL collective) on a CPU tensor on line 509 will crash, as HCCL only supports NPU tensors. Furthermore, if any rank skips the collective broadcast, it will cause a desynchronization and hang the other ranks. It is better to let the RuntimeError propagate or handle the OOM as a fatal error rather than attempting a CPU fallback that crashes or hangs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant