We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import os from tempfile import gettempdir import torch import torch_tensorrt class MyModule(torch.nn.Module): def forward(self, x): return torch.relu(x) with torch.inference_mode(): device = torch.device("cuda", 0) model = MyModule().eval().to(device) inputs1 = [torch.randn(1, 3, 224, 224, device=device)] trt_model1 = torch_tensorrt.compile( model, ir="dynamo", inputs=inputs1, enabled_precisions={torch.float}, device=device, make_refitable=True, debug=True, min_block_size=1, engine_cache_dir=os.path.join(gettempdir(), "torchtrt_issue3148"), ) trt_model1(*inputs1) print("\n========================================\n") inputs2 = [torch.randn(2, 3, 224, 224, device=device)] trt_model2 = torch_tensorrt.compile( model, ir="dynamo", inputs=inputs2, enabled_precisions={torch.float}, device=device, make_refitable=True, debug=True, min_block_size=1, engine_cache_dir=os.path.join(gettempdir(), "torchtrt_issue3148"), ) trt_model2(*inputs2)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) INFO:torch_tensorrt.dynamo._engine_cache:Disk engine cache initialized (cache directory:/tmp/torchtrt_issue3148, max size: 1073741824) INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True) DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten.relu.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph. INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten.relu.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0 Input shapes: [(1, 3, 224, 224)] graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return relu WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ()) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /relu (kind: aten.relu.default, args: ('x <Node>',)) INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /relu [aten.relu.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32) | Outputs: (relu: (1, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('relu <Node>',)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 224, 224), dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (relu: (1, 3, 224, 224)@torch.float32) | Outputs: (output: )) INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004144 INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Building weight name mapping... WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine. INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:03.345116 INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 12932 bytes of Memory DEBUG:torch_tensorrt.dynamo._engine_cache:The engine added to cache, saved to /tmp/torchtrt_issue3148/qcp2nbn7adw2zbhxzqql4er37brkw33awbzf4jqx33pbhuvino3/blob.bin DEBUG:torch_tensorrt.dynamo._DryRunTracker: ++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++ The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True) Graph Structure: Inputs: List[Tensor: (1, 3, 224, 224)@float32] ... TRT Engine #1 - Submodule name: _run_on_acc_0 Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32] Number of Operators in Engine: 1 Engine Outputs: List[Tensor: (1, 3, 224, 224)@float32] ... Outputs: List[Tensor: (1, 3, 224, 224)@float32] ------------------------- Aggregate Stats ------------------------- Average Number of Operators per TRT Engine: 1.0 Most Operators in a TRT Engine: 1 ********** Recommendations ********** - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s) - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s) ======================================== DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return (relu,) INFO:torch_tensorrt.dynamo._engine_cache:Disk engine cache initialized (cache directory:/tmp/torchtrt_issue3148, max size: 1073741824) INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True) DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten.relu.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph. INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten.relu.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0 Input shapes: [(2, 3, 224, 224)] graph(): %x : [num_users=1] = placeholder[target=x] %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%x,), kwargs = {}) return relu WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU DEBUG:torch_tensorrt.dynamo._engine_cache:Engine found in cache, loaded from /tmp/torchtrt_issue3148/qcp2nbn7adw2zbhxzqql4er37brkw33awbzf4jqx33pbhuvino3/blob.bin INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Found the cached engine that corresponds to this graph. It is directly loaded. WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ()) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[2, 3, 224, 224], dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (2, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /relu (kind: aten.relu.default, args: ('x <Node>',)) INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /relu [aten.relu.default] (Inputs: (x: (2, 3, 224, 224)@torch.float32) | Outputs: (relu: (2, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('relu <Node>',)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(2, 3, 224, 224), dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (relu: (2, 3, 224, 224)@torch.float32) | Outputs: (output: )) INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001975 DEBUG:torch_tensorrt.dynamo._DryRunTracker: ++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++ The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True) Graph Structure: Inputs: List[Tensor: (2, 3, 224, 224)@float32] ... TRT Engine #1 - Submodule name: _run_on_acc_0 Engine Inputs: List[Tensor: (2, 3, 224, 224)@float32] Number of Operators in Engine: 1 Engine Outputs: List[Tensor: (2, 3, 224, 224)@float32] ... Outputs: List[Tensor: (2, 3, 224, 224)@float32] ------------------------- Aggregate Stats ------------------------- Average Number of Operators per TRT Engine: 1.0 Most Operators in a TRT Engine: 1 ********** Recommendations ********** - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s) - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s) ERROR: [Torch-TensorRT] - IExecutionContext::setInputShape: Error Code 3: API Usage Error (Parameter check failed, condition: engineDims.d[i] == dims.d[i]. Static dimension mismatch while setting input shape for x. Set dimensions are [2,3,224,224]. Expected dimensions are [1,3,224,224].) Traceback (most recent call last): File "/home/holywu/test.py", line 45, in <module> trt_model2(*inputs2) File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 784, in call_wrapped return self._wrapped_call(self, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 361, in __call__ raise e File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 348, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<eval_with_key>.85", line 6, in forward File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_features.py", line 56, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 279, in forward outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/holywu/.local/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__ return self._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: [Error thrown at core/runtime/execute_engine.cpp:231] Expected compiled_engine->exec_ctx->setInputShape(name.c_str(), dims) to be true but got false Error while setting the input shape
conda
pip
libtorch
The text was updated successfully, but these errors were encountered:
Can you try specifying a dim as dynamic? like https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/engine_caching_example.py#L132-L136
Sorry, something went wrong.
No issue if the model was compiled with dynamic input spec in the first place.
zewenli98
Successfully merging a pull request may close this issue.
To Reproduce
Environment
conda
,pip
,libtorch
, source):The text was updated successfully, but these errors were encountered: