From c4b0ab45061c59a85f3f3cc187489e35b606506f Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 24 Feb 2025 19:13:30 +0000 Subject: [PATCH] [mlir][tosa] Fix crash on attempt to fold int_div by zero Fixes #118268. Change-Id: Ib3eeed6e796a573b30f04a992f4213862f0e0eb6 Signed-off-by: Luke Hutton --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 9 +++++---- mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9bfc2aae1d6a5..b9bdc1c7101ab 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -680,10 +680,11 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { return getInput1(); } - if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) { - if (llvm::isa(resultETy)) { - APInt l = lhsAttr.getSplatValue(); - APInt r = rhsAttr.getSplatValue(); + if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat() && + llvm::isa(resultETy)) { + APInt l = lhsAttr.getSplatValue(); + APInt r = rhsAttr.getSplatValue(); + if (!r.isZero()) { APInt result = l.sdiv(r); return DenseElementsAttr::get(resultTy, result); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 0e177a076ee7a..c08517b33b0f9 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1012,3 +1012,14 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { %2 = "tosa.reciprocal"(%1): (tensor<3x600x1200xi32>) -> tensor<3x600x1200xi32> return %2 : tensor<3x600x1200xi32> } + +// ----- + +// CHECK-LABEL: @do_not_fold_int_div_division_by_0 +func.func @do_not_fold_int_div_division_by_0() -> tensor<1x24x2xi32> { + // CHECK: tosa.int_div + %1 = "tosa.const"() <{value = dense<0> : tensor<1x24x2xi32>}> : () -> tensor<1x24x2xi32> + %4 = "tosa.const"() <{value = dense<20> : tensor<1x24x2xi32>}> : () -> tensor<1x24x2xi32> + %16 = tosa.int_div %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32> + return %16 : tensor<1x24x2xi32> +}