diff --git a/experiments/README.md b/experiments/README.md index 415c1a1..72e152a 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -51,7 +51,11 @@ wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/ 2. Run a single context window size test using one method: ```bash -python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1_000_000 +# If the context window is greater than 700K, you need to enable kv_cache_cpu. +python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 1_000_000 --kv_cache_cpu +python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 1_000_000 --kv_cache_cpu + +python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 500_000 ``` 3. Run all latency experiments using different methods: diff --git a/experiments/infinite_bench/run_infinitebench.py b/experiments/infinite_bench/run_infinitebench.py index 4ca1a57..fd69551 100644 --- a/experiments/infinite_bench/run_infinitebench.py +++ b/experiments/infinite_bench/run_infinitebench.py @@ -247,9 +247,11 @@ def load_model( if os.path.exists(output_path) and not args.rewrite: print(f"Output file {output_path} exists. Loading from file.") compute_scores(output_path, data_name, real_model_name, max_seq_length) + with open(output_path) as f: + preds = [json.loads(ii) for ii in f.readlines()] for i, eg in tqdm(enumerate(examples)): - if i < args.start_example_id: + if i < args.start_example_id or i < len(preds): continue input_text = create_prompt(eg, data_name, real_model_name, args.data_dir) ground_truth = get_answer(eg, data_name) diff --git a/minference/ops/streaming_kernel.py b/minference/ops/streaming_kernel.py index 8297b16..1eff0d7 100644 --- a/minference/ops/streaming_kernel.py +++ b/minference/ops/streaming_kernel.py @@ -429,55 +429,57 @@ def _forward( global _BLOCK_M try: - _attn_fwd[grid]( - q, k, v, sm_scale, m, o, l, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), k.stride(1), k.stride(2), k.stride(3), # - v.stride(0), v.stride(1), v.stride(2), v.stride(3), # - o.stride(0), o.stride(1), o.stride(2), o.stride(3), # - q.shape[0], q.shape[1], k.shape[1], # - q.shape[2], # - q_round_len, - k.shape[2], - sliding_window_offset, - sliding_window_size, - BLOCK_DMODEL=Lk, # - END=end, - INIT=init, - BLOCK_M=_BLOCK_M, - BLOCK_N=_BLOCK_N, - SLIDING_WINDOW=(sliding_window is not None), - COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, - num_warps=4, - num_stages=4 - ) + with torch.cuda.device(q.device): + _attn_fwd[grid]( + q, k, v, sm_scale, m, o, l, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], k.shape[1], # + q.shape[2], # + q_round_len, + k.shape[2], + sliding_window_offset, + sliding_window_size, + BLOCK_DMODEL=Lk, # + END=end, + INIT=init, + BLOCK_M=_BLOCK_M, + BLOCK_N=_BLOCK_N, + SLIDING_WINDOW=(sliding_window is not None), + COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, + num_warps=4, + num_stages=4 + ) except triton.OutOfResources as E: _BLOCK_N = _BLOCK_N // 2 _BLOCK_M = _BLOCK_M // 2 from warnings import warn warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.") - _attn_fwd[grid]( - q, k, v, sm_scale, m, o, l, # - q.stride(0), q.stride(1), q.stride(2), q.stride(3), # - k.stride(0), k.stride(1), k.stride(2), k.stride(3), # - v.stride(0), v.stride(1), v.stride(2), v.stride(3), # - o.stride(0), o.stride(1), o.stride(2), o.stride(3), # - q.shape[0], q.shape[1], k.shape[1], # - q.shape[2], # - q_round_len, - k.shape[2], - sliding_window_offset, - sliding_window_size, - BLOCK_DMODEL=Lk, # - END=end, - INIT=init, - BLOCK_M=_BLOCK_M, - BLOCK_N=_BLOCK_N, - SLIDING_WINDOW=(sliding_window is not None), - COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, - num_warps=4, - num_stages=4 - ) + with torch.cuda.device(q.device): + _attn_fwd[grid]( + q, k, v, sm_scale, m, o, l, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], k.shape[1], # + q.shape[2], # + q_round_len, + k.shape[2], + sliding_window_offset, + sliding_window_size, + BLOCK_DMODEL=Lk, # + END=end, + INIT=init, + BLOCK_M=_BLOCK_M, + BLOCK_N=_BLOCK_N, + SLIDING_WINDOW=(sliding_window is not None), + COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, + num_warps=4, + num_stages=4 + ) if end: