Skip to content

Commit

Permalink
Fix scalar axes in lowering reduction operations (#2925)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Sep 3, 2024
1 parent f7d8db5 commit 1900ea7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/Conversion/ONNXToKrnl/Math/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,13 @@ struct ONNXReductionOpLowering : public OpConversionPattern<ONNXReductionOp> {
// Default value of having no axes.
hasNoAxes = true;
} else {
// Check it has a rank of 1.
assert(
create.krnlIE.getShapedTypeRank(axesVal) == 1 && "expect rank 1");
axisShape0 = create.krnlIE.getShapeAsDim(axesVal, 0);
// Check it has a rank of 0 or 1.
int64_t axisRank = create.krnlIE.getShapedTypeRank(axesVal);
assert((axisRank == 0 || axisRank == 1) && "expect rank 0 or 1");
if (axisRank == 0)
axisShape0 = LitIE(1);
else
axisShape0 = create.krnlIE.getShapeAsDim(axesVal, 0);

if (!axisShape0.isLiteral())
// Don't even know the shape of the axis... it is dynamic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,37 @@

// -----

func.func @test_reduce_scalar_axes(%arg0: tensor<?x64x?xf32>) -> tensor<?x?xf32> {
%axes= onnx.Constant dense<-2> : tensor<i64>
%0 = "onnx.ReduceSum"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<?x64x?xf32>, tensor<i64>) -> tensor<?x?xf32>
return %0: tensor<?x?xf32>

// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0, s1] -> (s1)>
// CHECK-LABEL: func.func @test_reduce_scalar_axes
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x64x?xf32>) -> memref<?x?xf32> {
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x64x?xf32>
// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<?x64x?xf32>
// CHECK: [[RES_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref<?x?xf32>
// CHECK: krnl.memset [[RES_]], [[CST_0_dot_000000_]] : memref<?x?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x64x?xf32>
// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<?x64x?xf32>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_1_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_1_]], [[VAR_dim_2_]]{{.}}){
// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref<?x64x?xf32>
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#2] : memref<?x?xf32>
// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32
// CHECK: krnl.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#2] : memref<?x?xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<?x?xf32>
// CHECK: }
}

// -----

func.func private @test_reducemax_v13(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> {
%0 ="onnx.ReduceMaxV13"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xf32>)-> tensor<*xf32>
Expand Down

0 comments on commit 1900ea7

Please sign in to comment.