Skip to content

Commit

Permalink
changing the device setting in conversion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 9, 2024
1 parent 485adf9 commit 9b7f846
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
TRTInterpreter,
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device

import tensorrt as trt

Expand All @@ -41,7 +42,7 @@ def infer_module_output_dtypes(
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
device = to_torch_device(default_device())
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def reduce_operation_with_scatter(
print("Invalid Operation for Reduce op!!")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(default_device())
device = to_torch_device(scatter_tensor.device)
operation_lhs = operation_lhs.to(device)
operation_rhs = operation_rhs.to(device)
return self.func(operation_lhs, operation_rhs)
Expand Down

0 comments on commit 9b7f846

Please sign in to comment.