Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions lmcache_ascend/integration/vllm/vllm_v1_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,55 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
self.current_layer = 0
self._wait_for_save_done = False
super().start_load_kv(forward_context, **kwargs)
self._mark_failed_p2p_loads_for_recompute()

def _mark_failed_p2p_loads_for_recompute(self) -> None:
gpu_connector = getattr(self.lmcache_engine, "gpu_connector", None)
drain = getattr(gpu_connector, "drain_failed_load_req_ids", None)
if drain is None:
return
failed_req_ids = drain()
if not failed_req_ids:
return

metadata = self._parent._get_connector_metadata()
if not isinstance(metadata, LMCacheConnectorMetadata):
return

for request in metadata.requests:
if request.req_id not in failed_req_ids:
continue
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load:
continue

tokens = request.token_ids
slot_mapping = request.slot_mapping
token_mask = torch.ones(len(tokens), dtype=torch.bool)
masked_token_count = (
load_spec.vllm_cached_tokens
// self._lmcache_chunk_size
* self._lmcache_chunk_size
)
token_mask[:masked_token_count] = False

lmcache_cached_tokens = load_spec.lmcache_cached_tokens
expected_mask = token_mask[:lmcache_cached_tokens]
ret_mask = torch.zeros(lmcache_cached_tokens, dtype=torch.bool)

missing_blocks = self.record_failed_blocks(
request.req_id,
expected_mask,
ret_mask,
slot_mapping[:lmcache_cached_tokens],
)
self._invalid_block_ids.update(missing_blocks)
logger.error(
"Marked %d KV blocks invalid for req %s after P2P pull "
"failure; vLLM will recompute them locally.",
len(missing_blocks),
request.req_id,
)

@_lmcache_nvtx_annotate
def wait_for_save(self):
Expand Down
31 changes: 31 additions & 0 deletions lmcache_ascend/v1/npu_connector/npu_connectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from typing import Any, List, Optional, Set, Union
import threading

# Third Party
from lmcache.integration.vllm.utils import ENGINE_NAME
Expand Down Expand Up @@ -444,6 +445,9 @@ def __init__(

super().__init__(hidden_dim_size, num_layers, use_gpu, **kwargs)

self._failed_load_req_ids: Set[str] = set()
self._failed_load_lock = threading.Lock()

if is_310p():
assert "num_kv_head" in kwargs, ("num_kv_head should be provided in 310p",)
assert "head_size" in kwargs, ("head_size should be provided in 310p",)
Expand Down Expand Up @@ -913,6 +917,22 @@ def batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
self.to_gpu(memory_obj, start, end, **kwargs)
self.load_stream.synchronize()

def _record_failed_load(self, req_id: Optional[str]) -> None:
if not req_id:
logger.error(
"P2P pull failed but no req_id was provided; cannot mark "
"blocks invalid for recompute."
)
return
with self._failed_load_lock:
self._failed_load_req_ids.add(req_id)

def drain_failed_load_req_ids(self) -> Set[str]:
with self._failed_load_lock:
failed = self._failed_load_req_ids
self._failed_load_req_ids = set()
return failed

def _clear_proxy_batch(self, batch) -> None:
"""Clear the backing objects of the proxy batch."""
for proxy, _, _ in batch:
Expand Down Expand Up @@ -1053,6 +1073,17 @@ def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
**kwargs,
)
self._clear_proxy_batch(prev_batch)
except Exception as exc:
req_id = kwargs.get("req_id")
logger.error(
"P2P pull failed for req %s (%s): %s; treating KV as "
"a cache miss for local recompute.",
req_id,
type(exc).__name__,
exc,
exc_info=True,
)
self._record_failed_load(req_id)
finally:
# Guarantee ping-pong buffers are returned and the Done
# signal is sent even if the pipeline raises or
Expand Down
4 changes: 4 additions & 0 deletions lmcache_ascend/v1/proxy_memory_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def resolve(self) -> None:
"Cannot resolve: no backing buffer assigned. Call set_backing_obj() first."
)

self._transfer_context.check_lease()

