Skip to content

Commit

Permalink
[Graph] Minor bug fixes (#358)
Browse files Browse the repository at this point in the history
Fix a bug in batch_matmul schedule. Fix two bugs in channel last pass.
  • Loading branch information
hjjq committed Sep 19, 2023
1 parent f206735 commit 9eee253
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/matmul/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):

@hidet.script
def copy_a_g2r(
a: a_dtype[bs, m_size, k_size],
a: input_a_dtype[bs, m_size, k_size],
regs_a_ldg: TensorType(dtype=a_dtype, layout=regs_a_ldg_layout),
offset_m: i32,
offset_k: i32,
Expand Down Expand Up @@ -511,7 +511,7 @@ def copy_a_s2r(

@hidet.script
def copy_b_g2r(
b: b_dtype[bs, k_size, n_size],
b: input_b_dtype[bs, k_size, n_size],
regs_b_ldg: TensorType(dtype=b_dtype, layout=regs_b_ldg_layout),
offset_k: i32,
offset_n: i32,
Expand Down Expand Up @@ -549,7 +549,7 @@ def copy_b_s2r(
@hidet.script
def copy_c_r2g(
regs_c: TensorType(dtype=c_dtype, layout=regs_c_layout),
c: c_dtype[bs, m_size, n_size],
c: input_c_dtype[bs, m_size, n_size],
offset_m: i32,
offset_n: i32,
smem: void_p,
Expand Down
25 changes: 18 additions & 7 deletions python/hidet/graph/transforms/conv_channel_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]]
node = self.op
new_inputs: List[Tensor] = []
update_attributes: Dict[str, Any] = {}
new_perms: List[List[int]] = []
for x in node.inputs:
if x in tensor_map:
current_x, current_perm = tensor_map[x]
Expand All @@ -55,13 +54,20 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]]
new_perm = current_perm
else:
# Input is not channel last, convert it to channel last
new_perm = [0, 2, 3, 1]
x_rank = len(x.shape)
if x_rank == 4:
new_perm = [0, 2, 3, 1]
elif x_rank == 3:
new_perm = [1, 2, 0]
elif x_rank == 2:
new_perm = [1, 0]
elif x_rank == 1:
new_perm = [0]
else:
raise ValueError('Channel Last Pass met input tensor of scoped operator with shape > 4.')
new_x = transpose(current_x, new_perm)
tensor_map[x] = (new_x, new_perm)
new_inputs.append(new_x)
new_perms.append(new_perm)
new_perm = new_perms[0]
assert all(p == new_perm for p in new_perms)
if 'axis' in node.attrs and isinstance(node.attrs['axis'], int):
update_attributes['axis'] = new_perm.index(node.attrs['axis'])
outputs = node.reforward(new_inputs, update_attributes)
Expand Down Expand Up @@ -215,10 +221,15 @@ def process_graph(self, graph: FlowGraph) -> FlowGraph:
# TODO: Deal with FP16/FP32
from hidet.graph.ops.conv2d import Conv2dOp
from hidet.graph.ops.transform import transpose
from hidet.ir.dtypes import float16

nodes: List[Operator] = graph.nodes
# Start from all conv2d operators as seeds
seeds = [node for node in nodes if isinstance(node, Conv2dOp)]
# Start from all fp16 conv2d operators as seeds
seeds: List[Operator] = []
for node in nodes:
if isinstance(node, Conv2dOp):
if node.inputs[0].dtype == float16 and node.inputs[1].dtype == float16:
seeds.append(node)

# Only use this pass if there is convolution in the graph
if len(seeds) == 0:
Expand Down
1 change: 1 addition & 0 deletions scripts/regression/op_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def reduce_regression() -> ResultGroup:


def op_performance_regression(report_file):
hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression')
result_groups = []
result_groups.append(matmul_regression())
result_groups.append(fmha_regression())
Expand Down

0 comments on commit 9eee253

Please sign in to comment.