diff --git a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp index 2533561cfe..96b59154ab 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp @@ -79,6 +79,18 @@ LogicalResult ONNXGemmOpShapeHelper::computeShape() { if (aDims[1].isLiteral() && bDims[0].isLiteral() && aDims[1].getLiteral() != bDims[0].getLiteral()) { return op->emitError("Gemm 2nd dim of A is different than 1st dim of B"); + } else if (aDims[1].isLiteral()) { + // Save aDims[1] into bDims[0], in case bDims[0] was runtime. + bDims[0] = aDims[1]; + // Update the 1st dim of B to the literal aDims[1]. + this->updateInputDimAt( + B, aDims[1].getLiteral(), gemmOp.getTransB() == 0 ? 0 : 1); + } else if (bDims[0].isLiteral()) { + // Save bDims[0] into aDims[1], in case aDims[1] was runtime. + aDims[1] = bDims[0]; + // Update the last dim of A to the literal bDims[0]. + this->updateInputDimAt( + A, bDims[0].getLiteral(), gemmOp.getTransA() == 0 ? 1 : 0); } if (hasBias) { // Check first dim. diff --git a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp index d9fe81fce1..cb781e3695 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp @@ -140,11 +140,15 @@ LogicalResult ONNXGenericMatMulOpShapeHelper::computeShape() { if (aDims[aK].getLiteral() != bDims[bK].getLiteral()) return this->op->emitError("reduction dimension must be the same"); } else if (aDims[aK].isLiteral()) { - // Save aK dims into bK dims, in case bK dims was runtime + // Save aK dims into bK dims, in case bK dims was runtime. bDims[bK] = aDims[aK]; + // Update bK dim to the literal aK. + this->updateInputDimAt(B, aDims[aK].getLiteral(), -2); } else if (bDims[bK].isLiteral()) { - // Save bK dims into aK dims, in case aK dims was runtime + // Save bK dims into aK dims, in case aK dims was runtime. aDims[aK] = bDims[bK]; + // Update aK dim to the literal bK. + this->updateInputDimAt(A, bDims[bK].getLiteral(), -1); } // Add lower N x M dimensions if they are not padded dimensions. if (!aPadDims[aN]) diff --git a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp index f79c18f8f3..3defebf7e5 100644 --- a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp +++ b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp @@ -55,6 +55,11 @@ LogicalResult ONNXGenericRNNShapeHelper::customComputeShape( IndexExpr seqLength = batchwiseLayout ? xDims[1] : xDims[0]; IndexExpr batchSize = batchwiseLayout ? xDims[0] : xDims[1]; + // If input_size dim is dynamic in the input and static in the weight, + // update the input_size dim in the input to be static. + if (!xDims[2].isLiteral() && wDims[2].isLiteral()) + this->updateInputDimAt(X, wDims[2].getLiteral(), 2); + // Get hidden size from hidden_size attribute. IndexExpr hiddenSize; if (operandAdaptor.getHiddenSize().has_value()) { diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 92fc63d07f..7384f80772 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -185,6 +185,58 @@ void ONNXOpShapeHelper::setOutputDims( } } +void ONNXOpShapeHelper::updateInputDimAt( + Value inputVal, uint64_t dimSize, int64_t axis) { + auto valType = mlir::dyn_cast(inputVal.getType()); + if (!valType) + return; + + // Compute a new shape by updating the dim size at axis. + ArrayRef shape = getShape(valType); + SmallVector newShape = + SmallVector(shape.begin(), shape.end()); + if (axis < 0) + axis += newShape.size(); + newShape[axis] = dimSize; + + // Build a new type. + Attribute encoding = valType.getEncoding(); + RankedTensorType newType; + if (encoding) + newType = + RankedTensorType::get(newShape, valType.getElementType(), encoding); + else + newType = RankedTensorType::get(newShape, valType.getElementType()); + + // Update value type. + inputVal.setType(newType); + + // Update the function signature if the value is a BlockArgument. + if (auto blockArg = llvm::dyn_cast(inputVal)) { + // Get the block that owns the argument. + Block *block = blockArg.getOwner(); + // Get the region that owns the block. + Region *region = block->getParent(); + // Get the operation that owns the region. + Operation *op = region->getParentOp(); + // Cast to FuncOp if possible. + auto funcOp = dyn_cast(op); + if (funcOp) { + // Get the current function type. + FunctionType oldFuncType = funcOp.getFunctionType(); + // Create a new input type list with the updated type. + SmallVector newInputTypes( + oldFuncType.getInputs().begin(), oldFuncType.getInputs().end()); + newInputTypes[blockArg.getArgNumber()] = newType; + // Create the new function type. + FunctionType newFuncType = FunctionType::get( + funcOp.getContext(), newInputTypes, oldFuncType.getResults()); + // Update the function type. + funcOp.setType(newFuncType); + } + } +} + LogicalResult ONNXOpShapeHelper::setOutputDimsFromOperand( Value operand, int n, bool refineShape) { // Output and operand have the same shape. Just pass the operand shape to the diff --git a/src/Interface/ShapeHelperOpInterface.hpp b/src/Interface/ShapeHelperOpInterface.hpp index 0e7ec891c8..ed12173b4b 100644 --- a/src/Interface/ShapeHelperOpInterface.hpp +++ b/src/Interface/ShapeHelperOpInterface.hpp @@ -130,6 +130,9 @@ struct ONNXOpShapeHelper { void setOutputDims( const DimsExpr &inferredDims, int n = 0, bool refineShape = true); + // Set dimension at a specific axis for an input value. + void updateInputDimAt(mlir::Value inputVal, uint64_t dimSize, int64_t axis); + // Obtain the n-th output result as value. mlir::Value getOutput(int n = 0) { return op->getResult(n); } diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index ec8145659e..7455efe5e1 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -183,7 +183,7 @@ func.func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) "onnx.Return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_matmul_4 - // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor + // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor // CHECK: onnx.Return [[RES4]] : tensor } @@ -267,6 +267,30 @@ func.func @test_matmul_10(%arg0 : tensor, %arg1 : tensor<32xf32>) - // ----- +// COM: update the last dimension of the 1st input. +func.func @test_matmul_11(%arg0 : tensor<16x?x64x?xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x?xf32>, tensor<42x32xf32>) -> tensor<*xf32> + "onnx.Return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_11 + // CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: onnx.Return [[RES2]] : tensor<16x?x64x32xf32> +} + +// ----- + +// COM: update the 2nd dimension from the right of the 2nd input. +func.func @test_matmul_12(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor) -> tensor<*xf32> + "onnx.Return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_12 + // CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: onnx.Return [[RES2]] : tensor<16x?x64x32xf32> +} + +// ----- + /// QLinearMatMul func.func @test_qlinearmatmul_1(%arg0: tensor<2x2x4xui8>, %arg1: tensor<1xf32>, %arg2: tensor<1xui8>, %arg3: tensor<2x4x3xui8>, %arg4: tensor<1xf32>, %arg5: tensor<1xui8>, %arg6: tensor<1xf32>, %arg7: tensor<1xui8>) -> tensor<*xui8> { @@ -294,6 +318,46 @@ func.func @test_matmulinteger_1(%arg0: tensor<4x3xui8>, %arg1: tensor<3x2xui8>, // ----- +//===----------------------------------------------------------------------===// +/// Test shape inference for Gemm +//===----------------------------------------------------------------------===// + +// COM: update the 2nd dimension of the 1st input. +func.func @test_gemm_1(%arg0: tensor<1x?xf32>, %arg1: tensor<5x4xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %none) : (tensor<1x?xf32>, tensor<5x4xf32>, none) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// COM: update the 2nd dimension of the 1st transposed input. +func.func @test_gemm_2(%arg0: tensor, %arg1: tensor<5x4xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %none) {transA = 1 :si64} : (tensor, tensor<5x4xf32>, none) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// COM: update the 1st dimension of the 2nd input. +func.func @test_gemm_3(%arg0: tensor<1x5xf32>, %arg1: tensor) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %none) : (tensor<1x5xf32>, tensor, none) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +// COM: update the 1st dimension of the 2nd transposed input. +func.func @test_gemm_4(%arg0: tensor<1x5xf32>, %arg1: tensor<4x?xf32>) -> tensor<*xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %none) {transB = 1 : si64} : (tensor<1x5xf32>, tensor<4x?xf32>, none) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test shape inference for Conv (first with no bias) operation and all its attributes. //===----------------------------------------------------------------------===// @@ -1121,6 +1185,19 @@ func.func @test_rnn_infer_hidden_size_from_W(%arg0: tensor<4x3x2xf32>, %arg1: te // ----- +func.func @test_rnn_update_dynamic_input_size_from_weight(%arg0: tensor<4x3x?xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x?xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + onnx.Return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_rnn_update_dynamic_input_size_from_weight + // CHECK: [[CST:%.+]] = "onnx.NoValue"() {value} : () -> none + // CHECK-NEXT: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, [[CST]], [[CST]], [[CST]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 3 : si64, layout = 0 : si64} : (tensor<4x3x2xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: onnx.Return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + func.func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>) -> (none) { %cst = "onnx.NoValue"() {value} : () -> none %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x2xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (none, none)