Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
instrumentActions |= (1 << 3) - 1;
// Also enable instrumentation of signatures.
instrumentSignatures = "onnx.*,zhigh.*";
if (enableConstantOpProfiling) {
instrumentOps += ",krnl.global";
instrumentSignatures += ",krnl.global";
}
}

// Insert an instrumentation after lowering onnx to zhigh to get profiling /
Expand All @@ -198,10 +202,10 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// not to include timing of the signature printing.
if (hasSignatureInstrumentation(onnx_mlir::InstrumentStages::ZHigh))
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatures, instrumentOnnxNode));
instrumentSignatures, instrumentOnnxNode, enableConstantOpProfiling));
if (hasInstrumentation(onnx_mlir::InstrumentStages::ZHigh))
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentPass(
instrumentOps, instrumentActions, enableConstantOpProfiling));
}

void normalizeMemRefsPasses(mlir::PassManager &pm) {
Expand Down Expand Up @@ -298,7 +302,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
"stage is currently unsupported");
if (hasInstrumentation(onnx_mlir::InstrumentStages::ZLow))
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentPass(
instrumentOps, instrumentControlBits));
instrumentOps, instrumentControlBits, enableConstantOpProfiling));
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ bool disableConstantProp; // onnx-mlir only
std::vector<std::string> extraLibPaths; // onnx-mlir only
std::vector<std::string> extraLibs; // onnx-mlir only
ProfileIRs profileIR; // onnx-mlir only
bool enableConstantOpProfiling; // onnx-mlir only
OptReport optReport; // onnx-mlir only
bool enableTiming; // onnx-mlir only
bool enableBoundCheck; // onnx-mlir only
Expand Down Expand Up @@ -700,6 +701,14 @@ static llvm::cl::opt<ProfileIRs, true> profileIROpt("profile-ir",
APPLY_TO_ACCELERATORS(ACCEL_PROFILEIR_CL_ENUM)),
llvm::cl::init(ProfileIRs::None), llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableConstantOpProfilingOpt(
"enable-constant-op-profiling",
llvm::cl::desc("Normally we disable generating profiling information for\n"
"constant-generating operations.\nUse this option to force"
" profiling for all ops (default is false)."),
llvm::cl::location(enableConstantOpProfiling), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<OptReport, true> optReportOpt("opt-report",
llvm::cl::desc("Provide information on a specific compiler optimization:"),
llvm::cl::location(optReport),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ extern bool disableConstantProp; // onnx-mlir only
extern std::vector<std::string> extraLibPaths; // onnx-mlir only
extern std::vector<std::string> extraLibs; // onnx-mlir only
extern ProfileIRs profileIR; // onnx-mlir only
extern bool enableConstantOpProfiling; // onnx-mlir only
extern OptReport optReport; // onnx-mlir only
extern bool enableTiming; // onnx-mlir only
extern bool enableBoundCheck; // onnx-mlir only
Expand Down
12 changes: 8 additions & 4 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,19 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
instrumentActions |= (1 << 3) - 1;
// Also enable instrumentation of signatures.
instrumentSignatures = "onnx.*";
if (enableConstantOpProfiling) {
instrumentOps += ",krnl.global";
instrumentSignatures += ",krnl.global";
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious about why we need to add krnl.global here while this is for profiling onnx ops. I though there were no krnl.global ops at this IR level.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check

}
// Add createInstrument (timing) second so that it will guarantee not to
// include timing of the signature printing.
if (hasSignatureInstrumentation(onnx_mlir::InstrumentStages::Onnx))
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatures, instrumentOnnxNode));
instrumentSignatures, instrumentOnnxNode, enableConstantOpProfiling));
if (hasInstrumentation(onnx_mlir::InstrumentStages::Onnx))
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentPass(
instrumentOps, instrumentActions, enableConstantOpProfiling));
}

void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
Expand Down Expand Up @@ -276,7 +280,7 @@ void addKrnlToLLVMPasses(

pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass());

if (profileIR)
if (profileIR && !enableConstantOpProfiling)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentCleanupPass());

if (enableBoundCheck)
Expand Down
17 changes: 12 additions & 5 deletions src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ class InstrumentONNXSignaturePass
OperationPass<func::FuncOp>>() {
signaturePattern = pass.signaturePattern;
nodeNamePattern = pass.nodeNamePattern;
skipConstants = pass.skipConstants;
}
InstrumentONNXSignaturePass(
const std::string opPattern, const std::string nodePattern)
: signaturePattern(opPattern), nodeNamePattern(nodePattern) {}
InstrumentONNXSignaturePass(const std::string opPattern,
const std::string nodePattern, bool skipConstants)
: signaturePattern(opPattern), nodeNamePattern(nodePattern),
skipConstants(skipConstants) {}

