Skip to content

Commit fc525e6

Browse files
committed
fix refit issue for torch.compile
1 parent 42d18ac commit fc525e6

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
get_node_name,
4141
get_trt_tensor,
4242
)
43-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
43+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
4444
from torch_tensorrt.fx.observer import Observer
4545
from torch_tensorrt.logging import TRT_LOGGER
4646

@@ -434,9 +434,8 @@ def _save_weight_mapping(self) -> None:
434434
"""
435435
_LOGGER.info("Building weight name mapping...")
436436
# Stage 1: Name mapping
437-
sd = self.module.state_dict()
438437
torch_device = to_torch_device(self.compilation_settings.device)
439-
gm_is_on_cuda = list(sd.values())[0].device.type == "cuda"
438+
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
440439
if not gm_is_on_cuda:
441440
# If the model original position is on CPU, move it GPU
442441
sd = {

tests/py/dynamo/models/test_engine_cache.py

-6
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,6 @@ def test_dynamo_compile_with_custom_engine_cache(self):
184184
msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms",
185185
)
186186

187-
@pytest.mark.skip(
188-
reason="The test needs a fix for refit, which is reported in https://github.com/pytorch/TensorRT/issues/3126"
189-
)
190187
def test_torch_compile_with_default_disk_engine_cache(self):
191188
# Custom Engine Cache
192189
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -251,9 +248,6 @@ def test_torch_compile_with_default_disk_engine_cache(self):
251248
msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms",
252249
)
253250

254-
@pytest.mark.skip(
255-
reason="The test needs a fix for refit, which is reported in https://github.com/pytorch/TensorRT/issues/3126"
256-
)
257251
def test_torch_compile_with_custom_engine_cache(self):
258252
# Custom Engine Cache
259253
model = models.resnet18(pretrained=True).eval().to("cuda")

0 commit comments

Comments
 (0)