Skip to content

Commit

Permalink
changing the device setting in conversion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 10, 2024
1 parent 485adf9 commit f124297
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def reduce_operation_with_scatter(
print("Invalid Operation for Reduce op!!")

operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
device = to_torch_device(default_device())
device = to_torch_device(scatter_tensor.device)
operation_lhs = operation_lhs.to(device)
operation_rhs = operation_rhs.to(device)
return self.func(operation_lhs, operation_rhs)
Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

Expand Down Expand Up @@ -186,11 +187,14 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device:
device = None
for parameter in list(module.parameters()):
if isinstance(parameter, (torch.nn.parameter.Parameter, torch.Tensor)):
device = parameter.device
break
return parameter.device

for buffer in list(module.buffers()):
if isinstance(buffer, (torch.Tensor)):
return buffer.device

if device is None:
device = torch.device("cpu")
device = to_torch_device(default_device())
logger.warning(
"Could not detect the device on which the model exists. Assuming the model is on CPU"
)
Expand Down

0 comments on commit f124297

Please sign in to comment.