Skip to content

Commit

Permalink
chore: Fixes required for LLM models (#3002)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Aug 29, 2024
1 parent 015f13b commit fa812a9
Show file tree
Hide file tree
Showing 31 changed files with 1,029 additions and 353 deletions.
86 changes: 86 additions & 0 deletions examples/dynamo/torch_export_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
.. _torch_export_gpt2:
Compiling GPT2 using the Torch-TensorRT with dynamo backend
==========================================================
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a GPT2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import export_llm, generate

# %%

# Define the parameters and initialize the model
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the GPT2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
).eval()

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"]

# Auto-regressive generation loop for greedy decoding using PyTorch model
# We use a custom generate function which is very similar to the huggingface one.
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)


# %%
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
trt_model = torch_tensorrt.dynamo.compile(
gpt2_ep,
inputs=[input_ids],
enabled_precisions={torch.float32},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
# We use a custom generate function which is very similar to the huggingface one.
# Move inputs to GPU
input_ids = input_ids.to(DEVICE)
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
90 changes: 90 additions & 0 deletions examples/dynamo/torch_export_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
.. _torch_export_llama2:
Compiling Llama2 using the Torch-TensorRT with dynamo backend
==========================================================
This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a Llama2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import export_llm, generate

# %%
# Define the parameters and initialize the model
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the Llama2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()

tokenizer = AutoTokenizer.from_pretrained(llama_path)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "What is dynamic programming?"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs.input_ids

# Auto-regressive generation loop for greedy decoding using PyTorch model
# We use a custom generate function which is very similar to the huggingface one.
pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Export the llama2 model into an ExportedProgram which is input of TRT compilation
llama2_ep = export_llm(model, input_ids, max_seq_len=64)
trt_model = torch_tensorrt.dynamo.compile(
llama2_ep,
inputs=[input_ids],
enabled_precisions={torch.float32},
min_block_size=1,
truncate_double=True,
device=DEVICE,
disable_tf32=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
# We use a custom generate function which is very similar to the huggingface one.
# Move inputs to GPU
input_ids = input_ids.to(DEVICE)
trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.batch_decode(
pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0],
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.batch_decode(
trt_gen_tokens,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0],
)

# %%
# The output sentences should look like
# =============================
# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my
63 changes: 63 additions & 0 deletions examples/dynamo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from transformers import StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)


def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
"""
Exports the LLM model into an ExportedProgram with dynamic shapes.
In the case of guard failures due to some PyTorch kernel implements, we also
try to re-export the graph by expressing them as runtime assert nodes
"""
with torch.no_grad():
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
try:
print("Trying to export the model using torch.export.export()..")
# strict=False only enables aotautograd tracing and excludes dynamo.
ep = torch.export.export(
model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False
)
except:
print(
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
)
# This API is used to express the constraint violation guards as asserts in the graph.
ep = torch.export._trace._export(
model,
(inputs,),
dynamic_shapes=({1: seq_len},),
strict=False,
allow_complex_guards_as_runtime_asserts=True,
)

return ep


def generate(model, input_seq, max_tokens, eos_token_id):
"""
Greedy decoding of the model. This generates up to max_tokens.
"""
# Max length of output seq = current input_seq length + max_tokens allowed to generate
max_output_seq_length = input_seq.shape[1] + max_tokens
stopping_criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=max_output_seq_length),
EosTokenCriteria(eos_token_id=eos_token_id),
]
)

while True:
outputs = model(input_seq)
logits = outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
# TODO: Handle batch in this check
if stopping_criteria(input_seq, logits).item():
break

