Skip to content

Commit

Permalink
Add TTFT benchmarks + update sparsity benchmarks (#1140)
Browse files Browse the repository at this point in the history
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available.

Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
  • Loading branch information
jcaip authored Dec 4, 2024
1 parent b7630f1 commit 1a0dbf1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 16 deletions.
4 changes: 4 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
# neuralmagic doesn't come with tokenizer, so we need to copy it over
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4
3 changes: 3 additions & 0 deletions test/prototype/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def test_sparse(self):
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

if compile:
model = torch.compile(model)

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


Expand Down
21 changes: 19 additions & 2 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

Expand All @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

Expand All @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128

# TTFT benchmarks
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured

# 2:4 sparse model
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
116 changes: 104 additions & 12 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

class HostEvent:
def __init__(self):
self.event_time = None

def record(self):
self.event_time = time.perf_counter()

def elapsed_time(self, other_event):
if self.event_time is None:
raise ValueError("Event not recorded!")
# return ms to match cuda event
return abs(other_event.event_time - self.event_time) * 1000

def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
elif ("cpu" in device) or ("mps" in device):
return HostEvent()
else:
print(f"device={device} is not yet suppported")

def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
Expand Down Expand Up @@ -98,6 +121,10 @@ def generate(
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
prefill_start_event: Optional[torch.cuda.Event]=None,
prefill_end_event: Optional[torch.cuda.Event]=None,
decode_start_event: Optional[torch.cuda.Event]=None,
decode_end_event: Optional[torch.cuda.Event]=None,
**sampling_kwargs
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -128,12 +155,21 @@ def generate(
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# execute prefill
if prefill_start_event is not None:
prefill_start_event.record()
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
seq[:, T] = next_token.squeeze()
if prefill_end_event is not None:
prefill_end_event.record()

# execute token generation
if decode_start_event is not None:
decode_start_event.record()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
if decode_end_event is not None:
decode_end_event.record()

return seq

Expand All @@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
B_INST, E_INST = "[INST]", "[/INST]"

def main(
prefill_size: Optional[int] = None,
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
Expand All @@ -166,6 +203,7 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
sparsity: Optional[str] = None,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
Expand All @@ -181,6 +219,10 @@ def main(
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""

if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
prompt = "prompt " * (int(prefill_size)-3)

torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
Expand All @@ -205,6 +247,14 @@ def main(

torch.manual_seed(1234)

def ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn

def not_ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn)

def ffn_or_attn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn)

if quantization:
from torchao.quantization import (
Expand All @@ -228,9 +278,14 @@ def main(
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
elif "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
elif "int4wo" in quantization:
if "int8dq" in quantization:
if sparsity and "semi" in sparsity:
from torchao.dtypes import SemiSparseLayout
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only)
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only)
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
else:
Expand All @@ -250,9 +305,9 @@ def main(
layout=MarlinQQQLayout(),
),
)
else:
elif "semi" in sparsity:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only)
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
elif "embed-int8wo" in quantization:
Expand Down Expand Up @@ -440,6 +495,13 @@ def main(
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

# standalone sparsity
elif sparsity:
from torchao.sparsity import semi_sparse_weight, sparsify_
if "semi" in sparsity:
#TODO there is a bug here, need to fix
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if save:
Expand All @@ -465,6 +527,9 @@ def main(

aggregate_metrics = {
'tokens_per_sec': [],
'time': [],
'decode_tokens_per_sec': [],
'prefill_time': [],
}
start = -1 if compile else 0

Expand Down Expand Up @@ -499,6 +564,8 @@ def callback(x):
else:
callback = lambda x : x
t0 = time.perf_counter()
prefill_start_event, prefill_end_event = device_timer(device), device_timer(device)
decode_start_event, decode_end_event = device_timer(device), device_timer(device)
import contextlib
if (i != num_samples - 1 or not profile):
prof = contextlib.nullcontext()
Expand All @@ -518,6 +585,10 @@ def callback(x):
kv_cache_quantization=kv_cache_quantization,
cache_size=cache_size,
linear_causal_mask=linear_causal_mask,
prefill_start_event=prefill_start_event,
prefill_end_event=prefill_end_event,
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand All @@ -527,7 +598,7 @@ def callback(x):
device_sync(device=device) # MKG
t = time.perf_counter() - t0

if not interactive:
if not interactive and prefill_size is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
Expand All @@ -537,7 +608,14 @@ def callback(x):
tokens_generated = (y.size(-1) - prompt_length)
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
aggregate_metrics['time'].append(t)
decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000
decode_tokens_sec = tokens_generated / decode_time
aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec)
prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000
aggregate_metrics['prefill_time'].append(prefill_time)
print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec",
f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
Expand All @@ -558,8 +636,15 @@ def callback(x):
break
print("==========")

#ignore first sample for warmup
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item()
decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() /1e9
elif device == "xpu":
Expand All @@ -571,15 +656,17 @@ def callback(x):
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--sparsity {sparsity} " if sparsity else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
result_txt += f"--profile {profile} " if profile else ""
result_txt += f"--profile {memory_profile} " if memory_profile else ""
result_txt += f"--interactive " if interactive else ""
Expand All @@ -601,7 +688,7 @@ def callback(x):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')

parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode')
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
Expand All @@ -617,6 +704,11 @@ def callback(x):
+'embed-int8wo, marlin_qqq'
)
)
parser.add_argument('-s', '--sparsity', type=str,
help=(
'Which sparsity techniques to apply: semi-structured'
)
)
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand All @@ -631,6 +723,6 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
Loading

0 comments on commit 1a0dbf1

Please sign in to comment.