Skip to content

Commit

Permalink
[TorchToLinalg] Fix integer type handling for aten.mm (#2615)
Browse files Browse the repository at this point in the history
Despite aten.mm requiring the input and output types match, we still opt
to maintain signedness semantics in case later passes try to do any sort
of integer type narrowing.
  • Loading branch information
qedawkins committed Dec 7, 2023
1 parent c011570 commit 141202b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 13 deletions.
42 changes: 30 additions & 12 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,24 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
// The compiler cannot crash even if the user wrote an erroneous program!
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {

RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();

if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
return rewriter.notifyMatchFailure(
op, "expected both operands to aten.mm to be rank 2");
}

ValueTensorType lhsTorchType =
op.getSelf().getType().cast<ValueTensorType>();
ValueTensorType rhsTorchType =
op.getMat2().getType().cast<ValueTensorType>();
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
}

Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);

Expand All @@ -73,16 +85,22 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {

Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<OpFoldResult>{lhsDim0, rhsDim1}, elementType);
Value c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value zeroFill =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Value matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
Value zeroFill = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);

Value matmul;
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
if (intType && intType.isUnsigned()) {
matmul = rewriter
.create<linalg::MatmulUnsignedOp>(
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
} else {
matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
}
// When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result
Expand Down
38 changes: 37 additions & 1 deletion projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,40 @@ def forward(self, m, v):

@register_test_case(module_factory=lambda: Mv())
def Mv_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2), tu.rand(2))
module.forward(tu.rand(2, 2), tu.rand(2))

# ==============================================================================

class AtenMmFloatTypes(torch.nn.Module):

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.ops.aten.mm(a, b)


@register_test_case(module_factory=lambda: AtenMmFloatTypes())
def AtenMmFloatTypes_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 8), tu.rand(8, 8))

# ==============================================================================

class AtenMmIntTypes(torch.nn.Module):

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, a, b):
return torch.ops.aten.mm(a, b)


@register_test_case(module_factory=lambda: AtenMmIntTypes())
def AtenMmIntTypes_basic(module, tu: TestUtils):
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
11 changes: 11 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !

// -----

// CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned(
// CHECK: linalg.matmul_unsigned
func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32>
attributes {torch.assume_strict_symbolic_shapes}
{
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32>
return %0 : !torch.vtensor<[?,2],ui32>
}

// -----

// If the operands are missing dtype, we cannot lower it.
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}}
Expand Down

0 comments on commit 141202b

Please sign in to comment.