Skip to content

Commit

Permalink
Add aten.min.dim to linalg lowering (#2600)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-h authored Dec 5, 2023
1 parent d0b49a9 commit 6248216
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 60 deletions.
148 changes: 89 additions & 59 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,70 +30,80 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
// Aten maxdim lowering represents the MaxDim op as an linalg.indexed_generic
// op, producing two output buffers.
// Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an
// linalg.indexed_generic op, producing two output buffers.
//
// The first output buffer contains the maximum value found. It is initialized
// to the minimum representable value of the input element type.
// The first output buffer contains the maximum (minium) value found. It is
// initialized to the minimum (maximum) representable value of the input
// element type.
//
// The second output buffer contains the index of the found maximum value. It is
// initialized to 0 and is resulting integer type.
// The second output buffer contains the index of the found maximum (minimum)
// value. It is initialized to 0 and is resulting integer type.
//
// The indexed_generic op updates both the maximum value and index if the
// current value exceeds the running max.
class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
// The indexed_generic op updates both the maximum (minimum) value and index
// if the current value exceeds the running max (min).
template <typename OpTy>
class ConvertAtenMinMaxDimOp : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<AtenMaxDimOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpConversionPattern<OpTy>::getTypeConverter;

using OpAdaptor = typename OpTy::Adaptor;

LogicalResult
matchAndRewrite(AtenMaxDimOp maxDimOp, OpAdaptor adaptor,
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(std::is_same<OpTy, AtenMaxDimOp>() ||
std::is_same<OpTy, AtenMinDimOp>());
constexpr bool isMax = std::is_same<OpTy, AtenMaxDimOp>();
const llvm::StringRef opName = op->getName().getStringRef();

Location loc = maxDimOp.getLoc();
Location loc = op.getLoc();
Value input = adaptor.getSelf();
RankedTensorType valResultType =
getTypeConverter()
->convertType(maxDimOp.getResult(0).getType())
.cast<RankedTensorType>();
->convertType(op.getResult(0).getType())
.template cast<RankedTensorType>();

RankedTensorType idxResultType =
getTypeConverter()
->convertType(maxDimOp.getResult(1).getType())
.cast<RankedTensorType>();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
this->getTypeConverter()
->convertType(op.getResult(1).getType())
.template cast<RankedTensorType>();
RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();
Type idxElementType = idxResultType.getElementType();
if (!idxElementType.isa<IntegerType>())
return rewriter.notifyMatchFailure(
maxDimOp,
"aten.max_dim to linalg.* requires integer-like result type");
op, opName + " to linalg.* requires integer-like result type");

bool keepDim = false;
if (!matchPattern(maxDimOp.getKeepdim(), m_TorchConstantBool(&keepDim)))
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim requires boolean value for keepdim");
op, opName + " requires boolean value for keepdim");

int64_t dim;
if (!matchPattern(maxDimOp.getDim(), m_TorchConstantInt(&dim)))
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
op, opName + " to linalg.* requires int value for Dim");
dim = toPositiveDim(dim, inputType.getRank());
if (!isValidDim(dim, inputType.getRank()))
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");

Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {
if (inElementType.isa<mlir::IntegerType>()) {
auto integerTy = maxDimOp.getSelf()
auto integerTy = op.getSelf()
.getType()
.cast<BaseTensorType>()
.template cast<BaseTensorType>()
.getDtype()
.dyn_cast<mlir::IntegerType>();
.template dyn_cast<mlir::IntegerType>();
if (integerTy.isUnsigned())
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires input element type "
op, opName + " to linalg.* requires input element type "
"to be signed in case of integer");
} else {
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires Float or Integer "
op, opName + " to linalg.* requires Float or Integer "
"input element type");
}
}
Expand All @@ -112,29 +122,29 @@ class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
Value filledTensorIdx =
createZeroInitTensor(rewriter, loc, resultShape, idxElementType);

// Second fill the output buffer for the running max.
Value initTensorMax = rewriter.create<tensor::EmptyOp>(
// Second fill the output buffer for the running max or min.
Value initTensorVal = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), inElementType);

