diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py new file mode 100644 index 0000000000..a26305e4a3 --- /dev/null +++ b/examples/dynamo/torch_export_gpt2.py @@ -0,0 +1,86 @@ +""" +.. _torch_export_gpt2: + +Compiling GPT2 using the Torch-TensorRT with dynamo backend +========================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a GPT2 model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import export_llm, generate + +# %% + +# Define the parameters and initialize the model +MAX_TOKENS = 32 +DEVICE = torch.device("cuda:0") + +# Define the GPT2 model from hugging face +# kv_cache is not supported in Torch-TRT currently. +# CPU is used here so that GPU memory is reserved for TRT compilation. +with torch.no_grad(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model = AutoModelForCausalLM.from_pretrained( + "gpt2", + pad_token_id=tokenizer.eos_token_id, + use_cache=False, + attn_implementation="eager", + ).eval() + +# %% +# Tokenize a sample input prompt and get pytorch model outputs +prompt = "I enjoy walking with my cute dog" +model_inputs = tokenizer(prompt, return_tensors="pt") +input_ids = model_inputs["input_ids"] + +# Auto-regressive generation loop for greedy decoding using PyTorch model +# We use a custom generate function which is very similar to the huggingface one. +pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + + +# %% +# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Export the GPT2 model into an ExportedProgram which is input of TRT compilation +gpt2_ep = export_llm(model, input_ids, max_seq_len=1024) +trt_model = torch_tensorrt.dynamo.compile( + gpt2_ep, + inputs=[input_ids], + enabled_precisions={torch.float32}, + truncate_double=True, + device=DEVICE, + disable_tf32=True, +) + +# Auto-regressive generation loop for greedy decoding using TensorRT model +# We use a custom generate function which is very similar to the huggingface one. +# Move inputs to GPU +input_ids = input_ids.to(DEVICE) +trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + +# %% +# Decode the output sentences of PyTorch and TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +print("=============================") +print( + "Pytorch model generated text: ", + tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), +) +print("=============================") +print( + "TensorRT model generated text: ", + tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), +) + +# %% +# The output sentences should look like +# ============================= +# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# ============================= +# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py new file mode 100644 index 0000000000..195944688b --- /dev/null +++ b/examples/dynamo/torch_export_llama2.py @@ -0,0 +1,90 @@ +""" +.. _torch_export_llama2: + +Compiling Llama2 using the Torch-TensorRT with dynamo backend +========================================================== + +This interactive script is intended as a sample of the Torch-TensorRT workflow with dynamo backend on a Llama2 model.""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils import export_llm, generate + +# %% +# Define the parameters and initialize the model +MAX_TOKENS = 32 +DEVICE = torch.device("cuda:0") + +# Define the Llama2 model from hugging face +# kv_cache is not supported in Torch-TRT currently. +# CPU is used here so that GPU memory is reserved for TRT compilation. +llama_path = "meta-llama/Llama-2-7b-chat-hf" +with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + llama_path, use_cache=False, attn_implementation="eager" + ).eval() + +tokenizer = AutoTokenizer.from_pretrained(llama_path) + +# %% +# Tokenize a sample input prompt and get pytorch model outputs +prompt = "What is dynamic programming?" +model_inputs = tokenizer(prompt, return_tensors="pt") +input_ids = model_inputs.input_ids + +# Auto-regressive generation loop for greedy decoding using PyTorch model +# We use a custom generate function which is very similar to the huggingface one. +pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + +# %% +# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Export the llama2 model into an ExportedProgram which is input of TRT compilation +llama2_ep = export_llm(model, input_ids, max_seq_len=64) +trt_model = torch_tensorrt.dynamo.compile( + llama2_ep, + inputs=[input_ids], + enabled_precisions={torch.float32}, + min_block_size=1, + truncate_double=True, + device=DEVICE, + disable_tf32=True, +) + +# Auto-regressive generation loop for greedy decoding using TensorRT model +# We use a custom generate function which is very similar to the huggingface one. +# Move inputs to GPU +input_ids = input_ids.to(DEVICE) +trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id) + +# %% +# Decode the output sentences of PyTorch and TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +print("=============================") +print( + "Pytorch model generated text: ", + tokenizer.batch_decode( + pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0], +) +print("=============================") +print( + "TensorRT model generated text: ", + tokenizer.batch_decode( + trt_gen_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0], +) + +# %% +# The output sentences should look like +# ============================= +# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my +# ============================= +# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py new file mode 100644 index 0000000000..25ad99c12d --- /dev/null +++ b/examples/dynamo/utils.py @@ -0,0 +1,63 @@ +import torch +from transformers import StoppingCriteriaList +from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, +) + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + (inputs,), + dynamic_shapes=({1: seq_len},), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + +def generate(model, input_seq, max_tokens, eos_token_id): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + # Max length of output seq = current input_seq length + max_tokens allowed to generate + max_output_seq_length = input_seq.shape[1] + max_tokens + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + + while True: + outputs = model(input_seq) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + # TODO: Handle batch in this check + if stopping_criteria(input_seq, logits).item(): + break + + return input_seq diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 46d99ffe31..43789c4a0f 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -20,18 +20,18 @@ class PerSubgraphData: Args: subgraph_name (str): Name of the subgraph in the GraphModule subgraph_op_count (int): Number of operations in the subgraph - subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph - subgraph_input_dtypes (Any): Input data types of the subgraph - subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph - subgraph_output_dtypes (Any): Output data types of the subgraph + input_shapes (Any): Shapes of input Tensors of the subgraph + input_dtypes (Any): Input data types of the subgraph + output_shapes (Any): Shapes of output Tensors of the subgraph + output_dtypes (Any): Output data types of the subgraph """ subgraph_name: str = "" subgraph_op_count: int = 0 - subgraph_input_shapes: Any = field(default_factory=list) - subgraph_input_dtypes: Any = field(default_factory=list) - subgraph_output_shapes: Any = field(default_factory=list) - subgraph_output_dtypes: Any = field(default_factory=list) + input_shapes: Any = field(default_factory=list) + input_dtypes: Any = field(default_factory=list) + output_shapes: Any = field(default_factory=list) + output_dtypes: Any = field(default_factory=list) @dataclass @@ -41,10 +41,10 @@ class DryRunTracker: Args: total_ops_in_graph (int): Total number of operators in graph supported_ops_in_graph (int): Number of supported operators in graph - graph_input_shapes (Any): Shapes of input Tensors of the graph - graph_input_dtypes (Any): Input data types of the graph - graph_output_shapes (Any): Shapes of output Tensors of the graph - graph_output_dtypes (Any): Output data types of the graph + input_shapes (Any): Shapes of input Tensors of the graph + input_dtypes (Any): Input data types of the graph + output_shapes (Any): Shapes of output Tensors of the graph + output_dtypes (Any): Output data types of the graph per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class tensorrt_graph_count (int): Number of TensorRT engines to be generated compilation_settings (CompilationSettings): User Compilation Settings @@ -54,10 +54,10 @@ class DryRunTracker: total_ops_in_graph: int = 0 supported_ops_in_graph: int = 0 - graph_input_shapes: Any = field(default_factory=list) - graph_input_dtypes: Any = field(default_factory=list) - graph_output_shapes: Any = field(default_factory=list) - graph_output_dtypes: Any = field(default_factory=list) + input_shapes: Any = field(default_factory=list) + input_dtypes: Any = field(default_factory=list) + output_shapes: Any = field(default_factory=list) + output_dtypes: Any = field(default_factory=list) per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) tensorrt_graph_count: int = 0 compilation_settings: CompilationSettings = field( @@ -111,7 +111,7 @@ def dryrun_stats_display( formatted_stats += " " * 2 + "Graph Structure:\n\n" formatted_stats += ( " " * 3 - + f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n" + + f"Inputs: {input_formatter(dryrun_tracker.input_shapes, dryrun_tracker.input_dtypes)}\n" ) for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): @@ -122,7 +122,7 @@ def dryrun_stats_display( ) formatted_stats += ( " " * 5 - + f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n" + + f"Engine Inputs: {input_formatter(trt_subgraph_data.input_shapes, trt_subgraph_data.input_dtypes)}\n" ) formatted_stats += ( " " * 5 @@ -130,13 +130,13 @@ def dryrun_stats_display( ) formatted_stats += ( " " * 5 - + f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n" + + f"Engine Outputs: {input_formatter(trt_subgraph_data.output_shapes, trt_subgraph_data.output_dtypes)}\n" ) formatted_stats += " " * 4 + "...\n" formatted_stats += ( " " * 3 - + f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n" + + f"Outputs: {input_formatter(dryrun_tracker.output_shapes, dryrun_tracker.output_dtypes)}\n" ) # Print aggregate statistics about the graph structure, including recommended "min_block_size" options @@ -225,11 +225,20 @@ def input_formatter(shapes: Any, dtypes: Any) -> str: def input_formatter_helper(shapes: Any, dtypes: Any) -> str: """Helper for input formatter""" - # Base case - single shape, single dtype - if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes): - return f"Tensor: {shapes}@{str(dtypes)[6:]}, " - - # Base case - dynamic shape, single dtype + # Base case 1 - single static/dynamic shape, single dtype + if isinstance(shapes, tuple) and all( + isinstance(elt, (int, tuple)) for elt in shapes + ): + input_shape_string = "Tensor: (" + for elt in shapes: + if isinstance(elt, tuple): + input_shape_string += f"(min={elt[0]}, max={elt[1]}), " + else: + input_shape_string += f"{elt}, " + input_shape_string = input_shape_string[:-2] + ")" + f"@{str(dtypes)[6:]}, " + return input_shape_string + + # Base case 2 - dynamic shape, single dtype elif ( isinstance(shapes, dict) and len(shapes) == 3 diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 0e5e09de8a..a4849f257e 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -35,8 +35,7 @@ ) from torch_tensorrt.dynamo.utils import ( get_flat_args_with_check, - get_torch_inputs, - parse_complex_tensor_structs, + parse_graph_io, prepare_inputs, set_log_level, to_torch_device, @@ -220,6 +219,7 @@ def compile( ) gm = exported_program.module() logger.debug("Input graph: " + str(gm.graph)) + # Apply lowering on the graph module gm = post_lowering(gm) logger.debug("Lowered Input graph: " + str(gm.graph)) @@ -299,14 +299,6 @@ def compile_module( dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops - dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs( - sample_arg_inputs, - "shape", - lambda x: dict(x) if isinstance(x, dict) else tuple(x), - ) - dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( - sample_arg_inputs, "dtype", lambda t: t.to(torch.dtype, use_default=True) - ) dryrun_tracker.compilation_settings = settings if settings.dryrun and settings.min_block_size > 1: @@ -393,6 +385,11 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) + logger.debug( + "Submodule in PyTorch: %s\n %s", + str(name), + str(submodule.graph), + ) continue subgraph_data = PerSubgraphData() @@ -427,28 +424,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: name, ) - subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs( - submodule_inputs, - "shape", - lambda x: dict(x) if isinstance(x, dict) else tuple(x), - ) - subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( - submodule_inputs, "dtype", lambda t: t.to(torch.dtype) - ) - - submodule_outputs = submodule( - *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) - ) - - subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs( - submodule_outputs, - "shape", - lambda x: dict(x) if isinstance(x, dict) else tuple(x), - ) - subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs( - submodule_outputs, "dtype" - ) - + # Parse the subgraph I/O and store it + parse_graph_io(submodule, subgraph_data) dryrun_tracker.tensorrt_graph_count += 1 dryrun_tracker.per_subgraph_data.append(subgraph_data) @@ -463,23 +440,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: trt_modules[name] = trt_module - torch_sample_arg_inputs = get_torch_inputs( - sample_arg_inputs, to_torch_device(settings.device) - ) - torch_sample_kwarg_inputs = get_torch_inputs( - sample_kwarg_inputs, to_torch_device(settings.device) - ) - sample_outputs = gm(*torch_sample_arg_inputs, **torch_sample_kwarg_inputs) - - if not isinstance(sample_outputs, (list, tuple)): - sample_outputs = [sample_outputs] - - dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs( - sample_outputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) - ) - dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs( - sample_outputs, "dtype" - ) + # Parse the graph I/O and store it in dryrun tracker + parse_graph_io(gm, dryrun_tracker) # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 4cedcb80cb..e0643cf996 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -16,7 +16,7 @@ TRTInterpreterResult, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -from torch_tensorrt.dynamo.utils import get_torch_inputs +from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs logger = logging.getLogger(__name__) @@ -29,15 +29,21 @@ def infer_module_output_dtypes( truncate_double: bool = False, ) -> List[dtype]: """ - inputs can be either arg_inputs or flattened input list. If it is flattened list, kwarg_inputs - should be None, as it is already included in the flattened input. + This function performs model inference to determine the output dtypes + and truncates them accordingly. inputs can be either arg_inputs or flattened input list. + If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. """ + # TODO: We can also determine output dtypes from the module.graph based on node metadata. + # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, + # so we stick to the model inference approach currently. with unset_fake_temporarily(): + # Get the device on which the model exists + # For large models, this can be done on CPU to save GPU memory allocation for TRT. + device = get_model_device(module) torch_inputs = get_torch_inputs(inputs, device) if kwarg_inputs is None: kwarg_inputs = {} torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - module = module.to(device.to(torch.device)) module_outputs = module(*torch_inputs, **torch_kwarg_inputs) if not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] @@ -106,6 +112,7 @@ def interpret_module_to_result( output_dtypes=output_dtypes, compilation_settings=settings, ) + interpreter_result = interpreter.run() return interpreter_result diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9878119d60..b4daaaff25 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -319,6 +319,7 @@ def aten_ops_embedding_bag( ) +@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor, supports_dynamic_shapes=True) def aten_ops_fmod( @@ -2023,6 +2024,7 @@ def aten_ops_div( ) +@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True) @dynamo_tensorrt_converter( torch.ops.aten.pow.Tensor_Tensor, supports_dynamic_shapes=True ) @@ -3106,16 +3108,20 @@ def upsample_compute_output_size( input_size: torch.Size, output_size: Optional[Sequence[int]], scale_factors: Optional[Sequence[float]], -) -> Sequence[int]: +) -> Optional[Sequence[int]]: spatial_dimensions = len(input_size) - 2 + if output_size is None and scale_factors is None: + raise AssertionError( + "Must specify exactly one of output_size and scale_factors" + ) + if output_size is not None: torch._check( scale_factors is None, lambda: "Must specify exactly one of output_size and scale_factors", ) torch._check(len(output_size) == spatial_dimensions) - return output_size if scale_factors is not None: torch._check( @@ -3126,11 +3132,8 @@ def upsample_compute_output_size( output_size = [] for i, s in enumerate(scale_factors): output_size.append(int(input_size[i + 2] * s)) - return output_size - torch._check( - False, lambda: "Must specify exactly one of output_size and scale_factors" - ) + return output_size @torch.ops.aten.upsample_nearest1d.vec.py_impl( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index af0d6b720a..70135f86d3 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np +import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Argument, Target @@ -16,8 +17,6 @@ DynamoConverterImplSignature, ) -import tensorrt as trt - from ..types import Shape, TRTDataType, TRTLayer, TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -157,10 +156,9 @@ def cast_trt_tensor( target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" - identity_layer = ctx.net.add_identity(input_val) - identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" - return identity_layer.get_output(0) + cast_layer = ctx.net.add_cast(input_val, trt_dtype) + cast_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" + return cast_layer.get_output(0) else: return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index b6d024eb08..0b69f98fc9 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -2,17 +2,19 @@ from typing import Optional, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, + get_trt_tensor, +) from torch_tensorrt.fx.types import TRTDataType, TRTTensor -import tensorrt as trt - LOGGER: logging.Logger = logging.getLogger(__name__) @@ -21,14 +23,12 @@ def to_copy( target: Target, source_ir: Optional[SourceIR], name: str, - input: TRTTensor, + input: Union[TRTTensor, torch.Tensor, np.ndarray], dtype: Union[TRTDataType, torch.dtype, np.dtype, _enums.dtype], force_layer: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): - raise RuntimeError( - f"to_copy received input {input} that is not a TensorRT ITensor" - ) + input = get_trt_tensor(ctx, input, f"{name}_copy_tensor") # If cast is forced, insert identity layer regardless of whether the dtype # doesn't change @@ -38,10 +38,9 @@ def to_copy( target_str = ConverterRegistry.qualified_name_or_str(target) target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" - identity_layer = ctx.net.add_identity(input) - identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]" - return identity_layer.get_output(0) + cast_layer = ctx.net.add_cast(input, trt_dtype) + cast_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]" + return cast_layer.get_output(0) else: casted_tensor = cast_trt_tensor(ctx, input, dtype, name, target, source_ir) return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py index c7502cf97e..6a6b4ea3a1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -546,12 +546,13 @@ def pow( lhs_val: Union[TRTTensor, int, float], rhs_val: Union[TRTTensor, int, float], ) -> TRTTensor: - if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor): - lhs_val, rhs_val = cast_int_int_div_trt_tensor(ctx, lhs_val, rhs_val, name) - - return convert_binary_elementwise( + # POW operation supports only float32 and int8 inputs + lhs_val = get_trt_tensor(ctx, lhs_val, name + "_lhs_val", trt.float32) + rhs_val = get_trt_tensor(ctx, rhs_val, name + "_rhs_val", trt.float32) + out = convert_binary_elementwise( ctx, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val ) + return out def floor_divide( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/full.py b/py/torch_tensorrt/dynamo/conversion/impl/full.py index 34a2af564f..fc079f7f32 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/full.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/full.py @@ -66,10 +66,9 @@ def full( output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0) # https://stackoverflow.com/questions/37888620/comparing-boolean-and-int-using-isinstance if type(fill_value) in (int, float): - if isinstance(fill_value, float): - output = cast_trt_tensor( - ctx, output, trt.float32, name + "_casted", target, source_ir - ) + output = cast_trt_tensor( + ctx, output, output_dtype, name + "_casted", target, source_ir + ) output = impl.elementwise.add( ctx, target, source_ir, name + "_add", output, fill_value ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py index 77dd7ae6f5..12e8abf00d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -24,6 +24,7 @@ def matrix_multiply( input_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, other_matrix_op: trt.MatrixOperation = trt.MatrixOperation.NONE, ) -> TRTTensor: + if not isinstance(input, trt.ITensor): input = get_trt_tensor(ctx, input, f"{name}_input") if not isinstance(other, trt.ITensor): diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 1d498a9930..587bfd0373 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -156,6 +156,13 @@ def layer_norm( axes = get_axes_for_reduce_op(dims) weight = get_trt_tensor(ctx, weight, f"{name}_weight") bias = get_trt_tensor(ctx, bias, f"{name}_bias") + # Cast weight and bias to have same dtype as input + weight = cast_trt_tensor( + ctx, weight, input.dtype, f"{name}_weight_cast", target, source_ir + ) + bias = cast_trt_tensor( + ctx, bias, input.dtype, f"{name}_bias_cast", target, source_ir + ) if tuple(input.shape) != tuple(weight.shape): weight = impl.slice.expand( ctx, target, source_ir, f"{name}_expand_weight", weight, input.shape diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 6653e9e1a5..48193dbe11 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -96,10 +96,9 @@ def index( tensor_indices.append(ind) if not tensor_indices: - identity_layer = ctx.net.add_identity(input) - identity_layer.set_output_type(0, trt.int32) - set_layer_name(identity_layer, target, name + "_index_identity", source_ir) - return identity_layer.get_output(0) + cast_layer = ctx.net.add_cast(input, trt.int32) + set_layer_name(cast_layer, target, name + "_index_casted", source_ir) + return cast_layer.get_output(0) elif len(tensor_indices) == 1: indices_tensor = get_trt_tensor( ctx, tensor_indices[0], name + "_parameter_to_fp32_tensor" diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 378d407416..7fe0032d80 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -216,7 +216,7 @@ def slice_scatter_decomposition( index_tensor_shape.append(src_each_dim) for index in range(start, end, step): cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64)) - index_tensor = torch.stack(cat_tensors, dim).cuda() + index_tensor = torch.stack(cat_tensors, dim).to(input_tensor.device) index_tensor_64 = index_tensor.to(torch.int64) output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor) return output_tensor @@ -271,7 +271,7 @@ def scatter_add_decomposition( index_slice = torch.unsqueeze(index_slice, dim) # moving tensor to default device - device = to_torch_device(default_device()) + device = input_tensor.device scatter_add_tensor = scatter_add_tensor.to(device) to_scatter_tensor = to_scatter_tensor.to(device) index_slice = index_slice.to(device) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index dc76ca8036..b7c65f1880 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -8,6 +8,7 @@ from .lower_linear import lower_linear from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager +from .remove_assert_scalar import remove_assert_scalar from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -26,6 +27,7 @@ replace_max_pool_with_indices, replace_full_like_with_full, view_to_reshape, + remove_assert_scalar, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 0b2500f6f2..956cbd5a4d 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -30,9 +30,11 @@ def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: cf = _TorchTensorRTConstantFolder(gm, skip_constructors=False) cf.run() + # The constants are created on CPU to save GPU memory for TensorRT compilation. + # For TRT INetwork construction the constants are moved to CPU in get_attr call. for node, constant in cf.node_replacements.items(): replace_node_with_constant( - gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False) + gm, node, torch.nn.Parameter(constant, requires_grad=False) ) erased_params = [] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py index 19a21cf01d..bd9c873590 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py @@ -24,7 +24,6 @@ def lower_scaled_dot_product_attention( """ original_fns, replacement = scaled_dot_product_attention_replacement() replaced_nodes = [] - # For each original function, search for it in the graph and replace for original in original_fns: replaced_nodes += torch.fx.subgraph_rewriter.replace_pattern_with_filters( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py new file mode 100644 index 0000000000..ee468145f6 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py @@ -0,0 +1,24 @@ +import logging + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def remove_assert_scalar(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove assert_scalar ops in the graph""" + count = 0 + for node in gm.graph.nodes: + if node.target == torch.ops.aten._assert_scalar.default: + gm.graph.erase_node(node) + count += 1 + + if count > 0: + gm = clean_up_graph_after_modifications(gm) + + logger.debug(f"Removed {count} assert_scalar nodes:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py index a7e163c43b..81197b4ab0 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py @@ -43,7 +43,8 @@ def replace_max_pool_with_indices( kwargs=node.kwargs, ) maxpool_fused.meta = node.meta - + # The metadata for this node should exclude the indices metadata + maxpool_fused.meta["val"] = maxpool_fused.meta["val"][0] logger.debug( f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} " f"is the only user of placeholder {node} and was inserted by the compiler." diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index d87cda60fb..bd4fe73406 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -6,19 +6,16 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import DEBUG +from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info logger = logging.getLogger(__name__) -def contains_sym_int(tensor: torch.Tensor) -> bool: - """ - Returns true if the given tensor has symbolic shape. - """ - return any(isinstance(dim, torch.SymInt) for dim in tensor) - - def construct_dynamic_input( - input_shape: torch.Size, input_dtype: torch.dtype, is_shape_tensor: bool = False + input_shape: torch.Size, + input_dtype: torch.dtype, + name: str = "", + is_shape_tensor: bool = False, ) -> Input: """ Constructs a torch_tensorrt.Input based on a symbolic input @@ -32,27 +29,10 @@ def construct_dynamic_input( max_shape = [] for dim in input_shape: if isinstance(dim, torch.SymInt): - node = dim.node - expr = node.expr - shape_env = node.shape_env - # An expr can be a independent SymInt node (eg: s0 or s1) or a composition of them eg: (48*s0 or s0*s1). - # In the case of expr which has symbolic computation, bound_sympy evaluates them. - # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy - # expr.xreplace replaces the symbolic variables with their current values and computes the expression. - var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy( - expr - ) - var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace( - shape_env.var_to_val - ) - assert var_range, var_val - # Torchdynamo 0/1 specialization outlier - if var_range.lower == 2: - min_shape.append(1) - else: - min_shape.append(int(var_range.lower)) - opt_shape.append(int(var_val)) - max_shape.append(int(var_range.upper)) + min_max_opt = extract_var_range_info(dim) + min_shape.append(min_max_opt["min"]) + opt_shape.append(min_max_opt["opt"]) + max_shape.append(min_max_opt["max"]) else: min_shape.append(dim) opt_shape.append(dim) @@ -63,22 +43,28 @@ def construct_dynamic_input( opt_shape=opt_shape, max_shape=max_shape, dtype=input_dtype, + name=name, is_shape_tensor=is_shape_tensor, ) def get_input( - input_shape: torch.Size, dtype: torch.dtype, is_shape_tensor: bool = False + input_shape: torch.Size, + dtype: torch.dtype, + name: str = "", + is_shape_tensor: bool = False, ) -> Input: """ Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs """ if contains_sym_int(input_shape): return construct_dynamic_input( - input_shape, dtype, is_shape_tensor=is_shape_tensor + input_shape, dtype, name=name, is_shape_tensor=is_shape_tensor ) else: - return Input(shape=input_shape, dtype=dtype, is_shape_tensor=is_shape_tensor) + return Input( + shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor + ) def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: @@ -101,11 +87,18 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: input_meta = input.meta["val"] if isinstance(input_meta, (FakeTensor, torch.Tensor)): input_shape = input_meta.size() - torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + torchtrt_inputs.append( + get_input(input_shape, input_meta.dtype, name=input.name) + ) elif isinstance(input_meta, torch.SymInt): # Assuming sym_integers | shape inputs always have torch.int64 dtype torchtrt_inputs.append( - get_input([input_meta], torch.int64, is_shape_tensor=True) + get_input( + [input_meta], + torch.int64, + name=input.name, + is_shape_tensor=True, + ) ) else: raise ValueError( @@ -115,7 +108,9 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]: elif "tensor_meta" in input.meta: input_meta = input.meta["tensor_meta"] input_shape = input_meta.shape - torchtrt_inputs.append(get_input(input_shape, input_meta.dtype)) + torchtrt_inputs.append( + get_input(input_shape, input_meta.dtype, name=input.name) + ) else: raise AssertionError( f"Input {input.name} does not contain val and tensor_meta fields in the metadata. Please ensure you have exported the graph correctly" diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 6d74ab61bf..dfd22e7f9f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -8,6 +8,7 @@ import numpy as np import tensorrt as trt import torch +from torch._subclasses.fake_tensor import FakeTensor from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input @@ -129,46 +130,70 @@ def input_is_dynamic(inputs: Sequence[Union[Input, torch.Tensor]]) -> bool: ) +def get_torch_tensor( + input: Input, + device: torch.device, + mode: str = "", +) -> Union[int, torch.Tensor]: + if input.is_shape_tensor: + # TODO: All the shape tensors we've encountered so far are plain integers. + # Validate this assumption on more models. + return input.shape["opt_shape"][0] + + if len(mode) > 0: + return input.example_tensor(mode).to(device) + else: + return input.torch_tensor.to(device) + + def get_torch_inputs( inputs: Sequence[Input] | Dict[Any, Any], device: Union[Device, torch.device, str], mode: str = "", -) -> Sequence[torch.tensor] | Dict[Any, Any]: +) -> Sequence[torch.Tensor] | Dict[str, torch.Tensor]: """ Return the torch_tensor from the Input object. If mode is set, this implies user is using dynamic shaped inputs and return the corresponding input based on the mode requested. """ device = to_torch_device(device) - if mode: - if isinstance(inputs, dict): - result = {} - for k, v in inputs.items(): - if isinstance(v, (list, tuple, dict)): - result[k] = get_torch_inputs(v, device) - else: - result[k] = v.example_tensor(mode).to(device) - return result - else: - return [ - input.example_tensor(mode).to(device) - for input in inputs - if isinstance(input, Input) - ] if isinstance(inputs, dict): result = {} for k, v in inputs.items(): if isinstance(v, (list, tuple, dict)): result[k] = get_torch_inputs(v, device) - else: - result[k] = v.torch_tensor.to(device) - return result + elif isinstance(v, Input): + result[k] = get_torch_tensor(v, device, mode) else: - return [ - input.torch_tensor.to(device) if isinstance(input, Input) else input - for input in inputs - ] + result = [] + for input in inputs: + if isinstance(input, Input): + result.append(get_torch_tensor(input, device, mode)) + elif isinstance(input, torch.Tensor): + result.append(input.to(device)) + else: + raise AssertionError(f"Input type {type(input)} is not a valid type") + + return result + + +def get_model_device(module: torch.fx.GraphModule) -> Union[Device, torch.device, str]: + """ + Returns the device on which the module parameters exist. + """ + device = None + for parameter in list(module.parameters()): + if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)): + device = parameter.device + break + + if device is None: + device = torch.device("cpu") + logger.warning( + "Could not detect the device on which the model exists. Assuming the model is on CPU" + ) + return device def set_log_level(parent_logger: Any, level: Any) -> None: @@ -273,6 +298,118 @@ def parse_complex_tensor_structs( ) +def contains_sym_int(tensor: torch.Tensor) -> bool: + """ + Returns true if the given tensor has symbolic shape. + """ + return any(isinstance(dim, torch.SymInt) for dim in tensor) + + +def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Any]: + """ + This function returns the min, max, opt values of a symbolic integer. + """ + node = symbolic_integer.node + expr = node.expr + shape_env = node.shape_env + # An expr can be a independent SymInt node (eg: s0 or s1) or a composition of them eg: (48*s0 or s0*s1). + # In the case of expr which has symbolic computation, bound_sympy evaluates them. + # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy + # expr.xreplace replaces the symbolic variables with their current values and computes the expression. + var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy(expr) + var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace( + shape_env.var_to_val + ) + assert var_range, var_val + min_val, max_val, opt_val = int(var_range.lower), int(var_range.upper), int(var_val) + # Torchdynamo 0/1 specialization outlier + min_val = 1 if min_val == 2 else min_val + min_max_opt = {} + min_max_opt["min"] = min_val + min_max_opt["max"] = max_val + min_max_opt["opt"] = opt_val + + return min_max_opt + + +def unwrap_tensor_shape( + tensor: Union[torch.Tensor, FakeTensor, torch.SymInt] +) -> Sequence[Any]: + """ + This is a helper function used to print/return the shape of the tensor. + For regular torch.tensor's, it returns the static shape. + For symbolic tensors, eg:(1, s0, 4), this function returns [1, [min, max], 4]. The min + and max correspond to the lower and upper values of s0 symbolic dimension. + """ + tensor_shape = [] + # for dimension in tensor.shape: + if isinstance(tensor, int): + tensor_shape.append(tensor) + elif isinstance(tensor, torch.SymInt): + min_max_opt = extract_var_range_info(tensor) + tensor_shape.append((min_max_opt["min"], min_max_opt["max"])) + elif isinstance(tensor, (torch.Tensor, FakeTensor)): + for dimension in tensor.shape: + tensor_shape.extend(unwrap_tensor_shape(dimension)) + + return tuple(tensor_shape) + + +def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -> Any: + """ + Returns the dtype of torch.tensor or FakeTensor. For symbolic integers, we return int64 + """ + if isinstance(tensor, (torch.Tensor, FakeTensor)): + return tensor.dtype + elif isinstance(tensor, torch.SymInt): + return torch.int64 + else: + raise ValueError(f"Found invalid tensor type {type(tensor)}") + + +def get_graph_io_attrs( + io_nodes: Sequence[torch.fx.Node], attr_type: str +) -> Sequence[Any]: + """ + Returns a list of attributes (shapes or dtypes) of the I/O nodes + """ + assert attr_type in ["shape", "dtype"] + attr_fn = unwrap_tensor_shape if attr_type == "shape" else unwrap_tensor_dtype + graph_io_attrs = [] + for node in io_nodes: + if "val" in node.meta: + metadata = node.meta["val"] + if isinstance(metadata, (tuple, list)): + for tensor in metadata: + graph_io_attrs.append(attr_fn(tensor)) + else: + graph_io_attrs.append(attr_fn(metadata)) + + return graph_io_attrs + + +def parse_graph_io(module: torch.fx.GraphModule, dryrun_tracker: Any) -> None: + """ + Parse the graph I/O shape/dtype info for the whole graph and store in the dryrun tracker + """ + # Parse inputs of the graph + input_nodes = [node for node in module.graph.nodes if node.op == "placeholder"] + input_shapes = get_graph_io_attrs(input_nodes, "shape") + input_dtypes = get_graph_io_attrs(input_nodes, "dtype") + dryrun_tracker.input_shapes = input_shapes + dryrun_tracker.input_dtypes = input_dtypes + + # Parse outputs of the graph + mark_output_nodes = [node for node in module.graph.nodes if node.op == "output"] + output_nodes = [] + for node in mark_output_nodes: + output_nodes.extend(node.all_input_nodes) + output_shapes = get_graph_io_attrs(output_nodes, "shape") + output_dtypes = get_graph_io_attrs(output_nodes, "dtype") + dryrun_tracker.output_shapes = output_shapes + dryrun_tracker.output_dtypes = output_dtypes + + def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device: """Cast a device-type to torch.device diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index df1e4ee934..f53bdf5d59 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -244,7 +244,7 @@ def generate_graph( try: ShapeProp(fx_module).propagate(*torch_inputs) except (RuntimeError, AssertionError): - logger.warning( + _LOGGER.warning( "Shape Propagation failed on Graph, skipping it", exc_info=False, ) diff --git a/tests/py/dynamo/conversion/test_full_aten.py b/tests/py/dynamo/conversion/test_full_aten.py index d09b5cc56c..29b48d1451 100644 --- a/tests/py/dynamo/conversion/test_full_aten.py +++ b/tests/py/dynamo/conversion/test_full_aten.py @@ -50,9 +50,7 @@ def forward(self, shape): ) ] self.run_test_with_dynamic_shape( - full(), - inputs, - use_example_tensors=False, + full(), inputs, use_example_tensors=False, check_dtype=False ) @parameterized.expand( diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 50fa9a2f50..67eaddcc6c 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -11,9 +11,6 @@ assertions = unittest.TestCase() -@unittest.skip( - "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -) @pytest.mark.unit def test_base_dynamic(ir): """ @@ -66,9 +63,6 @@ def forward(self, x): ) -@unittest.skip( - "Skipping this test for now due to constraint violation error: https://github.com/pytorch/TensorRT/issues/2794" -) @pytest.mark.unit def test_base_dynamic_fallback(ir): """ diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index c52fb6ba56..5d41df13d6 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -18,10 +18,14 @@ import torch_tensorrt as torchtrt from utils import ( BENCHMARK_MODELS, + export_llm, parse_backends, parse_inputs, parse_precisions, precision_to_dtype, + time_generate, + torch_device_from_trt, + torch_dtype_from_trt, ) WARMUP_ITER = 10 @@ -41,30 +45,120 @@ def wrapper_func(*args, **kwargs): return wrapper_func -# Runs inference using Torch backend -@run_with_try_except -def run_torch(model, input_tensors, params, precision, batch_size): - print("Running Torch for precision: ", precision, " batch_size : ", batch_size) - iters = params.get("iterations", 20) +def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): + """ + Records different timing stats and adds it to the result + """ + times = np.array(timings) + speeds = batch_size / times + time_mean = np.mean(times) + time_med = np.median(times) + time_99th = np.percentile(times, 99) + time_std = np.std(times, ddof=0) + speed_mean = np.mean(speeds) + speed_med = np.median(speeds) + stats = { + "Backend": backend, + "Precision": precision, + "Batch size": batch_size, + "Median(FPS)": speed_med, + "Mean(FPS)": speed_mean, + "Median-Latency(ms)": time_med * 1000, + "Mean-Latency(ms)": time_mean * 1000, + "Latency-StdDev(ms)": time_std * 1000, + "Compile Time(s)": compile_time_s, + } + results.append(stats) + + +def record_llm_perf( + model, + backend, + input_tensors, + precision, + output_seq_length, + batch_size, + iterations, + compile_time_s=None, +): + """ + Measure LLM generation time and record the stats + """ + # We only support single input (B x seq_len) for LLMs now + input_seq = input_tensors[0] + with torch.no_grad(): + # Warm up for 3 iterations + _ = time_generate(model, input_seq, output_seq_length, iterations=iterations) + + torch.cuda.synchronize() + + # Actual perf measurement + timings = time_generate( + model, input_seq, output_seq_length, iterations=iterations + ) + + recordStats( + "Torch-TensorRT " + backend, timings, precision, batch_size, compile_time_s + ) + + +def record_perf( + model, + backend, + input_tensors, + precision, + iterations, + batch_size, + compile_time_s=None, +): + """ + Run the model for certain number of iterations and record the perf. + Model is warmed up initially + """ # Warm up with torch.no_grad(): for _ in range(WARMUP_ITER): - features = model(*input_tensors) + model(*input_tensors) torch.cuda.synchronize() timings = [] with torch.no_grad(): - for i in range(iters): + for i in range(iterations): start_time = timeit.default_timer() - features = model(*input_tensors) + _ = model(*input_tensors) torch.cuda.synchronize() end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) + timings.append(end_time - start_time) + + recordStats( + "Torch-TensorRT " + backend, timings, precision, batch_size, compile_time_s + ) + + +# Runs inference using Torch backend +@run_with_try_except +def run_torch(model, input_tensors, params, precision, batch_size): + print("Running Torch for precision: ", precision, " batch_size : ", batch_size) + iters = params.get("iterations", 20) + model = model.to("cuda:0") + if params["is_text_llm"]: + output_seq_length = params["output_sequence_length"] + return record_llm_perf( + model, + "Torch", + input_tensors, + precision, + output_seq_length, + batch_size, + iters, + None, + ) - recordStats("Torch", timings, precision, batch_size) + record_perf( + model, "Torch", input_tensors, precision, iters, batch_size, compile_time_s=None + ) # Runs inference using Torch-TensorRT backend @@ -86,31 +180,55 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size): if precision == "int8": compile_settings.update({"calib": params.get("calibration_cache")}) - start_compile = time.time_ns() + start_compile = timeit.default_timer() model = torchtrt.compile(model, ir="ts", **compile_settings) - end_compile = time.time_ns() - compile_time_s = (end_compile - start_compile) / 1e9 + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile iters = params.get("iterations", 20) - # Warm up - with torch.no_grad(): - for _ in range(WARMUP_ITER): - features = model(*input_tensors) - torch.cuda.synchronize() + record_perf( + model, + "Torchscript", + input_tensors, + precision, + iters, + batch_size, + compile_time_s, + ) - timings = [] - with torch.no_grad(): - for i in range(iters): - start_time = timeit.default_timer() - features = model(*input_tensors) - torch.cuda.synchronize() - end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) - recordStats( - "Torch-TensorRT [Torchscript]", timings, precision, batch_size, compile_time_s +@run_with_try_except +def run_hf_dynamo(model, input_tensors, params, precision, batch_size): + """ + Compile the huggingface model using Torch-TensorRT dynamo frontend and record performance stats + """ + + osl = params["output_sequence_length"] + iters = params.get("iterations", 20) + # Move the model and inputs to cpu and trace it. + model = model.to("cpu") + inputs_cpu = [tensor.clone().cpu() for tensor in input_tensors] + exp_program = export_llm(model, inputs_cpu, min_seq_len=1, max_seq_len=osl) + start_compile = timeit.default_timer() + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=input_tensors, + enabled_precisions={precision_to_dtype(precision)}, + truncate_double=params.get("truncate", False), + ) + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile + record_llm_perf( + trt_model, + "Dynamo", + input_tensors, + precision, + osl, + batch_size, + iters, + compile_time_s, ) @@ -125,7 +243,10 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): " batch_size : ", batch_size, ) - start_compile = time.time_ns() + if params["is_text_llm"]: + return run_hf_dynamo(model, input_tensors, params, precision, batch_size) + + start_compile = timeit.default_timer() model = torchtrt.compile( model, inputs=input_tensors, @@ -135,28 +256,12 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): debug=False, truncate_long_and_double=params.get("truncate", False), ) - end_compile = time.time_ns() - compile_time_s = (end_compile - start_compile) / 1e9 + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile iters = params.get("iterations", 20) - # Warm up - with torch.no_grad(): - for _ in range(WARMUP_ITER): - features = model(*input_tensors) - - torch.cuda.synchronize() - - timings = [] - with torch.no_grad(): - for i in range(iters): - start_time = timeit.default_timer() - features = model(*input_tensors) - torch.cuda.synchronize() - end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) - recordStats( - "Torch-TensorRT [Dynamo]", timings, precision, batch_size, compile_time_s + record_perf( + model, "Dynamo", input_tensors, precision, iters, batch_size, compile_time_s ) @@ -165,6 +270,8 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size): """ Compile the given model using Torch-TensorRT torch.compile frontend and record performance stats """ + # Move the model to GPU + model = model.to("cuda:0") torch._dynamo.reset() print( @@ -176,41 +283,52 @@ def run_torch_compile(model, input_tensors, params, precision, batch_size): compile_spec = { "inputs": input_tensors, "enabled_precisions": {precision_to_dtype(precision)}, - "truncate_long_and_double": params.get("truncate", False), + "truncate": params.get("truncate", False), "min_block_size": params.get("min_block_size", 1), } - start_compile = time.time_ns() - model = torch.compile( - model, backend="tensorrt", dynamic=False, options=compile_spec - ) + start_compile = timeit.default_timer() + model = torch.compile(model, backend="tensorrt", dynamic=None, options=compile_spec) model(*input_tensors) - end_compile = time.time_ns() - compile_time_s = (end_compile - start_compile) / 1e9 + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile iters = params.get("iterations", 20) - # Warm up - with torch.no_grad(): - for _ in range(WARMUP_ITER): - features = model(*input_tensors) - torch.cuda.synchronize() + record_perf( + model, + "torch_compile", + input_tensors, + precision, + iters, + batch_size, + compile_time_s, + ) - timings = [] - with torch.no_grad(): - for i in range(iters): - start_time = timeit.default_timer() - features = model(*input_tensors) - torch.cuda.synchronize() - end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) - # Reset torch dynamo cache - torch._dynamo.reset() - recordStats( - "Torch-TensorRT [torch_compile]", - timings, +@run_with_try_except +def run_hf_inductor(model, input_tensors, params, precision, batch_size): + """ + Compile the huggingface model using torch inductor and record performance stats + """ + osl = params["output_sequence_length"] + # Mark dynamic shapes for input sequence + input_seq = input_tensors[0] + torch._dynamo.mark_dynamic(input_seq, 1, min=1, max=osl) + start_compile = timeit.default_timer() + # Compile the model + model = torch.compile(model, backend="inductor", dynamic=None, mode="max-autotune") + model(input_seq) + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile + iters = params.get("iterations", 20) + + record_llm_perf( + model, + "Inductor", + input_tensors, precision, + osl, batch_size, + iters, compile_time_s, ) @@ -221,72 +339,28 @@ def run_inductor(model, input_tensors, params, precision, batch_size): Compile the given model using torch inductor and record performance stats """ torch._dynamo.reset() - + model = model.to("cuda:0") print( "Running Torch [inductor] for precision: ", precision, " batch_size : ", batch_size, ) + if params["is_text_llm"]: + return run_hf_inductor(model, input_tensors, params, precision, batch_size) - start_compile = time.time_ns() - model = torch.compile(model, backend="inductor", dynamic=False, mode="max-autotune") + start_compile = timeit.default_timer() + model = torch.compile(model, backend="inductor", dynamic=None, mode="max-autotune") model(*input_tensors) - end_compile = time.time_ns() - compile_time_s = (end_compile - start_compile) / 1e9 + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile iters = params.get("iterations", 20) - # Warm up - with torch.no_grad(): - for _ in range(WARMUP_ITER): - features = model(*input_tensors) - - torch.cuda.synchronize() - - timings = [] - with torch.no_grad(): - for i in range(iters): - start_time = timeit.default_timer() - features = model(*input_tensors) - torch.cuda.synchronize() - end_time = timeit.default_timer() - meas_time = end_time - start_time - timings.append(meas_time) - # Reset torch dynamo cache - torch._dynamo.reset() - recordStats( - "Torch [inductor]", - timings, - precision, - batch_size, - compile_time_s, + record_perf( + model, "inductor", input_tensors, precision, iters, batch_size, compile_time_s ) -def torch_dtype_from_trt(dtype): - if dtype == trt.int8: - return torch.int8 - elif dtype == trt.bool: - return torch.bool - elif dtype == trt.int32: - return torch.int32 - elif dtype == trt.float16: - return torch.float16 - elif dtype == trt.float32: - return torch.float32 - else: - raise TypeError("%s is not supported by torch" % dtype) - - -def torch_device_from_trt(device): - if device == trt.TensorLocation.DEVICE: - return torch.device("cuda") - elif device == trt.TensorLocation.HOST: - return torch.device("cpu") - else: - return TypeError("%s is not supported by torch" % device) - - @run_with_try_except def run_tensorrt( model, @@ -310,10 +384,10 @@ def run_tensorrt( config = builder.create_builder_config() if precision == "fp16": config.set_flag(trt.BuilderFlag.FP16) - start_compile = time.time_ns() + start_compile = timeit.default_timer() serialized_engine = builder.build_serialized_network(network, config) - end_compile = time.time_ns() - compile_time_s = (end_compile - start_compile) / 1e9 + end_compile = timeit.default_timer() + compile_time_s = end_compile - start_compile # Deserialize the TensorRT engine with trt.Runtime(logger) as runtime: engine = runtime.deserialize_cuda_engine(serialized_engine) @@ -443,32 +517,6 @@ def run( run_inductor(model_torch, input_tensors, params, precision, batch_size) -# Generate report -def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): - times = np.array(timings) - steps = len(times) - speeds = batch_size / times - time_mean = np.mean(times) - time_med = np.median(times) - time_99th = np.percentile(times, 99) - time_std = np.std(times, ddof=0) - speed_mean = np.mean(speeds) - speed_med = np.median(speeds) - - stats = { - "Backend": backend, - "Precision": precision, - "Batch size": batch_size, - "Median(FPS)": speed_med, - "Mean(FPS)": speed_mean, - "Median-Latency(ms)": time_med * 1000, - "Mean-Latency(ms)": time_mean * 1000, - "Latency-StdDev(ms)": time_std * 1000, - "Compile Time(s)": compile_time_s, - } - results.append(stats) - - if __name__ == "__main__": arg_parser = argparse.ArgumentParser( description="Run inference on a model with random input values" @@ -493,9 +541,24 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): type=str, help="List of input shapes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT", ) + arg_parser.add_argument( + "--is_text_llm", + action="store_true", + help="Boolean flag to determine if model is a huggingface model", + ) + arg_parser.add_argument( + "-osl", + "--output_sequence_length", + type=int, + help="Length of output sequence to HF model", + default=128, + ) arg_parser.add_argument( "--batch_size", type=int, default=1, help="Batch size to build and run" ) + arg_parser.add_argument( + "--iterations", type=int, default=20, help="Iterations to measure the perf" + ) arg_parser.add_argument( "--precision", default="fp32", @@ -542,16 +605,14 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None): # Load PyTorch Model, if provided if len(model_name_torch) > 0 and os.path.exists(model_name_torch): print("Loading user provided torch model: ", model_name_torch) - model_torch = torch.load(model_name_torch).eval().cuda() + model_torch = torch.load(model_name_torch).eval() elif model_name_torch in BENCHMARK_MODELS: - model_torch = BENCHMARK_MODELS[model_name_torch]["model"].eval().cuda() + model_torch = BENCHMARK_MODELS[model_name_torch]["model"].eval() # If neither model type was provided if (model is None) and (model_torch is None): raise ValueError( - "No valid models specified. Please provide a torchscript model file or model name " - + "(among the following options vgg16|resnet50|efficientnet_b0|vit) " - + "or provide a torch model file" + "No valid models specified. Please provide a torchscript model file or model name (defined in hub.py) or model_hf name in huggingface models " ) backends = parse_backends(params["backends"]) diff --git a/tools/perf/requirements.txt b/tools/perf/requirements.txt index 881241e24d..fcfb0b3d53 100644 --- a/tools/perf/requirements.txt +++ b/tools/perf/requirements.txt @@ -2,8 +2,8 @@ numpy argparse pyyaml onnx -transformers==4.38.0 +pandas +transformers diffusers==0.21.4 -pandas==2.0.1 timm==0.9.8 diff --git a/tools/perf/run_hf_model.sh b/tools/perf/run_hf_model.sh new file mode 100644 index 0000000000..ced2a46dfb --- /dev/null +++ b/tools/perf/run_hf_model.sh @@ -0,0 +1,26 @@ +#!/bin/bash +batch_size=$1 +backend=$2 +model_name=$3 +isl=$4 +osl=$5 +precision=$6 +iterations=$7 +modified_model_name=$(echo "$model_name" | sed 's/\//-/g') +echo "Benchmarking ${model_name} model for bs ${batch_size} with ISL ${isl}, OSL ${osl} and backend ${backend} for ${iterations} iterations" +python perf_run.py --model_torch ${model_name} \ + --is_text_llm \ + --precision ${precision} \ + --inputs "(${batch_size}, ${isl})@int64" \ + --output_sequence_length ${osl} \ + --batch_size ${batch_size} \ + --truncate \ + --backends ${backend} \ + --iterations ${iterations} \ + --report "${modified_model_name}_perf_bs${batch_size}_backend_${backend}_isl${isl}_osl${osl}.csv" + +# Move the report file to the mounted volume in the docker +mv "${modified_model_name}_perf_bs${batch_size}_backend_${backend}_isl${isl}_osl${osl}.csv" /work + +# Clear HF cache +rm -rf ~/.cache/huggingface/hub/ diff --git a/tools/perf/stage1.sh b/tools/perf/stage1.sh new file mode 100644 index 0000000000..412396ee0b --- /dev/null +++ b/tools/perf/stage1.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# global parameters +precision="fp16" +iterations=1 +backends=("dynamo" "inductor") +batch_sizes=(1 16) +hf_token="" +image_name="" + +# Stage 1 : GPT2 experiment +models=("gpt2") +isl=(128) +osl=(256) +for model in ${models[@]} + do + for bs in ${batch_sizes[@]} + do + for backend in ${backends[@]} + do + for i in ${!isl[@]}; + do + docker run --rm -it --gpus 0 --shm-size=10.24g --ulimit stack=67108864 -v "$PWD:/work" --ipc=host ${image_name} /bin/bash -c "cd /opt/torch_tensorrt/tools/perf; HF_TOKEN="${hf_token}" bash run_hf_model.sh "${bs}" "$backend" "$model" "${isl[i]}" "${osl[i]}" "${precision}" "${iterations}"; exit" + done + done + done + done +# Clear HF cache +rm -rf ~/.cache/huggingface/hub/ + +# Stage 2 : non-GPT2 experiments +isl=(128 128) +osl=(256 2176) +models=("meta-llama/Meta-Llama-3.1-8B-Instruct" "meta-llama/Llama-2-7b-chat-hf" "mistralai/Mistral-7B-Instruct-v0.3") +backends=("dynamo" "inductor") +for model in ${models[@]} + do + for bs in ${batch_sizes[@]} + do + for backend in ${backends[@]} + do + for i in ${!isl[@]}; + do + docker run --rm -it --gpus 0 --shm-size=10.24g --ulimit stack=67108864 -v "$PWD:/work" --ipc=host ${image_name} /bin/bash -c "cd /opt/torch_tensorrt/tools/perf; HF_TOKEN="${hf_token}" bash run_hf_model.sh "${bs}" "$backend" "$model" "${isl[i]}" "${osl[i]}" "${precision}" "${iterations}"; exit" + done + done + done + done +# Clear HF cache +rm -rf ~/.cache/huggingface/hub/ \ No newline at end of file diff --git a/tools/perf/stage2.sh b/tools/perf/stage2.sh new file mode 100644 index 0000000000..9411cd9d09 --- /dev/null +++ b/tools/perf/stage2.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# global parameters +precision="fp16" +iterations=1 +backends=("dynamo" "inductor") +batch_sizes=(1 16) +hf_token="" +image_name="" + +# Stage 2 : slower non-GPT2 experiments +isl=(2048) +osl=(2176) +models=("meta-llama/Meta-Llama-3.1-8B-Instruct" "meta-llama/Llama-2-7b-chat-hf" "mistralai/Mistral-7B-Instruct-v0.3") +backends=("dynamo" "inductor") +for model in ${models[@]} + do + for bs in ${batch_sizes[@]} + do + for backend in ${backends[@]} + do + for i in ${!isl[@]}; + do + docker run --rm -it --gpus 0 --shm-size=10.24g --ulimit stack=67108864 -v "$PWD:/work" --ipc=host ${image_name} /bin/bash -c "cd /opt/torch_tensorrt/tools/perf; HF_TOKEN="${hf_token}" bash run_hf_model.sh "${bs}" "$backend" "$model" "${isl[i]}" "${osl[i]}" "${precision}" "${iterations}"; exit" + done + done + done + done +# Clear HF cache +rm -rf ~/.cache/huggingface/hub/ \ No newline at end of file diff --git a/tools/perf/utils.py b/tools/perf/utils.py index a6f8ba236d..5dae807892 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -1,7 +1,13 @@ +import copy +import timeit + import custom_models as cm +import numpy as np +import tensorrt as trt import timm import torch import torchvision.models as models +from transformers import AutoModel, AutoModelForCausalLM BENCHMARK_MODEL_NAMES = { "vgg16", @@ -12,9 +18,28 @@ "vit_large", "bert_base_uncased", "sd_unet", + "meta-llama/Llama-2-7b-chat-hf", + "gpt2", + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "apple/DCLM-7B", + "mistralai/Mistral-7B-Instruct-v0.3", + "microsoft/Phi-3-mini-4k-instruct", } +def load_hf_model(model_name_hf): + print("Loading user-specified HF model: ", model_name_hf) + model_hf = AutoModelForCausalLM.from_pretrained( + model_name_hf, + trust_remote_code=True, + use_cache=False, + attn_implementation="eager", + ).eval() + + return {"model": model_hf} + + class ModelStorage: def __contains__(self, name: str): return name in BENCHMARK_MODEL_NAMES @@ -63,6 +88,26 @@ def __getitem__(self, name: str): "model": cm.StableDiffusionUnet(), "path": "pytorch", } + elif name in [ + "gpt2", + "meta-llama/Meta-Llama-3-8B", + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "mistralai/Mistral-7B-Instruct-v0.3", + "microsoft/Phi-3-mini-4k-instruct", + ]: + hf_artifact = load_hf_model(name) + return { + "model": hf_artifact["model"], + "path": "pytorch", + } + elif name == "apple/DCLM-7B": + # Load model directly + hf_artifact = AutoModel.from_pretrained("apple/DCLM-7B") + return { + "model": hf_artifact["model"], + "path": "pytorch", + } else: raise AssertionError(f"Invalid model name {name}") @@ -81,6 +126,8 @@ def precision_to_dtype(pr): return torch.half elif pr == "int32": return torch.int32 + elif pr == "int64": + return torch.int64 elif pr == "bool": return torch.bool else: @@ -102,7 +149,7 @@ def parse_inputs(user_inputs, dtype): input_shape.append(int(input_dim)) if input_shape != [1]: - if dtype == torch.int32: + if dtype == torch.int32 or dtype == torch.int64: torchtrt_inputs.append( torch.randint(0, 5, input_shape, dtype=dtype).cuda() ) @@ -120,3 +167,91 @@ def parse_backends(backends): def parse_precisions(precisions): return precisions.split(",") + + +def torch_dtype_from_trt(dtype): + if dtype == trt.int8: + return torch.int8 + elif dtype == trt.bool: + return torch.bool + elif dtype == trt.int32: + return torch.int32 + elif dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + else: + raise TypeError("%s is not supported by torch" % dtype) + + +def torch_device_from_trt(device): + if device == trt.TensorLocation.DEVICE: + return torch.device("cuda") + elif device == trt.TensorLocation.HOST: + return torch.device("cpu") + else: + return TypeError("%s is not supported by torch" % device) + + +def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + assert isinstance(inputs, list) + + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, tuple(inputs), dynamic_shapes=({1: seq_len},), strict=False + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + (inputs,), + dynamic_shapes=({1: seq_len},), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + +def generate(model, input_seq, output_seq_length): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + + while input_seq.shape[1] <= output_seq_length: + outputs = model(input_seq) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + + return input_seq + + +def time_generate(model, inputs, output_seq_length, iterations=10): + """ + Measure the time for generating a sentence over certain number of iterations + """ + timings = [] + for _ in range(iterations): + start_time = timeit.default_timer() + inputs_copy = copy.copy(inputs) + _ = generate(model, inputs_copy, output_seq_length) + torch.cuda.synchronize() + end_time = timeit.default_timer() + timings.append(end_time - start_time) + + return timings