Skip to content

Commit

Permalink
[BUG] Fixing a bug triggered while compiling in-place operator `torch…
Browse files Browse the repository at this point in the history
….Tensor.scatter_add_` (#429)

Closes #424 

The additional bug described in the comments in the linked
issue([here](CentML/hidet#424 (comment)))
is caused by accessing a PyTorch tensor in [this
line](https://github.com/CentML/hidet/blob/18f68ae34d8a08ca1b38ee00ac2ca7f15e599d0b/python/hidet/runtime/compiled_task.py#L161)
while we were supposed to be accessing a Hidet tensor.
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent f421a43 commit 4f142c4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, input: Tensor, index: Tensor, src: Tensor, dim: int, fname: s
super().__init__(
name=f'scatter_{fname}_dim_{dim}{"_inplace" if inplace else ""}',
inputs={'input': input, 'index': index, 'src': src},
attributes={'dim': dim, 'fname': fname},
attributes={'dim': dim, 'fname': fname, 'inplace': inplace},
share_map=share_map,
)

Expand Down Expand Up @@ -130,7 +130,7 @@ def scatter_internal_func(

linear_idx = work_per_block * blockIdx.x + threadIdx.x

assert input == out
_ = input

work_i = 0
while work_i < work_per_thread:
Expand Down
44 changes: 35 additions & 9 deletions python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,20 @@ def create_outputs(self, inputs):
if idx not in self.meta_data.share_map:
outputs.append(hidet.empty(shape, sig.dtype, sig.device))
else:
input_tensor = hidet.Tensor(
shape=shape,
dtype=sig.dtype,
device=sig.device,
storage=inputs[self.meta_data.share_map[idx]].storage,
)
shared_tensor = inputs[self.meta_data.share_map[idx]]
if not isinstance(shared_tensor, hidet.Tensor):
import torch

assert isinstance(shared_tensor, torch.Tensor), "Unknown tensor type"
tensor_dtype = getattr(torch, sig.dtype)

# we need to turn the tensor into a view with the graph output's shape & dtype
input_tensor = shared_tensor.view(*shape).view(tensor_dtype)
else:
input_tensor = hidet.Tensor(
shape=shape, dtype=sig.dtype, device=sig.device, storage=shared_tensor.storage
)

outputs.append(input_tensor)
return outputs

Expand Down Expand Up @@ -254,9 +262,27 @@ def profile(self, *args, warmup=1, number=2, repeat=10):
latency: List[float]
The measured latency in milliseconds. The length of the list is equal to `repeat`.
"""
num_outputs = len(self.meta_data.outputs)
inputs = args[:num_outputs]
outputs = args[num_outputs:]

num_inputs = len(self.meta_data.inputs)
inputs = args[:num_inputs]
outputs = args[num_inputs:]

# For operators like scatter_add_, if we run it multiple times on the same input & output tensors,
# the input and output tensors will be wrong as they will be wrongly updated multiple times.
# to avoid this, make a clone of the output tensors if they share the memory with some input tensors.
if len(self.meta_data.share_map) > 0:
from hidet import Tensor

outputs = list(outputs)
inputs = list(inputs)
for output_idx in self.meta_data.share_map:
original_output = outputs[output_idx]
if isinstance(original_output, Tensor):
outputs[output_idx] = original_output.copy()
else:
outputs[output_idx] = original_output.clone()
args = inputs + outputs

candidate = self.candidates[self.pick_best_candidate(inputs, outputs)]
return candidate.profile(*args, warmup=warmup, number=number, repeat=repeat)

Expand Down
21 changes: 21 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,26 @@ def test_torch_einsum(equation, operand_shapes):
)


def test_scatter_add_compile():
# This operator was already tested in tests/operators/test_inplace_operator.py.
# Just to add one more additional test here to ensure the bug mentioned in #429 is gone.
input_tensor = torch.zeros((6, 6), dtype=torch.float32, device='cuda')

index_tensor = torch.tensor([[4, 1, 4, 4, 2], [0, 1, 4, 5, 2], [5, 1, 3, 4, 2]]).to(dtype=torch.int64).cuda()

input_tensor_clone = input_tensor.clone()
src = torch.tensor([[0, 5, 3, 6, 5], [9, 6, 8, 8, 4], [7, 4, 5, 4, 7]]).to(dtype=torch.float32).cuda()

dim = 1

check_module(
FunctionalModule(op=lambda x, y, z: x.scatter_add_(dim, y, z)),
args=[input_tensor_clone, index_tensor, src],
atol=0,
rtol=0,
dynamic=False,
)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 4f142c4

Please sign in to comment.