channel_transfer_spec = {
TS_RECEIVER_ID: self._target_peer_url,
TS_REMOTE_BUFFER_UUIDS: [self._remote_buffer_uuid],
Expand Down Expand Up @@ -244,6 +246,7 @@ def resolve_batch(
return

first = unresolved[0]
first._transfer_context.check_lease()
buffers, channel_transfer_spec = ProxyMemoryObj._collect_batch_read_args(
unresolved
)
Expand Down Expand Up @@ -288,6 +291,7 @@ def submit_resolve_batch(
return None

first = unresolved[0]
first._transfer_context.check_lease()
channel = first._transfer_channel

if not hasattr(channel, "submit_batched_read"):
Expand Down
103 changes: 86 additions & 17 deletions lmcache_ascend/v1/storage_backend/p2p_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class AscendBatchedLookupAndGetRetMsg(BatchedLookupAndGetRetMsg):
# so that the server can identify which mem handle and buffer
# the client is referring to in the subsequent pull request.
remote_mem_indexes: list[int] = []
# Producer slot lease in seconds. Non-zero only under host staging: the
# reader must finish its one-sided read within this window, minus a local
# guard band, before the producer may reclaim the staged arena slot.
lease_ttl_s: float = 0.0


class AscendBatchedLookupAndGetDoneMsg(msgspec.Struct, tag=True):
Expand Down Expand Up @@ -283,6 +287,24 @@ def __init__(
self.dtypes = self.memory_allocator.cpu_allocator.dtypes
self.fmt: MemoryFormat = resolve_memory_format(metadata.use_mla)

self.use_host_staging: bool = bool(
config.get_extra_config_value("use_host_staging", False)
)
self._pull_lease_guard_s: float = float(
config.get_extra_config_value("p2p_pull_lease_guard_s", 15.0)
)
if self.use_host_staging:
assert self.delay_pull, (
"use_host_staging requires p2p_delay_pull=True because the "
"reader must use ping-pong buffers when the producer CPU KV "
"pool is not registered."
)
logger.info(
"P2P host staging enabled (lease_ttl=%.1fs, guard=%.1fs)",
self._pull_pending_ttl,
self._pull_lease_guard_s,
)

buffer_ptrs = [self.memory_allocator.cpu_allocator.buffer_ptr]
buffer_sizes = [self.memory_allocator.cpu_allocator.buffer_size]
buffer_types = ["cpu"]
Expand Down Expand Up @@ -341,6 +363,17 @@ def __init__(
self.chunk_size = config.chunk_size

# Keep transfer-channel ZMQ I/O on the Ascend P2P backend loop.
channel_kwargs: dict[str, Any] = {}
if self.use_host_staging:
channel_kwargs["use_host_staging"] = True
channel_kwargs["staging_shapes"] = self.full_size_shapes
channel_kwargs["staging_dtypes"] = self.dtypes
channel_kwargs["staging_fmt"] = self.fmt
for key in ("os_staging_bytes", "os_staging_copy_threads"):
value = config.get_extra_config_value(key, None)
if value is not None:
channel_kwargs[key] = value

self.transfer_channel = CreateTransferChannel(
channel_type=config.transfer_channel,
async_mode=True,
Expand All @@ -353,6 +386,7 @@ def __init__(
peer_init_url=self.peer_init_url,
peer_lookup_url=self.peer_lookup_url,
event_loop=self.loop,
**channel_kwargs,
)

self.running = asyncio.Event()
Expand Down Expand Up @@ -967,19 +1001,42 @@ async def _handle_batched_lookup_and_get(
# by the receiver's pull request.
remote_buffer_uuids = []
remote_mem_indexes = []
lease_ttl_s = 0.0
if num_hit_chunks > 0 and mem_objs:
remote_buffer_uuids, remote_mem_indexes = (
self.transfer_channel.get_local_buffer_refs(mem_objs)
)
if self.use_host_staging:
(
remote_buffer_uuids,
remote_mem_indexes,
staged_objs,
) = await self.transfer_channel.stage(mem_objs)
release_memory_objects(mem_objs, unpin=True)
should_release = False
num_hit_chunks = len(staged_objs)
if num_hit_chunks > 0:
self.pending_pull_resources[lookup_id] = (
self.loop.time(),
staged_objs,
)
lease_ttl_s = self._pull_pending_ttl
else:
logger.debug(
"Host-staging arena full for lookup_id %s; "
"returning num_hit_chunks=0 to report a miss.",
lookup_id,
)
else:
remote_buffer_uuids, remote_mem_indexes = (
self.transfer_channel.get_local_buffer_refs(mem_objs)
)

# Store mem_objs to prevent premature release.
# Record the timestamp so the TTL sweep can detect
# stale entries if the peer never sends Done.
self.pending_pull_resources[lookup_id] = (
self.loop.time(),
mem_objs,
)
should_release = False
# Store mem_objs to prevent premature release.
# Record the timestamp so the TTL sweep can detect
# stale entries if the peer never sends Done.
self.pending_pull_resources[lookup_id] = (
self.loop.time(),
mem_objs,
)
should_release = False
else:
logger.debug(
"Pull mode enabled but no hit chunks "
Expand All @@ -991,6 +1048,7 @@ async def _handle_batched_lookup_and_get(
num_hit_chunks=num_hit_chunks,
remote_buffer_uuids=remote_buffer_uuids,
remote_mem_indexes=remote_mem_indexes,
lease_ttl_s=lease_ttl_s,
)
else:
remote_buffer_uuids = msg.buffer_uuids
Expand Down Expand Up @@ -1027,8 +1085,11 @@ async def _handle_batched_lookup_and_get_done(
logger.debug("Received Done signal for lookup_id %s", lookup_id)

if lookup_id in self.pending_pull_resources:
_, mem_objs = self.pending_pull_resources.pop(lookup_id)
release_memory_objects(mem_objs, unpin=True)
_, objs = self.pending_pull_resources.pop(lookup_id)
if self.use_host_staging:
self.transfer_channel.release_staged(objs)
else:
release_memory_objects(objs, unpin=True)
logger.debug("Released resources for lookup_id %s", lookup_id)
else:
logger.warning("No pending resources found for lookup_id %s", lookup_id)
Expand Down Expand Up @@ -1056,14 +1117,20 @@ async def _sweep_expired_pending_pull_resources(self):
for pid in expired_ids:
entry = self.pending_pull_resources.pop(pid, None)
if entry is not None:
_, mem_objs = entry
release_memory_objects(mem_objs, unpin=True)
_, objs = entry
if self.use_host_staging:
self.transfer_channel.release_staged(objs)
else:
release_memory_objects(objs, unpin=True)
logger.warning(
"P2P pull mode: TTL expired for lookup_id %s "
" released %d pinned MemObjs "
"- released %d %s "
"(peer may have crashed).",
pid,
len(mem_objs),
len(objs),
"arena slots"
if self.use_host_staging
else "pinned MemObjs",
)
except Exception as e:
logger.error(
Expand Down Expand Up @@ -1660,6 +1727,8 @@ async def batched_get_non_blocking(
dtypes=self.dtypes,
fmt=self.fmt,
use_npu=self.use_npu,
lease_ttl_s=ret_msg.lease_ttl_s,
lease_guard_s=self._pull_lease_guard_s,
)

proxy_objs: list[MemoryObj] = []
Expand Down
Loading
Loading