-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add enable_profiling in runoptions #26846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -155,8 +155,8 @@ std::string ComposeSeriesName(const GraphViewer& graph_viewer) { | |
| class SessionScope { | ||
| public: | ||
| friend class KernelScope; | ||
| SessionScope(const SessionState& session_state, const ExecutionFrame& frame) | ||
| : session_state_(session_state) | ||
| SessionScope(const SessionState& session_state, const ExecutionFrame& frame, profiling::Profiler* run_profiler = nullptr) | ||
| : session_state_(session_state), run_profiler_(run_profiler) | ||
| #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) | ||
| , | ||
| frame_(frame) | ||
|
|
@@ -173,12 +173,18 @@ class SessionScope { | |
| #endif | ||
| #ifdef DEBUG_NODE_INPUTS_OUTPUTS | ||
| , | ||
| dump_context_{ | ||
| session_state_.GetGraphExecutionCounter(), 0} | ||
| dump_context_{session_state_.GetGraphExecutionCounter(), 0} | ||
| #endif | ||
| { | ||
| if (session_state_.Profiler().IsEnabled()) { | ||
| session_start_ = session_state.Profiler().Start(); | ||
| bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am still not convinced that we should allow both profilers to run in parallel. Do you have a use case for that? What would be the purpose to collect the same data? If someone wants continuous profiling, would it not be the same thing as running it with RunOptons?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This depends on how we want to handle the case when both run-level and session-level profiling are enabled. For example, when a user calls Session::Run with both run-level and session-level profiling enabled, there will be two profilers active: a local run_profiler and the session_profiler_ owned by InferenceSession. The current implementation guarantees that two JSON files are generated, and that the events recorded in the run-level profiling output are a strict subset of those in the session-level profiling output. In this scenario, each operator execution generates two identical profiling events: one is recorded by the session-level profiler, and the other is recorded by the run-level profiler. |
||
| bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| session_start_ = std::chrono::high_resolution_clock::now(); | ||
| if (session_profiling_enabled) { | ||
| session_state_.Profiler().Start(session_start_); | ||
| } | ||
| if (run_profiling_enabled) { | ||
| run_profiler_->Start(session_start_); | ||
| } | ||
|
|
||
| auto& logger = session_state_.Logger(); | ||
|
|
@@ -225,9 +231,17 @@ class SessionScope { | |
| } | ||
| #endif | ||
|
|
||
| if (session_state_.Profiler().IsEnabled()) { | ||
| session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_); | ||
| bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); | ||
| bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled(); | ||
|
|
||
| auto now = std::chrono::high_resolution_clock::now(); | ||
| if (session_profiling_enabled) { | ||
| session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now); | ||
| } | ||
| if (run_profiling_enabled) { | ||
| run_profiler_->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now); | ||
| } | ||
|
|
||
| #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) | ||
| auto& logger = session_state_.Logger(); | ||
| for (auto i : frame_.GetStaticMemorySizeInfo()) { | ||
|
|
@@ -254,6 +268,7 @@ class SessionScope { | |
|
|
||
| private: | ||
| const SessionState& session_state_; | ||
| profiling::Profiler* run_profiler_; | ||
| TimePoint session_start_; | ||
| #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) | ||
| const ExecutionFrame& frame_; | ||
|
|
@@ -340,16 +355,23 @@ class KernelScope { | |
| utils::DumpNodeInputs(dump_context_, kernel_context_, kernel_.Node(), session_state_, session_scope_.dump_analysis_); | ||
| #endif | ||
|
|
||
| #ifdef ENABLE_NVTX_PROFILE | ||
| node_compute_range_.Begin(); | ||
| #endif | ||
| bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); | ||
| bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled(); | ||
|
|
||
| if (session_state_.Profiler().IsEnabled()) { | ||
| if (session_profiling_enabled || run_profiling_enabled) { | ||
| auto& node = kernel.Node(); | ||
| node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name(); | ||
| concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool()); | ||
| VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_; | ||
| kernel_begin_time_ = session_state_.Profiler().Start(); | ||
|
|
||
| kernel_begin_time_ = std::chrono::high_resolution_clock::now(); | ||
| if (session_profiling_enabled) { | ||
| session_state_.Profiler().Start(kernel_begin_time_); | ||
| } | ||
| if (run_profiling_enabled) { | ||
| session_scope_.run_profiler_->Start(kernel_begin_time_); | ||
| } | ||
|
|
||
| CalculateTotalInputSizes(&kernel_context, &kernel_, | ||
| input_activation_sizes_, input_parameter_sizes_, | ||
| node_name_, input_type_shape_); | ||
|
|
@@ -363,26 +385,42 @@ class KernelScope { | |
| node_compute_range_.End(); | ||
| #endif | ||
|
|
||
| if (session_state_.Profiler().IsEnabled()) { | ||
| auto& profiler = session_state_.Profiler(); | ||
| bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); | ||
| bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled(); | ||
|
|
||
| if (session_profiling_enabled || run_profiling_enabled) { | ||
| std::string output_type_shape_; | ||
| CalculateTotalOutputSizes(&kernel_context_, total_output_sizes_, node_name_, output_type_shape_); | ||
| profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT, | ||
| node_name_ + "_kernel_time", | ||
| kernel_begin_time_, | ||
| // Log additional operation args / info. | ||
| { | ||
| {"op_name", kernel_.KernelDef().OpName()}, | ||
| {"provider", kernel_.KernelDef().Provider()}, | ||
| {"node_index", std::to_string(kernel_.Node().Index())}, | ||
| {"activation_size", std::to_string(input_activation_sizes_)}, | ||
| {"parameter_size", std::to_string(input_parameter_sizes_)}, | ||
| {"output_size", std::to_string(total_output_sizes_)}, | ||
| {"input_type_shape", input_type_shape_}, | ||
| {"output_type_shape", output_type_shape_}, | ||
| {"thread_scheduling_stats", | ||
| concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())}, | ||
| }); | ||
|
|
||
| std::initializer_list<std::pair<std::string, std::string>> event_args = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| {"op_name", kernel_.KernelDef().OpName()}, | ||
| {"provider", kernel_.KernelDef().Provider()}, | ||
| {"node_index", std::to_string(kernel_.Node().Index())}, | ||
| {"activation_size", std::to_string(input_activation_sizes_)}, | ||
| {"parameter_size", std::to_string(input_parameter_sizes_)}, | ||
| {"output_size", std::to_string(total_output_sizes_)}, | ||
| {"input_type_shape", input_type_shape_}, | ||
| {"output_type_shape", output_type_shape_}, | ||
| {"thread_scheduling_stats", | ||
| concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())}, | ||
| }; | ||
|
|
||
| auto now = std::chrono::high_resolution_clock::now(); | ||
| if (session_profiling_enabled) { | ||
| session_state_.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, | ||
| node_name_ + "_kernel_time", | ||
| kernel_begin_time_, | ||
| now, | ||
| event_args); | ||
| } | ||
|
|
||
| if (run_profiling_enabled) { | ||
| session_scope_.run_profiler_->EndTimeAndRecordEvent(profiling::NODE_EVENT, | ||
| node_name_ + "_kernel_time", | ||
| kernel_begin_time_, | ||
| now, | ||
| event_args); | ||
| } | ||
| } | ||
|
|
||
| #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT | ||
|
|
@@ -405,7 +443,6 @@ class KernelScope { | |
| } //~KernelScope | ||
|
|
||
| private: | ||
| TimePoint kernel_begin_time_; | ||
| SessionScope& session_scope_; | ||
| const SessionState& session_state_; | ||
| std::string node_name_; | ||
|
|
@@ -417,6 +454,8 @@ class KernelScope { | |
| size_t total_output_sizes_{}; | ||
| std::string input_type_shape_; | ||
|
|
||
| TimePoint kernel_begin_time_; | ||
|
|
||
| #ifdef CONCURRENCY_VISUALIZER | ||
| diagnostic::span span_; | ||
| #endif | ||
|
|
@@ -588,7 +627,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< | |
| #endif | ||
| const bool& terminate_flag, | ||
| const bool only_execute_path_to_fetches, | ||
| bool single_thread_mode) { | ||
| bool single_thread_mode, | ||
| profiling::Profiler* run_profiler) { | ||
| auto* execution_plan = session_state.GetExecutionPlan(); | ||
| VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size(); | ||
| int32_t valid_streams = 0; | ||
|
|
@@ -631,7 +671,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< | |
| ORT_UNUSED_PARAMETER(only_execute_path_to_fetches); | ||
| #endif | ||
|
|
||
| SessionScope session_scope(session_state, ctx.GetExecutionFrame()); | ||
| SessionScope session_scope(session_state, ctx.GetExecutionFrame(), run_profiler); | ||
|
|
||
| auto* tp = single_thread_mode ? nullptr : session_state.GetInterOpThreadPool(); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -599,7 +599,8 @@ ExecuteGraphImpl(const SessionState& session_state, | |
| DeviceStreamCollection* device_stream_collection, | ||
| #endif | ||
| const bool only_execute_path_to_fetches = false, | ||
| Stream* parent_stream = nullptr) { | ||
| Stream* parent_stream = nullptr, | ||
| profiling::Profiler* run_profiler = nullptr) { | ||
| const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo(); | ||
| const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks(); | ||
| #ifdef ORT_ENABLE_STREAM | ||
|
|
@@ -631,7 +632,8 @@ ExecuteGraphImpl(const SessionState& session_state, | |
| terminate_flag, | ||
| only_execute_path_to_fetches, | ||
| // single thread mode | ||
| single_thread_mode)); | ||
| single_thread_mode, | ||
| run_profiler)); | ||
| ORT_RETURN_IF_ERROR(status); | ||
| } else { | ||
| auto feeds_to_use = feeds; | ||
|
|
@@ -679,7 +681,8 @@ ExecuteGraphImpl(const SessionState& session_state, | |
| #endif | ||
| terminate_flag, | ||
| only_execute_path_to_fetches, | ||
| single_thread_mode)); | ||
| single_thread_mode, | ||
| run_profiler)); | ||
| ORT_RETURN_IF_ERROR(status); | ||
| InlinedVector<Stream*> fetches_streams; | ||
| fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size()); | ||
|
|
@@ -717,7 +720,8 @@ common::Status ExecuteGraph(const SessionState& session_state, | |
| DeviceStreamCollectionHolder& device_stream_collection_holder, | ||
| #endif | ||
| bool only_execute_path_to_fetches, | ||
| Stream* parent_stream) { | ||
| Stream* parent_stream, | ||
| profiling::Profiler* run_profiler) { | ||
| ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager)); | ||
|
|
||
| // finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background | ||
|
|
@@ -728,13 +732,15 @@ common::Status ExecuteGraph(const SessionState& session_state, | |
| execution_mode, terminate_flag, logger, | ||
| device_stream_collection, | ||
| only_execute_path_to_fetches, | ||
| parent_stream); | ||
| parent_stream, | ||
| run_profiler); | ||
| return retval; | ||
| #else | ||
| return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, | ||
| execution_mode, terminate_flag, logger, | ||
| only_execute_path_to_fetches, | ||
| parent_stream); | ||
| parent_stream, | ||
| run_profiler); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -745,17 +751,16 @@ common::Status ExecuteGraph(const SessionState& session_state, | |
| #ifdef ORT_ENABLE_STREAM | ||
| DeviceStreamCollectionHolder& device_stream_collection_holder, | ||
| #endif | ||
| const logging::Logger& logger) { | ||
| return ExecuteGraph(session_state, | ||
| feeds_fetches_manager, | ||
| feeds, fetches, | ||
| execution_mode, | ||
| run_options.terminate, | ||
| logger, | ||
| const logging::Logger& logger, | ||
| profiling::Profiler* run_profiler) { | ||
| return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches, | ||
| execution_mode, run_options.terminate, logger, | ||
| #ifdef ORT_ENABLE_STREAM | ||
| device_stream_collection_holder, | ||
| #endif | ||
| run_options.only_execute_path_to_fetches); | ||
| run_options.only_execute_path_to_fetches, | ||
| nullptr, | ||
| run_profiler); | ||
xiaofeihan1 marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a number of things being passed from RunOptions here. Can we modify the signature in a way that a reference to RunOptions is being passed? Then we can instantiate the profiler higher in the stack, inside ExecuteGraph? I can see that RunOptions are being passed in one of the overloads, that seems sensible. |
||
| } | ||
|
|
||
| #ifdef ENABLE_TRAINING | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need C and C++ API and Python comes after that.
And we need tests