Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] torch_tensorrt.load a model saved with dynamic shape is throwing error #3174

Open
lanluo-nvidia opened this issue Sep 23, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@lanluo-nvidia
Copy link
Collaborator

Bug Description

model with dynamic shape saved successfully, when load it throw the error

To Reproduce

Steps to reproduce the behavior:

test code to reproduce:

import torch

from torch.export import Dim
import torch.nn as nn
import torch_tensorrt as torchtrt
import os
import tempfile

class bitwise_and(nn.Module):
    def forward(self, lhs_val, rhs_val):
        return torch.ops.aten.bitwise_and.Tensor(lhs_val, rhs_val)

dyn_dim = Dim("dyn_dim", min=3, max=6)
lhs = torch.randint(0, 2, (2, 4, 2), dtype=bool, device="cuda")
rhs = torch.randint(0, 2, (4, 2), dtype=bool, device="cuda")
inputs = (lhs, rhs)
torchtrt_inputs = [torchtrt.Input(shape=lhs.shape, dtype=torch.bool), 
                   torchtrt.Input(shape=rhs.shape, dtype=torch.bool)] 
mod = bitwise_and()


fx_mod=torch.export.export(mod, inputs, dynamic_shapes={"lhs_val": {1: dyn_dim}, "rhs_val": {0: dyn_dim}})
print(f"lan added fx_mod={fx_mod}")
trt_model = torchtrt.dynamo.compile(fx_mod, inputs=inputs, enable_precisions={torch.bool}, min_block_size=1)
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")

lhs1 = torch.randint(0, 2, (2, 5, 2), dtype=bool, device="cuda")
rhs1 = torch.randint(0, 2, (5, 2), dtype=bool, device="cuda")
torchtrt.save(trt_model, trt_ep_path, inputs=[lhs1, rhs1])
print(f"lan added saved model to {trt_ep_path}")

loaded_trt_module = torch.export.load(trt_ep_path)
print(f"lan added load model from {trt_ep_path}") 

output = loaded_trt_module(lhs1, rhs1)
print(f"lan added got {output=}")

Expected behavior

torch.export.load should be able to load the model

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

the detailed error thrown are as below:

E0923 09:53:27.553340 1300597 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(s0 >= 0, True), **{'fx_node': False})
Traceback (most recent call last):
  File "/home/lanl/git/script/python/export_dynamic_shape_save_load_torchtrt_example.py", line 35, in <module>
    loaded_trt_module = torch.export.load(trt_ep_path)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/export/__init__.py", line 569, in load
    ep = deserialize(artifact, expected_opset_version)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2436, in deserialize
    ExportedProgramDeserializer(expected_opset_version)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 2315, in deserialize
    GraphModuleDeserializer()
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1906, in deserialize
    self.deserialize_graph(serialized_graph_module.graph)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1612, in deserialize_graph
    meta_val = self.deserialize_tensor_meta(tensor_value)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_export/serde/serialize.py", line 1579, in deserialize_tensor_meta
    torch.empty_strided(
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 2009, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_subclasses/fake_impls.py", line 176, in constructors
    r = func(*args, **new_kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 479, in expect_size
    r = b.expect_true(file, line)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 465, in expect_true
    return self.guard_bool(file, line)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/sym_node.py", line 449, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/recording.py", line 262, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5207, in evaluate_expr
    return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 5283, in _evaluate_expr
    static_expr = self._maybe_evaluate_static(expr,
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1604, in wrapper
    return fn_cache(self, *args, **kwargs)
  File "/home/lanl/miniconda3/envs/torch_tensorrt_py39/lib/python3.9/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4552, in _maybe_evaluate_static
    vr = var_ranges[k]
KeyError: s0
@lanluo-nvidia lanluo-nvidia added the bug Something isn't working label Sep 23, 2024
@lanluo-nvidia lanluo-nvidia changed the title 🐛 [Bug] torch_tensorrt.save and load with dynamic shape is not working 🐛 [Bug] torch_tensorrt.load a model saved with dynamic shape is throwing error Sep 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant