Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions test/test_ucm_connector_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
UCMConnectorMetadata,
)
from ucm.logger import init_logger
from ucm.store.factory_v1 import UcmConnectorFactoryV1
from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1

logger = init_logger(__name__)

Expand Down Expand Up @@ -91,7 +93,7 @@ def make_buffers(
is_mla: bool,
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}")
hashes = [secrets.token_hex(16) for _ in range(block_number)]
hashes = [secrets.token_bytes(16) for _ in range(block_number)]
device = f"cuda:{device_id}"
kv_caches: Dict[str, torch.Tensor] = {}

Expand Down Expand Up @@ -123,8 +125,8 @@ def build_vllm_config(
tp_size: int,
connector_name: str,
storage_backends: str,
transfer_stream_number: int,
use_direct: bool,
stream_number: int,
io_direct: bool,
) -> VllmConfig:
cache_config = CacheConfig(
block_size=block_size,
Expand Down Expand Up @@ -189,8 +191,8 @@ def build_vllm_config(
"ucm_connector_name": connector_name,
"ucm_connector_config": {
"storage_backends": storage_backends,
"use_direct": use_direct,
"stream_number": transfer_stream_number,
"io_direct": io_direct,
"stream_number": stream_number,
"local_rank_size": 1,
},
}
Expand Down Expand Up @@ -241,6 +243,7 @@ def compute_total_bytes(

def run_once(
connector: UCMConnector,
scheduler: UcmKVStoreBaseV1,
kv_caches: Dict[str, torch.Tensor],
hashes: List[str],
batch_size: int,
Expand All @@ -254,7 +257,9 @@ def run_once(
load_block_ids=([], []),
dump_block_ids=(dump_hashes, dump_vllm_block_ids),
)
connector.connector.kv_caches = kv_caches

if not hasattr(connector.connector, "store") or connector.connector.store is None:
connector.connector.register_kv_caches(kv_caches)
connector.bind_connector_metadata(metadata)

total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla)
Expand All @@ -267,7 +272,7 @@ def run_once(

write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0

lookup = connector.connector.store.lookup(dump_hashes)
lookup = scheduler.lookup(dump_hashes)
if not all(lookup):
raise RuntimeError("Found missing cache blocks before load test.")

Expand All @@ -277,7 +282,7 @@ def run_once(
load_block_ids=(dump_hashes, load_vllm_block_ids),
dump_block_ids=([], []),
)
connector.connector.kv_caches = kv_caches

connector.bind_connector_metadata(load_metadata)

forward_context = build_forward_context(kv_caches, is_mla)
Expand Down Expand Up @@ -316,8 +321,8 @@ def run_test(
ucm_connector_name: str,
total_tp_size: int,
model_path: str,
transfer_stream_number: int,
use_direct: bool,
stream_number: int,
io_direct: bool,
) -> Tuple[float, float, float, float, float, float]:
block_dim = head_size * num_head
io_size = block_dim * block_len * block_elem_size
Expand All @@ -335,8 +340,8 @@ def run_test(
tp_size=total_tp_size,
connector_name=ucm_connector_name,
storage_backends=storage_backends,
transfer_stream_number=transfer_stream_number,
use_direct=use_direct,
stream_number=stream_number,
io_direct=io_direct,
)

dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})()
Expand Down Expand Up @@ -375,6 +380,25 @@ def broadcast(self, tensor, src):
mla,
)

connector.connector.register_kv_caches(kv_caches)

storage_backends_list = [
os.path.join(path, "kv") for path in storage_backends.split(":") if path
]

scheduler_config = {
"storage_backends": storage_backends_list,
"block_size": block_size,
"device_id": -1, # device_id=-1 means transferEnable=false
"tensor_size": io_size,
"stream_number": stream_number,
"io_direct": io_direct,
"unique_id": secrets.token_hex(8),
}
scheduler = UcmConnectorFactoryV1.create_connector(
ucm_connector_name, scheduler_config
)

w_sizes, w_times, w_bws = [], [], []
r_sizes, r_times, r_bws = [], [], []

Expand All @@ -385,10 +409,10 @@ def broadcast(self, tensor, src):
round_hashes = hashes[start_hash_idx:end_hash_idx]

if len(round_hashes) < batch_size:
round_hashes = [secrets.token_hex(16) for _ in range(batch_size)]
round_hashes = [secrets.token_bytes(16) for _ in range(batch_size)]

(w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once(
connector, kv_caches, round_hashes, batch_size, mla
connector, scheduler, kv_caches, round_hashes, batch_size, mla
)

if round_idx != 0:
Expand Down Expand Up @@ -451,7 +475,7 @@ def main():
num_tokens_list = [2048, 4096, 8192, 16384, 32768]
ucm_connector_name = "UcmNfsStore"
model_path = "/home/models/QwQ-32B"
transfer_stream_numbers = [32, 64, 128]
stream_numbers = [32, 64, 128]
os.environ["UC_LOGGER_LEVEL"] = "debug"

print("1. Model Selection:")
Expand All @@ -462,8 +486,8 @@ def main():
print("\n2. IoDirect Transfer:")
print(" 1 - Disable IoDirect (default)")
print(" 2 - Enable IoDirect")
use_direct = get_user_input("Please select Direct IO mode", "1")
use_direct = False if use_direct == "1" else True
io_direct = get_user_input("Please select Direct IO mode", "1")
io_direct = False if io_direct == "1" else True

if mla:
block_lens = [64]
Expand Down Expand Up @@ -515,7 +539,7 @@ def main():

for num_head in num_head_list:
for block_len in block_lens:
for transfer_stream_number in transfer_stream_numbers:
for stream_number in stream_numbers:
block_dim = head_size * num_head
io_size = block_dim * block_len * block_elem_size

Expand Down Expand Up @@ -548,8 +572,8 @@ def main():
ucm_connector_name,
total_tp_size,
model_path,
transfer_stream_number,
use_direct,
stream_number,
io_direct,
),
)

Expand Down Expand Up @@ -579,7 +603,7 @@ def main():
kv,
num_head,
block_len,
transfer_stream_number,
stream_number,
io_count,
io_size,
f"{avg_w_size:.4f}",
Expand Down
Loading