private:
std::string signaturePattern;
std::string nodeNamePattern;
bool skipConstants;

public:
StringRef getArgument() const override {
Expand Down Expand Up @@ -113,6 +116,8 @@ class InstrumentONNXSignaturePass
isa<ONNXPrintSignatureOp, KrnlInstrumentOp>(op)) {
// Always skip function dialects (such as function call/return), as well
// as ONNX instrument operations.
} else if (skipConstants && (isa<ONNXConstantOp, KrnlGlobalOp>(op))) {
// Asked to skip constants, do nothing.
} else {
// If has both a print of data and print of signature, favor the
// printing of data as it also will print the signature.
Expand Down Expand Up @@ -141,6 +146,8 @@ class InstrumentONNXSignaturePass
* Create an instrumentation pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass(
const std::string pattern, const std::string nodePattern) {
return std::make_unique<InstrumentONNXSignaturePass>(pattern, nodePattern);
const std::string pattern, const std::string nodePattern,
bool profileConstantOps) {
return std::make_unique<InstrumentONNXSignaturePass>(
pattern, nodePattern, !profileConstantOps);
}
5 changes: 3 additions & 2 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ std::unique_ptr<mlir::Pass> createConstPropONNXToONNXPass();
/// Pass for instrument the ops in specific stage.
std::unique_ptr<mlir::Pass> createInstrumentPass();
std::unique_ptr<mlir::Pass> createInstrumentPass(
const std::string &ops, unsigned actions);
const std::string &ops, unsigned actions, bool profileConstantOps);
/// Pass for instrument cleanup.
std::unique_ptr<mlir::Pass> createInstrumentCleanupPass();

/// Passes for instrumenting the ONNX ops to print their operand type
/// signatures at runtime.
std::unique_ptr<mlir::Pass> createInstrumentONNXSignaturePass(
const std::string opPattern, const std::string nodePattern);
const std::string opPattern, const std::string nodePattern,
bool profileConstantOps);

/// Pass for simplifying shape-related ONNX operations.
std::unique_ptr<mlir::Pass> createSimplifyShapeRelatedOpsPass();
Expand Down
2 changes: 1 addition & 1 deletion src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ void registerOMPasses(int optLevel) {
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return createInstrumentONNXSignaturePass("NONE", "NONE");
return createInstrumentONNXSignaturePass("NONE", "NONE", false);
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
Expand Down
15 changes: 11 additions & 4 deletions src/Transform/InstrumentPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,22 @@ class InstrumentPass
llvm::cl::desc("instrument runtime reports memory usage"),
llvm::cl::init(false)};

Option<bool> skipConstants{*this, "skip-constants",
llvm::cl::desc("do not instrument constant ops"), llvm::cl::init(true)};

InstrumentPass() : allowedOps(/*emptyIsNone*/ true){};
InstrumentPass(const InstrumentPass &pass)
: mlir::PassWrapper<InstrumentPass, OperationPass<func::FuncOp>>(),
allowedOps(/*emptyIsNone*/ true) {}
InstrumentPass(const std::string &ops, unsigned actions)
InstrumentPass(const std::string &ops, unsigned actions, bool skipConstants)
: allowedOps(/*emptyIsNone*/ true) {
this->instrumentOps = ops;
unsigned long long tag = actions;
this->instrumentBefore = IS_INSTRUMENT_BEFORE_OP(tag);
this->instrumentAfter = IS_INSTRUMENT_AFTER_OP(tag);
this->reportTime = IS_INSTRUMENT_REPORT_TIME(tag);
this->reportMemory = IS_INSTRUMENT_REPORT_MEMORY(tag);
this->skipConstants = skipConstants;
}

private:
Expand Down Expand Up @@ -144,7 +148,10 @@ class InstrumentPass
if (op->getNumResults() == 1 && isa<NoneType>(op->getResult(0).getType()))
return WalkResult::advance();
// Skip other instrument ops.
if (isa<KrnlInstrumentOp>(op) || isa<ONNXPrintSignatureOp>(op))
if (isa<ONNXPrintSignatureOp, KrnlInstrumentOp>(op))
return WalkResult::advance();
if (skipConstants && isa<ONNXConstantOp, KrnlGlobalOp>(op))
// Asked to skip constants, do nothing.
return WalkResult::advance();

std::string opName = op->getName().getStringRef().str();
Expand Down Expand Up @@ -185,6 +192,6 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentPass() {
}

std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentPass(
const std::string &ops, unsigned actions) {
return std::make_unique<InstrumentPass>(ops, actions);
const std::string &ops, unsigned actions, bool profileConstantOps) {
return std::make_unique<InstrumentPass>(ops, actions, !profileConstantOps);
}
Loading