Skip to content

Commit

Permalink
Merge branch 'dev_int8_conv' of https://github.com/Oneflow-Inc/oneflow
Browse files Browse the repository at this point in the history
…into dev_int8_conv
  • Loading branch information
hjchen2 committed Sep 5, 2023
2 parents d414feb + c384f31 commit 1388490
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 112 deletions.
1 change: 0 additions & 1 deletion oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,6 @@ Maybe<void> LazyJobBuildAndInferCtx::Complete() {
JUST(DoPass("DoParallelCastBeforeWideningTypeCast"));
JUST(DoPass("FuseCastScalePass"));
JUST(DoPass("PruneParallelCastOpsPass"));
JUST(DoPass("PruneRedundantQuantizationOpsPass"));
JUST(DoPass("FuseUpdateOpsPass"));
JUST(DoPass("FuseModelUpdateCastOpsPass"));
JUST(DoPass("MultiTensorModelUpdatePass"));
Expand Down
111 changes: 0 additions & 111 deletions oneflow/core/job_rewriter/prune_redundant_quantization_op_pass.cpp

This file was deleted.

42 changes: 42 additions & 0 deletions oneflow/ir/lib/OneFlow/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,47 @@ LogicalResult FusedConsecutiveAddPattern<Add2Op>::matchAndRewrite(Add2Op op,
return TryFusedConsecutiveAdd<Add2Op>(op, {op.getIn0(), op.getIn1()}, rewriter);
}

struct PruneReduntantQuantizationOpsPattern : public OpInterfaceRewritePattern<UserOpCompatible> {
explicit PruneReduntantQuantizationOpsPattern(mlir::MLIRContext* context)
: OpInterfaceRewritePattern<UserOpCompatible>(context, /*benefit=*/1) {}

public:
LogicalResult matchAndRewrite(UserOpCompatible op, PatternRewriter& rewriter) const override {
DenseMap<Value, SmallVector<QuantizationOp, 4>> quantOps;
DenseMap<Value, SmallVector<DynamicQuantizationOp, 4>> dynamic_quantOps;
for (auto result : op->getResults()) {
for (auto u : result.getUsers()) {
if (auto q = llvm::dyn_cast<QuantizationOp>(u)) { quantOps[result].push_back(q); }
if (auto q = llvm::dyn_cast<DynamicQuantizationOp>(u)) {
dynamic_quantOps[result].push_back(q);
}
}
}
bool pruned = false;
for (const auto& it : quantOps) {
auto q0 = it.second[0];
for (auto q : it.second) {
if (q != q0) {
q->replaceAllUsesWith(q0->getResults());
q->erase();
pruned = true;
}
}
}
for (const auto& it : dynamic_quantOps) {
auto q0 = it.second[0];
for (auto q : it.second) {
if (q != q0) {
q->replaceAllUsesWith(q0->getResults());
q->erase();
pruned = true;
}
}
}
return success(pruned);
}
};

struct AutoNhwcPattern : public OpInterfaceRewritePattern<NCHWCompatible> {
explicit AutoNhwcPattern(mlir::MLIRContext* context)
: OpInterfaceRewritePattern<NCHWCompatible>(context, /*benefit=*/1) {}
Expand Down Expand Up @@ -1155,6 +1196,7 @@ void populatePreConvertInferenceOp(::mlir::RewritePatternSet& patterns) {

void populateConvertInferenceOp(::mlir::RewritePatternSet& patterns) {
populateFuseConv2DBatchNormPattern(patterns);
patterns.add<PruneReduntantQuantizationOpsPattern>(patterns.getContext());
}

void populatePostConvertInferenceOp(::mlir::RewritePatternSet& patterns) {
Expand Down

0 comments on commit 1388490

Please sign in to comment.