Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] make_refitable + reuse_cached_engines causing IndexError #3136

Closed
HolyWu opened this issue Aug 30, 2024 · 1 comment
Closed

🐛 [Bug] make_refitable + reuse_cached_engines causing IndexError #3136

HolyWu opened this issue Aug 30, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Aug 30, 2024

Bug Description

First execution

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 146, GPU 999 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2088, GPU +384, now: CPU 2389, GPU 1383 (MiB)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node relu/relu [aten.relu.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32) | Outputs: (relu: (1, 3, 224, 224)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (relu: (1, 3, 224, 224)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.004099
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Building weight name mapping...
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 256
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 1.16053 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 0 MiB, GPU 13 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 4103 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:01.165938
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 13844 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 26 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 16 timing cache entries
INFO:torch_tensorrt.dynamo._engine_caching:The blob was saved to /tmp/torch_tensorrt_engine_cache/qcp2nbn7adw2zbhxzqql4er37brkw33awbzf4jqx33pbhuvino3/blob.bin
ERROR: [Torch-TensorRT] - Platform constructor: linux_x86_64

Second execution

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refitable=True, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=True, reuse_cached_engines=True)

INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
WARNING:torch_tensorrt.dynamo.utils:Could not detect the device on which the model exists. Assuming the model is on CPU
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 146, GPU 999 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2088, GPU +384, now: CPU 2389, GPU 1383 (MiB)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Found the cached engine that corresponds to this graph. It is directly loaded.
INFO:torch_tensorrt [TensorRT Conversion Context]:Loaded engine size: 0 MiB
Traceback (most recent call last):
  File "/home/holywu/test.py", line 16, in <module>
    trt_model = torch_tensorrt.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_compile.py", line 269, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 288, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 462, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 561, in run
    _refit_single_trt_engine_with_gm(
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_refit.py", line 149, in _refit_single_trt_engine_with_gm
    torch_device = list(new_gm.state_dict().values())[0].device.type
                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^
IndexError: list index out of range

To Reproduce

import torch
import torch_tensorrt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x)


model = MyModule().eval().cuda()
inputs = (torch.randn((1, 3, 224, 224), device="cuda"),)
trt_model = torch_tensorrt.compile(
    model,
    ir="dynamo",
    inputs=inputs,
    enabled_precisions={torch.float32},
    min_block_size=1,
    make_refitable=True,
)

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240830+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240830+cu124
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:
@HolyWu HolyWu added the bug Something isn't working label Aug 30, 2024
@HolyWu
Copy link
Contributor Author

HolyWu commented Sep 5, 2024

Seems fixed in 8759736

@HolyWu HolyWu closed this as completed Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants