Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/ucm_config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ enable_event_sync: true
# Enable UCM metrics so they can be monitored online via Grafana and Prometheus.
# metrics_config_path: "/workspace/unified-cache-management/examples/metrics/metrics_configs.yaml"

chunk_size: 256

# Sparse attention configuration
# ucm_sparse_config:
# GSAOnDevice: {}
Expand Down
3 changes: 3 additions & 0 deletions test/suites/E2E/test_online_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def test_online_accuracy_hbm_ssd_mixed(
if ucm_connector_name == "UcmNfsStore"
else {}
),
"chunk_size": 256,
}

# Build vllm_server_startup_args
Expand Down Expand Up @@ -183,6 +184,7 @@ def test_online_accuracy_hbm_ssd_mixed_pp(
if ucm_connector_name == "UcmNfsStore"
else {}
),
"chunk_size": 256,
}

# Build vllm_server_startup_args with pipeline parallel size
Expand Down Expand Up @@ -256,6 +258,7 @@ def test_online_accuracy_hbm_ssd_mixed_tp(
if ucm_connector_name == "UcmNfsStore"
else {}
),
"chunk_size": 256,
}

# Build vllm_server_startup_args with tensor parallel size
Expand Down
62 changes: 48 additions & 14 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class RequestDispatchMeta:

class KVCacheLayout:
def __init__(
self, kvcaches, use_layerwise: bool, vllm_config: "VllmConfig"
self, kvcaches, launch_config: dict, vllm_config: "VllmConfig"
) -> None:
# each row is a layer, each column is a tensor_size/ptr in the layer (e.g., k, v, rope, k_index)
self.base_ptrs: np.ndarray # (n_layers, n_ptrs)
self.tensor_size_lists: np.ndarray # (n_layers, n_tensor_sizes)
self.use_layerwise = use_layerwise
self.use_layerwise = launch_config.get("use_layerwise", False)
self.vllm_config = vllm_config
self.pp_size = self.vllm_config.parallel_config.pipeline_parallel_size
self.num_hidden_layers = getattr(
Expand Down Expand Up @@ -246,7 +246,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.enable_event_sync = self.launch_config.get("enable_event_sync", True)
assert len(self.connector_configs) > 0, "no storage connector name in config."

self.chunk_size = self.block_size
self.chunk_size = self.launch_config.get("chunk_size", self.block_size)
assert (
self.chunk_size % self.block_size == 0
), "chunk_size must be divisible by block_size"
self.blocks_per_chunk = self.chunk_size // self.block_size

if role == KVConnectorRole.SCHEDULER:
Expand Down Expand Up @@ -361,7 +364,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for i, tensor in enumerate(sample_kv_layer):
logger.info(f"kv cache shape {i}: {tensor.shape}")
self.kv_cache_layout = KVCacheLayout(
self.kv_caches, self.use_layerwise, self._vllm_config
self.kv_caches, self.launch_config, self._vllm_config
)
self.block_data_size = self.kv_cache_layout.block_size
self.layer_name_to_id = self.kv_cache_layout.layer_name_to_id
Expand Down Expand Up @@ -395,10 +398,10 @@ def get_num_new_matched_tokens(
num_computed_tokens: int,
) -> tuple[int, bool]:
assert num_computed_tokens % self.block_size == 0
hbm_hit_block_num = num_computed_tokens // self.block_size
hbm_hit_block_num = num_computed_tokens // self.chunk_size

ucm_block_ids = self.generate_hash(
self.block_size, request.all_token_ids, self._seed
self.chunk_size, request.all_token_ids, self._seed
)

external_block_ids = ucm_block_ids[hbm_hit_block_num:]
Expand All @@ -422,12 +425,15 @@ def get_num_new_matched_tokens(

total_hit_block_num = hbm_hit_block_num + external_hit_blocks

external_hit_tokens = external_hit_blocks * self.block_size
external_hit_tokens = 0
if external_hit_blocks > 0:
remainder = num_computed_tokens % self.chunk_size
external_hit_tokens = external_hit_blocks * self.chunk_size - remainder
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be external_hit_tokens = (external_hit_blocks - remainder) * self.chunk_size?


# When all the tokens are cached in ssd or hbm,
# we need to recompute the last token. This if condition will be removed
# once vLLM scheduler provides a better solution in the future.
num_total_hit_tokens = total_hit_block_num * self.block_size
num_total_hit_tokens = external_hit_tokens + num_computed_tokens
if num_total_hit_tokens == request.num_tokens:
external_hit_tokens -= 1

Expand Down Expand Up @@ -474,13 +480,19 @@ def _generate_dispatch_meta(
dump_ucm_block_ids, dump_vllm_block_ids = [], []
if need_load:
load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num]
load_vllm_block_ids = vllm_block_ids[
hbm_hit_block_num
* self.blocks_per_chunk : total_hit_block_num
* self.blocks_per_chunk
]

if req_meta.token_processed < req_meta.num_token_ids:
start_idx = req_meta.token_processed // self.block_size
end_idx = (req_meta.token_processed + new_tokens) // self.block_size
start_idx = req_meta.token_processed // self.chunk_size
end_idx = (req_meta.token_processed + new_tokens) // self.chunk_size
dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
dump_vllm_block_ids = req_meta.vllm_block_ids[
start_idx * self.blocks_per_chunk : end_idx * self.blocks_per_chunk
]
req_meta.token_processed += new_tokens

return RequestDispatchMeta(
Expand Down Expand Up @@ -569,7 +581,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
for i, ucm_block_id in enumerate(ucm_block_ids):
ucm_block_ids[i] = self.request_hasher(ucm_block_id)
total_ptrs = self.kv_cache_layout.extract_block_addrs(vllm_block_ids)
total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1)
total_ptrs = total_ptrs.reshape(
total_ptrs.shape[0] // self.blocks_per_chunk, -1
)
assert total_ptrs.shape[0] == len(ucm_block_ids)
shard_indexs = [0] * len(ucm_block_ids)
try:
task = self.store.load_data(ucm_block_ids, shard_indexs, total_ptrs)
Expand Down Expand Up @@ -662,7 +677,10 @@ def wait_for_save(self) -> None:

if is_save:
total_ptrs = self.kv_cache_layout.extract_block_addrs(total_vllm_block_ids)
total_ptrs = total_ptrs.reshape(total_ptrs.shape[0], -1)
total_ptrs = total_ptrs.reshape(
total_ptrs.shape[0] // self.blocks_per_chunk, -1
)
assert total_ptrs.shape[0] == len(total_ucm_block_ids)
shard_indexs = [0] * len(total_ucm_block_ids)
try:
event_handle = self._get_dump_event_handle()
Expand Down Expand Up @@ -777,6 +795,14 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
total_ptrs = self.kv_cache_layout.extract_block_addrs(
vllm_block_ids, layer_first=True
)
# (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
n_layers, n_blocks, n_ptrs = total_ptrs.shape
total_ptrs = total_ptrs.reshape(
n_layers,
n_blocks // self.blocks_per_chunk,
self.blocks_per_chunk * n_ptrs,
)
assert total_ptrs.shape[1] == len(ucm_block_ids)
self.request_data.append((request_id, ucm_block_ids, total_ptrs))

if self.need_load:
Expand Down Expand Up @@ -840,6 +866,14 @@ def save_kv_layer(
self.dump_total_ptrs = self.kv_cache_layout.extract_block_addrs(
total_vllm_block_ids, layer_first=True
)
# (n_layers, num_blocks, n_ptrs) -> (n_layers, num_blocks//bpc, bpc*n_ptrs)
n_layers, n_blocks, n_ptrs = self.dump_total_ptrs.shape
self.dump_total_ptrs = self.dump_total_ptrs.reshape(
n_layers,
n_blocks // self.blocks_per_chunk,
self.blocks_per_chunk * n_ptrs,
)
assert self.dump_total_ptrs.shape[1] == len(total_ucm_block_ids)
shard_indexs = [layer_id] * len(total_ucm_block_ids)
try:
layer_ptrs = np.ascontiguousarray(self.dump_total_ptrs[local_layer_id])
Expand Down
Loading