Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions examples/hstu/configs/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ class KVCacheMetadata:
"""

# paged cache metadata
kv_indices: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
total_history_lengths: torch.Tensor = None
total_history_offsets: torch.Tensor = None
kv_indices: torch.Tensor = None # num_pages
kv_indptr: torch.Tensor = None # num_seq + 1
kv_last_page_len: torch.Tensor = None # num_seq
total_history_lengths: torch.Tensor = None # num_seq
total_history_offsets: torch.Tensor = None # num_seq + 1

# appending metadata
batch_indices: torch.Tensor = None
position: torch.Tensor = None
batch_indices: torch.Tensor = None # num_tokens
position: torch.Tensor = None # num_tokens
new_history_nnz: int = 0
new_history_nnz_cuda: torch.Tensor = None
new_history_nnz_cuda: torch.Tensor = None # 1

# onload utility
onload_history_kv_buffer: Optional[List[torch.Tensor]] = None
Expand All @@ -82,6 +82,16 @@ class KVCacheMetadata:
# paged cache table pointers
kv_cache_table: Optional[List[torch.Tensor]] = None

# async attributes
kv_onload_handle: Optional[object] = None
kv_offload_handle: Optional[object] = None

offload_user_ids: Optional[torch.Tensor] = None
offload_page_ids: Optional[torch.Tensor] = None
new_offload_startpos: Optional[torch.Tensor] = None
new_offload_lengths: Optional[torch.Tensor] = None

max_seqlen: Optional[int] = 0

@dataclass
class KVCacheConfig:
Expand Down
25 changes: 19 additions & 6 deletions examples/hstu/dataset/inference_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,25 @@ def __iter__(self) -> Iterator[Batch]:
)
dates.append(self._batch_logs_frame.iloc[sample_id][self._date_name])
seq_endptrs.append(seq_endptr)
if len(user_ids) == 0:
continue

last_date = dates[0]
final_user_ids: List[int] = []
final_dates: List[int] = []
final_seq_endptrs: List[int] = []
for (uid, date, endp) in zip(user_ids, dates, seq_endptrs):
if date != last_date:
continue
if uid not in final_user_ids:
final_user_ids.append(uid)
final_dates.append(date)
final_seq_endptrs.append(endp)
else:
idx = final_user_ids.index(uid)
final_seq_endptrs[idx] = max(final_seq_endptrs[idx], endp)
yield (
torch.tensor(user_ids),
torch.tensor(dates),
torch.tensor(seq_endptrs),
torch.tensor(final_user_ids),
torch.tensor(final_dates),
torch.tensor(final_seq_endptrs),
)

def get_input_batch(
Expand Down Expand Up @@ -306,7 +319,7 @@ def get_input_batch(
labels = torch.tensor(labels, dtype=torch.int64, device=self._device)
batch_kwargs = dict(
features=features,
batch_size=self._batch_size,
batch_size=len(user_ids), # self._batch_size,
feature_to_max_seqlen=feature_to_max_seqlen,
contextual_feature_names=self._contextual_feature_names,
item_feature_name=self._item_feature_name,
Expand Down
24 changes: 5 additions & 19 deletions examples/hstu/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,14 @@ ERROR: The input sequence has overlapping tokens from 5 to 9 (both inclusive).

## How to Setup

1. Build TensorRT-LLM (with HSTU KV cache extension):

The HSTU inference utilize customized KV cache manager from TensorRT-LLM.
The current version is based on the HSTU specialized implementation based on TensorRT-LLM v0.19.0.

```bash
~$ cd ${WORKING_DIR}
~$ git clone -b hstu-kvcache-recsys-examples https://github.com/geoffreyQiu/TensorRT-LLM.git tensorrt-llm-kvcache && cd tensorrt-llm-kvcache
~$ git submodule update --init --recursive
~$ make -C docker release_build CUDA_ARCHS="80-real;86-real"
# This will build a docker image with TensorRT-LLM installed.
```

2. Install the dependencies for Recsys-Examples.
1. Install the dependencies for Recsys-Examples.

Turn on option `INFERENCEBUILD=1` to skip Megatron installation, which is not required for inference.

```bash
~$ cd ${WORKING_DIR}
~$ git clone --recursive -b ${TEST_BRANCH} ${TEST_REPO} recsys-examples && cd recsys-examples
~$ TRTLLM_KVCACHE_IMAGE="tensorrt_llm/release:latest" docker build \
--build-arg BASE_IMAGE=${TRTLLM_KVCACHE_IMAGE} \
~$ docker build \
--build-arg INFERENCEBUILD=1 \
-t recsys-examples:inference \
-f docker/Dockerfile .
Expand All @@ -93,7 +79,7 @@ Turn on option `INFERENCEBUILD=1` to skip Megatron installation, which is not re
~$ python3 ./preprocessor.py --dataset_name "kuairand-1k" --inference
~$
~$ # Run the inference example
~$ python3 ./inference/inference_gr_ranking.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval
~$ python3 ./inference/inference_gr_ranking_async.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval
```

## Consistency Check for Inference
Expand Down Expand Up @@ -131,7 +117,7 @@ TrainerArgs.ckpt_save_interval = 550

2. Evaluation metrics from inference
```
/workspace/recsys-examples$ PYTHONPATH=${PYTHONPATH}:$(realpath ../) python3 ./inference/inference_gr_ranking.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval
/workspace/recsys-examples$ PYTHONPATH=${PYTHONPATH}:$(realpath ../) python3 ./inference/inference_gr_ranking_async.py --gin_config_file ./inference/configs/kuairand_1k_inference_ranking.gin --checkpoint_dir ${PATH_TO_CHECKPOINT} --mode eval
... [inference output] ...
[eval]:
Metrics.task0.AUC:0.556894
Expand All @@ -142,4 +128,4 @@ TrainerArgs.ckpt_save_interval = 550
Metrics.task5.AUC:0.580227
Metrics.task6.AUC:0.620498
Metrics.task7.AUC:0.556064
... [inference output] ...
... [inference output] ...
Loading