Skip to content

Commit 91b8f6c

Browse files
authored
_refit: Properly compare device type (#3149)
1 parent 8154408 commit 91b8f6c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

py/torch_tensorrt/dynamo/_refit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _refit_single_trt_engine_with_gm(
156156
# Get the refitting mapping
157157
trt_wt_location = (
158158
trt.TensorLocation.DEVICE
159-
if torch_device == "cuda"
159+
if torch_device.type == "cuda"
160160
else trt.TensorLocation.HOST
161161
)
162162
mapping = construct_refit_mapping_from_weight_name_map(

0 commit comments

Comments
 (0)