diff --git a/test/test_ucm_connector_save_load.py b/test/test_ucm_connector_save_load.py index c0def663f..29252d355 100644 --- a/test/test_ucm_connector_save_load.py +++ b/test/test_ucm_connector_save_load.py @@ -57,6 +57,8 @@ UCMConnectorMetadata, ) from ucm.logger import init_logger +from ucm.store.factory_v1 import UcmConnectorFactoryV1 +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 logger = init_logger(__name__) @@ -91,7 +93,7 @@ def make_buffers( is_mla: bool, ) -> Tuple[List[str], Dict[str, torch.Tensor]]: logger.info(f"Allocating buffers: blocks={block_number}, batch_size={batch_size}") - hashes = [secrets.token_hex(16) for _ in range(block_number)] + hashes = [secrets.token_bytes(16) for _ in range(block_number)] device = f"cuda:{device_id}" kv_caches: Dict[str, torch.Tensor] = {} @@ -123,8 +125,8 @@ def build_vllm_config( tp_size: int, connector_name: str, storage_backends: str, - transfer_stream_number: int, - use_direct: bool, + stream_number: int, + io_direct: bool, ) -> VllmConfig: cache_config = CacheConfig( block_size=block_size, @@ -189,8 +191,8 @@ def build_vllm_config( "ucm_connector_name": connector_name, "ucm_connector_config": { "storage_backends": storage_backends, - "use_direct": use_direct, - "stream_number": transfer_stream_number, + "io_direct": io_direct, + "stream_number": stream_number, "local_rank_size": 1, }, } @@ -241,6 +243,7 @@ def compute_total_bytes( def run_once( connector: UCMConnector, + scheduler: UcmKVStoreBaseV1, kv_caches: Dict[str, torch.Tensor], hashes: List[str], batch_size: int, @@ -254,7 +257,9 @@ def run_once( load_block_ids=([], []), dump_block_ids=(dump_hashes, dump_vllm_block_ids), ) - connector.connector.kv_caches = kv_caches + + if not hasattr(connector.connector, "store") or connector.connector.store is None: + connector.connector.register_kv_caches(kv_caches) connector.bind_connector_metadata(metadata) total_bytes = compute_total_bytes(kv_caches, batch_size, is_mla) @@ -267,7 +272,7 @@ def run_once( write_bw = (total_bytes / (1024**3)) / write_time if write_time > 0 else 0.0 - lookup = connector.connector.store.lookup(dump_hashes) + lookup = scheduler.lookup(dump_hashes) if not all(lookup): raise RuntimeError("Found missing cache blocks before load test.") @@ -277,7 +282,7 @@ def run_once( load_block_ids=(dump_hashes, load_vllm_block_ids), dump_block_ids=([], []), ) - connector.connector.kv_caches = kv_caches + connector.bind_connector_metadata(load_metadata) forward_context = build_forward_context(kv_caches, is_mla) @@ -316,8 +321,8 @@ def run_test( ucm_connector_name: str, total_tp_size: int, model_path: str, - transfer_stream_number: int, - use_direct: bool, + stream_number: int, + io_direct: bool, ) -> Tuple[float, float, float, float, float, float]: block_dim = head_size * num_head io_size = block_dim * block_len * block_elem_size @@ -335,8 +340,8 @@ def run_test( tp_size=total_tp_size, connector_name=ucm_connector_name, storage_backends=storage_backends, - transfer_stream_number=transfer_stream_number, - use_direct=use_direct, + stream_number=stream_number, + io_direct=io_direct, ) dummy_world_group = type("DummyWorldGroup", (), {"local_rank": 0})() @@ -375,6 +380,25 @@ def broadcast(self, tensor, src): mla, ) + connector.connector.register_kv_caches(kv_caches) + + storage_backends_list = [ + os.path.join(path, "kv") for path in storage_backends.split(":") if path + ] + + scheduler_config = { + "storage_backends": storage_backends_list, + "block_size": block_size, + "device_id": -1, # device_id=-1 means transferEnable=false + "tensor_size": io_size, + "stream_number": stream_number, + "io_direct": io_direct, + "unique_id": secrets.token_hex(8), + } + scheduler = UcmConnectorFactoryV1.create_connector( + ucm_connector_name, scheduler_config + ) + w_sizes, w_times, w_bws = [], [], [] r_sizes, r_times, r_bws = [], [], [] @@ -385,10 +409,10 @@ def broadcast(self, tensor, src): round_hashes = hashes[start_hash_idx:end_hash_idx] if len(round_hashes) < batch_size: - round_hashes = [secrets.token_hex(16) for _ in range(batch_size)] + round_hashes = [secrets.token_bytes(16) for _ in range(batch_size)] (w_size, w_time, w_bw), (r_size, r_time, r_bw) = run_once( - connector, kv_caches, round_hashes, batch_size, mla + connector, scheduler, kv_caches, round_hashes, batch_size, mla ) if round_idx != 0: @@ -451,7 +475,7 @@ def main(): num_tokens_list = [2048, 4096, 8192, 16384, 32768] ucm_connector_name = "UcmNfsStore" model_path = "/home/models/QwQ-32B" - transfer_stream_numbers = [32, 64, 128] + stream_numbers = [32, 64, 128] os.environ["UC_LOGGER_LEVEL"] = "debug" print("1. Model Selection:") @@ -462,8 +486,8 @@ def main(): print("\n2. IoDirect Transfer:") print(" 1 - Disable IoDirect (default)") print(" 2 - Enable IoDirect") - use_direct = get_user_input("Please select Direct IO mode", "1") - use_direct = False if use_direct == "1" else True + io_direct = get_user_input("Please select Direct IO mode", "1") + io_direct = False if io_direct == "1" else True if mla: block_lens = [64] @@ -515,7 +539,7 @@ def main(): for num_head in num_head_list: for block_len in block_lens: - for transfer_stream_number in transfer_stream_numbers: + for stream_number in stream_numbers: block_dim = head_size * num_head io_size = block_dim * block_len * block_elem_size @@ -548,8 +572,8 @@ def main(): ucm_connector_name, total_tp_size, model_path, - transfer_stream_number, - use_direct, + stream_number, + io_direct, ), ) @@ -579,7 +603,7 @@ def main(): kv, num_head, block_len, - transfer_stream_number, + stream_number, io_count, io_size, f"{avg_w_size:.4f}", diff --git a/ucm/store/test/e2e/nfsstore_embed_fetch.py b/ucm/store/test/e2e/nfsstore_embed_fetch.py index 1132afa50..9c5ebb1ea 100644 --- a/ucm/store/test/e2e/nfsstore_embed_fetch.py +++ b/ucm/store/test/e2e/nfsstore_embed_fetch.py @@ -32,7 +32,9 @@ import torch from ucm.store.nfsstore.nfsstore_connector import UcmNfsStore +from ucm.store.pcstore.pcstore_connector_v1 import UcmPcStoreV1 from ucm.store.ucmstore import UcmKVStoreBase +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 def setup( @@ -40,19 +42,19 @@ def setup( block_size, device_id, io_size, - transferStreamNumber, - transferIoDirect, -) -> UcmKVStoreBase: + stream_number, + io_direct, +) -> UcmKVStoreBaseV1: config = { - "storage_backends": storage_backends, - "kv_block_size": block_size, - "role": "worker", - "device": device_id, - "io_size": io_size, - "transferStreamNumber": transferStreamNumber, - "transferIoDirect": transferIoDirect, + "storage_backends": [storage_backends], + "block_size": block_size, + "device_id": device_id, + "tensor_size": io_size, + "stream_number": stream_number, + "io_direct": io_direct, + "unique_id": secrets.token_hex(8), } - return UcmNfsStore(config) + return UcmPcStoreV1(config) def make_aligned_tensor(shape, dtype, device, alignment=4096): @@ -79,66 +81,60 @@ def make_aligned_tensor(shape, dtype, device, alignment=4096): def make_buffers( block_number, device_id, batch_size, head_dim, block_len, block_layer, num_head, kv ): - hashes = [secrets.token_hex(16) for _ in range(block_number)] - kv_caches = {} - for i in range(block_layer): - kv_caches[i] = make_aligned_tensor( + hashes = [secrets.token_bytes(16) for _ in range(block_number)] + kvcaches = {} + for layer_id in range(block_layer): + kvcaches[layer_id] = make_aligned_tensor( [kv, block_number, block_len, num_head, head_dim], - dtype=torch.float16, + dtype=torch.bfloat16, device=f"cuda:{device_id}", ) - return hashes, kv_caches + kvcaches[layer_id].random_() + return hashes, kvcaches -def store_all_hashes(hashes: List[str]): +def store_all_hashes(hashes: List[bytes]): file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") with open(file_path, "w", encoding="utf-8") as f: for h in hashes: - f.write(h + "\n") + f.write(h.hex() + "\n") -def load_hashes_from_file() -> List[str]: +def load_hashes_from_file() -> List[bytes]: file_path = os.path.join(os.path.dirname(__file__), "kvcache_block_hashes.txt") if not os.path.exists(file_path): return [] with open(file_path, "r", encoding="utf-8") as f: - return [line.strip() for line in f.readlines()] + return [bytes.fromhex(line.strip()) for line in f.readlines()] def embed( - store: UcmKVStoreBase, - hashes: List[str], + store: UcmKVStoreBaseV1, + hashes: List[bytes], kvcaches: Dict[int, torch.Tensor], mla: bool, ): - start_time = time.perf_counter() - - total_block_ids, total_offsets, total_tensors = [], [], [] + total_tensors = [] total_size = 0 for i, hash_val in enumerate(hashes): - offset = 0 + tensors = [] for layer_id, kv_layer in kvcaches.items(): - k_tensor = kv_layer[0][i] # kv=1 - total_tensors.append(k_tensor) - total_block_ids.append(hash_val) - total_offsets.append(offset) + k_tensor = kv_layer[0][i].contiguous() + tensors.append(k_tensor) sz = k_tensor.numel() * k_tensor.element_size() - offset += sz total_size += sz if not mla: - v_tensor = kv_layer[1][i] - total_tensors.append(v_tensor) - total_block_ids.append(hash_val) - total_offsets.append(offset) + v_tensor = kv_layer[1][i].contiguous() + tensors.append(v_tensor) sz = v_tensor.numel() * v_tensor.element_size() - offset += sz total_size += sz + total_tensors.append(tensors) - task = store.dump(total_block_ids, total_offsets, total_tensors) + start_time = time.perf_counter() + task = store.dump(hashes, [], total_tensors) store.wait(task) - elapsed_time = time.perf_counter() - start_time throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0 @@ -151,45 +147,46 @@ def embed( def fetch( - store: UcmKVStoreBase, - hashes: List[str], + store: UcmKVStoreBaseV1, + scheduler: UcmKVStoreBaseV1, + hashes: List[bytes], kvcaches: Dict[int, torch.Tensor], mla: bool, ): - start_time = time.perf_counter() - - founds = store.lookup(hashes) + founds = scheduler.lookup(hashes) for f in founds: assert f, "Cache block miss detected" - block_ids, offsets, tensors = [], [], [] + totoal_tensors = [] total_size = 0 for i, hash_val in enumerate(hashes): - offset = 0 + tensors = [] for layer_id, kv_layer in kvcaches.items(): - k_tensor = kv_layer[0][i] # kv=1 - block_ids.append(hash_val) - offsets.append(offset) + k_tensor = kv_layer[0][i].contiguous() tensors.append(k_tensor) sz = k_tensor.numel() * k_tensor.element_size() - offset += sz total_size += sz if not mla: - v_tensor = kv_layer[1][i] - block_ids.append(hash_val) - offsets.append(offset) + v_tensor = kv_layer[1][i].contiguous() tensors.append(v_tensor) sz = v_tensor.numel() * v_tensor.element_size() - offset += sz total_size += sz + totoal_tensors.append(tensors) - task = store.load(block_ids, offsets, tensors) - ret = store.wait(task) - assert ret == 0, "Load operation failed" - + start_time = time.perf_counter() + task = store.load(hashes, [], totoal_tensors) + try: + ret = store.wait(task) + if ret is None: + ret = 0 + except RuntimeError as e: + print(f"Load operation failed with error: {e}") + raise + assert ret == 0, f"Load operation failed with return code: {ret}" elapsed_time = time.perf_counter() - start_time + throughput_gbps = (total_size / (1024**3)) / elapsed_time if elapsed_time > 0 else 0 print( @@ -206,14 +203,14 @@ def run( repeat: int, num_head: int, block_len: int, - transferStreamNumber: int, + stream_number: int, num_tokens: int, block_layer: int, head_size: int, block_elem_size: int, kv: int, mla: bool, - transferIoDirect: bool, + io_direct: bool, operation_mode: str = "both", # "write_only", "read_only", or "both" ) -> Tuple[float, float, float, float, float, float]: """ @@ -226,6 +223,10 @@ def run( block_dim = head_size * num_head io_size = block_dim * block_len * block_elem_size block_size = io_size * block_layer + + if not mla: + block_size = block_size * 2 + batch_size = int(num_tokens / block_len) real_blocks = batch_size + 10 @@ -238,8 +239,17 @@ def run( block_size, device_id, io_size, - transferStreamNumber, - transferIoDirect, + stream_number, + io_direct, + ) + + scheduler = setup( + storage_backends, + block_size, + -1, # device_id=-1 means transferEnable=false + io_size, + stream_number, + io_direct, ) for r in range(repeat): @@ -257,16 +267,13 @@ def run( kv, ) - results = store.create(hashes[:batch_size]) - assert sum(results) == 0, "Create operation failed" - w_size, w_time, w_bw = embed( store, hashes[:batch_size], kvcaches, mla, ) - store.commit(hashes[:batch_size], True) + time.sleep(1) if r == 0: store_all_hashes(hashes[:batch_size]) @@ -302,6 +309,7 @@ def run( r_size, r_time, r_bw = fetch( store, + scheduler, saved_hashes[:batch_size], kvcaches, mla, @@ -309,6 +317,7 @@ def run( else: r_size, r_time, r_bw = fetch( store, + scheduler, hashes[:batch_size], kvcaches, mla, @@ -349,18 +358,18 @@ def run( try: result = run( storage_backends=".", - device_id=1, - repeat=1, + device_id=6, + repeat=2, num_head=1, - block_len=128, - transferStreamNumber=32, + block_len=64, + stream_number=32, num_tokens=4096, block_layer=61, head_size=576, block_elem_size=2, kv=1, mla=True, - transferIoDirect=False, + io_direct=False, operation_mode="both", )