Skip to content

Commit

Permalink
Add decomposition for ONNXSoftmaxCrossEntropyLossOp (#2968)
Browse files Browse the repository at this point in the history
Signed-off-by: Sam <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
srcarroll and AlexandreEichenberger authored Nov 12, 2024
1 parent 7411403 commit cb9a949
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 4 deletions.
123 changes: 119 additions & 4 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -484,9 +486,7 @@ namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Dialect/ONNX/Transforms/ONNXDecompose.inc"

#ifdef ONNX_MLIR_ENABLE_STABLEHLO

RankedTensorType createResultType(
RankedTensorType createReducedType(
Type outputType, int64_t axisValue, bool keepDims) {
RankedTensorType outputShapeType =
mlir::dyn_cast<RankedTensorType>(outputType);
Expand All @@ -507,6 +507,8 @@ RankedTensorType createResultType(
return resultType;
}

#ifdef ONNX_MLIR_ENABLE_STABLEHLO

struct SoftmaxPattern : public OpRewritePattern<ONNXSoftmaxOp> {
using OpRewritePattern<ONNXSoftmaxOp>::OpRewritePattern;

Expand All @@ -526,7 +528,7 @@ struct SoftmaxPattern : public OpRewritePattern<ONNXSoftmaxOp> {
rewriter.getIntegerType(64, /*isSigned=*/true), 1);
ArrayAttr axisAttr = rewriter.getI64ArrayAttr({axisValue});
RankedTensorType resultType =
createResultType(inputType, axisValue, /*keepDims=*/true);
createReducedType(inputType, axisValue, /*keepDims=*/true);
Value maxInput = rewriter.create<ONNXReduceMaxV13Op>(
odsLoc, resultType, input, axisAttr, keepDimsAttr);
Value subValue =
Expand Down Expand Up @@ -985,6 +987,117 @@ struct GroupNormIntoLayerNormPattern2
}
};

/// Decompose `onnx.SoftmaxCrossEntropyLoss` to the following sequence:
/// In the following we assume classes is in dim=1 of scores.
/// 1. one_hot_encoded = onnx.Castlike(onnx.OneHot(labels, dim=1), scores)
/// 2. log_softmax = onnx.Log(onnx.Softmax(scores, dim=1))
/// 3. product = onnx.Mul(log_softmax, one_hot_encoded)
/// if `weights` arg is nont `none` then we additionally perform
/// product = onnx.Mul(product, op.Unsqueeze(weights))
/// where unsqueezing makes the operation broadcastable.
/// 4. reduce_sum = onnx.ReduceSum(product, dim=1)
/// 5. loss = onnx.ReduceMean(reduce_sum) if reduciton == "mean"
/// onnx.ReduceSum(reduce_sum) if reduction == "sum"
/// onnx.Squeeze(reduce_sum) if reduciton == "none"
///
struct SoftmaxCrossEntropyPattern
: public OpRewritePattern<ONNXSoftmaxCrossEntropyLossOp> {
using OpRewritePattern<ONNXSoftmaxCrossEntropyLossOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ONNXSoftmaxCrossEntropyLossOp sceOp,
PatternRewriter &rewriter) const final {
auto loc = sceOp.getLoc();
onnx_mlir::OnnxBuilder create(rewriter, loc);
auto scores = sceOp.getScores();
auto labels = sceOp.getLabels();
auto weights = sceOp.getWeights();
auto scoresTy = cast<ShapedType>(scores.getType());
auto labelsTy = cast<ShapedType>(labels.getType());
SmallVector<int64_t> newLabelsShape(labelsTy.getShape());
newLabelsShape.insert(newLabelsShape.begin() + 1, scoresTy.getShape()[1]);
auto none = rewriter.create<ONNXNoneOp>(loc);
auto numClasses = (scoresTy.isDynamicDim(1))
? create.dim(scores, 1)
: create.constantInt64({scoresTy.getShape()[1]});
auto elemTy = scoresTy.getElementType();
// Compute one hot encoded labels and cast to `scores` element type.
auto oneHotValsAttr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI64Type()),
ArrayRef<int64_t>{0, 1});
auto oneHotVals = create.constant(oneHotValsAttr);
auto oneHot = create.cast(
rewriter.create<ONNXOneHotOp>(loc,
RankedTensorType::get(newLabelsShape, labelsTy.getElementType()),
labels, numClasses, oneHotVals, /*axis=*/1),
/*saturate=*/
rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), 1),
TypeAttr::get(elemTy));
// Compute logsoftmax of scores.
auto softmax =
rewriter.create<ONNXSoftmaxOp>(loc, scoresTy, scores, /*axis=*/1);
auto logSoftmax = rewriter.create<ONNXLogOp>(loc, scoresTy, softmax);
auto prod = rewriter.create<ONNXMulOp>(loc, logSoftmax, oneHot);
// Multiply by `weights` if not none.
if (auto weightTy = dyn_cast<ShapedType>(weights.getType())) {
// Unsqueeze weight from [C] to [1 x C x 1 x ... x 1] to make it
// broadcast-compliant.
llvm::SmallVector<int64_t, 4> unsqueezedShape(scoresTy.getRank(), 1);
unsqueezedShape[1] = scoresTy.getShape()[1];
llvm::SmallVector<int64_t, 4> axesList(scoresTy.getRank() - 1, 0);
std::iota(axesList.begin() + 1, axesList.end(), 2);
auto axes = create.constantInt64(axesList);
auto weightsUnsqueezed = create.unsqueeze(
RankedTensorType::get(unsqueezedShape, elemTy), weights, axes);
prod = rewriter.create<ONNXMulOp>(loc, prod, weightsUnsqueezed);
}
// Reduction across `class` (dim=1) axis.
auto axes = create.constant(onnx_mlir::createDenseArrayAttr(
rewriter, rewriter.getI64ArrayAttr({1})));
auto reducedType = createReducedType(scoresTy, 1, /*keepdims=*/true);
Value loss = rewriter.create<ONNXReduceSumOp>(loc, reducedType, prod, axes);
// ReduceMean/ReduceSum/Squeeze if reduction = mean/sum/none respectively.
// Set `axes=none` to indicate reducing all dims.
auto reduction = cast<StringAttr>(sceOp.getReductionAttr()).getValue();
if (reduction == "mean") {
if (isa<NoneType>(weights.getType())) {
loss = rewriter.create<ONNXReduceMeanOp>(loc,
RankedTensorType::get({}, elemTy), loss, none,
/*keepdims=*/0);
} else {
auto sumL = rewriter.create<ONNXReduceSumOp>(loc,
RankedTensorType::get({}, elemTy), loss, none,
/*keepdims=*/0);
// Perform einsum(one_hot, weights) as a simple way of producing
// W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]]
auto scatteredWeights = rewriter.create<ONNXEinsumOp>(loc,
RankedTensorType::get(labelsTy.getShape(), elemTy),
ValueRange{oneHot, weights}, "ij...,j->i...");
auto sumW = rewriter.create<ONNXReduceSumOp>(loc,
RankedTensorType::get({}, elemTy), scatteredWeights, none,
/*keepdims=*/0);
loss = rewriter.create<ONNXDivOp>(loc, sumL, sumW);
}
} else if (reduction == "sum") {
loss = rewriter.create<ONNXReduceSumOp>(loc,
RankedTensorType::get({}, elemTy), loss, none,
/*keepdims=*/0);
} else if (reduction == "none") {
loss = rewriter.create<ONNXSqueezeOp>(loc,
createReducedType(reducedType, 1, /*keepdims=*/false), loss, axes);
} else {
llvm_unreachable("unexpected reduction type");
}
// Negate.
loss = rewriter.create<ONNXNegOp>(loc, loss.getType(), loss);
// Second return value replacement depends if it is `none` or not.
if (isa<NoneType>(sceOp.getLogProb().getType()))
rewriter.replaceOp(sceOp, {loss, none});
else
rewriter.replaceOp(sceOp, {loss, logSoftmax});
return success();
}
};

/// Decompose `onnx.Sum` to a sequence of `onnx.Add`
struct SumToAddPattern : public OpRewritePattern<ONNXSumOp> {
using OpRewritePattern<ONNXSumOp>::OpRewritePattern;
Expand Down Expand Up @@ -1114,6 +1227,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
target.addIllegalOp<ONNXScalerOp>();
target.addIllegalOp<ONNXScatterOp>();
target.addIllegalOp<ONNXSequenceConstructOp>();
target.addIllegalOp<ONNXSoftmaxCrossEntropyLossOp>();
target.addIllegalOp<ONNXSplitV11Op>();
target.addIllegalOp<ONNXSplitV13Op>();
target.addIllegalOp<ONNXSqueezeV11Op>();
Expand Down Expand Up @@ -1190,6 +1304,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
patterns.insert<InstanceNormIntoLayerNormPattern>(context);
patterns.insert<GroupNormIntoLayerNormPattern1>(context);
patterns.insert<GroupNormIntoLayerNormPattern2>(context);
patterns.insert<SoftmaxCrossEntropyPattern>(context);
patterns.insert<SumToAddPattern>(context);

// TODO: consider whether to include SoftmaxPattern here
Expand Down
Loading

0 comments on commit cb9a949

Please sign in to comment.