Skip to content

Commit 71e33cb

Browse files
committed
chore: updates
1 parent a90191d commit 71e33cb

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

examples/dynamo/torch_export_gpt2.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,16 @@
2525
# CPU is used here so that GPU memory is reserved for TRT compilation.
2626
with torch.no_grad():
2727
tokenizer = AutoTokenizer.from_pretrained("gpt2")
28-
model = AutoModelForCausalLM.from_pretrained(
29-
"gpt2",
30-
pad_token_id=tokenizer.eos_token_id,
31-
use_cache=False,
32-
attn_implementation="eager",
33-
).eval()
28+
model = (
29+
AutoModelForCausalLM.from_pretrained(
30+
"gpt2",
31+
pad_token_id=tokenizer.eos_token_id,
32+
use_cache=False,
33+
attn_implementation="eager",
34+
)
35+
.eval()
36+
.half()
37+
)
3438

3539
# %%
3640
# Tokenize a sample input prompt and get pytorch model outputs
@@ -56,6 +60,8 @@
5660
truncate_double=True,
5761
device=DEVICE,
5862
disable_tf32=True,
63+
use_strong_types=True,
64+
use_fp32_acc=True,
5965
)
6066

6167
# Auto-regressive generation loop for greedy decoding using TensorRT model

py/torch_tensorrt/dynamo/_compiler.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def compile(
8888
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
8989
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
9090
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
91+
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
92+
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
9193
**kwargs: Any,
9294
) -> torch.fx.GraphModule:
9395
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -158,6 +160,8 @@ def compile(
158160
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
159161
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
160162
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
163+
use_strong_types (bool): Enable strong typing in TensorRT compilation
164+
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
161165
**kwargs: Any,
162166
Returns:
163167
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -197,6 +201,20 @@ def compile(
197201
"\nThis feature is unimplemented in Torch-TRT Dynamo currently."
198202
)
199203

204+
if use_strong_types:
205+
if len(enabled_precisions) != 1 or not any(
206+
x in enabled_precisions for x in {torch.float32, dtype.f32}
207+
):
208+
raise AssertionError(
209+
f"When use_strong_types is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
210+
)
211+
212+
if use_fp32_acc:
213+
logger.debug(
214+
"FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \
215+
This flag inserts casts around matmul layers and ensures TensorRT which executes the matmul layers in FP16 with FP32 accumulation."
216+
)
217+
200218
# Aliasing inputs to arg_inputs for better understanding
201219
if not arg_inputs and not inputs:
202220
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
@@ -232,7 +250,7 @@ def compile(
232250
logger.debug("Input graph: " + str(gm.graph))
233251

234252
# Apply lowering on the graph module
235-
gm = post_lowering(gm)
253+
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
236254
logger.debug("Lowered Input graph: " + str(gm.graph))
237255

238256
engine_cache = None
@@ -281,6 +299,8 @@ def compile(
281299
"lazy_engine_init": lazy_engine_init,
282300
"cache_built_engines": cache_built_engines,
283301
"reuse_cached_engines": reuse_cached_engines,
302+
"use_strong_types": use_strong_types,
303+
"use_fp32_acc": use_fp32_acc,
284304
}
285305

286306
settings = CompilationSettings(**compilation_options)
@@ -520,6 +540,8 @@ def convert_exported_program_to_serialized_trt_engine(
520540
calibrator: object = None,
521541
allow_shape_tensors: bool = False,
522542
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
543+
use_strong_types: bool = _defaults.USE_STRONG_TYPES,
544+
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
523545
**kwargs: Any,
524546
) -> bytes:
525547
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -578,6 +600,8 @@ def convert_exported_program_to_serialized_trt_engine(
578600
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
579601
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
580602
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
603+
use_strong_types (bool): Enable strong typing in TensorRT compilation
604+
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
581605
Returns:
582606
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
583607
"""
@@ -651,6 +675,8 @@ def convert_exported_program_to_serialized_trt_engine(
651675
"dla_local_dram_size": dla_local_dram_size,
652676
"dla_global_dram_size": dla_global_dram_size,
653677
"timing_cache_path": timing_cache_path,
678+
"use_strong_types": use_strong_types,
679+
"use_fp32_acc": use_fp32_acc,
654680
}
655681

656682
exported_program = pre_export_lowering(exported_program)

py/torch_tensorrt/dynamo/_defaults.py

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
4141
ENGINE_CACHE_SIZE = 1073741824
4242
CUSTOM_ENGINE_CACHE = None
43+
USE_STRONG_TYPES = False
44+
USE_FP32_ACC = False
4345

4446

4547
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
TIMING_CACHE_PATH,
3131
TRUNCATE_DOUBLE,
3232
USE_FAST_PARTITIONER,
33+
USE_FP32_ACC,
3334
USE_PYTHON_RUNTIME,
35+
USE_STRONG_TYPES,
3436
VERSION_COMPATIBLE,
3537
WORKSPACE_SIZE,
3638
default_device,
@@ -78,6 +80,8 @@ class CompilationSettings:
7880
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
7981
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
8082
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
83+
use_strong_types (bool): Enable strong typing in TensorRT compilation
84+
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
8185
"""
8286

8387
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -112,6 +116,8 @@ class CompilationSettings:
112116
lazy_engine_init: bool = LAZY_ENGINE_INIT
113117
cache_built_engines: bool = CACHE_BUILT_ENGINES
114118
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
119+
use_strong_types: bool = USE_STRONG_TYPES
120+
use_fp32_acc: bool = USE_FP32_ACC
115121

116122

117123
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __init__(
124124
dict()
125125
)
126126

127-
self.compilation_settings = compilation_settings
128127
# Data types for TRT Module output Tensors
129128
self.output_dtypes = (
130129
[dtype._from(o) for o in output_dtypes] if output_dtypes else None

0 commit comments

Comments
 (0)