return input_seq
59 changes: 34 additions & 25 deletions py/torch_tensorrt/dynamo/_DryRunTracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ class PerSubgraphData:
Args:
subgraph_name (str): Name of the subgraph in the GraphModule
subgraph_op_count (int): Number of operations in the subgraph
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
subgraph_input_dtypes (Any): Input data types of the subgraph
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
subgraph_output_dtypes (Any): Output data types of the subgraph
input_shapes (Any): Shapes of input Tensors of the subgraph
input_dtypes (Any): Input data types of the subgraph
output_shapes (Any): Shapes of output Tensors of the subgraph
output_dtypes (Any): Output data types of the subgraph
"""

subgraph_name: str = ""
subgraph_op_count: int = 0
subgraph_input_shapes: Any = field(default_factory=list)
subgraph_input_dtypes: Any = field(default_factory=list)
subgraph_output_shapes: Any = field(default_factory=list)
subgraph_output_dtypes: Any = field(default_factory=list)
input_shapes: Any = field(default_factory=list)
input_dtypes: Any = field(default_factory=list)
output_shapes: Any = field(default_factory=list)
output_dtypes: Any = field(default_factory=list)


@dataclass
Expand All @@ -41,10 +41,10 @@ class DryRunTracker:
Args:
total_ops_in_graph (int): Total number of operators in graph
supported_ops_in_graph (int): Number of supported operators in graph
graph_input_shapes (Any): Shapes of input Tensors of the graph
graph_input_dtypes (Any): Input data types of the graph
graph_output_shapes (Any): Shapes of output Tensors of the graph
graph_output_dtypes (Any): Output data types of the graph
input_shapes (Any): Shapes of input Tensors of the graph
input_dtypes (Any): Input data types of the graph
output_shapes (Any): Shapes of output Tensors of the graph
output_dtypes (Any): Output data types of the graph
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
tensorrt_graph_count (int): Number of TensorRT engines to be generated
compilation_settings (CompilationSettings): User Compilation Settings
Expand All @@ -54,10 +54,10 @@ class DryRunTracker:

total_ops_in_graph: int = 0
supported_ops_in_graph: int = 0
graph_input_shapes: Any = field(default_factory=list)
graph_input_dtypes: Any = field(default_factory=list)
graph_output_shapes: Any = field(default_factory=list)
graph_output_dtypes: Any = field(default_factory=list)
input_shapes: Any = field(default_factory=list)
input_dtypes: Any = field(default_factory=list)
output_shapes: Any = field(default_factory=list)
output_dtypes: Any = field(default_factory=list)
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
tensorrt_graph_count: int = 0
compilation_settings: CompilationSettings = field(
Expand Down Expand Up @@ -111,7 +111,7 @@ def dryrun_stats_display(
formatted_stats += " " * 2 + "Graph Structure:\n\n"
formatted_stats += (
" " * 3
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
+ f"Inputs: {input_formatter(dryrun_tracker.input_shapes, dryrun_tracker.input_dtypes)}\n"
)

for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
Expand All @@ -122,21 +122,21 @@ def dryrun_stats_display(
)
formatted_stats += (
" " * 5
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.input_shapes, trt_subgraph_data.input_dtypes)}\n"
)
formatted_stats += (
" " * 5
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.output_shapes, trt_subgraph_data.output_dtypes)}\n"
)

formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 3
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
+ f"Outputs: {input_formatter(dryrun_tracker.output_shapes, dryrun_tracker.output_dtypes)}\n"
)

# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
Expand Down Expand Up @@ -225,11 +225,20 @@ def input_formatter(shapes: Any, dtypes: Any) -> str:

def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
"""Helper for input formatter"""
# Base case - single shape, single dtype
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "

# Base case - dynamic shape, single dtype
# Base case 1 - single static/dynamic shape, single dtype
if isinstance(shapes, tuple) and all(
isinstance(elt, (int, tuple)) for elt in shapes
):
input_shape_string = "Tensor: ("
for elt in shapes:
if isinstance(elt, tuple):
input_shape_string += f"(min={elt[0]}, max={elt[1]}), "
else:
input_shape_string += f"{elt}, "
input_shape_string = input_shape_string[:-2] + ")" + f"@{str(dtypes)[6:]}, "
return input_shape_string

# Base case 2 - dynamic shape, single dtype
elif (
isinstance(shapes, dict)
and len(shapes) == 3
Expand Down
Loading

0 comments on commit fa812a9

Please sign in to comment.