diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9bfc2aae1d6a5..b9bdc1c7101ab 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -680,10 +680,11 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { return getInput1(); } - if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { - if (llvm::isa(resultETy)) { - APInt l = lhsAttr.getSplatValue(); - APInt r = rhsAttr.getSplatValue(); + if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() && + llvm::isa(resultETy)) { + APInt l = lhsAttr.getSplatValue(); + APInt r = rhsAttr.getSplatValue(); + if (!r.isZero()) { APInt result = l.sdiv(r); return DenseElementsAttr::get(resultTy, result); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 0e177a076ee7a..c08517b33b0f9 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1012,3 +1012,14 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> return %2 : tensor<3x600x1200xi32> } + +// ----- + +// CHECK-LABEL: @do_not_fold_int_div_division_by_0 +func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> { + // CHECK: tosa.int_div + %1 = "tosa.const"() <{value = dense<0> : tensor<1x24x2xi32>}> : () -> tensor<1x24x2xi32> + %4 = "tosa.const"() <{value = dense<20> : tensor<1x24x2xi32>}> : () -> tensor<1x24x2xi32> + %16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32> + return %16 : tensor<1x24x2xi32> +}