Skip to content

Commit

Permalink
Softmax op dynamic shape addition for stablehlo (#2918)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek-TyRnT <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
Abhishek-TyRnT and AlexandreEichenberger authored Aug 27, 2024
1 parent 4e99738 commit 4f0a141
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 53 deletions.
47 changes: 34 additions & 13 deletions src/Conversion/ONNXToStablehlo/Math/Softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {

Value operand = operands[0];
assert(
hasStaticShape(operand.getType()) && "Only Static shapes are accepted");
bool isStaticShape = hasStaticShape(operand.getType());

Location loc = op->getLoc();
Type outputType = *op->result_type_begin();
Expand Down Expand Up @@ -151,29 +150,51 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern {
// Sum of the all the exponents for the denominator
SmallVector<int64_t> reducedShape =
getReductionShape(ExpOutputType, axes, false);
ShapedType ReducedShapeType = mlir::cast<ShapedType>(
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
ShapedType ReducedShapeType;
if (isStaticShape) {
ReducedShapeType = mlir::cast<ShapedType>(
RankedTensorType::get(reducedShape, ExpOutputType.getElementType()));
} else {
SmallVector<int64_t> ReducedShapeVector =
getReductionShape(ExpOutputType, axes, true);
ReducedShapeType = mlir::cast<ShapedType>(RankedTensorType::get(
ReducedShapeVector, ExpOutputType.getElementType()));
}
Value identity = rewriter.create<stablehlo::ConstantOp>(
loc, rewriter.getZeroAttr(ExpOutputType.getElementType()));
Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity,
reducedShape, axes, rewriter, false, ReducedShapeType);
reducedShape, axes, rewriter, !isStaticShape, ReducedShapeType);

if (ReduceSum == nullptr)
return failure();

SmallVector<int64_t> broadcast_dims =
getBroadcastDims(ElementwiseExpStableHLO, axes);
Value BroadCastOp =
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
Value BroadCastOp;
if (isStaticShape) {
SmallVector<int64_t> broadcast_dims =
getBroadcastDims(ElementwiseExpStableHLO, axes);
BroadCastOp =
rewriter.create<stablehlo::BroadcastInDimOp>(loc, ExpOutputType,
ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims));
} else {
mlir::Value OutputDimensions =
rewriter.create<shape::ShapeOfOp>(loc, operand);
SmallVector<int64_t> DimIndex;
for (int64_t i = 0; i < ExpOutputType.getRank(); i++)
DimIndex.push_back(i);
BroadCastOp = rewriter.create<stablehlo::DynamicBroadcastInDimOp>(loc,
ExpOutputType, ReduceSum, OutputDimensions,
rewriter.getDenseI64ArrayAttr(DimIndex));
}
if (BroadCastOp == nullptr)
return failure();

Value Softmax_output = rewriter.create<stablehlo::DivOp>(
Value SoftmaxOutput = rewriter.create<stablehlo::DivOp>(
loc, ElementwiseExpStableHLO, BroadCastOp);
if (Softmax_output == nullptr)

if (SoftmaxOutput == nullptr)
return failure();

rewriter.replaceOp(op, Softmax_output);
rewriter.replaceOp(op, SoftmaxOutput);
return success();
}
};
Expand Down
80 changes: 40 additions & 40 deletions test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,46 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
"func.return"(%0) : (tensor<?x20x30xf32>) -> ()
}

//TODO: Renable dynamic shape test
// func.func @test_softmax_dynamic
// ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// [[CST_2_:%.+]] = arith.constant 2 : index
// [[CST_1_:%.+]] = arith.constant 1 : index
// [[CST_0_:%.+]] = arith.constant 0 : index
// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// separator of consecutive DAGs
// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// separator of consecutive DAGs
// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// separator of consecutive DAGs
// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
// return [[VAR_28_]] : tensor<?x20x30xf32>
// }
// CHECK-LABEL: func.func @test_softmax_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index
// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor<?x20x30xf32>
// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index
// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32>
// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32> -> tensor<3xindex>
// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor<?x20x30xf32>
// CHECK: return [[VAR_28_]] : tensor<?x20x30xf32>
// CHECK: }


// -----

Expand Down

0 comments on commit 4f0a141

Please sign in to comment.