Skip to content

Commit

Permalink
undo change
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Dec 3, 2024
1 parent 0e579ae commit 443db19
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,5 @@ TTFT(Time to First Token) Benchmarks
20241202185149, tok/s= 86.09, tok/s_decode=113.40, ttft=0.5585, mem/s=1292.25 GB/s, peak_mem=35.37 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185310, tok/s=113.04, tok/s_decode=167.77, ttft=0.5761, mem/s= 850.12 GB/s, peak_mem=28.81 GB, model_size= 7.52 GB quant: int8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185731, tok/s=113.14, tok/s_decode=157.50, ttft=0.4971, mem/s= 849.12 GB/s, peak_mem=20.65 GB, model_size= 7.51 GB quant: float8dq, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185845, tok/s=109.58, tok/s_decode=160.29, ttft=0.5766, mem/s= 822.78 GB/s, peak_mem=20.62 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202185845, tok/s=109.58, tok/s_decode=160.29, ttft=0.5766, mem/s= 822.78 GB/s, peak_mem=20.62 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20241202190021, tok/s= 38.13, tok/s_decode=216.50, ttft=4.3203, mem/s= 160.99 GB/s, peak_mem=16.35 GB, model_size= 4.22 GB quant: int4wo-64, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
13 changes: 8 additions & 5 deletions torchao/sparsity/sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import torch
from typing import Callable, Optional

import torch
from torch.ao.pruning import WeightNormSparsifier
from torch.sparse import to_sparse_semi_structured

from torchao.quantization.quant_api import (
_get_linear_subclass_inserter,
_is_linear,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)

from typing import Callable, Optional


# Sparsity helper functions
def apply_fake_sparsity(model, **kwargs):
Expand Down Expand Up @@ -76,4 +75,8 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
from torchao.dtypes import SemiSparseLayout
m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn)
"""
quantize_(model, apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn)
_replace_with_custom_fn_if_matches_filter(
model,
apply_tensor_subclass,
_is_linear if filter_fn is None else filter_fn,
)

0 comments on commit 443db19

Please sign in to comment.