diff --git a/examples/hstu/configs/inference_config.py b/examples/hstu/configs/inference_config.py index ded06924b..06b1dc0af 100755 --- a/examples/hstu/configs/inference_config.py +++ b/examples/hstu/configs/inference_config.py @@ -63,17 +63,17 @@ class KVCacheMetadata: """ # paged cache metadata - kv_indices: torch.Tensor = None - kv_indptr: torch.Tensor = None - kv_last_page_len: torch.Tensor = None - total_history_lengths: torch.Tensor = None - total_history_offsets: torch.Tensor = None + kv_indices: torch.Tensor = None # num_pages + kv_indptr: torch.Tensor = None # num_seq + 1 + kv_last_page_len: torch.Tensor = None # num_seq + total_history_lengths: torch.Tensor = None # num_seq + total_history_offsets: torch.Tensor = None # num_seq + 1 # appending metadata - batch_indices: torch.Tensor = None - position: torch.Tensor = None + batch_indices: torch.Tensor = None # num_tokens + position: torch.Tensor = None # num_tokens new_history_nnz: int = 0 - new_history_nnz_cuda: torch.Tensor = None + new_history_nnz_cuda: torch.Tensor = None # 1 # onload utility onload_history_kv_buffer: Optional[List[torch.Tensor]] = None @@ -82,6 +82,16 @@ class KVCacheMetadata: # paged cache table pointers kv_cache_table: Optional[List[torch.Tensor]] = None + # async attributes + kv_onload_handle: Optional[object] = None + kv_offload_handle: Optional[object] = None + + offload_user_ids: Optional[torch.Tensor] = None + offload_page_ids: Optional[torch.Tensor] = None + new_offload_startpos: Optional[torch.Tensor] = None + new_offload_lengths: Optional[torch.Tensor] = None + + max_seqlen: Optional[int] = 0 @dataclass class KVCacheConfig: diff --git a/examples/hstu/dataset/inference_dataset.py b/examples/hstu/dataset/inference_dataset.py index a5db14464..bc4e011b0 100644 --- a/examples/hstu/dataset/inference_dataset.py +++ b/examples/hstu/dataset/inference_dataset.py @@ -160,12 +160,25 @@ def __iter__(self) -> Iterator[Batch]: ) dates.append(self._batch_logs_frame.iloc[sample_id][self._date_name]) seq_endptrs.append(seq_endptr) - if len(user_ids) == 0: - continue + + last_date = dates[0] + final_user_ids: List[int] = [] + final_dates: List[int] = [] + final_seq_endptrs: List[int] = [] + for (uid, date, endp) in zip(user_ids, dates, seq_endptrs): + if date != last_date: + continue + if uid not in final_user_ids: + final_user_ids.append(uid) + final_dates.append(date) + final_seq_endptrs.append(endp) + else: + idx = final_user_ids.index(uid) + final_seq_endptrs[idx] = max(final_seq_endptrs[idx], endp) yield ( - torch.tensor(user_ids), - torch.tensor(dates), - torch.tensor(seq_endptrs), + torch.tensor(final_user_ids), + torch.tensor(final_dates), + torch.tensor(final_seq_endptrs), ) def get_input_batch( @@ -306,7 +319,7 @@ def get_input_batch( labels = torch.tensor(labels, dtype=torch.int64, device=self._device) batch_kwargs = dict( features=features, - batch_size=self._batch_size, + batch_size=len(user_ids), # self._batch_size, feature_to_max_seqlen=feature_to_max_seqlen, contextual_feature_names=self._contextual_feature_names, item_feature_name=self._item_feature_name, diff --git a/examples/hstu/inference/README.md b/examples/hstu/inference/README.md index bfcc8d70f..a90417c3f 100644 --- a/examples/hstu/inference/README.md +++ b/examples/hstu/inference/README.md @@ -56,28 +56,14 @@ ERROR: The input sequence has overlapping tokens from 5 to 9 (both inclusive). ## How to Setup -1. Build TensorRT-LLM (with HSTU KV cache extension): - -The HSTU inference utilize customized KV cache manager from TensorRT-LLM. -The current version is based on the HSTU specialized implementation based on TensorRT-LLM v0.19.0. - -```bash -~$ cd ${WORKING_DIR} -~$ git clone -b hstu-kvcache-recsys-examples https://github.com/geoffreyQiu/TensorRT-LLM.git tensorrt-llm-kvcache && cd tensorrt-llm-kvcache -~$ git submodule update --init --recursive -~$ make -C docker release_build CUDA_ARCHS="80-real;86-real" -# This will build a docker image with TensorRT-LLM installed. -``` - -2. Install the dependencies for Recsys-Examples. +1. Install the dependencies for Recsys-Examples. Turn on option `INFERENCEBUILD=1` to skip Megatron installation, which is not required for inference. ```bash ~$ cd ${WORKING_DIR} ~$ git clone --recursive -b ${TEST_BRANCH} ${TEST_REPO} recsys-examples && cd recsys-examples -~$ TRTLLM_KVCACHE_IMAGE="tensorrt_llm/release:latest" docker build \ - --build-arg BASE_IMAGE=${TRTLLM_KVCACHE_IMAGE} \ +~$ docker build \ --build-arg INFERENCEBUILD=1 \ -t recsys-examples:inference \ -f docker/Dockerfile . @@ -93,7 +79,7 @@ Turn on option `INFERENCEBUILD=1` to skip Megatron installation, which is not re ~$ python3 ./preprocessor.py --dataset_name "kuairand-1k" --inference ~$ ~$ # Run the inference example -~$ python3 ./inference/inference_gr_ranking.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval +~$ python3 ./inference/inference_gr_ranking_async.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval ``` ## Consistency Check for Inference @@ -131,7 +117,7 @@ TrainerArgs.ckpt_save_interval = 550 2. Evaluation metrics from inference ``` -/workspace/recsys-examples$ PYTHONPATH=${PYTHONPATH}:$(realpath ../) python3 ./inference/inference_gr_ranking.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval +/workspace/recsys-examples$ PYTHONPATH=${PYTHONPATH}:$(realpath ../) python3 ./inference/inference_gr_ranking_async.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval ... [inference output] ... [eval]: Metrics.task0.AUC:0.556894 @@ -142,4 +128,4 @@ TrainerArgs.ckpt_save_interval = 550 Metrics.task5.AUC:0.580227 Metrics.task6.AUC:0.620498 Metrics.task7.AUC:0.556064 -... [inference output] ... \ No newline at end of file +... [inference output] ... diff --git a/examples/hstu/inference/async_kvcache_eval.py b/examples/hstu/inference/async_kvcache_eval.py new file mode 100644 index 000000000..b3c6efcb9 --- /dev/null +++ b/examples/hstu/inference/async_kvcache_eval.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import enum +import math +import sys +import time +import os +import shutil + +import gin +import torch +from commons.utils.stringify import stringify_dict +from configs import ( + InferenceEmbeddingConfig, + PositionEncodingConfig, + RankingConfig, + get_inference_hstu_config, + get_kvcache_config, +) +from dataset import get_data_loader +from dataset.inference_dataset import InferenceDataset +from dataset.sequence_dataset import get_dataset +from modules.metrics import get_multi_event_metric_module +from preprocessor import get_common_preprocessors +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from utils import DatasetArgs, NetworkArgs, RankingArgs + +sys.path.append("./model/") +from inference_ranking_gr import InferenceRankingGR + +import modules.paged_hstu_infer_layer as pg +from modules.paged_hstu_infer_layer import init + +class RunningMode(enum.Enum): + EVAL = "eval" + SIMULATE = "simulate" + + def __str__(self): + return self.value + + +def get_inference_dataset_and_embedding_configs(): + dataset_args = DatasetArgs() + embedding_dim = NetworkArgs().hidden_size + HASH_SIZE = 10_000_000 + if dataset_args.dataset_name == "kuairand-1k": + embedding_configs = [ + InferenceEmbeddingConfig( + feature_names=["user_id"], + table_name="user_id", + vocab_size=1000, + dim=embedding_dim, + use_dynamicemb=True, + ), + InferenceEmbeddingConfig( + feature_names=["user_active_degree"], + table_name="user_active_degree", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["follow_user_num_range"], + table_name="follow_user_num_range", + vocab_size=9, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["fans_user_num_range"], + table_name="fans_user_num_range", + vocab_size=9, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["friend_user_num_range"], + table_name="friend_user_num_range", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["register_days_range"], + table_name="register_days_range", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["video_id"], + table_name="video_id", + vocab_size=HASH_SIZE, + dim=embedding_dim, + use_dynamicemb=True, + ), + InferenceEmbeddingConfig( + feature_names=["action_weights"], + table_name="action_weights", + vocab_size=233, + dim=embedding_dim, + use_dynamicemb=False, + ), + ] + return dataset_args, embedding_configs + + raise ValueError(f"dataset {dataset_args.dataset_name} is not supported") + + +def get_inference_hstu_model( + emb_configs, + max_batch_size, + num_contextual_features, + total_max_seqlen, + checkpoint_dir, +): + network_args = NetworkArgs() + if network_args.dtype_str == "bfloat16": + inference_dtype = torch.bfloat16 + # elif network_args.dtype_str == "float16": + # inference_dtype = torch.float16 + else: + raise ValueError( + f"Inference data type {network_args.dtype_str} is not supported" + ) + + position_encoding_config = PositionEncodingConfig( + num_position_buckets=8192, + num_time_buckets=2048, + use_time_encoding=False, + static_max_seq_len=math.ceil(total_max_seqlen / 32) * 32, + ) + + hstu_config = get_inference_hstu_config( + hidden_size=network_args.hidden_size, + num_layers=network_args.num_layers, + num_attention_heads=network_args.num_attention_heads, + head_dim=network_args.kv_channels, + dtype=inference_dtype, + position_encoding_config=position_encoding_config, + contextual_max_seqlen=num_contextual_features, + scaling_seqlen=network_args.scaling_seqlen, + ) + + kvcache_args = { + "blocks_in_primary_pool": 10240, + "page_size": 32, + "offload_chunksize": 1024, + "max_batch_size": max_batch_size, + "max_seq_len": math.ceil(total_max_seqlen / 32) * 32, + } + kv_cache_config = get_kvcache_config(**kvcache_args) + + ranking_args = RankingArgs() + task_config = RankingConfig( + embedding_configs=emb_configs, + prediction_head_arch=ranking_args.prediction_head_arch, + prediction_head_act_type=ranking_args.prediction_head_act_type, + prediction_head_bias=ranking_args.prediction_head_bias, + num_tasks=ranking_args.num_tasks, + eval_metrics=ranking_args.eval_metrics, + ) + + hstu_cudagraph_configs = { + "batch_size": [1], + "length_per_sequence": [128] + [i * 256 for i in range(1, 34)], + } + + model = InferenceRankingGR( + hstu_config=hstu_config, + kvcache_config=kv_cache_config, + task_config=task_config, + use_cudagraph=False, + cudagraph_configs=hstu_cudagraph_configs, + ) + if hstu_config.bf16: + model.bfloat16() + elif hstu_config.fp16: + model.half() + model.load_checkpoint(checkpoint_dir) + model.eval() + + return model + + +def get_new_batch( + batch, hist_lengths, ratio, num_contextuals +): + partial_lengths = torch.ceil(hist_lengths * ratio).long() - num_contextuals + partial_lengths = partial_lengths // 2 + + kjt_dict = batch.features.to_dict() + item_jt = kjt_dict["video_id"] + vals = item_jt.values() + lens = item_jt.lengths() + num_candidates = batch.num_candidates + split_lens = torch.stack( + [partial_lengths + num_candidates, lens - partial_lengths - num_candidates], dim=1 + ).reshape((-1,)) + stripped_vals = torch.split(vals, split_lens.tolist())[::2] + kjt_dict["video_id"] = JaggedTensor.from_dense(stripped_vals) + + action_jt = kjt_dict["action_weights"] + vals = action_jt.values() + lens = action_jt.lengths() + split_lens = torch.stack( + [partial_lengths, lens - partial_lengths], dim=1 + ).reshape((-1,)) + stripped_vals = torch.split(vals, split_lens.tolist())[::2] + kjt_dict["action_weights"] = JaggedTensor.from_dense(stripped_vals) + + batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict) + hist_lengths = num_contextuals + partial_lengths * 2 + + return batch, hist_lengths + + + +def run_kvcache_consistency_check( + checkpoint_dir: str, + disable_kvcache: bool = False, +): + dataset_args, emb_configs = get_inference_dataset_and_embedding_configs() + + dataproc = get_common_preprocessors("")[dataset_args.dataset_name] + num_contextual_features = len(dataproc._contextual_feature_names) + + max_batch_size = 1 + total_max_seqlen = dataset_args.max_sequence_length * 2 + num_contextual_features + print("total_max_seqlen", total_max_seqlen) + + def strip_candidate_action_tokens(batch, action_feature_name): + kjt_dict = batch.features.to_dict() + action_jagged_tensor = kjt_dict[action_feature_name] + values = action_jagged_tensor.values() + lengths = action_jagged_tensor.lengths() + num_candidates = batch.num_candidates + split_lengths = torch.stack( + [lengths - num_candidates, num_candidates], dim=1 + ).reshape((-1,)) + stripped_value = torch.split(values, split_lengths.tolist())[::2] + kjt_dict[action_feature_name] = JaggedTensor.from_dense(stripped_value) + batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict) + return batch + + def strip_padding_batch(batch, unpadded_batch_size): + batch.batch_size = unpadded_batch_size + kjt_dict = batch.features.to_dict() + for k in kjt_dict: + kjt_dict[k] = JaggedTensor.from_dense_lengths( + kjt_dict[k].to_padded_dense()[: batch.batch_size], + kjt_dict[k].lengths()[: batch.batch_size].long(), + ) + batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict) + batch.num_candidates = batch.num_candidates[: batch.batch_size] + return batch + + with torch.inference_mode(): + model = get_inference_hstu_model( + emb_configs, + max_batch_size, + num_contextual_features, + total_max_seqlen, + checkpoint_dir, + ) + + eval_module = get_multi_event_metric_module( + num_classes=model._task_config.prediction_head_arch[-1], + num_tasks=model._task_config.num_tasks, + metric_types=model._task_config.eval_metrics, + ) + + train_dataset, _ = get_dataset( + dataset_name=dataset_args.dataset_name, + dataset_path=dataset_args.dataset_path, + max_sequence_length=dataset_args.max_sequence_length, + max_num_candidates=dataset_args.max_num_candidates, + num_tasks=model._task_config.num_tasks, + batch_size=max_batch_size, + rank=0, + world_size=1, + shuffle=False, + random_seed=0, + eval_batch_size=max_batch_size, + ) + + dataloader = get_data_loader(dataset=train_dataset) + + num_kvc_test_rounds = 2 + + # torch.cuda.memory._record_memory_history() + # torch.cuda.profiler.start() + for round_id in [0, 1]: + dataloader_iter = iter(dataloader) + + length_ratio = (round_id + 1) / num_kvc_test_rounds + while True: + try: + batch = next(dataloader_iter) + if model._task_config.num_tasks > 0: + batch = strip_candidate_action_tokens( + batch, dataproc._action_feature_name + ) + + batch = batch.to(device=torch.cuda.current_device()) + + d = batch.features.to_dict() + user_ids = d["user_id"].values().cpu().long() + if user_ids.shape[0] != batch.batch_size: + batch = strip_padding_batch(batch, user_ids.shape[0]) + total_history_lengths = torch.sum(batch.features.lengths().view(-1, batch.batch_size), 0).view(-1) - batch.num_candidates + + if round_id != num_kvc_test_rounds - 1: + batch, total_history_lengths = get_new_batch(batch, total_history_lengths, length_ratio, num_contextual_features) + + # if int(user_ids[0]) == 0: + # pg.dmp = True + if not disable_kvcache: + logits = model.forward(batch, user_ids, total_history_lengths.cpu()) + else: + logits = model.forward_nokvcache(batch) + + if pg.dmp: + if disable_kvcache: + for lidx in range(model._hstu_config.num_layers): + if user_ids[0] < 10 or user_ids[0] >= 690: + shutil.move(f"/tmp/in_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy") + shutil.move(f"/tmp/key_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy") + shutil.move(f"/tmp/value_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy") + shutil.move(f"/tmp/attn_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy") + shutil.move(f"/tmp/out_l{lidx}.npy", f"dump/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy") + + else: + os.remove(f"/tmp/key_l{lidx}.npy") + os.remove(f"/tmp/value_l{lidx}.npy") + else: + for lidx in range(model._hstu_config.num_layers): + if user_ids[0] < 10 or user_ids[0] >= 690: + shutil.move(f"/tmp/in_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_in_l{lidx}.npy") + shutil.move(f"/tmp/key_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_key_l{lidx}.npy") + shutil.move(f"/tmp/value_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_value_l{lidx}.npy") + shutil.move(f"/tmp/attn_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_attn_l{lidx}.npy") + shutil.move(f"/tmp/out_l{lidx}.npy", f"cached/round{round_id}_user{user_ids[0]}_out_l{lidx}.npy") + else: + os.remove(f"/tmp/key_l{lidx}.npy") + os.remove(f"/tmp/value_l{lidx}.npy") + pg.dmp = False + + if round_id == num_kvc_test_rounds - 1: + eval_module(logits, batch.labels) + except StopIteration: + break + # torch.cuda.profiler.stop() + # torch.cuda.memory._dump_snapshot("my_snapshot.pickle") + + eval_metric_dict = eval_module.compute() + print( + f"[eval]:\n " + + stringify_dict(eval_metric_dict, prefix="Metrics", sep="\n ") + ) + # print("X") + +if __name__ == "__main__": + init() + parser = argparse.ArgumentParser(description="Inference End-to-end Example") + parser.add_argument("--gin_config_file", type=str, required=True) + parser.add_argument("--checkpoint_dir", type=str, required=True) + parser.add_argument("--disable_kvcache", action="store_true") + # parser.add_argument("--max_bs", type=int, required=True) + + + args = parser.parse_args() + gin.parse_config_file(args.gin_config_file) + + run_kvcache_consistency_check( + checkpoint_dir=args.checkpoint_dir, + disable_kvcache=args.disable_kvcache, + ) + print("Finished.") diff --git a/examples/hstu/inference/inference_gr_ranking_async.py b/examples/hstu/inference/inference_gr_ranking_async.py new file mode 100644 index 000000000..ddb1c8d0e --- /dev/null +++ b/examples/hstu/inference/inference_gr_ranking_async.py @@ -0,0 +1,458 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import enum +import math +import sys +import time +import os + +import gin +import torch +from commons.utils.stringify import stringify_dict +from configs import ( + InferenceEmbeddingConfig, + PositionEncodingConfig, + RankingConfig, + get_inference_hstu_config, + get_kvcache_config, +) +from dataset import get_data_loader +from dataset.inference_dataset import InferenceDataset +from dataset.sequence_dataset import get_dataset +from modules.metrics import get_multi_event_metric_module +from preprocessor import get_common_preprocessors +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from utils import DatasetArgs, NetworkArgs, RankingArgs + +sys.path.append("./model/") +from inference_ranking_gr import InferenceRankingGR + + +class RunningMode(enum.Enum): + EVAL = "eval" + SIMULATE = "simulate" + + def __str__(self): + return self.value + + +def get_inference_dataset_and_embedding_configs( + disable_contextual_features: bool = False, +): + dataset_args = DatasetArgs() + embedding_dim = NetworkArgs().hidden_size + HASH_SIZE = 10_000_000 + if dataset_args.dataset_name == "kuairand-1k": + embedding_configs = [ + InferenceEmbeddingConfig( + feature_names=["user_id"], + table_name="user_id", + vocab_size=1000, + dim=embedding_dim, + use_dynamicemb=True, + ), + InferenceEmbeddingConfig( + feature_names=["user_active_degree"], + table_name="user_active_degree", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["follow_user_num_range"], + table_name="follow_user_num_range", + vocab_size=9, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["fans_user_num_range"], + table_name="fans_user_num_range", + vocab_size=9, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["friend_user_num_range"], + table_name="friend_user_num_range", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["register_days_range"], + table_name="register_days_range", + vocab_size=8, + dim=embedding_dim, + use_dynamicemb=False, + ), + InferenceEmbeddingConfig( + feature_names=["video_id"], + table_name="video_id", + vocab_size=HASH_SIZE, + dim=embedding_dim, + use_dynamicemb=True, + ), + InferenceEmbeddingConfig( + feature_names=["action_weights"], + table_name="action_weights", + vocab_size=233, + dim=embedding_dim, + use_dynamicemb=False, + ), + ] + return ( + dataset_args, + embedding_configs + if not disable_contextual_features + else embedding_configs[-2:], + ) + + raise ValueError(f"dataset {dataset_args.dataset_name} is not supported") + + +def get_inference_hstu_model( + emb_configs, + max_batch_size, + num_contextual_features, + total_max_seqlen, + checkpoint_dir, +): + network_args = NetworkArgs() + if network_args.dtype_str == "bfloat16": + inference_dtype = torch.bfloat16 + # elif network_args.dtype_str == "float16": + # inference_dtype = torch.float16 + else: + raise ValueError( + f"Inference data type {network_args.dtype_str} is not supported" + ) + + position_encoding_config = PositionEncodingConfig( + num_position_buckets=8192, + num_time_buckets=2048, + use_time_encoding=False, + static_max_seq_len=math.ceil(total_max_seqlen / 32) * 32, + ) + + hstu_config = get_inference_hstu_config( + hidden_size=network_args.hidden_size, + num_layers=network_args.num_layers, + num_attention_heads=network_args.num_attention_heads, + head_dim=network_args.kv_channels, + dtype=inference_dtype, + position_encoding_config=position_encoding_config, + contextual_max_seqlen=num_contextual_features, + scaling_seqlen=network_args.scaling_seqlen, + ) + + kvcache_args = { + "blocks_in_primary_pool": 10240, + "page_size": 32, + "offload_chunksize": 1024, + "max_batch_size": max_batch_size, + "max_seq_len": math.ceil(total_max_seqlen / 32) * 32, + } + kv_cache_config = get_kvcache_config(**kvcache_args) + + ranking_args = RankingArgs() + task_config = RankingConfig( + embedding_configs=emb_configs, + prediction_head_arch=ranking_args.prediction_head_arch, + prediction_head_act_type=ranking_args.prediction_head_act_type, + prediction_head_bias=ranking_args.prediction_head_bias, + num_tasks=ranking_args.num_tasks, + eval_metrics=ranking_args.eval_metrics, + ) + + hstu_cudagraph_configs = { + "batch_size": [1], + "length_per_sequence": [128] + [i * 256 for i in range(1, 34)], + } + + model = InferenceRankingGR( + hstu_config=hstu_config, + kvcache_config=kv_cache_config, + task_config=task_config, + use_cudagraph=False, + cudagraph_configs=hstu_cudagraph_configs, + ) + if hstu_config.bf16: + model.bfloat16() + elif hstu_config.fp16: + model.half() + model.load_checkpoint(checkpoint_dir) + model.eval() + + return model + + +def run_ranking_gr_simulate( + checkpoint_dir: str, + check_auc: bool = False, + disable_contextual_features: bool = False, + disable_kvcache: bool = False, + max_bs: int = 1, +): + dataset_args, emb_configs = get_inference_dataset_and_embedding_configs( + disable_contextual_features + ) + + dataproc = get_common_preprocessors("")[dataset_args.dataset_name] + num_contextual_features = ( + len(dataproc._contextual_feature_names) + if not disable_contextual_features + else 0 + ) + + max_batch_size = max_bs + total_max_seqlen = dataset_args.max_sequence_length * 2 + num_contextual_features + print("total_max_seqlen", total_max_seqlen) + + with torch.inference_mode(): + model = get_inference_hstu_model( + emb_configs, + max_batch_size, + num_contextual_features, + total_max_seqlen, + checkpoint_dir, + ) + + if check_auc: + eval_module = get_multi_event_metric_module( + num_classes=model._task_config.prediction_head_arch[-1], + num_tasks=model._task_config.num_tasks, + metric_types=model._task_config.eval_metrics, + ) + + dataset = InferenceDataset( + seq_logs_file=dataproc._inference_sequence_file, + batch_logs_file=dataproc._inference_batch_file, + batch_size=max_batch_size, + max_seqlen=dataset_args.max_sequence_length, + item_feature_name=dataproc._item_feature_name, + contextual_feature_names=dataproc._contextual_feature_names + if not disable_contextual_features + else [], + action_feature_name=dataproc._action_feature_name, + max_num_candidates=dataset_args.max_num_candidates, + item_vocab_size=10_000_000, + userid_name="user_id", + date_name="date", + sequence_endptr_name="interval_indptr", + timestamp_names=["date", "interval_end_ts"], + ) + + dataloader = get_data_loader(dataset=dataset) + dataloader_iter = iter(dataloader) + + num_batches_ctr = 0 + start_time = time.time() + cur_date = None + while True: + try: + uids, dates, seq_endptrs = next(dataloader_iter) + if dates[0] != cur_date: + # if cur_date is not None: + # eval_metric_dict = eval_module.compute() + # print( + # f"[eval]:\n " + # + stringify_dict( + # eval_metric_dict, prefix="Metrics", sep="\n " + # ) + # ) + # model.clear_kv_cache() + if cur_date is not None: + break + cur_date = dates[0] + + batch = dataset.get_input_batch( + uids, + dates, + seq_endptrs, + torch.zeros_like(seq_endptrs), + with_contextual_features=True, + with_ranking_labels=False, + ) + total_history_lengths = seq_endptrs * 2 + num_contextual_features + + if batch is not None: + if not disable_kvcache: + logits = model.forward( + batch, + uids, + total_history_lengths, + ) + else: + logits = model.forward_nokvcache(batch) + # eval_module(logits, batch.labels) + + num_batches_ctr += 1 + # if num_batches_ctr == 1000: + # break + except StopIteration: + break + end_time = time.time() + print("Total #batch:", num_batches_ctr) + print("Total time(s):", end_time - start_time) + + +def run_ranking_gr_evaluate( + checkpoint_dir: str, + disable_contextual_features: bool = False, + disable_kvcache: bool = False, +): + dataset_args, emb_configs = get_inference_dataset_and_embedding_configs( + disable_contextual_features + ) + + dataproc = get_common_preprocessors("")[dataset_args.dataset_name] + num_contextual_features = ( + len(dataproc._contextual_feature_names) + if not disable_contextual_features + else 0 + ) + + max_batch_size = 1 + total_max_seqlen = dataset_args.max_sequence_length * 2 + num_contextual_features + print("total_max_seqlen", total_max_seqlen) + + def strip_candidate_action_tokens(batch, action_feature_name): + kjt_dict = batch.features.to_dict() + action_jagged_tensor = kjt_dict[action_feature_name] + values = action_jagged_tensor.values() + lengths = action_jagged_tensor.lengths() + num_candidates = batch.num_candidates + split_lengths = torch.stack( + [lengths - num_candidates, num_candidates], dim=1 + ).reshape((-1,)) + stripped_value = torch.split(values, split_lengths.tolist())[::2] + kjt_dict[action_feature_name] = JaggedTensor.from_dense(stripped_value) + batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict) + return batch + + def strip_padding_batch(batch, unpadded_batch_size): + batch.batch_size = unpadded_batch_size + kjt_dict = batch.features.to_dict() + for k in kjt_dict: + kjt_dict[k] = JaggedTensor.from_dense_lengths( + kjt_dict[k].to_padded_dense()[: batch.batch_size], + kjt_dict[k].lengths()[: batch.batch_size].long(), + ) + batch.features = KeyedJaggedTensor.from_jt_dict(kjt_dict) + batch.num_candidates = batch.num_candidates[: batch.batch_size] + return batch + + with torch.inference_mode(): + model = get_inference_hstu_model( + emb_configs, + max_batch_size, + num_contextual_features, + total_max_seqlen, + checkpoint_dir, + ) + + eval_module = get_multi_event_metric_module( + num_classes=model._task_config.prediction_head_arch[-1], + num_tasks=model._task_config.num_tasks, + metric_types=model._task_config.eval_metrics, + ) + + eval_dataset, _ = get_dataset( + dataset_name=dataset_args.dataset_name, + dataset_path=dataset_args.dataset_path, + max_sequence_length=dataset_args.max_sequence_length, + max_num_candidates=dataset_args.max_num_candidates, + num_tasks=model._task_config.num_tasks, + batch_size=max_batch_size, + rank=0, + world_size=1, + shuffle=False, + random_seed=0, + eval_batch_size=max_batch_size, + ) + + dataloader = get_data_loader(dataset=eval_dataset) + dataloader_iter = iter(dataloader) + + # torch.cuda.profiler.start() + while True: + try: + batch = next(dataloader_iter) + if model._task_config.num_tasks > 0: + batch = strip_candidate_action_tokens( + batch, dataproc._action_feature_name + ) + + batch = batch.to(device=torch.cuda.current_device()) + d = batch.features.to_dict() + user_ids = d["user_id"].values().cpu().long() + if user_ids.shape[0] != batch.batch_size: + batch = strip_padding_batch(batch, user_ids.shape[0]) + total_history_lengths = torch.sum(batch.features.lengths().view(-1, batch.batch_size), 0).view(-1) - batch.num_candidates + total_history_lengths = total_history_lengths.cpu() + print(batch.features.lengths()) + + if not disable_kvcache: + logits = model.forward(batch, user_ids, total_history_lengths) + else: + logits = model.forward_nokvcache(batch) + eval_module(logits, batch.labels) + except StopIteration: + break + # torch.cuda.profiler.stop() + + eval_metric_dict = eval_module.compute() + print( + f"[eval]:\n " + + stringify_dict(eval_metric_dict, prefix="Metrics", sep="\n ") + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Inference End-to-end Example") + parser.add_argument("--gin_config_file", type=str, required=True) + parser.add_argument("--checkpoint_dir", type=str, required=True) + parser.add_argument( + "--mode", type=RunningMode, choices=list(RunningMode), required=True + ) + parser.add_argument("--disable_auc", action="store_true") + parser.add_argument("--disable_context", action="store_true") + parser.add_argument("--disable_kvcache", action="store_true") + parser.add_argument("--max_bs", type=int, required=True) + + + args = parser.parse_args() + gin.parse_config_file(args.gin_config_file) + + if args.mode == RunningMode.EVAL: + if args.disable_auc: + print("disable_auc is ignored in Eval mode.") + if args.disable_context: + print("disable_context is ignored in Eval mode.") + run_ranking_gr_evaluate( + checkpoint_dir=args.checkpoint_dir, + disable_kvcache=args.disable_kvcache, + ) + elif args.mode == RunningMode.SIMULATE: + run_ranking_gr_simulate( + checkpoint_dir=args.checkpoint_dir, + check_auc=not args.disable_auc, + disable_contextual_features=args.disable_context, + disable_kvcache=args.disable_kvcache, + max_bs=args.max_bs, + ) + print("Finished.") diff --git a/examples/hstu/model/inference_ranking_gr.py b/examples/hstu/model/inference_ranking_gr.py index cdc4267a8..808ac0416 100755 --- a/examples/hstu/model/inference_ranking_gr.py +++ b/examples/hstu/model/inference_ranking_gr.py @@ -26,14 +26,13 @@ get_kvcache_metadata_buffer, ) from dataset.utils import Batch -from modules.gpu_kv_cache_manager import HSTUGpuKVCacheManager -from modules.host_kv_storage_manager import HSTUHostKVStorageManager from modules.hstu_block_inference import HSTUBlockInference from modules.inference_embedding import InferenceEmbedding from modules.jagged_data import JaggedData from modules.mlp import MLP from ops.triton_ops.triton_jagged import triton_concat_2D_jagged - +from modules.async_kvcache_manager import AsyncHSTUKVCacheManager +import math def get_jagged_metadata_buffer(max_batch_size, max_seq_len, contextual_max_seqlen): int_dtype = torch.int32 @@ -134,11 +133,6 @@ def __init__( self._embedding_collection = InferenceEmbedding(task_config.embedding_configs) - self._gpu_kv_cache_manager = HSTUGpuKVCacheManager(hstu_config, kvcache_config) - self._host_kv_storage_manager = HSTUHostKVStorageManager( - hstu_config, kvcache_config - ) - self._hstu_block = HSTUBlockInference(hstu_config, kvcache_config) self._mlp = MLP( self._embedding_dim, @@ -170,33 +164,33 @@ def __init__( self._jagged_metadata = get_jagged_metadata_buffer( max_batch_size, max_seq_len, hstu_config.contextual_max_seqlen ) - self._kvcache_metadata = get_kvcache_metadata_buffer( - hstu_config, kvcache_config + + self.async_kvcache = AsyncHSTUKVCacheManager( + hstu_config.num_layers, + hstu_config.num_heads, + hstu_config.head_dim, + kvcache_config.page_size, + kvcache_config.blocks_in_primary_pool, + math.ceil(kvcache_config.max_batch_size * kvcache_config.max_seq_len / kvcache_config.page_size), + 4 * math.ceil(kvcache_config.max_batch_size * kvcache_config.max_seq_len / kvcache_config.page_size), + kvcache_config.offload_chunksize, + -1, + kvcache_config.max_seq_len, + kvcache_config.max_batch_size, ) - self._offload_states = None - self._kvcache_metadata.onload_history_kv_buffer = [ - self._gpu_kv_cache_manager.get_onload_buffers(layer_idx) - for layer_idx in range(hstu_config.num_layers) - ] - self._kvcache_metadata.onload_history_kv_events = [ - torch.cuda.Event() for _ in range(hstu_config.num_layers) - ] - self._kvcache_metadata.kv_cache_table = [ - self._gpu_kv_cache_manager.get_kvcache_table(layer_idx) - for layer_idx in range(hstu_config.num_layers) - ] - # TODO(junyiq): Add cudagraph optimization for the MLP as well. - self.use_cudagraph = use_cudagraph - if use_cudagraph: - self._hstu_block.set_cudagraph( - max_batch_size, - max_seq_len, - self._hidden_states, - self._jagged_metadata, - self._kvcache_metadata, - cudagraph_configs=cudagraph_configs, - ) + from ops.triton_ops.common import set_use_runtime_max_seq_len, set_static_max_seq_lens + set_use_runtime_max_seq_len(False) + set_static_max_seq_lens(max_seq_len, max_seq_len) + # from ops.triton_ops.triton_position import triton_add_position_embeddings + # triton_add_position_embeddings( + # jagged=self._hidden_states, + # jagged_offsets=torch.tensor[], #seq_offsets, + # high_inds=high_inds, + # max_seq_len=max_seq_len, + # dense=self._position_embeddings_weight, + # scale=alpha, + # ind_offsets=ind_offsets) def bfloat16(self): """ @@ -314,223 +308,358 @@ def load_state_dict(self, model_state_dict, *args, **kwargs): if self._hstu_config.contextual_max_seqlen != 0: assert unloaded_modules.unexpected_keys == [] - def get_user_kvdata_info( - self, - user_ids: Union[List[int], torch.Tensor], - allow_bubble: bool = False, - dbg_print: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - kvdata_start_pos = list() - kvdata_lengths = list() - for idx in range(len(user_ids)): - uid = int(user_ids[idx]) - host_sp, host_len = self._host_kv_storage_manager.get_user_kvdata_info(uid) - gpu_sp, gpu_len = self._gpu_kv_cache_manager.get_user_kvdata_info(uid) - sp = host_sp if gpu_sp == -1 or gpu_len == 0 else min(host_sp, gpu_sp) - length = ( - host_len - if gpu_sp == -1 or gpu_len == 0 - else (gpu_sp + gpu_len - host_sp) - ) - if gpu_sp > host_sp + host_len and not allow_bubble: - warnings.warn( - f"KVdata missing between host storage and gpu kvcache for user {uid}" - ) - length = host_len - kvdata_start_pos.append(sp) - kvdata_lengths.append(length) - return ( - torch.tensor(kvdata_start_pos, dtype=torch.int32), - torch.tensor(kvdata_lengths, dtype=torch.int32), - ) + # def get_user_kvdata_info( + # self, + # user_ids: Union[List[int], torch.Tensor], + # allow_bubble: bool = False, + # dbg_print: bool = False, + # ) -> Tuple[torch.Tensor, torch.Tensor]: + # kvdata_start_pos = list() + # kvdata_lengths = list() + # for idx in range(len(user_ids)): + # uid = int(user_ids[idx]) + # host_sp, host_len = self._host_kv_storage_manager.get_user_kvdata_info(uid) + # gpu_sp, gpu_len = self._gpu_kv_cache_manager.get_user_kvdata_info(uid) + # sp = host_sp if gpu_sp == -1 or gpu_len == 0 else min(host_sp, gpu_sp) + # length = ( + # host_len + # if gpu_sp == -1 or gpu_len == 0 + # else (gpu_sp + gpu_len - host_sp) + # ) + # if gpu_sp > host_sp + host_len and not allow_bubble: + # warnings.warn( + # f"KVdata missing between host storage and gpu kvcache for user {uid}" + # ) + # length = host_len + # kvdata_start_pos.append(sp) + # kvdata_lengths.append(length) + # return ( + # torch.tensor(kvdata_start_pos, dtype=torch.int32), + # torch.tensor(kvdata_lengths, dtype=torch.int32), + # ) + + # def strip_contextual_features(self, embeddings, batch, user_start_pos): + # if int(min(user_start_pos)) >= len(batch.contextual_feature_names): + # embeddings = { + # batch.item_feature_name: embeddings[batch.item_feature_name], + # batch.action_feature_name: embeddings[batch.action_feature_name], + # } + # batch.contextual_feature_names = [] + # return embeddings, batch + # elif int(max(user_start_pos)) < len(batch.contextual_feature_names): + # return embeddings, batch + # else: + # raise Exception("Do not accept mixing contextual features input") + + # def prepare_kv_cache( + # self, batch: Batch, user_ids: torch.Tensor, user_start_pos: torch.Tensor + # ) -> KVCacheMetadata: + # batch_size = user_ids.shape[0] + # new_history_lengths = ( + # torch.sum(batch.features.lengths().view(-1, batch_size), 0).view(-1) + # - batch.num_candidates + # ) + # ( + # cached_start_pos, + # cached_lengths, + # ) = self._gpu_kv_cache_manager.get_batch_kvdata_info(user_ids) + + # self._gpu_kv_cache_manager.allocate( + # user_ids, user_start_pos, new_history_lengths + # ) + # kv_cache_metadata = self._gpu_kv_cache_manager.get_cache_metadata(user_ids) + # append_metadata = self._gpu_kv_cache_manager.get_append_metadata( + # new_history_lengths, kv_cache_metadata.total_history_lengths + # ) + # for _field_name in [ + # "batch_indices", + # "position", + # "new_history_nnz", + # "new_history_nnz_cuda", + # ]: + # setattr( + # kv_cache_metadata, _field_name, getattr(append_metadata, _field_name) + # ) + + # kv_cache_metadata.onload_history_kv_buffer = ( + # self._kvcache_metadata.onload_history_kv_buffer[:] + # ) + # kv_cache_metadata.onload_history_kv_events = ( + # self._kvcache_metadata.onload_history_kv_events[:] + # ) + # kv_cache_metadata.kv_cache_table = self._kvcache_metadata.kv_cache_table[:] + # ( + # onload_length, + # onload_kv_page_ids, + # onload_kv_page_indptr, + # ) = self._host_kv_storage_manager.lookup_kvdata( + # user_ids, cached_start_pos, cached_lengths + # ) + # if onload_length > 0: + # kv_page_ids = triton_concat_2D_jagged( + # max_seq_len=onload_kv_page_indptr[-1] + kv_cache_metadata.kv_indptr[-1], + # values_a=onload_kv_page_ids.view(-1, 1), + # values_b=kv_cache_metadata.kv_indices.view(-1, 1), + # offsets_a=onload_kv_page_indptr.to(torch.int64), + # offsets_b=kv_cache_metadata.kv_indptr.to(torch.int64), + # ) + # kv_cache_metadata.kv_indices = kv_page_ids.view(-1) + # kv_cache_metadata.kv_indptr = ( + # onload_kv_page_indptr + kv_cache_metadata.kv_indptr + # ) + # self._gpu_kv_cache_manager.onload( + # self._host_kv_storage_manager.get_lookup_buffer(), + # onload_length, + # self._kvcache_metadata if self.use_cudagraph else kv_cache_metadata, + # ) + + # # cudagraph preparation + # if self.use_cudagraph: + # copy_kvcache_metadata(self._kvcache_metadata, kv_cache_metadata) + # # assert max(kv_cache_metadata.kv_indices.tolist()) < self._kvcache_metadata.kv_cache_table[0].shape[0] + + # return kv_cache_metadata + + # def finalize_kv_cache(self, user_ids: torch.Tensor, **kwargs): + # pass - def strip_contextual_features(self, embeddings, batch, user_start_pos): - if int(min(user_start_pos)) >= len(batch.contextual_feature_names): - embeddings = { - batch.item_feature_name: embeddings[batch.item_feature_name], - batch.action_feature_name: embeddings[batch.action_feature_name], - } - batch.contextual_feature_names = [] - return embeddings, batch - elif int(max(user_start_pos)) < len(batch.contextual_feature_names): - return embeddings, batch - else: - raise Exception("Do not accept mixing contextual features input") - - def prepare_kv_cache( - self, batch: Batch, user_ids: torch.Tensor, user_start_pos: torch.Tensor - ) -> KVCacheMetadata: - batch_size = user_ids.shape[0] - new_history_lengths = ( - torch.sum(batch.features.lengths().view(-1, batch_size), 0).view(-1) - - batch.num_candidates - ) - ( - cached_start_pos, - cached_lengths, - ) = self._gpu_kv_cache_manager.get_batch_kvdata_info(user_ids) + def clear_kv_cache(self): + self._gpu_kv_cache_manager.evict_all() + self._host_kv_storage_manager.evict_all_kvdata() - self._gpu_kv_cache_manager.allocate( - user_ids, user_start_pos, new_history_lengths - ) - kv_cache_metadata = self._gpu_kv_cache_manager.get_cache_metadata(user_ids) - append_metadata = self._gpu_kv_cache_manager.get_append_metadata( - new_history_lengths, kv_cache_metadata.total_history_lengths - ) - for _field_name in [ - "batch_indices", - "position", - "new_history_nnz", - "new_history_nnz_cuda", - ]: - setattr( - kv_cache_metadata, _field_name, getattr(append_metadata, _field_name) - ) + # def offload_kv_cache( + # self, user_ids: torch.Tensor, kvcache_metadata: KVCacheMetadata + # ): + # offload_results = self.offload_kv_cache_async(user_ids, kvcache_metadata) + # if offload_results is not None: + # self.offload_kv_cache_wait(offload_results) + + # def offload_kv_cache_async( + # self, user_ids: torch.Tensor, kvcache_metadata: KVCacheMetadata + # ): + # host_kvdata_start_pos, host_kvdata_lengths = zip( + # *[ + # self._host_kv_storage_manager.get_user_kvdata_info(int(user_ids[idx])) + # for idx in range(len(user_ids)) + # ] + # ) + # host_kvdata_start_pos = torch.tensor(host_kvdata_start_pos, dtype=torch.int32) + # host_kvdata_lengths = torch.tensor(host_kvdata_lengths, dtype=torch.int32) + + # offload_results = self._gpu_kv_cache_manager.offload_async( + # user_ids, host_kvdata_start_pos, host_kvdata_lengths, kvcache_metadata + # ) + # return offload_results + + # def offload_kv_cache_wait( + # self, + # offload_results: Optional[ + # Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor] + # ], + # ): + # if offload_results is not None: + # self._gpu_kv_cache_manager.offload_wait() + # self._host_kv_storage_manager.append_kvdata(*offload_results) + + # def forward( + # self, + # batch: Batch, + # user_ids: torch.Tensor, + # user_start_pos: torch.Tensor, + # ): + # with torch.inference_mode(): + # user_start_pos_cuda = user_start_pos.to( + # device=torch.cuda.current_device(), non_blocking=True + # ) + # kvcache_metadata = self.prepare_kv_cache(batch, user_ids, user_start_pos) + # embeddings = self._embedding_collection(batch.features) + # embeddings, batch = self.strip_contextual_features( + # embeddings, batch, user_start_pos + # ) + # jagged_data = self._hstu_block._preprocessor( + # embeddings=embeddings, + # batch=batch, + # seq_start_position=user_start_pos_cuda, + # ) + + # num_tokens = batch.features.values().shape[0] + # if self.use_cudagraph: + # self._hidden_states[:num_tokens, ...].copy_( + # jagged_data.values, non_blocking=True + # ) + # copy_jagged_metadata(self._jagged_metadata, jagged_data) + # self._kvcache_metadata.total_history_offsets += ( + # self._jagged_metadata.num_candidates_offsets + # ) + # # self.offload_kv_cache_wait(self._offload_states) + + # hstu_output = self._hstu_block.predict( + # batch.batch_size, + # num_tokens, + # self._hidden_states, + # self._jagged_metadata, + # self._kvcache_metadata, + # ) + # jagged_data.values = hstu_output + # else: + # kvcache_metadata.total_history_offsets += ( + # jagged_data.num_candidates_offsets + # ) + # # self.offload_kv_cache_wait(self._offload_states) + # hstu_output = self._hstu_block.predict( + # batch.batch_size, + # num_tokens, + # jagged_data.values, + # jagged_data, + # kvcache_metadata, + # ) + # jagged_data.values = hstu_output + + # self._gpu_kv_cache_manager._offload_start_event.record( + # torch.cuda.current_stream() + # ) + + # jagged_data = self._hstu_block._postprocessor(jagged_data) + # jagged_item_logit = self._mlp(jagged_data.values) + # self._offload_states = self.offload_kv_cache_async( + # user_ids, kvcache_metadata + # ) + # self.offload_kv_cache_wait(self._offload_states) + # self.finalize_kv_cache(user_ids) + + # return jagged_item_logit - kv_cache_metadata.onload_history_kv_buffer = ( - self._kvcache_metadata.onload_history_kv_buffer[:] - ) - kv_cache_metadata.onload_history_kv_events = ( - self._kvcache_metadata.onload_history_kv_events[:] - ) - kv_cache_metadata.kv_cache_table = self._kvcache_metadata.kv_cache_table[:] - ( - onload_length, - onload_kv_page_ids, - onload_kv_page_indptr, - ) = self._host_kv_storage_manager.lookup_kvdata( - user_ids, cached_start_pos, cached_lengths - ) - if onload_length > 0: - kv_page_ids = triton_concat_2D_jagged( - max_seq_len=onload_kv_page_indptr[-1] + kv_cache_metadata.kv_indptr[-1], - values_a=onload_kv_page_ids.view(-1, 1), - values_b=kv_cache_metadata.kv_indices.view(-1, 1), - offsets_a=onload_kv_page_indptr.to(torch.int64), - offsets_b=kv_cache_metadata.kv_indptr.to(torch.int64), - ) - kv_cache_metadata.kv_indices = kv_page_ids.view(-1) - kv_cache_metadata.kv_indptr = ( - onload_kv_page_indptr + kv_cache_metadata.kv_indptr + def forward( + self, + batch: Batch, + user_ids: torch.Tensor, + total_history_lengths: torch.Tensor, + ): + with torch.inference_mode(): + # print("[DEBUG] total_history_lengths", total_history_lengths) + user_ids_list = user_ids.tolist() + + prepare_kvcache_result = self.async_kvcache.prepare_kvcache_async( + batch.batch_size, + user_ids_list, + total_history_lengths.tolist(), + self.async_kvcache.static_page_ids_gpu_buffer, + self.async_kvcache.static_offload_page_ids_gpu_buffer, + self.async_kvcache.static_onload_handle, ) - self._gpu_kv_cache_manager.onload( - self._host_kv_storage_manager.get_lookup_buffer(), - onload_length, - self._kvcache_metadata if self.use_cudagraph else kv_cache_metadata, + # print("[DEBUG] return from trigger\n", flush=True) + + ( + old_cached_lengths, + num_history_tokens, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + kvcache_metadata_fut, + onload_fut, + ) = prepare_kvcache_result + # print("[DEBUG] old_cached_lengths", old_cached_lengths) + old_cached_lengths = torch.tensor(old_cached_lengths, dtype=torch.int32) + + striped_batch = self.async_kvcache.strip_cached_tokens( + batch, old_cached_lengths, ) - # cudagraph preparation - if self.use_cudagraph: - copy_kvcache_metadata(self._kvcache_metadata, kv_cache_metadata) - # assert max(kv_cache_metadata.kv_indices.tolist()) < self._kvcache_metadata.kv_cache_table[0].shape[0] - - return kv_cache_metadata - - def finalize_kv_cache(self, user_ids: torch.Tensor, **kwargs): - pass + embeddings = self._embedding_collection(striped_batch.features) + jagged_data = self._hstu_block._preprocessor( + embeddings=embeddings, + batch=striped_batch, + seq_start_position=old_cached_lengths.cuda(), + ) - def clear_kv_cache(self): - self._gpu_kv_cache_manager.evict_all() - self._host_kv_storage_manager.evict_all_kvdata() + kvcache_metadata = self.async_kvcache.prepare_kvcache_wait( + onload_fut, + kvcache_metadata_fut, + batch.batch_size, + num_history_tokens, + self.async_kvcache.static_page_ids_gpu_buffer, + self.async_kvcache.static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + self.async_kvcache.static_onload_handle, + ) - def offload_kv_cache( - self, user_ids: torch.Tensor, kvcache_metadata: KVCacheMetadata - ): - offload_results = self.offload_kv_cache_async(user_ids, kvcache_metadata) - if offload_results is not None: - self.offload_kv_cache_wait(offload_results) + # print("[DEBUG] kv_indices", kvcache_metadata.kv_indices) + # print("[DEBUG] kv_indptr", kvcache_metadata.kv_indptr) + # print("[DEBUG] kv_last_page_len", kvcache_metadata.kv_last_page_len) + # print("[DEBUG] total_history_lengths", kvcache_metadata.total_history_lengths) + # print("[DEBUG] total_history_offsets", kvcache_metadata.total_history_offsets) + # print("[DEBUG] seqlen", jagged_data.seqlen) + # print("[DEBUG] seqlen_offsets", jagged_data.seqlen_offsets) + # print("[DEBUG] num_candidates_offsets", jagged_data.num_candidates_offsets) + # print("[DEBUG] >>> ", kvcache_metadata.batch_indices.shape) + + # print("[DEBUG] batch_indices", kvcache_metadata.batch_indices) + # print("[DEBUG] position", kvcache_metadata.position) + # print("[DEBUG] batch_indices", kvcache_metadata.position[:154]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154:154*2]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*2:154*3]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*3:154*4]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*4:154*5]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*5:154*6]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*6:154*7]) + # print("[DEBUG] batch_indices", kvcache_metadata.position[154*7:154*8]) + # print("[DEBUG] >>> ", kvcache_metadata.position.shape) + # print("[DEBUG] new_history_nnz", kvcache_metadata.new_history_nnz) + # print("[DEBUG] new_history_nnz_cuda", kvcache_metadata.new_history_nnz_cuda) + + # print("kvcache_metadata.offload_user_ids", kvcache_metadata.offload_user_ids) + # print("kvcache_metadata.offload_page_ids", kvcache_metadata.offload_page_ids.shape) + + kvcache_metadata.total_history_offsets += jagged_data.num_candidates_offsets + kvcache_metadata.total_history_lengths += jagged_data.num_candidates + kvcache_metadata.max_seqlen += jagged_data.max_num_candidates + + # print("[DEBUG] <<>> total_history_offsets", kvcache_metadata.total_history_offsets) + # input() + + num_tokens = striped_batch.features.values().shape[0] + hstu_output = self._hstu_block.predict( + striped_batch.batch_size, + num_tokens, + jagged_data.values, + jagged_data, + kvcache_metadata, + ) + jagged_data.values = hstu_output - def offload_kv_cache_async( - self, user_ids: torch.Tensor, kvcache_metadata: KVCacheMetadata - ): - host_kvdata_start_pos, host_kvdata_lengths = zip( - *[ - self._host_kv_storage_manager.get_user_kvdata_info(int(user_ids[idx])) - for idx in range(len(user_ids)) - ] - ) - host_kvdata_start_pos = torch.tensor(host_kvdata_start_pos, dtype=torch.int32) - host_kvdata_lengths = torch.tensor(host_kvdata_lengths, dtype=torch.int32) + self.async_kvcache.offload_kvcache(kvcache_metadata) - offload_results = self._gpu_kv_cache_manager.offload_async( - user_ids, host_kvdata_start_pos, host_kvdata_lengths, kvcache_metadata - ) - return offload_results + jagged_data = self._hstu_block._postprocessor(jagged_data) + jagged_item_logit = self._mlp(jagged_data.values) - def offload_kv_cache_wait( - self, - offload_results: Optional[ - Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor] - ], - ): - if offload_results is not None: - self._gpu_kv_cache_manager.offload_wait() - self._host_kv_storage_manager.append_kvdata(*offload_results) + return jagged_item_logit + - def forward( + def forward_nokvcache( self, batch: Batch, - user_ids: torch.Tensor, - user_start_pos: torch.Tensor, ): with torch.inference_mode(): - user_start_pos_cuda = user_start_pos.to( - device=torch.cuda.current_device(), non_blocking=True - ) - kvcache_metadata = self.prepare_kv_cache(batch, user_ids, user_start_pos) + embeddings = self._embedding_collection(batch.features) - embeddings, batch = self.strip_contextual_features( - embeddings, batch, user_start_pos - ) jagged_data = self._hstu_block._preprocessor( embeddings=embeddings, batch=batch, - seq_start_position=user_start_pos_cuda, ) num_tokens = batch.features.values().shape[0] - if self.use_cudagraph: - self._hidden_states[:num_tokens, ...].copy_( - jagged_data.values, non_blocking=True - ) - copy_jagged_metadata(self._jagged_metadata, jagged_data) - self._kvcache_metadata.total_history_offsets += ( - self._jagged_metadata.num_candidates_offsets - ) - # self.offload_kv_cache_wait(self._offload_states) - - hstu_output = self._hstu_block.predict( - batch.batch_size, - num_tokens, - self._hidden_states, - self._jagged_metadata, - self._kvcache_metadata, - ) - jagged_data.values = hstu_output - else: - kvcache_metadata.total_history_offsets += ( - jagged_data.num_candidates_offsets - ) - # self.offload_kv_cache_wait(self._offload_states) - hstu_output = self._hstu_block.predict( - batch.batch_size, - num_tokens, - jagged_data.values, - jagged_data, - kvcache_metadata, - ) - jagged_data.values = hstu_output - - self._gpu_kv_cache_manager._offload_start_event.record( - torch.cuda.current_stream() + hstu_output = self._hstu_block.predict( + batch.batch_size, + num_tokens, + jagged_data.values, + jagged_data, + None, ) - + jagged_data.values = hstu_output jagged_data = self._hstu_block._postprocessor(jagged_data) jagged_item_logit = self._mlp(jagged_data.values) - self._offload_states = self.offload_kv_cache_async( - user_ids, kvcache_metadata - ) - self.offload_kv_cache_wait(self._offload_states) - self.finalize_kv_cache(user_ids) return jagged_item_logit + diff --git a/examples/hstu/modules/async_kvcache_manager.py b/examples/hstu/modules/async_kvcache_manager.py new file mode 100644 index 000000000..fb68b1de8 --- /dev/null +++ b/examples/hstu/modules/async_kvcache_manager.py @@ -0,0 +1,222 @@ +import torch +from concurrent.futures import ThreadPoolExecutor +import paged_kvcache_ops +import math +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from configs import KVCacheMetadata +import os + +# def offload_callback(fut): +# torch.cuda.empty_cache() + +class AsyncHSTUKVCacheManager: + def __init__( + self, + num_layers, + num_kv_heads, + kv_headdim, + num_tokens_per_page, + num_primary_cache_pages, + num_onload_buffer_pages, + num_reserved_buffer_pages, + num_tokens_per_chunk, + max_num_sequences, + max_sequence_length, + max_batch_size, + ): + self.executor = ThreadPoolExecutor(max_workers=1) + self.onload_worker = ThreadPoolExecutor(max_workers=1) + # self.offload_worker = ThreadPoolExecutor(max_workers=4) + + self.num_layers = num_layers + self.num_heads = num_kv_heads + self.head_dim = kv_headdim + self.page_size = num_tokens_per_page + self.num_primary_cache_pages = num_primary_cache_pages + self.num_onload_buffer_pages = num_onload_buffer_pages + self.num_reserved_buffer_pages = num_reserved_buffer_pages + self.chunk_size = num_tokens_per_chunk + self.max_num_sequences = max_num_sequences + self.max_sequence_length = max_sequence_length + self.max_batch_size = max_batch_size + self.max_num_pages_per_seq = math.ceil(self.max_sequence_length / self.page_size) + + self.cache_table = torch.empty( + [num_layers, (num_primary_cache_pages + num_onload_buffer_pages), 2, self.page_size, self.num_heads, self.head_dim], + dtype=torch.bfloat16, device=torch.cuda.current_device() + ) + + self.host_kv_mgr = paged_kvcache_ops.HostKVStorageImpl( + self.num_layers, self.num_heads, self.head_dim, self.page_size, self.chunk_size + ) + self.gpu_kvcache_mgr = paged_kvcache_ops.GPUKVCacheMangerImpl( + self.num_layers, self.num_heads, self.head_dim, self.page_size, + self.num_primary_cache_pages, self.num_onload_buffer_pages, + self.num_reserved_buffer_pages, self.chunk_size, + self.max_num_sequences, self.max_num_sequences, + self.cache_table, + self.host_kv_mgr + ) + + self.static_page_ids_gpu_buffer = torch.empty([self.max_batch_size * self.max_num_pages_per_seq,], dtype=torch.int32).cuda() + self.static_offload_page_ids_gpu_buffer = torch.empty([self.max_batch_size * self.max_num_pages_per_seq,], dtype=torch.int32).cuda() + # self.static_pinned_kv_buffer = torch.empty( + # [self.num_layers, self.max_batch_size * self.max_num_pages_per_seq, 2, self.page_size, self.num_heads, self.head_dim], + # dtype=torch.bfloat16, pin_memory=True + # ) + self.static_onload_handle = paged_kvcache_ops.KVOnloadHandle(self.num_layers) + + self.cache_table_list = [ self.cache_table[idx] for idx in range(self.num_layers) ] + + def prepare_kvcache_async(self, + batch_size, + user_ids, + total_history_lengths, + static_page_ids_gpu_buffer, + static_offload_page_ids_gpu_buffer, + # static_pinned_kv_buffer, + static_onload_handle, + ): + origin_cached_lengths = self.gpu_kvcache_mgr.get_total_cache_length(user_ids) + new_tokens = sum([ total_history_lengths[idx] - origin_cached_lengths[idx] for idx in range(batch_size) ]) + if new_tokens <= 0: + print(total_history_lengths) + print(origin_cached_lengths) + + offload_uids_buffer = torch.empty([batch_size,], dtype=torch.int64) + metadata_host_buffer = torch.empty([batch_size * 7 + 7,], dtype=torch.int, pin_memory=True) + metadata_gpu_buffer = torch.empty([batch_size * 5 + 4 + new_tokens * 2,], dtype=torch.int, device = torch.cuda.current_device()) + + kvcache_metadata_fut = self.executor.submit(paged_kvcache_ops.prepare_kvcache, + self.gpu_kvcache_mgr, self.host_kv_mgr, + user_ids, total_history_lengths, + static_page_ids_gpu_buffer, static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, metadata_gpu_buffer) + + static_onload_handle.reset() + onload_fut = self.onload_worker.submit(self.gpu_kvcache_mgr.onload_kvcache, + user_ids, static_onload_handle) + + return origin_cached_lengths, new_tokens, offload_uids_buffer, metadata_host_buffer, metadata_gpu_buffer, kvcache_metadata_fut, onload_fut + + def prepare_kvcache_wait(self, + onload_fut, + kvcache_metadata_fut, + batch_size, + new_tokens, + static_page_ids_gpu_buffer, + static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + static_onload_handle): + + # onload_fut.result() + kvcache_metadata_fut.result() + return self.get_kvcache_metadata_from_buffer( + batch_size, + new_tokens, + static_page_ids_gpu_buffer, + static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + static_onload_handle) + + def offload_kvcache(self, kvcache_metadata): + num_offload_pages = len(kvcache_metadata.offload_page_ids) + if num_offload_pages == 0: + return None + + kvcache_metadata.kv_offload_handle.record_ready() + + kvcache_metadata.gather_kv_gpu_buffer = torch.empty( + [self.num_layers * num_offload_pages, 2, self.page_size, self.num_heads, self.head_dim], + dtype = torch.bfloat16, device = torch.cuda.current_device(), + ) + + self.gpu_kvcache_mgr.offload_kvcache( + kvcache_metadata.kv_offload_handle, + kvcache_metadata.offload_user_ids, + kvcache_metadata.offload_page_ids, + kvcache_metadata.gather_kv_gpu_buffer, + kvcache_metadata.new_offload_startpos, + kvcache_metadata.new_offload_lengths, + ) + + def get_kvcache_metadata_from_buffer(self, + batch_size, + new_tokens, + static_page_ids_gpu_buffer, + static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + static_onload_handle): + # assert int(metadata_host_buffer[batch_size * 4 + 2]) == new_tokens + return KVCacheMetadata( + kv_indices = static_page_ids_gpu_buffer[: metadata_host_buffer[batch_size * 7 + 4]], + kv_indptr = metadata_gpu_buffer[: batch_size + 1], + kv_last_page_len = metadata_gpu_buffer[batch_size + 1 : batch_size * 2 + 1], + total_history_lengths = metadata_gpu_buffer[batch_size * 2 + 1 : batch_size * 3 + 1], + total_history_offsets = metadata_gpu_buffer[batch_size * 3 + 1 : batch_size * 4 + 2], + batch_indices = metadata_gpu_buffer[batch_size * 5 + 4:batch_size * 5 + 4 + new_tokens], + position = metadata_gpu_buffer[batch_size * 5 + 4 + new_tokens:batch_size * 5 + 4 + new_tokens*2], + new_history_nnz = new_tokens, + new_history_nnz_cuda = metadata_gpu_buffer[batch_size * 4 + 2 : batch_size * 4 + 3], + kv_cache_table = self.cache_table_list, + kv_onload_handle = static_onload_handle, + kv_offload_handle = paged_kvcache_ops.KVOffloadHandle(), + offload_user_ids = offload_uids_buffer[: metadata_host_buffer[batch_size * 7 + 6]], + offload_page_ids = static_offload_page_ids_gpu_buffer[: int(metadata_host_buffer[batch_size * 7 + 5])].clone(), + new_offload_startpos = metadata_host_buffer[batch_size * 5 + 4 : batch_size * 6 + 4], + new_offload_lengths = metadata_host_buffer[batch_size * 6 + 4 : batch_size * 7 + 4], + max_seqlen = torch.max(metadata_host_buffer[batch_size * 2 + 1 : batch_size * 3 + 1]).item(), + ) + + + def strip_cached_tokens(self, batch, origin_num_cached): + torch.cuda.nvtx.range_push("strip_cached_tokens") + + num_context = len(batch.contextual_feature_names) + + num_cached = torch.maximum(origin_num_cached - num_context, torch.tensor([0], dtype=torch.int32)) + num_cached_action = num_cached // 2 + num_cached_item = num_cached - num_cached_action + num_hist_cached = torch.concat([num_cached_item, num_cached_action], dim=0) + + old_offsets = batch.features.offsets().cpu() + old_lengths = batch.features.lengths().cpu() + + item_offset = num_context * batch.batch_size + act_offset = item_offset + batch.batch_size + + new_lengths = torch.zeros_like(old_lengths) + new_lengths[:item_offset] = torch.where( + (origin_num_cached == 0).view(-1, batch.batch_size), + old_lengths[:item_offset].view(-1, batch.batch_size), + new_lengths[:item_offset].view(-1, batch.batch_size)).view(-1) + new_lengths[item_offset:] = old_lengths[item_offset:] - num_hist_cached + + startpos = old_offsets[item_offset : item_offset + 2 * batch.batch_size] + num_hist_cached + endpos = old_offsets[item_offset + 1 :] + + old_values = batch.features.values() + new_hist_value = [ + old_values[startpos[idx]:endpos[idx]] for idx in range(2*batch.batch_size) + ] + + new_context_value = [ + old_values[idx: idx + 1] for idx in range(num_context*batch.batch_size) if int(new_lengths[idx]) > 0 + ] + + new_features = KeyedJaggedTensor( + values = torch.cat(new_context_value + new_hist_value, dim = 0), + lengths = new_lengths.cuda(), + keys = batch.features.keys() + ) + batch.features = new_features + + torch.cuda.nvtx.range_pop() + return batch diff --git a/examples/hstu/modules/hstu_block_inference.py b/examples/hstu/modules/hstu_block_inference.py index c687f1d81..60c3092c5 100644 --- a/examples/hstu/modules/hstu_block_inference.py +++ b/examples/hstu/modules/hstu_block_inference.py @@ -160,9 +160,10 @@ def predict_cudagraph( self._hstu_graph[batch_size][num_tokens_padded][0].replay() # type: ignore for idx in range(1, self.config.num_layers + 1): - kv_cache_metadata.onload_history_kv_events[idx - 1].wait( - torch.cuda.current_stream() - ) + # kv_cache_metadata.onload_history_kv_events[idx - 1].wait( + # torch.cuda.current_stream() + # ) + kv_cache_metadata.kv_onload_handle.wait_host(idx - 1) self._hstu_graph[batch_size][num_tokens_padded][idx].replay() # type: ignore hstu_output = torch.zeros_like(hidden_states[:num_tokens, ...]) diff --git a/examples/hstu/modules/hstu_processor.py b/examples/hstu/modules/hstu_processor.py index 38c99c4dc..cd71a5c97 100644 --- a/examples/hstu/modules/hstu_processor.py +++ b/examples/hstu/modules/hstu_processor.py @@ -155,29 +155,32 @@ def hstu_preprocess_embeddings( contextual_jts_offsets, contextual_max_seqlens, ) - if contextual_mlp is not None: - contextual_sequence_embeddings = contextual_mlp( - contextual_sequence_embeddings + if torch.sum(contextual_seqlen, dim=0).cpu().item() == 0: + contextual_seqlen = None + else: + if contextual_mlp is not None: + contextual_sequence_embeddings = contextual_mlp( + contextual_sequence_embeddings + ) + contextual_seqlen_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + contextual_seqlen + ) + contextual_max_seqlen = max( + len(batch.contextual_feature_names), sum(contextual_max_seqlens) + ) + ( + sequence_embeddings, + sequence_embeddings_lengths, + ) = jagged_2D_tensor_concat( + [contextual_sequence_embeddings, sequence_embeddings], + [contextual_seqlen_offsets, sequence_embeddings_lengths_offsets], + [contextual_max_seqlen, sequence_max_seqlen], ) - contextual_seqlen_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( - contextual_seqlen - ) - contextual_max_seqlen = max( - len(batch.contextual_feature_names), sum(contextual_max_seqlens) - ) - ( - sequence_embeddings, - sequence_embeddings_lengths, - ) = jagged_2D_tensor_concat( - [contextual_sequence_embeddings, sequence_embeddings], - [contextual_seqlen_offsets, sequence_embeddings_lengths_offsets], - [contextual_max_seqlen, sequence_max_seqlen], - ) - sequence_embeddings_lengths_offsets = ( - torch.ops.fbgemm.asynchronous_complete_cumsum(sequence_embeddings_lengths) - ) - sequence_max_seqlen = sequence_max_seqlen + contextual_max_seqlen + sequence_embeddings_lengths_offsets = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(sequence_embeddings_lengths) + ) + sequence_max_seqlen = sequence_max_seqlen + contextual_max_seqlen return JaggedData( values=sequence_embeddings, diff --git a/examples/hstu/modules/paged_hstu_infer_layer.py b/examples/hstu/modules/paged_hstu_infer_layer.py index 8680f94b8..86f0b30a7 100644 --- a/examples/hstu/modules/paged_hstu_infer_layer.py +++ b/examples/hstu/modules/paged_hstu_infer_layer.py @@ -23,6 +23,11 @@ from ops.triton_ops.triton_layer_norm import triton_weighted_layer_norm_fwd from ops.triton_ops.triton_norm_mul_dropout import triton_layer_norm_mul_dropout_fwd +import numpy as np + +def init(): + global dmp + dmp = False class PagedHSTUInferLayer(torch.nn.Module): """ @@ -268,7 +273,8 @@ def forward_naive( eps=self._eps, ) - mixed_uvqk = self.uvqk_addmm_impl(normed_input, num_tokens) + # mixed_uvqk = self.uvqk_addmm_impl(normed_input, 0) + mixed_uvqk = F.silu(torch.matmul(normed_input, self._linear_uvqk_weight) + self._linear_uvqk.bias) (user, value, query, key) = torch.split( mixed_uvqk, self._split_arg_list, @@ -279,64 +285,83 @@ def forward_naive( query = query.view(-1, self._num_heads, self._attention_dim_per_head) key = key.view(-1, self._num_heads, self._attention_dim_per_head) - kv_cache_table = kv_cache_metadata.kv_cache_table[self.layer_idx] - (paged_k_cache, paged_v_cache) = kv_cache_table.unbind(dim=1) - paged_kvcache_ops.append_kvcache( - key, - value, - kv_cache_metadata.batch_indices, - kv_cache_metadata.position, - jd.num_candidates_offsets[: batch_size + 1], - kv_cache_metadata.new_history_nnz_cuda, - num_tokens, # kv_cache_metadata.new_history_nnz - paged_k_cache, - paged_v_cache, - kv_cache_metadata.kv_indices, - kv_cache_metadata.kv_indptr, - kv_cache_metadata.kv_last_page_len, - 0, # NHD layout - ) + if kv_cache_metadata is not None: + kv_cache_table = kv_cache_metadata.kv_cache_table[self.layer_idx] + (paged_k_cache, paged_v_cache) = kv_cache_table.unbind(dim=1) + paged_kvcache_ops.append_kvcache( + key, + value, + kv_cache_metadata.batch_indices, + kv_cache_metadata.position, + jd.num_candidates_offsets[: batch_size + 1], + kv_cache_metadata.new_history_nnz_cuda, + kv_cache_metadata.new_history_nnz, + paged_k_cache, + paged_v_cache, + kv_cache_metadata.kv_indices, + kv_cache_metadata.kv_indptr, + kv_cache_metadata.kv_last_page_len, + 0, # NHD layout + ) + + kv_cache_metadata.kv_onload_handle.wait_host(self.layer_idx) + jagged_attn_output = hstu_attn.hstu_attn_varlen_func( + query, + key, + value, + jd.seqlen_offsets[: batch_size + 1], + kv_cache_metadata.total_history_offsets[: batch_size + 1], + jd.max_seqlen, + jd.max_seqlen, # kv_cache_metadata.max_seqlen, + num_contexts = None, + num_targets=jd.num_candidates[:batch_size], + target_group_size=1, + window_size=(-1, 0), + alpha=self._alpha, + rab=None, + has_drab=False, + kv_cache=kv_cache_table, + page_offsets=kv_cache_metadata.kv_indptr, + page_ids=kv_cache_metadata.kv_indices, + last_page_lens=kv_cache_metadata.kv_last_page_len, + cu_seqlens_t=jd.num_candidates_offsets[: batch_size + 1], + scaling_seqlen=jd.scaling_seqlen, + ) + else: + jagged_attn_output = hstu_attn.hstu_attn_varlen_func( + query, + key, + value, + jd.seqlen_offsets[: batch_size + 1], + jd.seqlen_offsets[: batch_size + 1], + jd.max_seqlen, + jd.max_seqlen, + num_contexts = None, + num_targets=jd.num_candidates[:batch_size], + target_group_size=1, + window_size=(-1, 0), + alpha=self._alpha, + rab=None, + has_drab=False, + scaling_seqlen=jd.scaling_seqlen, + ) - kv_cache_metadata.onload_history_kv_events[self.layer_idx].wait( - torch.cuda.current_stream() - ) - jagged_attn_output = hstu_attn.hstu_attn_varlen_func( - query, - key, - value, - jd.seqlen_offsets[: batch_size + 1], - kv_cache_metadata.total_history_offsets[: batch_size + 1], - self._max_seqlen, - self._max_seqlen, - num_contexts=jd.contextual_seqlen - if jd.contextual_seqlen is None - else jd.contextual_seqlen[:batch_size], - num_targets=jd.num_candidates[:batch_size], - target_group_size=1, - window_size=(-1, 0), - alpha=self._alpha, - rab=None, - has_drab=False, - kv_cache=kv_cache_table, - page_offsets=kv_cache_metadata.kv_indptr, - page_ids=kv_cache_metadata.kv_indices, - last_page_lens=kv_cache_metadata.kv_last_page_len, - cu_seqlens_t=jd.num_candidates_offsets[: batch_size + 1], - scaling_seqlen=jd.scaling_seqlen, - ) jagged_attn_output = jagged_attn_output.view( -1, self._num_heads * self._linear_dim_per_head ) parallel_input = self.norm_mul_impl( - jagged_attn_output, user, num_tokens >= 2048 + jagged_attn_output, user, False #num_tokens >= 2048 ) if self._residual: - layer_output = self.proj_addmm_impl(parallel_input, layer_input, num_tokens) + # layer_output = self.proj_addmm_impl(parallel_input, layer_input, 0) + layer_output = (torch.matmul(parallel_input, self._linear_proj_weight) + layer_input) else: - layer_output = self._linear_proj(parallel_input) + # layer_output = self._linear_proj(parallel_input) + layer_output = torch.matmul(parallel_input, self._linear_proj_weight) + return layer_output @torch.inference_mode() diff --git a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp b/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp index 168f1d669..4909e7210 100755 --- a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp +++ b/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_cuda.cpp @@ -23,10 +23,33 @@ #include #include #include -// #include +#include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define cudaCheck(ans) { cudaSuccesAssert((ans), __FILE__, __LINE__); } +inline void cudaSuccesAssert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + // if (abort) exit(code); + } +} + template cudaError_t AppendPagedKVCache(DType* k_data, DType* v_data, @@ -59,6 +82,32 @@ cudaError_t GatherPagedKVCache(DType* gather_kv, uint32_t nnz, cudaStream_t stream); +template +cudaError_t GatherPagedKVCacheAllLayers(DType* gather_kv, + IdType* page_ids, + uint32_t num_layers, + uint32_t stride_gather, + uint32_t stride_layer, + uint32_t num_heads, + uint32_t head_dim, + uint32_t page_size, + uint32_t stride_page, + uint32_t stride_k2v, + uint32_t stride_n, + uint32_t stride_h, + DType* kv_cache, + uint32_t nnz, + cudaStream_t stream); + +cudaError_t GetPagedBatchIndicesPositions( + int32_t batch_size, + int32_t* append_indptr, + int32_t* seq_lens_ptr, + int32_t* batch_indices_ptr, + int32_t* positions_ptr, + cudaStream_t stream +); + void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, at::Tensor positions, at::Tensor seqlen_offsets, at::Tensor nnz_cuda, unsigned int nnz, @@ -204,7 +253,1108 @@ void gather_paged_kv_cache(at::Tensor gather_kv_gpu_buffer, "GatherPagedKVCache failed with error: ", cudaGetErrorString(status)); } +void gather_paged_kv_cache_all_layers(uint16_t *gather_kv_gpu_buffer, + uint16_t *paged_kv_cache, + int *page_ids_to_offload, + uint32_t num_layers, + uint32_t stride_gather, + uint32_t stride_layer, + uint32_t num_heads, + uint32_t head_dim, + uint32_t page_size, + uint32_t stride_page, + uint32_t stride_k2v, + uint32_t stride_n, + uint32_t stride_h, + uint32_t num_pages, + cudaStream_t stream) { + // auto device = paged_kv_cache.device(); + // const c10::cuda::OptionalCUDAGuard device_guard(device); + + cudaError_t status; + status = GatherPagedKVCacheAllLayers( + reinterpret_cast(gather_kv_gpu_buffer), + static_cast(page_ids_to_offload), + num_layers, stride_gather, stride_layer, + num_heads, head_dim, page_size, + stride_page, stride_k2v, stride_n, stride_h, + reinterpret_cast(paged_kv_cache), + num_pages * page_size, stream); + TORCH_CHECK(status == cudaSuccess, + "GatherPagedKVCacheAllLayers failed with error: ", cudaGetErrorString(status)); +} + +namespace kvcache { + +class HostKVStorageImpl +{ +public: + HostKVStorageImpl( + int num_layers, + int num_kv_heads, + int kv_headdim, + int num_tokens_per_page, + int64_t num_tokens_per_chunk + ) + : num_layers(num_layers) + , num_kv_heads(num_kv_heads) + , kv_headdim(kv_headdim) + , page_size(num_tokens_per_page) + , chunk_size(num_tokens_per_chunk) + , _uid_to_chunk_id(num_layers, std::unordered_map>()) + { + this->chunk_numel = num_tokens_per_chunk * 2 * num_kv_heads * kv_headdim; + this->page_numel = 2 * page_size * num_kv_heads * kv_headdim; + this->per_token_numel = 2 * num_kv_heads * kv_headdim; + }; + + ~HostKVStorageImpl() + {} + + int64_t get_kvdata_length(int64_t user_id) { + auto it = _uid_to_length.find(user_id); + if (it == _uid_to_length.end()) return 0; + return it->second; + }; + + void append_kvdata_v2(int64_t user_id, int64_t start_position, int64_t length, uint16_t *pinned_input_ptr, size_t gather_layer_stride) { + assert(length % this->chunk_size == 0); + if (start_position != 0) { + assert(_uid_to_length[user_id] == start_position); + } + else { + assert(_uid_to_length.find(user_id) == _uid_to_length.end()); + for (int layer_idx = 0; layer_idx < num_layers; layer_idx++) + _uid_to_chunk_id[layer_idx][user_id] = std::vector(); + _uid_to_mempool[user_id] = std::vector(); + } + + size_t num_chunks = length / chunk_size; + size_t num_elem = length * per_token_numel; + size_t kvdata_size = num_elem * sizeof(uint16_t); + + for (int layer_idx = 0; layer_idx < num_layers; layer_idx++) { + uint16_t* src_ptr = pinned_input_ptr + layer_idx * gather_layer_stride; + + for (size_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx ++) { + _uid_to_chunk_id[layer_idx][user_id].push_back(reinterpret_cast(src_ptr + chunk_idx * this->chunk_numel)); + } + + _uid_to_mempool[user_id].push_back(reinterpret_cast(src_ptr)); + } + _uid_to_length[user_id] = start_position + length; + }; + + std::vector get_kvdata_v2(int64_t user_id, int64_t length, int64_t layer_idx) { + // int64_t offloaded_length = get_kvdata_length(user_id); + // assert(offloaded_length >= length); + + std::vector chunk_ptrs; + if (length == 0) { + return chunk_ptrs; + } + // assert(length % this->chunk_size == 0); + size_t num_chunks = length / chunk_size; + const size_t chunk_bytesize = this->chunk_numel * sizeof(uint16_t); + const auto &chunk_ptr_list = _uid_to_chunk_id[layer_idx][user_id]; + // assert(chunk_ptr_list.size() >= num_chunks); + for (size_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx ++) { + uint16_t* src_ptr = reinterpret_cast(chunk_ptr_list[chunk_idx]); + chunk_ptrs.push_back(src_ptr); + } + return chunk_ptrs; + }; + +public: + std::vector>> _uid_to_chunk_id; + std::unordered_map _uid_to_length; + std::unordered_map> _uid_to_mempool; + + const int num_layers; + const int num_kv_heads; + const int kv_headdim; + const int page_size; + + const int64_t chunk_size; + size_t chunk_numel; + size_t page_numel; + size_t per_token_numel; + size_t layer_numel; +}; + +// class PinnedDoubleBuffer { +// public: +// public: +// }; + +class KVOnloadHandle { +public: + KVOnloadHandle( + int num_layers + ) + : num_layers(num_layers) + , event(std::vector(num_layers)) + , host_complete(num_layers, 0) { + for (int layer_idx = 0; layer_idx < num_layers; layer_idx ++) { + cudaEventCreate(&event[layer_idx]); + } + }; + + ~KVOnloadHandle(){ + for (int layer_idx = 0; layer_idx < num_layers; layer_idx ++) { + cudaEventDestroy(event[layer_idx]); + } + }; + + // void record(int layer_idx, cudaStream_t stream) { + // cudaEventRecord(event[layer_idx], stream); + // }; + + // void wait(int layer_idx) { + // auto stream = at::cuda::getCurrentCUDAStream(); + // cudaStreamWaitEvent(stream, event[layer_idx], 0); + // }; + void reset() { + for (int layer_idx = 0; layer_idx < num_layers; layer_idx ++) { + host_complete[layer_idx] = 0; + } + } + + void complete_host(int layer_idx, cudaStream_t stream) { + cudaEventRecord(event[layer_idx], stream); + { + std::unique_lock lock(mtx_); + host_complete[layer_idx] = 1; + } + cv_.notify_one(); + }; + + void wait_host(int layer_idx) { + { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this, layer_idx](){ return host_complete[layer_idx] == 1; }); + } + auto stream = at::cuda::getCurrentCUDAStream(); + cudaStreamWaitEvent(stream, event[layer_idx], 0); + }; + + void wait_all(void); +public: + int num_layers; + std::vector event; + std::mutex mtx_; + std::condition_variable cv_; + std::vector host_complete; +}; + +class KVOffloadHandle { +public: + +void record_ready(void) { + auto stream = at::cuda::getCurrentCUDAStream(); + cudaEventCreate(&ready_event); + cudaEventRecord(ready_event, stream); + }; + + // void wait_ready(cudaStream_t stream) { + // cudaStreamWaitEvent(stream, ready_event); + // }; + +public: + cudaEvent_t ready_event; +}; + +class GPUKVCacheMangerImpl +{ +public: + GPUKVCacheMangerImpl( + int num_layers, + int num_kv_heads, + int kv_headdim, + int num_tokens_per_page, + int num_primary_cache_pages, + int num_onload_buffer_pages, + int num_reserved_buffer_pages, + int num_tokens_per_chunk, + int max_num_sequences, + int max_sequence_length, + at::Tensor cache_table_tensor, + HostKVStorageImpl& host_kv_mgr) + : num_layers(num_layers) + , num_kv_heads(num_kv_heads) + , kv_headdim(kv_headdim) + , num_tokens_per_page(num_tokens_per_page) + , num_primary_cache_pages(num_primary_cache_pages) + , num_onload_buffer_pages(num_onload_buffer_pages) + , num_reserved_buffer_pages(num_reserved_buffer_pages) + , num_tokens_per_chunk(num_tokens_per_chunk) + , max_num_sequences(max_num_sequences) + , max_sequence_length(max_sequence_length) + , cache_table(static_cast(cache_table_tensor.data_ptr())) + , device(cache_table_tensor.device()) + , onload_pinned_buffers(2) + , onload_memcpy_event(2) + , offload_pinned_buffers(2) + , offload_memcpy_event(2) + , onload_memcpy_barrier_(std::barrier<>(3 + 1)) + { + const c10::cuda::OptionalCUDAGuard device_guard(this->device); + + for (int page_id = 0; page_id < num_primary_cache_pages; page_id++) + _empty_pages.push(page_id); + + page_stride = 2 * num_tokens_per_page * num_kv_heads * kv_headdim; + k2v_stride = num_tokens_per_page * num_kv_heads * kv_headdim; + layer_stride = (num_primary_cache_pages + num_onload_buffer_pages) * page_stride; + + per_token_kv_stride = 2 * num_kv_heads * kv_headdim; + + cudaStreamCreate(&worker_stream); + cudaStreamCreate(&onload_stream); + cudaStreamCreate(&offload_stream); + + this->host_kv_mgr = &host_kv_mgr; + + for (int i = 0; i < 2; i++) { + cudaMallocHost((void**)&onload_pinned_buffers[i], host_kv_mgr.chunk_numel * sizeof(uint16_t)); + cudaEventCreate(&onload_memcpy_event[i]); + } + + for (int i = 0; i < 2; i++) { + cudaMallocHost((void**)&offload_pinned_buffers[i], host_kv_mgr.chunk_numel * sizeof(uint16_t)); + cudaEventCreate(&offload_memcpy_event[i]); + } + + this->terminate_ = false; + this->num_onload_memcpy_worker = 3; + for (int i = 0; i < num_onload_memcpy_worker; i++) { + this->onload_memcpy_worker.emplace_back(std::thread( + &GPUKVCacheMangerImpl::onload_host_memcpy_loop, this + )); + } + // ; + + this->num_offload_memcpy_worker = 3; + for (int i = 0; i < num_offload_memcpy_worker; i++) { + this->offload_memcpy_worker.emplace_back(std::thread( + &GPUKVCacheMangerImpl::offload_host_memcpy_loop, this + )); + } + + this->queued_offload_tokens = 0; + this->queued_offload_limits = num_reserved_buffer_pages * num_tokens_per_page; + this->offload_busy_.store(false); + this->offload_worker = std::thread(&GPUKVCacheMangerImpl::offload_loop, this); + + }; + + ~GPUKVCacheMangerImpl() { + { + std::unique_lock lock(onload_memcpy_task_mtx_); + std::unique_lock lock2(offload_memcpy_task_mtx_); + std::unique_lock lock3(offload_task_mutex_); + this->terminate_ = true; + } + onload_memcpy_task_cv_.notify_all(); + offload_memcpy_task_cv_.notify_all(); + offload_task_cv_.notify_all(); + + for (int i = 0; i < num_onload_memcpy_worker; i++) { + this->onload_memcpy_worker[i].join(); + } + + for (int i = 0; i < num_offload_memcpy_worker; i++) { + this->offload_memcpy_worker[i].join(); + } + + this->offload_worker.join(); + + for (int i = 0; i < 2; i++) { + cudaFree(onload_pinned_buffers[i]); + cudaFree(offload_pinned_buffers[i]); + cudaEventDestroy(onload_memcpy_event[i]); + cudaEventDestroy(offload_memcpy_event[i]); + } + } + + int64_t getUIdToEvict(std::unordered_set extra_freezed_uids) { + while (true) { + int num_offloading_uids = 0; + { + std::unique_lock lock(offload_freezed_uids_mtx_); + num_offloading_uids = offload_freezed_uids_.size(); + // std::cout << "Saw " << num_offloading_uids << " freezed for offloading" << std::endl; + + for (auto it = std::rbegin(_lru_list); it != std::rend(_lru_list); ++it) { + if (offload_freezed_uids_.find((int64_t)*it) != offload_freezed_uids_.end()) + continue; + if (extra_freezed_uids.find((int64_t)*it) != extra_freezed_uids.end()) + continue; + return *it; + } + } + if (num_offloading_uids == 0) assert(false); + + std::this_thread::yield(); + } + + return _lru_list.back(); + }; + + std::vector& alloc(int64_t uid, int new_total_length, std::unordered_set freezed_uids) { + int cur_cached_start = 0; + int cur_cached_len = 0; + // int padding_last_page = 0; + + bool found_in_cache = retain(uid); + if (found_in_cache) { + cur_cached_start = _uid_to_paged_cache_startpos[uid]; + cur_cached_len = _uid_to_paged_cache_length[uid]; + } else { + _uid_to_page_id[uid] = std::vector(); + if (_uid_to_offloaded_length.find(uid) != _uid_to_offloaded_length.end()) { + _uid_to_paged_cache_startpos[uid] = _uid_to_offloaded_length[uid]; + cur_cached_start = _uid_to_offloaded_length[uid]; + } + else { + _uid_to_paged_cache_startpos[uid] = 0; + } + } + + int new_cached_len = new_total_length - cur_cached_start; + int cur_num_pages = (cur_cached_len + num_tokens_per_page - 1) / num_tokens_per_page; + int new_num_pages = (new_cached_len + num_tokens_per_page - 1) / num_tokens_per_page; + + int num_append_pages = new_num_pages - cur_num_pages; + // std::cout << " *** " << cur_num_pages << " " << cur_cached_len << std::endl; + // std::cout << " *** " << new_num_pages << " " << new_cached_len << std::endl; + + while ((size_t)num_append_pages > _empty_pages.size()) { + int64_t uid_to_evict = getUIdToEvict(freezed_uids); + evict(uid_to_evict); + // std::cout << "evict " << uid_to_evict << std::endl; + } + + std::vector& page_ids = _uid_to_page_id[uid]; + // for (auto pid : page_ids) { + // // std::cout << " - " << pid << std::endl; + // } + for (int i = 0; i < num_append_pages; i++) { + page_ids.push_back(_empty_pages.front()); + _empty_pages.pop(); + } + _uid_to_paged_cache_length[uid] = new_cached_len; + + return page_ids; + }; + + std::vector get_total_cache_length(std::vector& uids) { + int batch_size = uids.size(); + std::vector total_cached_lengths(batch_size); + for (int seq_idx = 0; seq_idx < batch_size; seq_idx++) { + int64_t uid = uids[seq_idx]; + if (_uid_to_paged_cache_startpos.find(uid) != _uid_to_paged_cache_startpos.end()) { + total_cached_lengths[seq_idx] = _uid_to_paged_cache_startpos[uid] + _uid_to_paged_cache_length[uid]; + } else if (_uid_to_offloaded_length.find(uid) != _uid_to_offloaded_length.end()) + total_cached_lengths[seq_idx] = _uid_to_offloaded_length[uid]; + else { + total_cached_lengths[seq_idx] = 0; + } + } + return total_cached_lengths; + }; + + void evict(int64_t uid) + { + auto const tableIt = _lru_lookup_table.find(uid); + assert(_lru_lookup_table.end() != tableIt); + // if (_lru_lookup_table.end() != tableIt) { + _lru_list.erase(tableIt->second); + _lru_lookup_table.erase(tableIt); + // assert(_uid_to_page_id[uid].size() > 0); + + for (auto page_id : _uid_to_page_id[uid]) { + _empty_pages.push(page_id); + } + + _uid_to_page_id.erase(uid); + _uid_to_paged_cache_startpos.erase(uid); + _uid_to_paged_cache_length.erase(uid); + // } + }; + + void evict_all() + { + std::queue empty_pages; + std::swap(_empty_pages, empty_pages); + _lru_list.clear(); + _lru_lookup_table.clear(); + _uid_to_page_id.clear(); + _uid_to_paged_cache_startpos.clear(); + _uid_to_paged_cache_length.clear(); + + for (int page_id = 0; page_id < this->num_primary_cache_pages; page_id++) + _empty_pages.push(page_id); + }; + + void invalid(int64_t uid) { + auto const tableIt = _lru_lookup_table.find(uid); + if (_lru_lookup_table.end() != tableIt) { + _lru_list.erase(tableIt->second); + _lru_lookup_table.erase(tableIt); + + for (auto page_id : _uid_to_page_id[uid]) { + _empty_pages.push(page_id); + } + + _uid_to_page_id.erase(uid); + _uid_to_paged_cache_startpos.erase(uid); + _uid_to_paged_cache_length.erase(uid); + _uid_to_offloaded_length.erase(uid); + } + }; + + bool retain(int64_t uid) + { + auto const tableIt = _lru_lookup_table.find(uid); + bool found = (_lru_lookup_table.end() != tableIt); + if (found) { + _lru_list.erase(tableIt->second); + } + _lru_list.push_front(uid); + _lru_lookup_table[uid] = _lru_list.begin(); + return found; + }; + + uint16_t *get_cache_table(void) { + return cache_table; + }; + + uint16_t *get_cache_table_by_layer(int layer_idx) { + return cache_table + layer_idx * layer_stride; + }; + +public: + void onload_kvcache( + std::vector& user_ids, + KVOnloadHandle& onloadhandle) { + const c10::cuda::OptionalCUDAGuard device_guard(this->device); + + // std::cout << "onload_kvcache start" << std::endl << std::flush; + const int batch_size = user_ids.size(); + + std::vector onload_length(batch_size); + std::vector onload_offsets(batch_size + 1); + onload_offsets[0] = 0; + for (int seq_idx = 0; seq_idx < batch_size; seq_idx++) { + auto uid = user_ids[seq_idx]; + if (this->_uid_to_paged_cache_startpos.find(uid) != this->_uid_to_paged_cache_startpos.end()) + onload_length[seq_idx] = this->_uid_to_paged_cache_startpos[uid]; + else if (this->_uid_to_offloaded_length.find(uid) != this->_uid_to_offloaded_length.end()) + onload_length[seq_idx] = this->_uid_to_offloaded_length[uid]; + else + onload_length[seq_idx] = 0; + + onload_offsets[seq_idx + 1] = onload_offsets[seq_idx] + onload_length[seq_idx]; + } + size_t total_onload_length = onload_offsets[batch_size]; + if (total_onload_length == 0) { + for (int layer_idx = 0; layer_idx < this->num_layers; layer_idx++) + onloadhandle.complete_host(layer_idx, this->onload_stream); + // std::cout << "onload_kvcache empty" << std::endl << std::flush; + return; + } + // std::cout << "[Onload Launch] {" << user_ids[0] << "}: " << onload_length[0] << std::endl; + + const size_t chunk_numel_part = host_kv_mgr->chunk_numel / (this->num_onload_memcpy_worker + 1); + + int task_idx = 0; + for (int layer_idx = 0; layer_idx < this->num_layers; layer_idx++) { + uint16_t *gpu_onload_buffer = this->get_cache_table_by_layer(layer_idx) + this->num_primary_cache_pages * this->page_stride; + + for (int seq_idx = 0; seq_idx < batch_size; seq_idx++) { + std::vector chunk_ptrs = host_kv_mgr->get_kvdata_v2(user_ids[0], onload_length[0], layer_idx); + + // std::cout << "[Onload] uid: " << user_ids[seq_idx] << " - " << chunk_ptrs.size() << std::endl; + // std::cout << "\t" << host_kv_mgr->chunk_numel << " - " << host_kv_mgr->chunk_size * this->per_token_kv_stride << std::endl; + for (int chunk_idx = 0; chunk_idx < chunk_ptrs.size(); chunk_idx++) { + // std::cout << "\t" << reinterpret_cast(chunk_ptrs[chunk_idx]) << " - " << (chunk_ptrs[chunk_idx] - chunk_ptrs[0]) << std::endl; + + onload_host_memcpy(onload_pinned_buffers[task_idx%2], chunk_ptrs[chunk_idx], host_kv_mgr->chunk_numel * sizeof(uint16_t), chunk_numel_part * sizeof(uint16_t)); + + cudaCheck(cudaMemcpyAsync(gpu_onload_buffer + onload_offsets[seq_idx] * this->per_token_kv_stride + chunk_idx * host_kv_mgr->chunk_numel, + onload_pinned_buffers[task_idx%2], host_kv_mgr->chunk_numel * sizeof(uint16_t), cudaMemcpyHostToDevice, this->onload_stream)); + cudaCheck(cudaEventRecord(onload_memcpy_event[task_idx%2], this->onload_stream)); + + if (task_idx > 0) { + cudaCheck(cudaEventSynchronize(onload_memcpy_event[(task_idx - 1)%2])); + } + + task_idx++; + } + } + + onloadhandle.complete_host(layer_idx, this->onload_stream); + } + // std::cout << "onload_kvcache end" << std::endl << std::flush; + }; + + void offload_kvcache( + KVOffloadHandle& offload_handle, + at::Tensor offload_user_ids, // host + at::Tensor offload_page_ids, // gpu + at::Tensor gather_kv_gpu_buffer, // gpu + at::Tensor new_offload_startpos, // host + at::Tensor new_offload_lengths) // host + { + const size_t num_offload_uids = offload_user_ids.numel(); + { + std::unique_lock lock(queued_offload_lastpos_mutex_); + if (queued_offload_tokens >= queued_offload_limits) { + return; + } + for (auto seq_idx = 0; seq_idx < num_offload_uids; seq_idx++) { + queued_offload_tokens += ((int*)new_offload_lengths.data_ptr())[seq_idx]; + } + } + + std::vector offload_host_metadata(4*num_offload_uids); + + std::memcpy((void*)offload_host_metadata.data(), + (void*)offload_user_ids.data_ptr(), num_offload_uids * sizeof(int64_t)); + std::memcpy(offload_host_metadata.data() + num_offload_uids * 2, + new_offload_startpos.data_ptr(), num_offload_uids * sizeof(int)); + std::memcpy(offload_host_metadata.data() + num_offload_uids * 3, + new_offload_lengths.data_ptr(), num_offload_uids * sizeof(int)); + + { + std::unique_lock lock(offload_freezed_uids_mtx_); + int64_t *offload_uids = reinterpret_cast(offload_host_metadata.data()); + for (int idx = 0; idx < num_offload_uids; idx++) { + int cur_freezed_times = offload_freezed_uids_[offload_uids[idx]]; + offload_freezed_uids_[offload_uids[idx]] = cur_freezed_times + 1; + // std::cout << "Freezing " << offload_uids[idx] << " from " << cur_freezed_times << " to " << (cur_freezed_times + 1) << std::endl; + } + } + { + std::unique_lock lock(offload_task_mutex_); + offload_task_queue.push(std::make_tuple( + offload_host_metadata, + offload_page_ids, + gather_kv_gpu_buffer, + offload_handle.ready_event + )); + } + + offload_task_cv_.notify_one(); + }; + + bool is_busy_offloading() { + return !offload_task_queue.empty() || this->offload_busy_.load(); + } + +private: + void onload_host_memcpy(void* dst, void* src, size_t bytes, size_t bytes_part) { + NVTX3_FUNC_RANGE(); + { + std::unique_lock lock(onload_memcpy_task_mtx_); + // this->onload_memcpy_cmplt_flag = 0; + for (int i = 1; i < (this->num_onload_memcpy_worker + 1); i++) + onload_memcpy_task_queue.push(std::make_tuple( + (reinterpret_cast(dst) + i * bytes_part), + (reinterpret_cast(src) + i * bytes_part), + bytes_part + )); + } + onload_memcpy_task_cv_.notify_all(); + std::memcpy(dst, src, bytes_part); + + // { + // std::unique_lock lock(onload_memcpy_cmplt_mtx_); + // onload_memcpy_cmplt_cv_.wait(lock, [this] { + // return this->onload_memcpy_cmplt_flag == this->num_onload_memcpy_worker || this->terminate_; + // }); + // } + onload_memcpy_barrier_.arrive_and_wait(); + } + + void onload_host_memcpy_loop() { + while (true) { + void *dst, *src; + size_t bytes; + { + std::unique_lock lock(onload_memcpy_task_mtx_); + onload_memcpy_task_cv_.wait(lock, [this]{ + return !this->onload_memcpy_task_queue.empty() || this->terminate_; + }); + if (terminate_) return; + std::tie(dst, src, bytes) = onload_memcpy_task_queue.front(); + onload_memcpy_task_queue.pop(); + } + std::memcpy(dst, src, bytes); + // { + // std::unique_lock lock(onload_memcpy_cmplt_mtx_); + // onload_memcpy_cmplt_flag += 1; + // if (onload_memcpy_cmplt_flag == this->num_onload_memcpy_worker) + // onload_memcpy_cmplt_cv_.notify_one(); + // } + onload_memcpy_barrier_.arrive_and_wait(); + } + } + + void offload_host_memcpy(void* dst, void* src, size_t bytes, size_t bytes_part) { + NVTX3_FUNC_RANGE(); + { + std::unique_lock lock(offload_memcpy_task_mtx_); + this->offload_memcpy_cmplt_flag = 0; + for (int i = 1; i < (this->num_offload_memcpy_worker + 1); i++) + offload_memcpy_task_queue.push(std::make_tuple( + (reinterpret_cast(dst) + i * bytes_part), + (reinterpret_cast(src) + i * bytes_part), + bytes_part + )); + } + offload_memcpy_task_cv_.notify_all(); + std::memcpy(dst, src, bytes_part); + + { + std::unique_lock lock(offload_memcpy_cmplt_mtx_); + offload_memcpy_cmplt_cv_.wait(lock, [this] { + return this->offload_memcpy_cmplt_flag == this->num_offload_memcpy_worker || this->terminate_; + }); + } + } + + void offload_host_memcpy_loop() { + while (true) { + void *dst, *src; + size_t bytes; + { + std::unique_lock lock(offload_memcpy_task_mtx_); + offload_memcpy_task_cv_.wait(lock, [this]{ + return !this->offload_memcpy_task_queue.empty() || this->terminate_; + }); + if (terminate_) return; + std::tie(dst, src, bytes) = offload_memcpy_task_queue.front(); + offload_memcpy_task_queue.pop(); + } + std::memcpy(dst, src, bytes); + { + std::unique_lock lock(offload_memcpy_cmplt_mtx_); + offload_memcpy_cmplt_flag += 1; + if (offload_memcpy_cmplt_flag == this->num_offload_memcpy_worker) + offload_memcpy_cmplt_cv_.notify_one(); + } + } + } + + void offload_loop() + { + const c10::cuda::OptionalCUDAGuard device_guard(this->device); + + while (true) { + std::vector host_metadata; + at::Tensor offload_page_ids, gather_kv_gpu_buffer; + cudaEvent_t offload_gpu_acq_event; + { + nvtx3::scoped_range r{"offload_prelogue"}; + + std::unique_lock lock(offload_task_mutex_); + offload_task_cv_.wait(lock, [this] { + return !offload_task_queue.empty() || this->terminate_; + }); + if (this->terminate_) { + break; + } + + { + nvtx3::scoped_range r1{"offload_prelogue unpack_input"}; + + std::tie( + host_metadata, offload_page_ids, gather_kv_gpu_buffer, offload_gpu_acq_event + ) = offload_task_queue.front(); + } + { + nvtx3::scoped_range r2{"offload_prelogue pop"}; + offload_task_queue.pop(); + } + + this->offload_busy_.store(true); + } + + int64_t *offload_uids = reinterpret_cast(host_metadata.data()); + uint16_t *gather_kv_gpu_buffer_data_ptr = static_cast(gather_kv_gpu_buffer.data_ptr()); + const int num_offload_uids = host_metadata.size() / 4; + const int num_offload_pages = offload_page_ids.numel(); + size_t gather_layer_stride = num_offload_pages * this->page_stride; + + int64_t dbg_uid; + { + int64_t *offload_uids = reinterpret_cast(host_metadata.data()); + dbg_uid = offload_uids[0]; + } + // std::cout << "[Offload Launch] {" << dbg_uid << "}: waiting event" << std::endl; + cudaStreamWaitEvent(this->offload_stream, offload_gpu_acq_event); + cudaEventDestroy(offload_gpu_acq_event); + // std::cout << "[Offload Launch] {" << dbg_uid << "}: waited event" << std::endl; + + // gather + gather_paged_kv_cache_all_layers( + gather_kv_gpu_buffer_data_ptr, + this->get_cache_table(), + static_cast(offload_page_ids.data_ptr()), + this->num_layers, + gather_layer_stride, + this->layer_stride, + this->num_kv_heads, + this->kv_headdim, + this->num_tokens_per_page, + this->page_stride, + this->k2v_stride, + this->num_kv_heads * this->kv_headdim, + this->kv_headdim, + num_offload_pages, + this->offload_stream); + cudaStreamSynchronize(this->offload_stream); + // release on gpu kvcache + { + std::unique_lock lock(offload_freezed_uids_mtx_); + for (int idx = 0; idx < num_offload_uids; idx++) { + int cur_freezed_times = offload_freezed_uids_[offload_uids[idx]]; + if (cur_freezed_times == 1) { + offload_freezed_uids_.erase(offload_uids[idx]); + } else { + offload_freezed_uids_[offload_uids[idx]] = cur_freezed_times - 1; + } + // std::cout << "Released " << offload_uids[idx] << " from " << cur_freezed_times << " to " << (cur_freezed_times - 1) << std::endl; + } + } + // std::cout << "[Offload Launch] {" << dbg_uid << "}: gathered gpu" << std::endl; + // skipped + + size_t pinned_bytes = this->num_layers * gather_layer_stride * sizeof(uint16_t); + uint16_t *host_kv_ptr = static_cast(aligned_alloc(sysconf(_SC_PAGESIZE), pinned_bytes)); + + size_t chunk_numel_part = host_kv_mgr->chunk_numel / (this->num_offload_memcpy_worker + 1); + + int num_chunks = (num_offload_pages * this->num_tokens_per_page) / this->num_tokens_per_chunk * this->num_layers; + for (int chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) { + cudaMemcpyAsync( + offload_pinned_buffers[chunk_idx%2], + gather_kv_gpu_buffer_data_ptr + chunk_idx * host_kv_mgr->chunk_numel, + host_kv_mgr->chunk_numel * sizeof(uint16_t), + cudaMemcpyDeviceToHost, + this->offload_stream); + cudaEventRecord(offload_memcpy_event[chunk_idx%2], this->offload_stream); + + if (chunk_idx > 0) { + cudaEventSynchronize(offload_memcpy_event[(chunk_idx-1)%2]); + offload_host_memcpy( + host_kv_ptr + (chunk_idx-1) * host_kv_mgr->chunk_numel, + offload_pinned_buffers[(chunk_idx-1)%2], + host_kv_mgr->chunk_numel * sizeof(uint16_t), + chunk_numel_part * sizeof(uint16_t)); + } + } + { + cudaEventSynchronize(offload_memcpy_event[(num_chunks-1)%2]); + offload_host_memcpy( + host_kv_ptr + (num_chunks-1) * host_kv_mgr->chunk_numel, + offload_pinned_buffers[(num_chunks-1)%2], + host_kv_mgr->chunk_numel * sizeof(uint16_t), + chunk_numel_part * sizeof(uint16_t)); + } + // std::cout << "[Offload Launch] {" << dbg_uid << "}: copied data" << std::endl; + + // bookkeepping int host kv storage + { + nvtx3::scoped_range r{"offload_epilogue"}; + { + nvtx3::scoped_range r1{"offload_bookkeeping"}; + + size_t page_offset = 0; + int* offload_startpos = reinterpret_cast(host_metadata.data() + num_offload_uids * 2); + int* offload_lengths = reinterpret_cast(host_metadata.data() + num_offload_uids * 3); + + for (int seq_idx = 0; seq_idx < num_offload_uids; seq_idx++) { + int64_t uid = offload_uids[seq_idx]; + uint16_t *input_ptr = host_kv_ptr + page_offset * this->page_stride; + host_kv_mgr->append_kvdata_v2(uid, offload_startpos[seq_idx], offload_lengths[seq_idx], input_ptr, gather_layer_stride); + this->_uid_to_offloaded_length[uid] = offload_startpos[seq_idx] + offload_lengths[seq_idx]; + page_offset += offload_lengths[seq_idx] / this->num_tokens_per_page; + { + std::unique_lock lock(queued_offload_lastpos_mutex_); + if (offload_startpos[seq_idx] + offload_lengths[seq_idx] == queued_offload_lastpos[uid]) { + queued_offload_lastpos.erase(uid); + } + queued_offload_tokens -= offload_lengths[seq_idx]; + } + } + } + + this->offload_busy_.store(false); + } + // std::cout << "[Offload Launch] {" << dbg_uid << "}: offloading finish" << std::endl; + } + } + +public: + int num_layers; + int num_kv_heads; + int kv_headdim; + int num_tokens_per_page; + int num_primary_cache_pages; + int num_onload_buffer_pages; + int num_reserved_buffer_pages; + int num_tokens_per_chunk; + int max_num_sequences; + int max_sequence_length; + + size_t layer_stride; + size_t k2v_stride; + size_t page_stride; + size_t per_token_kv_stride; + +public: + std::list _lru_list; + std::unordered_map::iterator> _lru_lookup_table; + std::queue _empty_pages; + std::unordered_map> _uid_to_page_id; + std::unordered_map _uid_to_paged_cache_startpos; + std::unordered_map _uid_to_paged_cache_length; + std::unordered_map _uid_to_offloaded_length; + + int num_onload_memcpy_worker; + std::vector onload_memcpy_worker; + std::vector onload_pinned_buffers; + std::vector onload_memcpy_event; + + std::queue> onload_memcpy_task_queue; + std::mutex onload_memcpy_task_mtx_; + std::condition_variable onload_memcpy_task_cv_; + + // int onload_memcpy_cmplt_flag; + // std::mutex onload_memcpy_cmplt_mtx_; + // std::condition_variable onload_memcpy_cmplt_cv_; + + std::barrier<> onload_memcpy_barrier_; + + std::thread offload_worker; + + std::queue, at::Tensor, at::Tensor, cudaEvent_t>> offload_task_queue; + std::mutex offload_task_mutex_; + std::condition_variable offload_task_cv_; + std::unordered_map queued_offload_lastpos; + size_t queued_offload_tokens; + std::mutex queued_offload_lastpos_mutex_; + size_t queued_offload_limits; + + int num_offload_memcpy_worker; + std::vector offload_memcpy_worker; + std::vector offload_pinned_buffers; + std::vector offload_memcpy_event; + + std::queue> offload_memcpy_task_queue; + std::mutex offload_memcpy_task_mtx_; + std::condition_variable offload_memcpy_task_cv_; + + int offload_memcpy_cmplt_flag; + std::mutex offload_memcpy_cmplt_mtx_; + std::condition_variable offload_memcpy_cmplt_cv_; + + std::unordered_map offload_freezed_uids_; + std::mutex offload_freezed_uids_mtx_; + + bool terminate_; + std::atomic offload_busy_; + + cudaStream_t worker_stream; + cudaStream_t onload_stream; + cudaStream_t offload_stream; + + HostKVStorageImpl *host_kv_mgr; + +public: + uint16_t *cache_table; + c10::Device device; +}; + + +void prepare_kvcache( + GPUKVCacheMangerImpl& gpu_mgr, + HostKVStorageImpl& host_mgr, + std::vector& user_ids, + std::vector& total_hist_lens, // all histo w/o candi + at::Tensor page_ids_gpu_buffer, + at::Tensor offload_page_ids_gpu_buffer, + at::Tensor offload_uids_buffer, + at::Tensor metadata_host_buffer, + at::Tensor metadata_gpu_buffer) { + + const c10::cuda::OptionalCUDAGuard device_guard(gpu_mgr.device); + + // std::cout << "prepare_kvcache start" << std::endl << std::flush; + + int batch_size = user_ids.size(); + + std::vector old_history_lengths = gpu_mgr.get_total_cache_length(user_ids); + + std::vector page_indices; + std::vector offload_page_ids; + int64_t *offload_user_ids = static_cast(offload_uids_buffer.data_ptr()); + + int *host_bufptr = static_cast(metadata_host_buffer.data_ptr()); + + int *page_indptr = host_bufptr + 0; + int *last_page_len = host_bufptr + batch_size + 1; + int *total_history_lengths = host_bufptr + batch_size * 2 + 1; + int *total_history_offsets = host_bufptr + batch_size * 3 + 1; + int *new_history_nnz_cuda = host_bufptr + batch_size * 4 + 2; + int *new_history_offsets = host_bufptr + batch_size * 4 + 3; + // === ^ GPU === v Host === + int *new_offload_startpos = host_bufptr + batch_size * 5 + 4; + int *new_offload_lengths = host_bufptr + batch_size * 6 + 4; + + int *num_page_ids = host_bufptr + batch_size * 7 + 4; + int *num_offload_page_ids = host_bufptr + batch_size * 7 + 5; + int *num_offload_user_ids = host_bufptr + batch_size * 7 + 6; + + size_t onload_page_offset = gpu_mgr.num_primary_cache_pages; + size_t num_offload_uids = 0; + size_t num_offload_pages = 0; + + page_indptr[0] = 0; + total_history_offsets[0] = 0; + new_history_offsets[0] = 0; + + const std::unordered_set freezed_uids(user_ids.begin(), user_ids.end()); + for (int seq_idx = 0; seq_idx < batch_size; seq_idx++) { + int64_t uid = user_ids[seq_idx]; + int total_history_length = total_hist_lens[seq_idx]; + + std::vector& page_ids = gpu_mgr.alloc(uid, total_history_length, freezed_uids); + int gpu_cache_startpos = gpu_mgr._uid_to_paged_cache_startpos[uid]; + int gpu_cache_length = gpu_mgr._uid_to_paged_cache_length[uid]; + + int num_onload_pages = gpu_cache_startpos / gpu_mgr.num_tokens_per_page; + for (int i = 0; i < num_onload_pages; i++) page_indices.push_back(onload_page_offset+i); + page_indices.insert(page_indices.end(), page_ids.begin(), page_ids.end()); + page_indptr[seq_idx + 1] = page_indptr[seq_idx] + page_ids.size() + num_onload_pages; + last_page_len[seq_idx] = gpu_cache_length % gpu_mgr.num_tokens_per_page; + onload_page_offset += num_onload_pages; + + total_history_lengths[seq_idx] = total_history_length; + total_history_offsets[seq_idx + 1] = total_history_offsets[seq_idx] + total_history_length; + new_history_offsets[seq_idx + 1] = new_history_offsets[seq_idx] + total_history_length - old_history_lengths[seq_idx]; + + auto offloaded_length = 0; + auto chunked_length = total_history_length - total_history_length % gpu_mgr.num_tokens_per_chunk; + if (gpu_mgr._uid_to_offloaded_length.find(uid) != gpu_mgr._uid_to_offloaded_length.end()) + offloaded_length = gpu_mgr._uid_to_offloaded_length[uid]; + { + std::unique_lock lock(gpu_mgr.queued_offload_lastpos_mutex_); + if (gpu_mgr.queued_offload_lastpos.find(uid) != gpu_mgr.queued_offload_lastpos.end()) { + offloaded_length = gpu_mgr.queued_offload_lastpos[uid]; + } + if (total_history_length - offloaded_length >= gpu_mgr.num_tokens_per_chunk) { + gpu_mgr.queued_offload_lastpos[uid] = chunked_length; + } + } + if (total_history_length - offloaded_length >= gpu_mgr.num_tokens_per_chunk) { + // auto chunked_length = total_history_length - total_history_length % gpu_mgr.num_tokens_per_chunk; + auto new_offload_page_start = (offloaded_length - gpu_cache_startpos) / gpu_mgr.num_tokens_per_page; + + offload_user_ids[num_offload_uids] = uid; + new_offload_startpos[num_offload_uids] = offloaded_length; + new_offload_lengths[num_offload_uids] = chunked_length - offloaded_length; + auto num_pages = new_offload_lengths[num_offload_uids] / gpu_mgr.num_tokens_per_page; + offload_page_ids.insert(offload_page_ids.end(), page_ids.begin() + new_offload_page_start, page_ids.begin() + new_offload_page_start + num_pages); + + num_offload_uids += 1; + num_offload_pages += num_pages; + } + } + auto new_tokens = new_history_offsets[batch_size]; + *new_history_nnz_cuda = new_tokens; + *num_page_ids = page_indptr[batch_size]; + *num_offload_page_ids = num_offload_pages; + *num_offload_user_ids = num_offload_uids; + // std::cout << "num_offload_pages: " << num_offload_pages << std::endl << std::flush; + + cudaMemcpyAsync(page_ids_gpu_buffer.data_ptr(), page_indices.data(), page_indptr[batch_size] * sizeof(int32_t), cudaMemcpyHostToDevice, gpu_mgr.worker_stream); + cudaMemcpyAsync(offload_page_ids_gpu_buffer.data_ptr(), offload_page_ids.data(), num_offload_pages * sizeof(int32_t), cudaMemcpyHostToDevice, gpu_mgr.worker_stream); + + size_t host_buffer_d2h_size = (batch_size * 5 + 4) * sizeof(int32_t); + cudaMemcpyAsync(metadata_gpu_buffer.data_ptr(), metadata_host_buffer.data_ptr(), host_buffer_d2h_size, cudaMemcpyHostToDevice, gpu_mgr.worker_stream); + + int *gpu_bufptr = static_cast(metadata_gpu_buffer.data_ptr()); + + int *total_history_lengths_dev = gpu_bufptr + batch_size * 2 + 1; + int *new_history_offsets_dev = gpu_bufptr + batch_size * 4 + 3; + int *batch_indices_dev = gpu_bufptr + batch_size * 5 + 4; + int *position_dev = gpu_bufptr + batch_size * 5 + 4 + new_tokens; + + GetPagedBatchIndicesPositions( + batch_size, + new_history_offsets_dev, + total_history_lengths_dev, + batch_indices_dev, + position_dev, + gpu_mgr.worker_stream + ); + + cudaStreamSynchronize(gpu_mgr.worker_stream); + // std::cout << "prepare_kvcache stop" << std::endl << std::flush; + } + +} // namespace kvcache + PYBIND11_MODULE(paged_kvcache_ops, m) { - m.def("append_kvcache", &append_paged_kv_cache, "append paged kv cache on GPU"); - m.def("gather_kvcache", &gather_paged_kv_cache, "gather paged kv cache on GPU"); + m.def("append_kvcache", &append_paged_kv_cache, "append paged kv cache on GPU", py::call_guard()); + m.def("gather_kvcache", &gather_paged_kv_cache, "gather paged kv cache on GPU", py::call_guard()); + + py::class_(m, "HostKVStorageImpl") + .def(py::init(), + py::arg("num_layers"), + py::arg("num_kv_heads"), + py::arg("kv_headdim"), + py::arg("num_tokens_per_page"), + py::arg("num_tokens_per_chunk")) + ; + + py::class_(m, "GPUKVCacheMangerImpl") + .def(py::init(), + py::arg("num_layers"), + py::arg("num_kv_heads"), + py::arg("kv_headdim"), + py::arg("num_tokens_per_page"), + py::arg("num_primary_cache_pages"), + py::arg("num_onload_buffer_pages"), + py::arg("num_reserved_buffer_pages"), + py::arg("num_tokens_per_chunk"), + py::arg("max_num_sequences"), + py::arg("max_sequence_length"), + py::arg("cache_table"), + py::arg("host_kv_mgr")) + .def("get_total_cache_length", &kvcache::GPUKVCacheMangerImpl::get_total_cache_length) + .def("evict_all", &kvcache::GPUKVCacheMangerImpl::evict_all) + .def("onload_kvcache", &kvcache::GPUKVCacheMangerImpl::onload_kvcache, py::call_guard()) + .def("offload_kvcache", &kvcache::GPUKVCacheMangerImpl::offload_kvcache, py::call_guard()) + .def("is_busy_offloading", &kvcache::GPUKVCacheMangerImpl::is_busy_offloading) + ; + + py::class_(m, "KVOnloadHandle") + .def(py::init(), py::arg("num_layers")) + // .def("wait", &kvcache::KVOnloadHandle::wait) + .def("wait_host", &kvcache::KVOnloadHandle::wait_host) + .def("reset", &kvcache::KVOnloadHandle::reset) + ; + + py::class_(m, "KVOffloadHandle") + .def(py::init()) + .def("record_ready", &kvcache::KVOffloadHandle::record_ready) + ; + + m.def("prepare_kvcache", &kvcache::prepare_kvcache, "prepare_kvcache", py::call_guard()); } \ No newline at end of file diff --git a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu b/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu index 1324de9a7..4a08ece4c 100755 --- a/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu +++ b/examples/hstu/ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu @@ -342,4 +342,155 @@ cudaError_t GatherPagedKVCache( uint32_t stride_h, nv_half* kv_cache, uint32_t nnz, - cudaStream_t stream); \ No newline at end of file + cudaStream_t stream); + +template +__global__ void GatherPagedKVCacheAllLayersKernel(DType* gather_kv, + IdType* page_ids, + uint32_t num_layers, + uint32_t stride_layer_gather, + uint32_t stride_layer, + uint32_t page_size, + uint32_t stride_page, + uint32_t stride_k2v, + uint32_t stride_n, + uint32_t stride_h, + uint32_t nnz, + DType* __restrict__ kv_cache, + uint32_t m, uint32_t s, uint32_t a) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t head_idx = ty; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + + for (uint32_t layer_idx = 0; layer_idx < num_layers; layer_idx++) { + DType* gather_k = gather_kv + layer_idx * stride_layer_gather; + DType* gather_v = gather_kv + layer_idx * stride_layer_gather + stride_k2v; + DType* __restrict__ k_cache = kv_cache + layer_idx * stride_layer; + DType* __restrict__ v_cache = kv_cache + layer_idx * stride_layer + stride_k2v; + +#pragma unroll 4 + for (uint32_t i = cta_id; i < nnz; i += num_ctas) { + uint32_t page_id_idx, entry_idx; + divmod(i, page_size, m, s, a, + page_id_idx, entry_idx); + size_t inner_page_offset = head_idx * stride_h + entry_idx * stride_n + tx * vec_size; + size_t src_offset = __ldg(page_ids + page_id_idx) * stride_page + inner_page_offset; + size_t dst_offset = page_id_idx * stride_page + inner_page_offset; + vec_t::memcpy( + gather_k + dst_offset, k_cache + src_offset); + vec_t::memcpy( + gather_v + dst_offset, v_cache + src_offset); + } + } +} + +template +cudaError_t GatherPagedKVCacheAllLayers(DType* gather_kv, + IdType* page_ids, + uint32_t num_layers, + uint32_t stride_gather, + uint32_t stride_layer, + uint32_t num_heads, + uint32_t head_dim, + uint32_t page_size, + uint32_t stride_page, + uint32_t stride_k2v, + uint32_t stride_n, + uint32_t stride_h, + DType* kv_cache, + uint32_t nnz, + cudaStream_t stream) { + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + uint32_t num_threads = bdx * bdy; + uint32_t smem_size = 0; + auto kernel = GatherPagedKVCacheAllLayersKernel; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size); + num_blocks_per_sm = min(num_blocks_per_sm, ((int(nnz) + num_sms - 1) / num_sms)); + dim3 nblks(num_blocks_per_sm * num_sms); + dim3 nthrs(bdx, bdy); + + uint32_t m, s, a; + get_uint_fastdiv_msa(page_size, m, s, a); + + void* args[] = {(void*)&gather_kv, (void*)&page_ids, (void*)&num_layers, + (void*)&stride_gather, (void*)&stride_layer, (void*)&page_size, + (void*)&stride_page, (void*)&stride_k2v, (void*)&stride_n, + (void*)&stride_h, (void*)&nnz, (void*)&kv_cache, + (void*)&m, (void*)&s, (void*)&a}; + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream); + }); + return cudaSuccess; +} + +template +cudaError_t GatherPagedKVCacheAllLayers( + nv_bfloat16* gather_kv, + int32_t* page_ids, + uint32_t num_layers, + uint32_t stride_gather, + uint32_t stride_layer, + uint32_t num_heads, + uint32_t head_dim, + uint32_t page_size, + uint32_t stride_page, + uint32_t stride_k2v, + uint32_t stride_n, + uint32_t stride_h, + nv_bfloat16* kv_cache, + uint32_t nnz, + cudaStream_t stream); + + +__global__ void GetPagedBatchIndicesPositionsKernel( + int32_t batch_size, + int32_t* append_indptr, + int32_t* seq_lens_ptr, + int32_t* batch_indices_ptr, + int32_t* positions_ptr) { + + int32_t tx = threadIdx.x; + int32_t seq_idx = blockIdx.x; + int32_t seq_start = append_indptr[seq_idx]; + int32_t total_seq_len = seq_lens_ptr[seq_idx]; + int32_t append_per_seq = append_indptr[seq_idx + 1] - seq_start; + + int32_t* batch_indices_ptr_per_seq = batch_indices_ptr + seq_start; + int32_t* positions_ptr_per_seq = positions_ptr + seq_start; + int32_t pos_start = total_seq_len - append_per_seq; + +#pragma unroll 4 + for (int32_t i = tx; i < append_per_seq; i += blockDim.x) { + batch_indices_ptr_per_seq[i] = seq_idx; + positions_ptr_per_seq[i] = pos_start + i; + } +} + +cudaError_t GetPagedBatchIndicesPositions( + int32_t batch_size, + int32_t* append_indptr, + int32_t* seq_lens_ptr, + int32_t* batch_indices_ptr, + int32_t* positions_ptr, + cudaStream_t stream +) +{ + dim3 nblks(batch_size); + dim3 nthrs(128, 1); + + void* args[] = {(void*)&batch_size, (void*)&append_indptr, (void*)&seq_lens_ptr, + (void*)&batch_indices_ptr, (void*)&positions_ptr}; + auto kernel = GetPagedBatchIndicesPositionsKernel; + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream); + return cudaSuccess; +} \ No newline at end of file diff --git a/examples/hstu/setup.py b/examples/hstu/setup.py index a5db640e6..1467593e8 100644 --- a/examples/hstu/setup.py +++ b/examples/hstu/setup.py @@ -46,7 +46,7 @@ def nvcc_threads_args(): "ops/cuda_ops/csrc/paged_kvcache_ops_kernel.cu", ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], + "cxx": ["-O3", "-std=c++20", "-fvisibility=hidden"], "nvcc": nvcc_threads_args() + nvcc_flags, }, ), diff --git a/examples/hstu/test_async_kvcache.py b/examples/hstu/test_async_kvcache.py new file mode 100644 index 000000000..fa778e8c9 --- /dev/null +++ b/examples/hstu/test_async_kvcache.py @@ -0,0 +1,141 @@ +import torch + +import math +from modules.async_kvcache_manager import AsyncHSTUKVCacheManager +import random +import time + +def forward( + async_kvcache, + batch_size: int, + user_ids: torch.Tensor, + total_history_lengths: torch.Tensor, + ): + with torch.inference_mode(): + # print("[DEBUG] total_history_lengths", total_history_lengths) + user_ids_list = user_ids.tolist() + + prepare_kvcache_result = async_kvcache.prepare_kvcache_async( + batch_size, + user_ids_list, + total_history_lengths.tolist(), + async_kvcache.static_page_ids_gpu_buffer, + async_kvcache.static_offload_page_ids_gpu_buffer, + async_kvcache.static_onload_handle, + ) + # print("[DEBUG] return from trigger\n", flush=True) + + ( + old_cached_lengths, + num_history_tokens, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + kvcache_metadata_fut, + onload_fut, + ) = prepare_kvcache_result + # print("[DEBUG] old_cached_lengths", old_cached_lengths) + old_cached_lengths = torch.tensor(old_cached_lengths, dtype=torch.int32) + + + kvcache_metadata = async_kvcache.prepare_kvcache_wait( + onload_fut, + kvcache_metadata_fut, + batch_size, + num_history_tokens, + async_kvcache.static_page_ids_gpu_buffer, + async_kvcache.static_offload_page_ids_gpu_buffer, + offload_uids_buffer, + metadata_host_buffer, + metadata_gpu_buffer, + async_kvcache.static_onload_handle, + ) + + for layer_idx in range(async_kvcache.num_layers): + kvcache_metadata.kv_onload_handle.wait_host(layer_idx) + + async_kvcache.offload_kvcache(kvcache_metadata) + + # async_kvcache.onload_kvcache_finalize(user_ids_list) + + return None + +if __name__ == "__main__": + with torch.inference_mode(): + + max_batch_size = 4 + max_seq_len = 20000 + kwargs = { + "num_layers": 3, + "num_kv_heads": 4, + "kv_headdim": 128, + "num_tokens_per_page": 32, + "num_primary_cache_pages": 10240, + "num_onload_buffer_pages": math.ceil(max_batch_size * max_seq_len / 32), + "num_reserved_buffer_pages": 0, + "num_tokens_per_chunk": 2048, + "max_num_sequences": -1, + "max_sequence_length": max_seq_len, + "max_batch_size": max_batch_size, + } + kvc_mgr = AsyncHSTUKVCacheManager(**kwargs) + + max_num_users = 50 + user_ids_pool = list(range(max_num_users)) + + init_user_ids = list(user_ids_pool) + random.shuffle(init_user_ids) + + torch.cuda.profiler.start() + + running_batch_size = 1 + for ind in range(0, len(init_user_ids), running_batch_size): + batch_size = running_batch_size + user_ids = init_user_ids[ind:ind+batch_size] + total_history_lengths = [ 6000 for _ in user_ids ] + + user_ids = torch.tensor(user_ids, dtype=torch.int64) + total_history_lengths = torch.tensor(total_history_lengths, dtype=torch.int32) + + print("forward", init_user_ids[ind:ind+batch_size]) + + forward(kvc_mgr, batch_size, user_ids, total_history_lengths) + + while (kvc_mgr.gpu_kvcache_mgr.is_busy_offloading()): + pass + + # torch.cuda.profiler.stop() + + # for uid in range(max_num_users): + # print(uid, kvc_mgr.gpu_kvcache_mgr.get_total_cache_length([uid])) + + kvc_mgr.gpu_kvcache_mgr.evict_all() + + # for uid in range(max_num_users): + # print(uid, kvc_mgr.gpu_kvcache_mgr.get_total_cache_length([uid])) + + running_batch_size = 1 + appending_user_ids = list(init_user_ids) + random.shuffle(appending_user_ids) + appending_user_ids = appending_user_ids[:10] + + # torch.cuda.profiler.start() + + for ind in range(0, len(appending_user_ids), running_batch_size): + batch_size = running_batch_size + user_ids = appending_user_ids[ind:ind+batch_size] + total_history_lengths = [ 6200 for _ in user_ids ] + + user_ids = torch.tensor(user_ids, dtype=torch.int64) + total_history_lengths = torch.tensor(total_history_lengths, dtype=torch.int32) + + forward(kvc_mgr, batch_size, user_ids, total_history_lengths) + + while (kvc_mgr.gpu_kvcache_mgr.is_busy_offloading()): + pass + + torch.cuda.profiler.stop() + + print("Done") + +# nsys profile -f true -o ./bench_async_init10x6000_append10x200_offload -c cudaProfilerApi --cuda-graph-trace=node python3 ./test_async_kvcache.py \ No newline at end of file