Skip to content

Commit

Permalink
Fix(MInference): fix e2e benchmark guideline & fix A-shape multi gpu (#…
Browse files Browse the repository at this point in the history
…66)

Co-authored-by: Yucheng Li <[email protected]>
Co-authored-by: Chengruidong Zhang <[email protected]>
Co-authored-by: Yuqing Yang <[email protected]>
  • Loading branch information
4 people authored Aug 8, 2024
1 parent 7a11a33 commit f0cae77
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 46 deletions.
6 changes: 5 additions & 1 deletion experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/
2. Run a single context window size test using one method:

```bash
python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1_000_000
# If the context window is greater than 700K, you need to enable kv_cache_cpu.
python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 1_000_000 --kv_cache_cpu
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 1_000_000 --kv_cache_cpu

python experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 500_000
```

3. Run all latency experiments using different methods:
Expand Down
4 changes: 3 additions & 1 deletion experiments/infinite_bench/run_infinitebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ def load_model(
if os.path.exists(output_path) and not args.rewrite:
print(f"Output file {output_path} exists. Loading from file.")
compute_scores(output_path, data_name, real_model_name, max_seq_length)
with open(output_path) as f:
preds = [json.loads(ii) for ii in f.readlines()]

for i, eg in tqdm(enumerate(examples)):
if i < args.start_example_id:
if i < args.start_example_id or i < len(preds):
continue
input_text = create_prompt(eg, data_name, real_model_name, args.data_dir)
ground_truth = get_answer(eg, data_name)
Expand Down
90 changes: 46 additions & 44 deletions minference/ops/streaming_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,55 +429,57 @@ def _forward(
global _BLOCK_M

try:
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)
with torch.cuda.device(q.device):
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)
except triton.OutOfResources as E:
_BLOCK_N = _BLOCK_N // 2
_BLOCK_M = _BLOCK_M // 2
from warnings import warn
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.")
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)
with torch.cuda.device(q.device):
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)


if end:
Expand Down

0 comments on commit f0cae77

Please sign in to comment.