Skip to content

Commit

Permalink
chore: updates
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Sep 26, 2024
1 parent a90191d commit 71e33cb
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 8 deletions.
18 changes: 12 additions & 6 deletions examples/dynamo/torch_export_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
).eval()
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.half()
)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
Expand All @@ -56,6 +60,8 @@
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_strong_types=True,
use_fp32_acc=True,
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
Expand Down
28 changes: 27 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def compile(
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -158,6 +160,8 @@ def compile(
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -197,6 +201,20 @@ def compile(
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
)

if use_strong_types:
if len(enabled_precisions) != 1 or not any(
x in enabled_precisions for x in {torch.float32, dtype.f32}
):
raise AssertionError(
f"When use_strong_types is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
)

if use_fp32_acc:
logger.debug(
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
This flag inserts casts around matmul layers and ensures TensorRT which executes the matmul layers in FP16 with FP32 accumulation."
)

# Aliasing inputs to arg_inputs for better understanding
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down Expand Up @@ -232,7 +250,7 @@ def compile(
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm)
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
Expand Down Expand Up @@ -281,6 +299,8 @@ def compile(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"use_strong_types": use_strong_types,
"use_fp32_acc": use_fp32_acc,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -520,6 +540,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator: object = None,
allow_shape_tensors: bool = False,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -578,6 +600,8 @@ def convert_exported_program_to_serialized_trt_engine(
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -651,6 +675,8 @@ def convert_exported_program_to_serialized_trt_engine(
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"timing_cache_path": timing_cache_path,
"use_strong_types": use_strong_types,
"use_fp32_acc": use_fp32_acc,
}

exported_program = pre_export_lowering(exported_program)
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
CUSTOM_ENGINE_CACHE = None
USE_STRONG_TYPES = False
USE_FP32_ACC = False


def default_device() -> Device:
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
USE_PYTHON_RUNTIME,
USE_STRONG_TYPES,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
default_device,
Expand Down Expand Up @@ -78,6 +80,8 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_types (bool): Enable strong typing in TensorRT compilation
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -112,6 +116,8 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
use_strong_types: bool = USE_STRONG_TYPES
use_fp32_acc: bool = USE_FP32_ACC


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(
dict()
)

self.compilation_settings = compilation_settings
# Data types for TRT Module output Tensors
self.output_dtypes = (
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
Expand Down

0 comments on commit 71e33cb

Please sign in to comment.