Skip to content

Commit

Permalink
support scalar reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Mar 1, 2024
1 parent ddcc491 commit 5009a3f
Show file tree
Hide file tree
Showing 7 changed files with 463 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tao_compiler/mlir/disc/disc_compiler.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addNestedPass<FuncOp>(
createCanonicalizerPass(cano_rewrite_config, disablePatterns));
pm.addNestedPass<FuncOp>(disc_ral::createDiscMemRefCSEPass());
// pm.addNestedPass<FuncOp>(disc_ral::createDiscMemRefCSEPass());
// convert linearizeOp/delinearizeOp to std dialect.
pm.addNestedPass<FuncOp>(disc_ral::createDiscConvertShapeToStandardPass());
pm.addNestedPass<FuncOp>(
Expand Down
38 changes: 34 additions & 4 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ StringRef fusionTypeToString(FusionType ft) {
return "kRowReduction";
case FusionType::kColReduction:
return "kColReduction";
case FusionType::kScalarReduction:
return "kScalarReduction";
case FusionType::kInput:
return "kInput";
case FusionType::kStitch:
Expand Down Expand Up @@ -205,6 +207,8 @@ FusionType fusionTypeFromString(StringRef ft) {
return FusionType::kRowReduction;
} else if (ft == "kColReduction") {
return FusionType::kColReduction;
} else if (ft == "kScalarReduction") {
return FusionType::kScalarReduction;
} else if (ft == "kInput") {
return FusionType::kInput;
} else if (ft == "kStitch") {
Expand Down Expand Up @@ -479,6 +483,21 @@ bool isRowReduction(Operation* op) {
return true;
}

bool isRank2ScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
return false;
int rank = op->getOperand(2).getType().cast<MemRefType>().getRank();
// TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid
// reshape to scalar tensor behand reduce op
Operation* reshapeOp = *op->getOperand(2).getUsers().begin();
if (isa<ReshapeOp>(reshapeOp) &&
reshapeOp->getOperand(1).getType().cast<MemRefType>().getRank() == 0) {
return true;
}
return false;
}

// Returns true if this op is a rank-2 column reduction.
bool isRank2ColReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
Expand All @@ -487,7 +506,8 @@ bool isRank2ColReduction(Operation* op) {

int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
auto dimensions = reduce_op.getDimensions().getValues<int64_t>();
return ((*dimensions.begin() == 0) && (rank == 2));
return ((*dimensions.begin() == 0) && (rank == 2)) &&
!isRank2ScalarReduction(op);
}

// Return true if this op is a rank-2 transpose
Expand Down Expand Up @@ -558,6 +578,9 @@ bool initFusionPatternBase(ShapeAnalysis& shapeAnalysis,
inferredFusionType = FusionType::kColReduction;
inferredDominantOp = op;
}
} else if (isRank2ScalarReduction(op)) {
inferredFusionType = FusionType::kScalarReduction;
inferredDominantOp = op;
} else if (isFusible(op)) {
// Ignore if already a kRowReduction or kColReduction, otherwise update
// the fusion type to kLoop and dominant op to current op. This supposes
Expand Down Expand Up @@ -750,6 +773,7 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis)
FusionType fusionType = FusionType::kNone;
auto deviceAttr = op->getAttrOfType<StringAttr>(kDiscPlaceAssignment);
auto fusionTypeAttr = op->getAttrOfType<StringAttr>(kDiscFusionTypeAttrName);

if (fusionTypeAttr) {
fusionType = fusionTypeFromString(fusionTypeAttr.getValue());
}
Expand All @@ -773,6 +797,7 @@ FusionPattern::FusionPattern(lmhlo::FusionOp op, ShapeAnalysis* shape_analysis)
FusionStrategy& strategy =
getFusionStrategy(deviceAttr.getValue(), strategyStr);
bool status = strategy.initFusionPattern(*shape_analysis, *this);
fusion_type_ = fusionType;
assert(status);
(void)(status);
}
Expand Down Expand Up @@ -1451,7 +1476,8 @@ bool BaseCpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool BaseGpuFusionStrategy::isFusible(Operation* op) {
// Only rank-2 tensor -> rank-1 tensor reduction are supported now.
if (isa<lmhlo::ReduceOp>(op) &&
(!isRank2RowReduction(op) && !isRank2ColReduction(op)))
(!isRank2RowReduction(op) && !isRank2ColReduction(op) &&
!isRank2ScalarReduction(op))) // || isScalarReduction(op)))
return false;

