diff --git a/python/hidet/graph/transforms/conv_channel_last.py b/python/hidet/graph/transforms/conv_channel_last.py index d7b4a5465..ce7c6f5df 100644 --- a/python/hidet/graph/transforms/conv_channel_last.py +++ b/python/hidet/graph/transforms/conv_channel_last.py @@ -43,6 +43,9 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] node = self.op new_inputs: List[Tensor] = [] update_attributes: Dict[str, Any] = {} + + rank_to_perm = {4: [0, 2, 3, 1], 3: [1, 2, 0], 2: [1, 0], 1: [0]} + for x in node.inputs: if x in tensor_map: current_x, current_perm = tensor_map[x] @@ -55,16 +58,10 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] else: # Input is not channel last, convert it to channel last x_rank = len(x.shape) - if x_rank == 4: - new_perm = [0, 2, 3, 1] - elif x_rank == 3: - new_perm = [1, 2, 0] - elif x_rank == 2: - new_perm = [1, 0] - elif x_rank == 1: - new_perm = [0] - else: + new_perm = rank_to_perm.get(x_rank, None) + if new_perm is None: raise ValueError('Channel Last Pass met input tensor of scoped operator with shape > 4.') + new_x = transpose(current_x, new_perm) tensor_map[x] = (new_x, new_perm) new_inputs.append(new_x) @@ -72,7 +69,7 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]] update_attributes['axis'] = new_perm.index(node.attrs['axis']) outputs = node.reforward(new_inputs, update_attributes) for idx, y in enumerate(node.outputs): - tensor_map[y] = (outputs[idx], new_perm) + tensor_map[y] = (outputs[idx], rank_to_perm[len(outputs[idx].shape)]) @staticmethod def initialize_scoped_ops():