Skip to content

Commit

Permalink
Added cpu support for llama generate.py/eval.py (#1307)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Nov 20, 2024
1 parent 129316d commit d224653
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
11 changes: 3 additions & 8 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,22 @@
from generate import (
_load_model,
device_sync,

)
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
fpx_weight_only,
uintx_weight_only,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from torchao.quantization.granularity import PerRow, PerTensor

from torchao.quantization import PerRow, PerTensor
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass

def run_evaluation(
checkpoint_path: Path,
Expand Down
31 changes: 19 additions & 12 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
Expand Down Expand Up @@ -345,15 +345,19 @@ def main(
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

if memory_profile:
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
aggregate_metrics = {
'tokens_per_sec': [],
}
start = -1 if compile else 0

for i in range(start, num_samples):
if i==0:
torch.cuda.reset_peak_memory_stats()
if device == "cuda":
torch.cuda.reset_peak_memory_stats() # MKG
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
Expand Down Expand Up @@ -421,15 +425,18 @@ def callback(x):
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
snapshot = torch.cuda.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
snapshot = torch.cuda.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break

print("==========")

Expand Down

0 comments on commit d224653

Please sign in to comment.