diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 07e38a57da..fcb887fae5 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -190,6 +190,10 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { instrumentActions |= (1 << 3) - 1; // Also enable instrumentation of signatures. instrumentSignatures = "onnx.*,zhigh.*"; + if (enableConstantOpProfiling) { + instrumentOps += ",krnl.global"; + instrumentSignatures += ",krnl.global"; + } } // Insert an instrumentation after lowering onnx to zhigh to get profiling / @@ -198,10 +202,10 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { // not to include timing of the signature printing. if (hasSignatureInstrumentation(onnx_mlir::InstrumentStages::ZHigh)) pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( - instrumentSignatures, instrumentOnnxNode)); + instrumentSignatures, instrumentOnnxNode, enableConstantOpProfiling)); if (hasInstrumentation(onnx_mlir::InstrumentStages::ZHigh)) - pm.addNestedPass( - onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); + pm.addNestedPass(onnx_mlir::createInstrumentPass( + instrumentOps, instrumentActions, enableConstantOpProfiling)); } void normalizeMemRefsPasses(mlir::PassManager &pm) { @@ -298,7 +302,7 @@ void addPassesNNPA(mlir::OwningOpRef &module, "stage is currently unsupported"); if (hasInstrumentation(onnx_mlir::InstrumentStages::ZLow)) pm.addNestedPass(onnx_mlir::createInstrumentPass( - instrumentOps, instrumentControlBits)); + instrumentOps, instrumentControlBits, enableConstantOpProfiling)); } } diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 6a92557749..c49d3f6b8d 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -93,6 +93,7 @@ bool disableConstantProp; // onnx-mlir only std::vector extraLibPaths; // onnx-mlir only std::vector extraLibs; // onnx-mlir only ProfileIRs profileIR; // onnx-mlir only +bool enableConstantOpProfiling; // onnx-mlir only OptReport optReport; // onnx-mlir only bool enableTiming; // onnx-mlir only bool enableBoundCheck; // onnx-mlir only @@ -700,6 +701,14 @@ static llvm::cl::opt profileIROpt("profile-ir", APPLY_TO_ACCELERATORS(ACCEL_PROFILEIR_CL_ENUM)), llvm::cl::init(ProfileIRs::None), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt enableConstantOpProfilingOpt( + "enable-constant-op-profiling", + llvm::cl::desc("Normally we disable generating profiling information for\n" + "constant-generating operations.\nUse this option to force" + " profiling for all ops (default is false)."), + llvm::cl::location(enableConstantOpProfiling), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + static llvm::cl::opt optReportOpt("opt-report", llvm::cl::desc("Provide information on a specific compiler optimization:"), llvm::cl::location(optReport), diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 294839eb62..c790727948 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -138,6 +138,7 @@ extern bool disableConstantProp; // onnx-mlir only extern std::vector extraLibPaths; // onnx-mlir only extern std::vector extraLibs; // onnx-mlir only extern ProfileIRs profileIR; // onnx-mlir only +extern bool enableConstantOpProfiling; // onnx-mlir only extern OptReport optReport; // onnx-mlir only extern bool enableTiming; // onnx-mlir only extern bool enableBoundCheck; // onnx-mlir only diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index c975ce31f6..422f0b511a 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -167,15 +167,19 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, instrumentActions |= (1 << 3) - 1; // Also enable instrumentation of signatures. instrumentSignatures = "onnx.*"; + if (enableConstantOpProfiling) { + instrumentOps += ",krnl.global"; + instrumentSignatures += ",krnl.global"; + } } // Add createInstrument (timing) second so that it will guarantee not to // include timing of the signature printing. if (hasSignatureInstrumentation(onnx_mlir::InstrumentStages::Onnx)) pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( - instrumentSignatures, instrumentOnnxNode)); + instrumentSignatures, instrumentOnnxNode, enableConstantOpProfiling)); if (hasInstrumentation(onnx_mlir::InstrumentStages::Onnx)) - pm.addNestedPass( - onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); + pm.addNestedPass(onnx_mlir::createInstrumentPass( + instrumentOps, instrumentActions, enableConstantOpProfiling)); } void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, @@ -276,7 +280,7 @@ void addKrnlToLLVMPasses( pm.addPass(mlir::memref::createFoldMemRefAliasOpsPass()); - if (profileIR) + if (profileIR && !enableConstantOpProfiling) pm.addNestedPass(onnx_mlir::createInstrumentCleanupPass()); if (enableBoundCheck) diff --git a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp index 23c75be41a..5dd234370d 100644 --- a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp +++ b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp @@ -52,14 +52,17 @@ class InstrumentONNXSignaturePass OperationPass>() { signaturePattern = pass.signaturePattern; nodeNamePattern = pass.nodeNamePattern; + skipConstants = pass.skipConstants; } - InstrumentONNXSignaturePass( - const std::string opPattern, const std::string nodePattern) - : signaturePattern(opPattern), nodeNamePattern(nodePattern) {} + InstrumentONNXSignaturePass(const std::string opPattern, + const std::string nodePattern, bool skipConstants) + : signaturePattern(opPattern), nodeNamePattern(nodePattern), + skipConstants(skipConstants) {} private: std::string signaturePattern; std::string nodeNamePattern; + bool skipConstants; public: StringRef getArgument() const override { @@ -113,6 +116,8 @@ class InstrumentONNXSignaturePass isa(op)) { // Always skip function dialects (such as function call/return), as well // as ONNX instrument operations. + } else if (skipConstants && (isa(op))) { + // Asked to skip constants, do nothing. } else { // If has both a print of data and print of signature, favor the // printing of data as it also will print the signature. @@ -141,6 +146,8 @@ class InstrumentONNXSignaturePass * Create an instrumentation pass. */ std::unique_ptr onnx_mlir::createInstrumentONNXSignaturePass( - const std::string pattern, const std::string nodePattern) { - return std::make_unique(pattern, nodePattern); + const std::string pattern, const std::string nodePattern, + bool profileConstantOps) { + return std::make_unique( + pattern, nodePattern, !profileConstantOps); } diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index c696845b59..114db9eb53 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -56,14 +56,15 @@ std::unique_ptr createConstPropONNXToONNXPass(); /// Pass for instrument the ops in specific stage. std::unique_ptr createInstrumentPass(); std::unique_ptr createInstrumentPass( - const std::string &ops, unsigned actions); + const std::string &ops, unsigned actions, bool profileConstantOps); /// Pass for instrument cleanup. std::unique_ptr createInstrumentCleanupPass(); /// Passes for instrumenting the ONNX ops to print their operand type /// signatures at runtime. std::unique_ptr createInstrumentONNXSignaturePass( - const std::string opPattern, const std::string nodePattern); + const std::string opPattern, const std::string nodePattern, + bool profileConstantOps); /// Pass for simplifying shape-related ONNX operations. std::unique_ptr createSimplifyShapeRelatedOpsPass(); diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 1af70b866e..55f73c309b 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -75,7 +75,7 @@ void registerOMPasses(int optLevel) { }); mlir::registerPass([]() -> std::unique_ptr { - return createInstrumentONNXSignaturePass("NONE", "NONE"); + return createInstrumentONNXSignaturePass("NONE", "NONE", false); }); mlir::registerPass([]() -> std::unique_ptr { diff --git a/src/Transform/InstrumentPass.cpp b/src/Transform/InstrumentPass.cpp index 9ee74cbcfe..280aad592a 100644 --- a/src/Transform/InstrumentPass.cpp +++ b/src/Transform/InstrumentPass.cpp @@ -70,11 +70,14 @@ class InstrumentPass llvm::cl::desc("instrument runtime reports memory usage"), llvm::cl::init(false)}; + Option skipConstants{*this, "skip-constants", + llvm::cl::desc("do not instrument constant ops"), llvm::cl::init(true)}; + InstrumentPass() : allowedOps(/*emptyIsNone*/ true){}; InstrumentPass(const InstrumentPass &pass) : mlir::PassWrapper>(), allowedOps(/*emptyIsNone*/ true) {} - InstrumentPass(const std::string &ops, unsigned actions) + InstrumentPass(const std::string &ops, unsigned actions, bool skipConstants) : allowedOps(/*emptyIsNone*/ true) { this->instrumentOps = ops; unsigned long long tag = actions; @@ -82,6 +85,7 @@ class InstrumentPass this->instrumentAfter = IS_INSTRUMENT_AFTER_OP(tag); this->reportTime = IS_INSTRUMENT_REPORT_TIME(tag); this->reportMemory = IS_INSTRUMENT_REPORT_MEMORY(tag); + this->skipConstants = skipConstants; } private: @@ -144,7 +148,10 @@ class InstrumentPass if (op->getNumResults() == 1 && isa(op->getResult(0).getType())) return WalkResult::advance(); // Skip other instrument ops. - if (isa(op) || isa(op)) + if (isa(op)) + return WalkResult::advance(); + if (skipConstants && isa(op)) + // Asked to skip constants, do nothing. return WalkResult::advance(); std::string opName = op->getName().getStringRef().str(); @@ -185,6 +192,6 @@ std::unique_ptr onnx_mlir::createInstrumentPass() { } std::unique_ptr onnx_mlir::createInstrumentPass( - const std::string &ops, unsigned actions) { - return std::make_unique(ops, actions); + const std::string &ops, unsigned actions, bool profileConstantOps) { + return std::make_unique(ops, actions, !profileConstantOps); }