Value fillValueMax;
Value fillValue;
if (inElementType.isa<mlir::FloatType>()) {
fillValueMax = rewriter.create<arith::ConstantOp>(
fillValue = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getFloatAttr(
inElementType,
APFloat::getInf(
inElementType.cast<mlir::FloatType>().getFloatSemantics(),
/*Negative=*/true)));
/*Negative=*/isMax)));
} else {
fillValueMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
inElementType,
APSInt::getSignedMinValue(
inElementType.cast<mlir::IntegerType>().getWidth())));
auto width = inElementType.cast<mlir::IntegerType>().getWidth();
auto init = isMax ? APSInt::getSignedMinValue(width)
: APSInt::getSignedMaxValue(width);
fillValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(inElementType, init));
}

Value filledTensorMax =
rewriter.create<linalg::FillOp>(loc, fillValueMax, initTensorMax)
Value filledTensorVal =
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal)
.result();

// Create the affine expressions that will be used to
Expand All @@ -161,8 +171,8 @@ class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc,
ArrayRef<Type>({filledTensorMax.getType(), filledTensorIdx.getType()}),
input, ValueRange({filledTensorMax, filledTensorIdx}), maps,
ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}),
input, ValueRange({filledTensorVal, filledTensorIdx}), maps,
iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Expand All @@ -174,33 +184,51 @@ class ConvertAtenMaxDimOp : public OpConversionPattern<AtenMaxDimOp> {
nestedLoc, oldIndex.getType(),
rewriter.create<linalg::IndexOp>(loc, dim));

Value resultMax, predicate;
Value resultVal, predicate;
if (inElementType.isa<mlir::FloatType>()) {
resultMax = rewriter.create<arith::MaximumFOp>(nestedLoc, newValue,
oldValue);
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
arith::CmpFPredicate predType;
if constexpr (isMax) {
predType = arith::CmpFPredicate::OGT;
resultVal = rewriter.create<arith::MaximumFOp>(
nestedLoc, newValue, oldValue);
} else {
predType = arith::CmpFPredicate::OLT;
resultVal = rewriter.create<arith::MinimumFOp>(
nestedLoc, newValue, oldValue);
}

predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
newValue, oldValue);
} else {
resultMax =
rewriter.create<arith::MaxSIOp>(nestedLoc, newValue, oldValue);
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
arith::CmpIPredicate predType;
if constexpr (isMax) {
predType = arith::CmpIPredicate::sgt;
resultVal = rewriter.create<arith::MaxSIOp>(nestedLoc, newValue,
oldValue);
} else {
predType = arith::CmpIPredicate::slt;
resultVal = rewriter.create<arith::MinSIOp>(nestedLoc, newValue,
oldValue);
}
predicate = rewriter.create<arith::CmpIOp>(nestedLoc, predType,
newValue, oldValue);
}
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultMax, resultIndex}));
nestedLoc, ValueRange({resultVal, resultIndex}));
});

// This cast is required to fix the shape in the case of keepDim=True
Value maxValuesCast = rewriter.create<tensor::CastOp>(
Value valuesCast = rewriter.create<tensor::CastOp>(
loc, valResultType, linalgOp.getResult(0));
Value maxIdxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
rewriter.replaceOp(maxDimOp, {maxValuesCast, maxIdxCast});
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
rewriter.replaceOp(op, {valuesCast, idxCast});
return success();
}
};

} // namespace

static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
Expand Down Expand Up @@ -574,7 +602,9 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenMaxDimOp>();
patterns.add<ConvertAtenMaxDimOp>(typeConverter, context);
patterns.add<ConvertAtenMinMaxDimOp<AtenMaxDimOp>>(typeConverter, context);
target.addIllegalOp<AtenMinDimOp>();
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenProdDimIntOp>();
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6872,6 +6872,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.min.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
Expand Down Expand Up @@ -10691,6 +10697,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple<int, int>) -> !torch.int\n"
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
Expand Down
3 changes: 2 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"IscloseStaticModuleTrue_basic"
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -69,6 +69,7 @@
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
"UpSampleNearest2dDynamicFactor_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
#ERROR: value (-56) is not equal to golden value (200)
"AtenIntTensorByteDtypeModule_basic",
# ERROR: assert isinstance(e, FakeTensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@ def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape

def aten〇min〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape

def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)

Expand Down Expand Up @@ -3286,6 +3290,10 @@ def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k
def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return aten〇max〡dtype(self_rank_dtype), torch.int64

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]:
return aten〇min〡dtype(self_rank_dtype), torch.int64

@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}

Expand Down
Loading

0 comments on commit 6248216

Please sign in to comment.