-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Fixes required for LLM models (#3002)
- Loading branch information
Showing
31 changed files
with
1,029 additions
and
353 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
.. _torch_export_gpt2: | ||
Compiling GPT2 using the Torch-TensorRT with dynamo backend | ||
========================================================== | ||
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a GPT2 model.""" | ||
|
||
# %% | ||
# Imports and Model Definition | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
import torch | ||
import torch_tensorrt | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from utils import export_llm, generate | ||
|
||
# %% | ||
|
||
# Define the parameters and initialize the model | ||
MAX_TOKENS = 32 | ||
DEVICE = torch.device("cuda:0") | ||
|
||
# Define the GPT2 model from hugging face | ||
# kv_cache is not supported in Torch-TRT currently. | ||
# CPU is used here so that GPU memory is reserved for TRT compilation. | ||
with torch.no_grad(): | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
model = AutoModelForCausalLM.from_pretrained( | ||
"gpt2", | ||
pad_token_id=tokenizer.eos_token_id, | ||
use_cache=False, | ||
attn_implementation="eager", | ||
).eval() | ||
|
||
# %% | ||
# Tokenize a sample input prompt and get pytorch model outputs | ||
prompt = "I enjoy walking with my cute dog" | ||
model_inputs = tokenizer(prompt, return_tensors="pt") | ||
input_ids = model_inputs["input_ids"] | ||
|
||
# Auto-regressive generation loop for greedy decoding using PyTorch model | ||
# We use a custom generate function which is very similar to the huggingface one. | ||
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) | ||
|
||
|
||
# %% | ||
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
# Export the GPT2 model into an ExportedProgram which is input of TRT compilation | ||
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024) | ||
trt_model = torch_tensorrt.dynamo.compile( | ||
gpt2_ep, | ||
inputs=[input_ids], | ||
enabled_precisions={torch.float32}, | ||
truncate_double=True, | ||
device=DEVICE, | ||
disable_tf32=True, | ||
) | ||
|
||
# Auto-regressive generation loop for greedy decoding using TensorRT model | ||
# We use a custom generate function which is very similar to the huggingface one. | ||
# Move inputs to GPU | ||
input_ids = input_ids.to(DEVICE) | ||
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) | ||
|
||
# %% | ||
# Decode the output sentences of PyTorch and TensorRT | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
print("=============================") | ||
print( | ||
"Pytorch model generated text: ", | ||
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), | ||
) | ||
print("=============================") | ||
print( | ||
"TensorRT model generated text: ", | ||
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), | ||
) | ||
|
||
# %% | ||
# The output sentences should look like | ||
# ============================= | ||
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my | ||
# ============================= | ||
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
""" | ||
.. _torch_export_llama2: | ||
Compiling Llama2 using the Torch-TensorRT with dynamo backend | ||
========================================================== | ||
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a Llama2 model.""" | ||
|
||
# %% | ||
# Imports and Model Definition | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
import torch | ||
import torch_tensorrt | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
from utils import export_llm, generate | ||
|
||
# %% | ||
# Define the parameters and initialize the model | ||
MAX_TOKENS = 32 | ||
DEVICE = torch.device("cuda:0") | ||
|
||
# Define the Llama2 model from hugging face | ||
# kv_cache is not supported in Torch-TRT currently. | ||
# CPU is used here so that GPU memory is reserved for TRT compilation. | ||
llama_path = "meta-llama/Llama-2-7b-chat-hf" | ||
with torch.no_grad(): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
llama_path, use_cache=False, attn_implementation="eager" | ||
).eval() | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(llama_path) | ||
|
||
# %% | ||
# Tokenize a sample input prompt and get pytorch model outputs | ||
prompt = "What is dynamic programming?" | ||
model_inputs = tokenizer(prompt, return_tensors="pt") | ||
input_ids = model_inputs.input_ids | ||
|
||
# Auto-regressive generation loop for greedy decoding using PyTorch model | ||
# We use a custom generate function which is very similar to the huggingface one. | ||
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) | ||
|
||
# %% | ||
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
# Export the llama2 model into an ExportedProgram which is input of TRT compilation | ||
llama2_ep = export_llm(model, input_ids, max_seq_len=64) | ||
trt_model = torch_tensorrt.dynamo.compile( | ||
llama2_ep, | ||
inputs=[input_ids], | ||
enabled_precisions={torch.float32}, | ||
min_block_size=1, | ||
truncate_double=True, | ||
device=DEVICE, | ||
disable_tf32=True, | ||
) | ||
|
||
# Auto-regressive generation loop for greedy decoding using TensorRT model | ||
# We use a custom generate function which is very similar to the huggingface one. | ||
# Move inputs to GPU | ||
input_ids = input_ids.to(DEVICE) | ||
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) | ||
|
||
# %% | ||
# Decode the output sentences of PyTorch and TensorRT | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
print("=============================") | ||
print( | ||
"Pytorch model generated text: ", | ||
tokenizer.batch_decode( | ||
pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
)[0], | ||
) | ||
print("=============================") | ||
print( | ||
"TensorRT model generated text: ", | ||
tokenizer.batch_decode( | ||
trt_gen_tokens, | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False, | ||
)[0], | ||
) | ||
|
||
# %% | ||
# The output sentences should look like | ||
# ============================= | ||
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my | ||
# ============================= | ||
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from transformers import StoppingCriteriaList | ||
from transformers.generation.stopping_criteria import ( | ||
EosTokenCriteria, | ||
MaxLengthCriteria, | ||
) | ||
|
||
|
||
def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): | ||
""" | ||
Exports the LLM model into an ExportedProgram with dynamic shapes. | ||
In the case of guard failures due to some PyTorch kernel implements, we also | ||
try to re-export the graph by expressing them as runtime assert nodes | ||
""" | ||
with torch.no_grad(): | ||
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 | ||
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) | ||
try: | ||
print("Trying to export the model using torch.export.export()..") | ||
# strict=False only enables aotautograd tracing and excludes dynamo. | ||
ep = torch.export.export( | ||
model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False | ||
) | ||
except: | ||
print( | ||
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed" | ||
) | ||
# This API is used to express the constraint violation guards as asserts in the graph. | ||
ep = torch.export._trace._export( | ||
model, | ||
(inputs,), | ||
dynamic_shapes=({1: seq_len},), | ||
strict=False, | ||
allow_complex_guards_as_runtime_asserts=True, | ||
) | ||
|
||
return ep | ||
|
||
|
||
def generate(model, input_seq, max_tokens, eos_token_id): | ||
""" | ||
Greedy decoding of the model. This generates up to max_tokens. | ||
""" | ||
# Max length of output seq = current input_seq length + max_tokens allowed to generate | ||
max_output_seq_length = input_seq.shape[1] + max_tokens | ||
stopping_criteria = StoppingCriteriaList( | ||
[ | ||
MaxLengthCriteria(max_length=max_output_seq_length), | ||
EosTokenCriteria(eos_token_id=eos_token_id), | ||
] | ||
) | ||
|
||
while True: | ||
outputs = model(input_seq) | ||
logits = outputs.logits | ||
next_token_logits = logits[:, -1, :] | ||
next_tokens = torch.argmax(next_token_logits, dim=-1) | ||
input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) | ||
# TODO: Handle batch in this check | ||
if stopping_criteria(input_seq, logits).item(): | ||
break | ||
|
||
return input_seq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.