From 081c84d03e61f8fa14275262e4e6907849fd9c3c Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Tue, 27 Aug 2024 10:09:45 -0400 Subject: [PATCH] [BUG] Fixing a bug triggered while compiling in-place operator `torch.Tensor.scatter_add_` (#429) Closes #424 The additional bug described in the comments in the linked issue([here](https://github.com/CentML/hidet/issues/424#issuecomment-2297325930)) is caused by accessing a PyTorch tensor in [this line](https://github.com/CentML/hidet/blob/18f68ae34d8a08ca1b38ee00ac2ca7f15e599d0b/python/hidet/runtime/compiled_task.py#L161) while we were supposed to be accessing a Hidet tensor. --- python/hidet/graph/ops/scatter.py | 4 +- python/hidet/runtime/compiled_task.py | 44 +++++++++++++++---- .../torch/test_torch_interoperability.py | 21 +++++++++ 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/python/hidet/graph/ops/scatter.py b/python/hidet/graph/ops/scatter.py index 2f237f363..67a46b5c2 100644 --- a/python/hidet/graph/ops/scatter.py +++ b/python/hidet/graph/ops/scatter.py @@ -54,7 +54,7 @@ def __init__(self, input: Tensor, index: Tensor, src: Tensor, dim: int, fname: s super().__init__( name=f'scatter_{fname}_dim_{dim}{"_inplace" if inplace else ""}', inputs={'input': input, 'index': index, 'src': src}, - attributes={'dim': dim, 'fname': fname}, + attributes={'dim': dim, 'fname': fname, 'inplace': inplace}, share_map=share_map, ) @@ -130,7 +130,7 @@ def scatter_internal_func( linear_idx = work_per_block * blockIdx.x + threadIdx.x - assert input == out + _ = input work_i = 0 while work_i < work_per_thread: diff --git a/python/hidet/runtime/compiled_task.py b/python/hidet/runtime/compiled_task.py index f62a128f6..500bcf40f 100644 --- a/python/hidet/runtime/compiled_task.py +++ b/python/hidet/runtime/compiled_task.py @@ -154,12 +154,20 @@ def create_outputs(self, inputs): if idx not in self.meta_data.share_map: outputs.append(hidet.empty(shape, sig.dtype, sig.device)) else: - input_tensor = hidet.Tensor( - shape=shape, - dtype=sig.dtype, - device=sig.device, - storage=inputs[self.meta_data.share_map[idx]].storage, - ) + shared_tensor = inputs[self.meta_data.share_map[idx]] + if not isinstance(shared_tensor, hidet.Tensor): + import torch + + assert isinstance(shared_tensor, torch.Tensor), "Unknown tensor type" + tensor_dtype = getattr(torch, sig.dtype) + + # we need to turn the tensor into a view with the graph output's shape & dtype + input_tensor = shared_tensor.view(*shape).view(tensor_dtype) + else: + input_tensor = hidet.Tensor( + shape=shape, dtype=sig.dtype, device=sig.device, storage=shared_tensor.storage + ) + outputs.append(input_tensor) return outputs @@ -254,9 +262,27 @@ def profile(self, *args, warmup=1, number=2, repeat=10): latency: List[float] The measured latency in milliseconds. The length of the list is equal to `repeat`. """ - num_outputs = len(self.meta_data.outputs) - inputs = args[:num_outputs] - outputs = args[num_outputs:] + + num_inputs = len(self.meta_data.inputs) + inputs = args[:num_inputs] + outputs = args[num_inputs:] + + # For operators like scatter_add_, if we run it multiple times on the same input & output tensors, + # the input and output tensors will be wrong as they will be wrongly updated multiple times. + # to avoid this, make a clone of the output tensors if they share the memory with some input tensors. + if len(self.meta_data.share_map) > 0: + from hidet import Tensor + + outputs = list(outputs) + inputs = list(inputs) + for output_idx in self.meta_data.share_map: + original_output = outputs[output_idx] + if isinstance(original_output, Tensor): + outputs[output_idx] = original_output.copy() + else: + outputs[output_idx] = original_output.clone() + args = inputs + outputs + candidate = self.candidates[self.pick_best_candidate(inputs, outputs)] return candidate.profile(*args, warmup=warmup, number=number, repeat=repeat) diff --git a/tests/frontends/torch/test_torch_interoperability.py b/tests/frontends/torch/test_torch_interoperability.py index 506a52098..405348c4a 100644 --- a/tests/frontends/torch/test_torch_interoperability.py +++ b/tests/frontends/torch/test_torch_interoperability.py @@ -371,5 +371,26 @@ def test_torch_einsum(equation, operand_shapes): ) +def test_scatter_add_compile(): + # This operator was already tested in tests/operators/test_inplace_operator.py. + # Just to add one more additional test here to ensure the bug mentioned in #429 is gone. + input_tensor = torch.zeros((6, 6), dtype=torch.float32, device='cuda') + + index_tensor = torch.tensor([[4, 1, 4, 4, 2], [0, 1, 4, 5, 2], [5, 1, 3, 4, 2]]).to(dtype=torch.int64).cuda() + + input_tensor_clone = input_tensor.clone() + src = torch.tensor([[0, 5, 3, 6, 5], [9, 6, 8, 8, 4], [7, 4, 5, 4, 7]]).to(dtype=torch.float32).cuda() + + dim = 1 + + check_module( + FunctionalModule(op=lambda x, y, z: x.scatter_add_(dim, y, z)), + args=[input_tensor_clone, index_tensor, src], + atol=0, + rtol=0, + dynamic=False, + ) + + if __name__ == '__main__': pytest.main([__file__])