Skip to content

Commit

Permalink
Fixing slice scatter and select scatter decomposition (#3093)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose authored Aug 29, 2024
1 parent ffa4f64 commit 3a3d62a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def slice_scatter_decomposition(
step: Optional[int] = None,
) -> torch.Tensor:
dim_size = input_tensor.shape[dim]
device_input_tensor = input_tensor.device
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
end = dim_size
Expand All @@ -216,7 +217,8 @@ def slice_scatter_decomposition(
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
index_tensor = torch.stack(cat_tensors, dim).to(input_tensor.device)
index_tensor = torch.stack(cat_tensors, dim)
index_tensor = index_tensor.to(device_input_tensor)
index_tensor_64 = index_tensor.to(torch.int64)
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
return output_tensor
Expand Down

0 comments on commit 3a3d62a

Please sign in to comment.