From 054717eb70217b071f23e0a362d3fae984e618de Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 2 Dec 2024 19:06:11 -0800 Subject: [PATCH] add padding --- torchao/_models/llama/benchmark_results.txt | 5 ++++- torchao/dtypes/uintx/semi_sparse_layout.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index a876d65d0e..768b43fcef 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -136,4 +136,7 @@ TTFT(Time to First Token) Benchmarks 20241202185310, tok/s=113.04, tok/s_decode=167.77, ttft=0.5761, mem/s= 850.12 GB/s, peak_mem=28.81 GB, model_size= 7.52 GB quant: int8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 20241202185731, tok/s=113.14, tok/s_decode=157.50, ttft=0.4971, mem/s= 849.12 GB/s, peak_mem=20.65 GB, model_size= 7.51 GB quant: float8dq, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 20241202185845, tok/s=109.58, tok/s_decode=160.29, ttft=0.5766, mem/s= 822.78 GB/s, peak_mem=20.62 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 -20241202190021, tok/s= 38.13, tok/s_decode=216.50, ttft=4.3203, mem/s= 160.99 GB/s, peak_mem=16.35 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 \ No newline at end of file +20241202190021, tok/s= 38.13, tok/s_decode=216.50, ttft=4.3203, mem/s= 160.99 GB/s, peak_mem=16.35 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241202190201, tok/s=112.36, tok/s_decode=179.92, ttft=0.6674, mem/s= 436.68 GB/s, peak_mem=16.00 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241202190351, tok/s=132.14, tok/s_decode=134.81, ttft=0.0291, mem/s=1983.35 GB/s, peak_mem=16.46 GB, model_size=15.01 GB quant: None, sparse: None, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 +20241202190509, tok/s=138.52, tok/s_decode=162.58, ttft=0.2128, mem/s=1395.88 GB/s, peak_mem=22.90 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: SparseLlama-3-8B-pruned_50.2of4, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 \ No newline at end of file diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index e2c94a7a38..d832731657 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -41,13 +41,17 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( w_vals_int8 = weight_tensor.tensor_impl.int_data w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # must pad + row, col = tmp.shape + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( w_vals_int8, - tmp.t(), + tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t() + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] )