From c6f0e66cdd6fa6b3798bb494c2e734600be894c7 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Fri, 30 Aug 2024 14:04:58 -0400 Subject: [PATCH] [BUG] Fixing an error triggered from the `conv_channel_last_pass` while compiling the model `sam` (#444) Closes #325 The error in the linked issue was caused by [this code segment](https://github.com/CentML/hidet/blob/bfbb4db6d7792ed3de3be4e9702e597b8fbbe373/python/hidet/graph/transforms/conv_channel_last.py#L46-L75) in `graph/transforms/conv_channel_last.py`. By the logic flow of this code segment, if the operator `node` has two inputs, the first one with rank 4 and the second rank 3(an example case in the model: an `AddOp` where the first input has shape `[1, 256, 64, 64]` and the second `[256, 1, 1]`) , then by the time the code reaches the line 75, the variable `new_perm`would have value `[1, 2, 0]`, and this value will be recorded as the permutation scheme used to get the new output, which is incorrect as the appropriate value should be `[0, 2, 3, 1]` here. --- .../hidet/graph/transforms/conv_channel_last.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/python/hidet/graph/transforms/conv_channel_last.py b/python/hidet/graph/transforms/conv_channel_last.py index d7b4a5465..ce7c6f5df 100644 --- a/python/hidet/graph/transforms/conv_channel_last.py +++ b/python/hidet/graph/transforms/conv_channel_last.py @@ -43,6 +43,9 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] node = self.op new_inputs: List[Tensor] = [] update_attributes: Dict[str, Any] = {} + + rank_to_perm = {4: [0, 2, 3, 1], 3: [1, 2, 0], 2: [1, 0], 1: [0]} + for x in node.inputs: if x in tensor_map: current_x, current_perm = tensor_map[x] @@ -55,16 +58,10 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] else: # Input is not channel last, convert it to channel last x_rank = len(x.shape) - if x_rank == 4: - new_perm = [0, 2, 3, 1] - elif x_rank == 3: - new_perm = [1, 2, 0] - elif x_rank == 2: - new_perm = [1, 0] - elif x_rank == 1: - new_perm = [0] - else: + new_perm = rank_to_perm.get(x_rank, None) + if new_perm is None: raise ValueError('Channel Last Pass met input tensor of scoped operator with shape > 4.') + new_x = transpose(current_x, new_perm) tensor_map[x] = (new_x, new_perm) new_inputs.append(new_x) @@ -72,7 +69,7 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] update_attributes['axis'] = new_perm.index(node.attrs['axis']) outputs = node.reforward(new_inputs, update_attributes) for idx, y in enumerate(node.outputs): - tensor_map[y] = (outputs[idx], new_perm) + tensor_map[y] = (outputs[idx], rank_to_perm[len(outputs[idx].shape)]) @staticmethod def initialize_scoped_ops():