if (isa<lmhlo::TransposeOp>(op) && isRank2or3Transpose(op)) return false;
Expand Down Expand Up @@ -1481,8 +1507,12 @@ bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
bool has_rank2_scalar_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ScalarReduction(op); });
int cnt = has_rank2_row_reduction + has_rank2_col_reduction +
has_rank2_scalar_reduction;
if (cnt >= 2) {
return false;
}

Expand Down
1 change: 1 addition & 0 deletions tao_compiler/mlir/disc/transforms/fusion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ enum FusionType {
// kInput fusion pattern and all reduce ops of the fused pattern are column
// reduction
kColReduction,
kScalarReduction,
// kInput fusion pattern
kInput,
// Stitch Fusion pattern
Expand Down
54 changes: 47 additions & 7 deletions tao_compiler/mlir/disc/transforms/fusion_utils_stitch_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,32 @@ namespace disc_ral {

////////////////////// Stitch GPU FusionStrategy Implemenation /////////
////////////////////////////////////////////////////////////////////////

bool isScalarReduction(Operation* op) {
auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
if (!reduce_op || reduce_op.getDimensions().getNumElements() != 1)
return false;
int rank = op->getOperand(2).getType().cast<MemRefType>().getRank();
// TODO(yancey): rewrite scalar reduction result to scalar tensor to avoid
// reshape to scalar tensor behand reduce op
Operation* reshapeOp = *op->getOperand(2).getUsers().begin();
if (reshapeOp && isa<lmhlo::ReshapeOp>(reshapeOp) &&
reshapeOp->getOperand(1).getType().cast<MemRefType>().getRank() == 0) {
return true;
}
return false;
}
bool findValidReductionOps(FusionPatternBase& target,
SmallVectorImpl<Operation*>& row_reductions,
SmallVectorImpl<Operation*>& col_reductions) {
SmallVectorImpl<Operation*>& col_reductions,
SmallVectorImpl<Operation*>& scalar_reductions) {
row_reductions.clear();
col_reductions.clear();
auto& op_list = target.getOpList();
for (Operation* op : op_list) {
if (!isa<lmhlo::ReduceOp>(op)) continue;
if (isRank2RowReduction(op)) {
row_reductions.push_back(op);
} else if (isRank2ColReduction(op)) {
} else if (isRank2ColReduction(op) || isScalarReduction(op)) {
// Middle col-reduction is not supported currently. We may support it with
// AStitch technique in the future.
int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
Expand All @@ -41,7 +55,11 @@ bool findValidReductionOps(FusionPatternBase& target,
}
}
}
col_reductions.push_back(op);
if (isScalarReduction(op)) {
scalar_reductions.push_back(op);
} else {
col_reductions.push_back(op);
}
} else {
// Non supported reduction type.
return false;
Expand All @@ -65,8 +83,12 @@ bool StitchGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
bool has_rank2_col_reduction =
llvm::any_of(target.getOpList(),
[](Operation* op) { return isRank2ColReduction(op); });
bool has_rank2_scalar_reduction = llvm::any_of(
target.getOpList(), [](Operation* op) { return isScalarReduction(op); });

if (has_rank2_row_reduction && has_rank2_col_reduction) {
int cnt = has_rank2_row_reduction + has_rank2_col_reduction +
has_rank2_scalar_reduction;
if (cnt >= 2) {
return false;
}

Expand Down Expand Up @@ -371,7 +393,9 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(

SmallVector<Operation*, 4> row_reductions;
SmallVector<Operation*, 4> col_reductions;
if (!findValidReductionOps(fusion_pattern, row_reductions, col_reductions)) {
SmallVector<Operation*, 4> scalar_reductions;
if (!findValidReductionOps(fusion_pattern, row_reductions, col_reductions,
scalar_reductions)) {
LLVM_DEBUG(llvm::dbgs() << "Check reduction ops failed.");
return false;
}
Expand Down Expand Up @@ -440,7 +464,23 @@ bool StitchGpuFusionStrategy::findFusionPatternTypeAndSubroot(
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return isRank2ColReduction(op) &&
return (isRank2ColReduction(op)) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
}
} else if (!scalar_reductions.empty()) {
fusion_type = FusionType::kScalarReduction;
dominant_op = scalar_reductions.back();
Value ref = cast<lmhlo::LmhloOp>(dominant_op).getResultBuffer();
Value ref_shape = getEffectiveShape(fusion_pattern, ref);
if (!llvm::all_of(results, [&](Value result) {
auto op = fusion_pattern.findLastWriter(result);
if (op == dominant_op) {
return true;
}
Value shape = getEffectiveShape(fusion_pattern, result);
return (isRank2ColReduction(op)) &&
shapeAnalysis.isShapeEqual(ref_shape, shape);
})) {
return false;
Expand Down
Loading

0 comments on commit 5009a3f

Please sign in to comment.