Skip to content

Commit

Permalink
add padding
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Dec 3, 2024
1 parent 443db19 commit 054717e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 4 additions & 1 deletion torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,7 @@ TTFT(Time to First Token) Benchmarks
20241202185310, tok/s=113.04, tok/s_decode=167.77, ttft=0.5761, mem/s= 850.12 GB/s, peak_mem=28.81 GB, model_size= 7.52 GB quant: int8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185731, tok/s=113.14, tok/s_decode=157.50, ttft=0.4971, mem/s= 849.12 GB/s, peak_mem=20.65 GB, model_size= 7.51 GB quant: float8dq, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185845, tok/s=109.58, tok/s_decode=160.29, ttft=0.5766, mem/s= 822.78 GB/s, peak_mem=20.62 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190021, tok/s= 38.13, tok/s_decode=216.50, ttft=4.3203, mem/s= 160.99 GB/s, peak_mem=16.35 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190021, tok/s= 38.13, tok/s_decode=216.50, ttft=4.3203, mem/s= 160.99 GB/s, peak_mem=16.35 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190201, tok/s=112.36, tok/s_decode=179.92, ttft=0.6674, mem/s= 436.68 GB/s, peak_mem=16.00 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190351, tok/s=132.14, tok/s_decode=134.81, ttft=0.0291, mem/s=1983.35 GB/s, peak_mem=16.46 GB, model_size=15.01 GB quant: None, sparse: None, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190509, tok/s=138.52, tok/s_decode=162.58, ttft=0.2128, mem/s=1395.88 GB/s, peak_mem=22.90 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
8 changes: 6 additions & 2 deletions torchao/dtypes/uintx/semi_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
w_vals_int8 = weight_tensor.tensor_impl.int_data
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
# must pad
row, col = tmp.shape
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8,
tmp.t(),
tmp_padded.t(),
alpha=w_scales.to(torch.float32),
out_dtype=torch.bfloat16,
).t()
).t()[:row, :]
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
Expand Down

0 comments on commit 054717e

Please sign in to comment.