Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ LogicalResult ONNXGemmOpShapeHelper::computeShape() {
if (aDims[1].isLiteral() && bDims[0].isLiteral() &&
aDims[1].getLiteral() != bDims[0].getLiteral()) {
return op->emitError("Gemm 2nd dim of A is different than 1st dim of B");
} else if (aDims[1].isLiteral()) {
// Save aDims[1] into bDims[0], in case bDims[0] was runtime.
bDims[0] = aDims[1];
// Update the 1st dim of B to the literal aDims[1].
this->updateInputDimAt(
B, aDims[1].getLiteral(), gemmOp.getTransB() == 0 ? 0 : 1);
} else if (bDims[0].isLiteral()) {
// Save bDims[0] into aDims[1], in case aDims[1] was runtime.
aDims[1] = bDims[0];
// Update the last dim of A to the literal bDims[0].
this->updateInputDimAt(
A, bDims[0].getLiteral(), gemmOp.getTransA() == 0 ? 1 : 0);
}
if (hasBias) {
// Check first dim.
Expand Down
8 changes: 6 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,15 @@ LogicalResult ONNXGenericMatMulOpShapeHelper<OP_TYPE>::computeShape() {
if (aDims[aK].getLiteral() != bDims[bK].getLiteral())
return this->op->emitError("reduction dimension must be the same");
} else if (aDims[aK].isLiteral()) {
// Save aK dims into bK dims, in case bK dims was runtime
// Save aK dims into bK dims, in case bK dims was runtime.
bDims[bK] = aDims[aK];
// Update bK dim to the literal aK.
this->updateInputDimAt(B, aDims[aK].getLiteral(), -2);
} else if (bDims[bK].isLiteral()) {
// Save bK dims into aK dims, in case aK dims was runtime
// Save bK dims into aK dims, in case aK dims was runtime.
aDims[aK] = bDims[bK];
// Update aK dim to the literal bK.
this->updateInputDimAt(A, bDims[bK].getLiteral(), -1);
}
// Add lower N x M dimensions if they are not padded dimensions.
if (!aPadDims[aN])
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ LogicalResult ONNXGenericRNNShapeHelper<OP_TYPE>::customComputeShape(
IndexExpr seqLength = batchwiseLayout ? xDims[1] : xDims[0];
IndexExpr batchSize = batchwiseLayout ? xDims[0] : xDims[1];

// If input_size dim is dynamic in the input and static in the weight,
// update the input_size dim in the input to be static.
if (!xDims[2].isLiteral() && wDims[2].isLiteral())
this->updateInputDimAt(X, wDims[2].getLiteral(), 2);

// Get hidden size from hidden_size attribute.
IndexExpr hiddenSize;
if (operandAdaptor.getHiddenSize().has_value()) {
Expand Down
52 changes: 52 additions & 0 deletions src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,58 @@ void ONNXOpShapeHelper::setOutputDims(
}
}

void ONNXOpShapeHelper::updateInputDimAt(
Value inputVal, uint64_t dimSize, int64_t axis) {
auto valType = mlir::dyn_cast<RankedTensorType>(inputVal.getType());
if (!valType)
return;

// Compute a new shape by updating the dim size at axis.
ArrayRef<int64_t> shape = getShape(valType);
SmallVector<int64_t> newShape =
SmallVector<int64_t>(shape.begin(), shape.end());
if (axis < 0)
axis += newShape.size();
newShape[axis] = dimSize;

// Build a new type.
Attribute encoding = valType.getEncoding();
RankedTensorType newType;
if (encoding)
newType =
RankedTensorType::get(newShape, valType.getElementType(), encoding);
else
newType = RankedTensorType::get(newShape, valType.getElementType());

// Update value type.
inputVal.setType(newType);

// Update the function signature if the value is a BlockArgument.
if (auto blockArg = llvm::dyn_cast<BlockArgument>(inputVal)) {
// Get the block that owns the argument.
Block *block = blockArg.getOwner();
// Get the region that owns the block.
Region *region = block->getParent();
// Get the operation that owns the region.
Operation *op = region->getParentOp();
// Cast to FuncOp if possible.
auto funcOp = dyn_cast<func::FuncOp>(op);
if (funcOp) {
// Get the current function type.
FunctionType oldFuncType = funcOp.getFunctionType();
// Create a new input type list with the updated type.
SmallVector<Type, 4> newInputTypes(
oldFuncType.getInputs().begin(), oldFuncType.getInputs().end());
newInputTypes[blockArg.getArgNumber()] = newType;
// Create the new function type.
FunctionType newFuncType = FunctionType::get(
funcOp.getContext(), newInputTypes, oldFuncType.getResults());
// Update the function type.
funcOp.setType(newFuncType);
}
}
}

LogicalResult ONNXOpShapeHelper::setOutputDimsFromOperand(
Value operand, int n, bool refineShape) {
// Output and operand have the same shape. Just pass the operand shape to the
Expand Down
3 changes: 3 additions & 0 deletions src/Interface/ShapeHelperOpInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ struct ONNXOpShapeHelper {
void setOutputDims(
const DimsExpr &inferredDims, int n = 0, bool refineShape = true);

// Set dimension at a specific axis for an input value.
void updateInputDimAt(mlir::Value inputVal, uint64_t dimSize, int64_t axis);

// Obtain the n-th output result as value.
mlir::Value getOutput(int n = 0) { return op->getResult(n); }

Expand Down
79 changes: 78 additions & 1 deletion test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func.func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>)
"onnx.Return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: test_matmul_4
// CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x64x?xf32>
// CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x42x?xf32>) -> tensor<?x?x64x?xf32>
// CHECK: onnx.Return [[RES4]] : tensor<?x?x64x?xf32>
}

Expand Down Expand Up @@ -267,6 +267,30 @@ func.func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -

// -----

// COM: update the last dimension of the 1st input.
func.func @test_matmul_11(%arg0 : tensor<16x?x64x?xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x?xf32>, tensor<42x32xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: test_matmul_11
// CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32>
// CHECK: onnx.Return [[RES2]] : tensor<16x?x64x32xf32>
}

// -----

// COM: update the 2nd dimension from the right of the 2nd input.
func.func @test_matmul_12(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<?x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<?x32xf32>) -> tensor<*xf32>
"onnx.Return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: test_matmul_12
// CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32>
// CHECK: onnx.Return [[RES2]] : tensor<16x?x64x32xf32>
}

// -----

/// QLinearMatMul

func.func @test_qlinearmatmul_1(%arg0: tensor<2x2x4xui8>, %arg1: tensor<1xf32>, %arg2: tensor<1xui8>, %arg3: tensor<2x4x3xui8>, %arg4: tensor<1xf32>, %arg5: tensor<1xui8>, %arg6: tensor<1xf32>, %arg7: tensor<1xui8>) -> tensor<*xui8> {
Expand Down Expand Up @@ -294,6 +318,46 @@ func.func @test_matmulinteger_1(%arg0: tensor<4x3xui8>, %arg1: tensor<3x2xui8>,

// -----

//===----------------------------------------------------------------------===//
/// Test shape inference for Gemm
//===----------------------------------------------------------------------===//

// COM: update the 2nd dimension of the 1st input.
func.func @test_gemm_1(%arg0: tensor<1x?xf32>, %arg1: tensor<5x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Gemm"(%arg0, %arg1, %none) : (tensor<1x?xf32>, tensor<5x4xf32>, none) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

// COM: update the 2nd dimension of the 1st transposed input.
func.func @test_gemm_2(%arg0: tensor<?x1xf32>, %arg1: tensor<5x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Gemm"(%arg0, %arg1, %none) {transA = 1 :si64} : (tensor<?x1xf32>, tensor<5x4xf32>, none) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

// COM: update the 1st dimension of the 2nd input.
func.func @test_gemm_3(%arg0: tensor<1x5xf32>, %arg1: tensor<?x4xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Gemm"(%arg0, %arg1, %none) : (tensor<1x5xf32>, tensor<?x4xf32>, none) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

// COM: update the 1st dimension of the 2nd transposed input.
func.func @test_gemm_4(%arg0: tensor<1x5xf32>, %arg1: tensor<4x?xf32>) -> tensor<*xf32> {
%none = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Gemm"(%arg0, %arg1, %none) {transB = 1 : si64} : (tensor<1x5xf32>, tensor<4x?xf32>, none) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

//===----------------------------------------------------------------------===//
/// Test shape inference for Conv (first with no bias) operation and all its attributes.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1121,6 +1185,19 @@ func.func @test_rnn_infer_hidden_size_from_W(%arg0: tensor<4x3x2xf32>, %arg1: te

// -----

func.func @test_rnn_update_dynamic_input_size_from_weight(%arg0: tensor<4x3x?xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>) -> tensor<*xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x?xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
onnx.Return %Y_h : tensor<*xf32>

// CHECK-LABEL: test_rnn_update_dynamic_input_size_from_weight
// CHECK: [[CST:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NEXT: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, [[CST]], [[CST]], [[CST]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 3 : si64, layout = 0 : si64} : (tensor<4x3x2xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: onnx.Return [[RES]] : tensor<1x3x3xf32>
}

// -----

func.func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>) -> (none) {
%cst = "onnx.NoValue"() {value} : () -> none
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x2xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, none) -> (none, none)
Expand Down
Loading