diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 923174cbfe488..e55f8e969efab 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -360,6 +360,10 @@ class IExecutionProvider { return {}; } + virtual std::unique_ptr GetRunProfiler() { + return {}; + } + virtual DataLayout GetPreferredLayout() const { // EPs which prefer a different layout should override to return their preferred layout. return DataLayout::Default; diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index e63ab044834f5..21a847dcabd13 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -35,6 +35,14 @@ struct OrtRunOptions { // So it is possible that only some of the nodes are executed. bool only_execute_path_to_fetches = false; + // Set to 'true' to enable profiling for this run. + bool enable_profiling = false; + + // File prefix for profiling result for this run. + // The actual filename will be: _.json + // Only used when enable_profiling is true. + std::string profile_file_prefix = "onnxruntime_run_profile"; + #ifdef ENABLE_TRAINING // Used by onnxruntime::training::TrainingSession. This class is now deprecated. // Delete training_mode when TrainingSession is deleted. diff --git a/onnxruntime/core/common/profiler.cc b/onnxruntime/core/common/profiler.cc index 8562e5524af74..731c2eb0d78c1 100644 --- a/onnxruntime/core/common/profiler.cc +++ b/onnxruntime/core/common/profiler.cc @@ -20,8 +20,11 @@ profiling::Profiler::~Profiler() {} #endif ::onnxruntime::TimePoint profiling::Profiler::Start() { + return Start(std::chrono::high_resolution_clock::now()); +} + +::onnxruntime::TimePoint profiling::Profiler::Start(const TimePoint& start_time) { ORT_ENFORCE(enabled_); - auto start_time = std::chrono::high_resolution_clock::now(); auto ts = TimeDiffMicroSeconds(profiling_start_time_, start_time); for (const auto& ep_profiler : ep_profilers_) { ep_profiler->Start(ts); @@ -75,8 +78,17 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category, const std::string& event_name, const TimePoint& start_time, const std::initializer_list>& event_args, + bool sync_gpu) { + EndTimeAndRecordEvent(category, event_name, start_time, std::chrono::high_resolution_clock::now(), event_args, sync_gpu); +} + +void Profiler::EndTimeAndRecordEvent(EventCategory category, + const std::string& event_name, + const TimePoint& start_time, + const TimePoint& end_time, + const std::initializer_list>& event_args, bool /*sync_gpu*/) { - long long dur = TimeDiffMicroSeconds(start_time); + long long dur = TimeDiffMicroSeconds(start_time, end_time); long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time); EventRecord event(category, logging::GetProcessId(), diff --git a/onnxruntime/core/common/profiler.h b/onnxruntime/core/common/profiler.h index 0103d8abb151f..02f3c738d31d3 100644 --- a/onnxruntime/core/common/profiler.h +++ b/onnxruntime/core/common/profiler.h @@ -54,6 +54,11 @@ class Profiler { */ TimePoint Start(); + /* + Start profiling with a specific start time. + */ + TimePoint Start(const TimePoint& start_time); + /* Whether data collection and output from this profiler is enabled. */ @@ -80,6 +85,13 @@ class Profiler { const std::initializer_list>& event_args = {}, bool sync_gpu = false); + void EndTimeAndRecordEvent(EventCategory category, + const std::string& event_name, + const TimePoint& start_time, + const TimePoint& end_time, + const std::initializer_list>& event_args = {}, + bool sync_gpu = false); + /* Write profile data to the given stream in chrome format defined below. https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview# diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 7180d976c1d3c..84a2b051e1a11 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -155,8 +155,8 @@ std::string ComposeSeriesName(const GraphViewer& graph_viewer) { class SessionScope { public: friend class KernelScope; - SessionScope(const SessionState& session_state, const ExecutionFrame& frame) - : session_state_(session_state) + SessionScope(const SessionState& session_state, const ExecutionFrame& frame, profiling::Profiler* run_profiler = nullptr) + : session_state_(session_state), run_profiler_(run_profiler) #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) , frame_(frame) @@ -173,12 +173,18 @@ class SessionScope { #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS , - dump_context_{ - session_state_.GetGraphExecutionCounter(), 0} + dump_context_{session_state_.GetGraphExecutionCounter(), 0} #endif { - if (session_state_.Profiler().IsEnabled()) { - session_start_ = session_state.Profiler().Start(); + bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); + bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled(); + + session_start_ = std::chrono::high_resolution_clock::now(); + if (session_profiling_enabled) { + session_state_.Profiler().Start(session_start_); + } + if (run_profiling_enabled) { + run_profiler_->Start(session_start_); } auto& logger = session_state_.Logger(); @@ -225,9 +231,17 @@ class SessionScope { } #endif - if (session_state_.Profiler().IsEnabled()) { - session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_); + bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); + bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled(); + + auto now = std::chrono::high_resolution_clock::now(); + if (session_profiling_enabled) { + session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now); + } + if (run_profiling_enabled) { + run_profiler_->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now); } + #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) auto& logger = session_state_.Logger(); for (auto i : frame_.GetStaticMemorySizeInfo()) { @@ -254,6 +268,7 @@ class SessionScope { private: const SessionState& session_state_; + profiling::Profiler* run_profiler_; TimePoint session_start_; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) const ExecutionFrame& frame_; @@ -340,16 +355,23 @@ class KernelScope { utils::DumpNodeInputs(dump_context_, kernel_context_, kernel_.Node(), session_state_, session_scope_.dump_analysis_); #endif -#ifdef ENABLE_NVTX_PROFILE - node_compute_range_.Begin(); -#endif + bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); + bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled(); - if (session_state_.Profiler().IsEnabled()) { + if (session_profiling_enabled || run_profiling_enabled) { auto& node = kernel.Node(); node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name(); concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool()); VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_; - kernel_begin_time_ = session_state_.Profiler().Start(); + + kernel_begin_time_ = std::chrono::high_resolution_clock::now(); + if (session_profiling_enabled) { + session_state_.Profiler().Start(kernel_begin_time_); + } + if (run_profiling_enabled) { + session_scope_.run_profiler_->Start(kernel_begin_time_); + } + CalculateTotalInputSizes(&kernel_context, &kernel_, input_activation_sizes_, input_parameter_sizes_, node_name_, input_type_shape_); @@ -363,26 +385,42 @@ class KernelScope { node_compute_range_.End(); #endif - if (session_state_.Profiler().IsEnabled()) { - auto& profiler = session_state_.Profiler(); + bool session_profiling_enabled = session_state_.Profiler().IsEnabled(); + bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled(); + + if (session_profiling_enabled || run_profiling_enabled) { std::string output_type_shape_; CalculateTotalOutputSizes(&kernel_context_, total_output_sizes_, node_name_, output_type_shape_); - profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT, - node_name_ + "_kernel_time", - kernel_begin_time_, - // Log additional operation args / info. - { - {"op_name", kernel_.KernelDef().OpName()}, - {"provider", kernel_.KernelDef().Provider()}, - {"node_index", std::to_string(kernel_.Node().Index())}, - {"activation_size", std::to_string(input_activation_sizes_)}, - {"parameter_size", std::to_string(input_parameter_sizes_)}, - {"output_size", std::to_string(total_output_sizes_)}, - {"input_type_shape", input_type_shape_}, - {"output_type_shape", output_type_shape_}, - {"thread_scheduling_stats", - concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())}, - }); + + std::initializer_list> event_args = { + {"op_name", kernel_.KernelDef().OpName()}, + {"provider", kernel_.KernelDef().Provider()}, + {"node_index", std::to_string(kernel_.Node().Index())}, + {"activation_size", std::to_string(input_activation_sizes_)}, + {"parameter_size", std::to_string(input_parameter_sizes_)}, + {"output_size", std::to_string(total_output_sizes_)}, + {"input_type_shape", input_type_shape_}, + {"output_type_shape", output_type_shape_}, + {"thread_scheduling_stats", + concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())}, + }; + + auto now = std::chrono::high_resolution_clock::now(); + if (session_profiling_enabled) { + session_state_.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT, + node_name_ + "_kernel_time", + kernel_begin_time_, + now, + event_args); + } + + if (run_profiling_enabled) { + session_scope_.run_profiler_->EndTimeAndRecordEvent(profiling::NODE_EVENT, + node_name_ + "_kernel_time", + kernel_begin_time_, + now, + event_args); + } } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT @@ -405,7 +443,6 @@ class KernelScope { } //~KernelScope private: - TimePoint kernel_begin_time_; SessionScope& session_scope_; const SessionState& session_state_; std::string node_name_; @@ -417,6 +454,8 @@ class KernelScope { size_t total_output_sizes_{}; std::string input_type_shape_; + TimePoint kernel_begin_time_; + #ifdef CONCURRENCY_VISUALIZER diagnostic::span span_; #endif @@ -588,7 +627,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #endif const bool& terminate_flag, const bool only_execute_path_to_fetches, - bool single_thread_mode) { + bool single_thread_mode, + profiling::Profiler* run_profiler) { auto* execution_plan = session_state.GetExecutionPlan(); VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size(); int32_t valid_streams = 0; @@ -631,7 +671,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< ORT_UNUSED_PARAMETER(only_execute_path_to_fetches); #endif - SessionScope session_scope(session_state, ctx.GetExecutionFrame()); + SessionScope session_scope(session_state, ctx.GetExecutionFrame(), run_profiler); auto* tp = single_thread_mode ? nullptr : session_state.GetInterOpThreadPool(); diff --git a/onnxruntime/core/framework/sequential_executor.h b/onnxruntime/core/framework/sequential_executor.h index c22ccb041afdf..f42ccc24dabd0 100644 --- a/onnxruntime/core/framework/sequential_executor.h +++ b/onnxruntime/core/framework/sequential_executor.h @@ -46,7 +46,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #endif const bool& terminate_flag, const bool only_execute_path_to_fetches, - bool single_thread_mode); + bool single_thread_mode, + profiling::Profiler* run_profiler = nullptr); #ifdef ENABLE_TRAINING onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span feed_mlvalue_idxs, diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index ca64c7c7cae89..d28157e16feb4 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -599,7 +599,8 @@ ExecuteGraphImpl(const SessionState& session_state, DeviceStreamCollection* device_stream_collection, #endif const bool only_execute_path_to_fetches = false, - Stream* parent_stream = nullptr) { + Stream* parent_stream = nullptr, + profiling::Profiler* run_profiler = nullptr) { const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo(); const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks(); #ifdef ORT_ENABLE_STREAM @@ -631,7 +632,8 @@ ExecuteGraphImpl(const SessionState& session_state, terminate_flag, only_execute_path_to_fetches, // single thread mode - single_thread_mode)); + single_thread_mode, + run_profiler)); ORT_RETURN_IF_ERROR(status); } else { auto feeds_to_use = feeds; @@ -679,7 +681,8 @@ ExecuteGraphImpl(const SessionState& session_state, #endif terminate_flag, only_execute_path_to_fetches, - single_thread_mode)); + single_thread_mode, + run_profiler)); ORT_RETURN_IF_ERROR(status); InlinedVector fetches_streams; fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size()); @@ -717,7 +720,8 @@ common::Status ExecuteGraph(const SessionState& session_state, DeviceStreamCollectionHolder& device_stream_collection_holder, #endif bool only_execute_path_to_fetches, - Stream* parent_stream) { + Stream* parent_stream, + profiling::Profiler* run_profiler) { ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager)); // finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background @@ -728,13 +732,15 @@ common::Status ExecuteGraph(const SessionState& session_state, execution_mode, terminate_flag, logger, device_stream_collection, only_execute_path_to_fetches, - parent_stream); + parent_stream, + run_profiler); return retval; #else return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, execution_mode, terminate_flag, logger, only_execute_path_to_fetches, - parent_stream); + parent_stream, + run_profiler); #endif } @@ -745,17 +751,16 @@ common::Status ExecuteGraph(const SessionState& session_state, #ifdef ORT_ENABLE_STREAM DeviceStreamCollectionHolder& device_stream_collection_holder, #endif - const logging::Logger& logger) { - return ExecuteGraph(session_state, - feeds_fetches_manager, - feeds, fetches, - execution_mode, - run_options.terminate, - logger, + const logging::Logger& logger, + profiling::Profiler* run_profiler) { + return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches, + execution_mode, run_options.terminate, logger, #ifdef ORT_ENABLE_STREAM device_stream_collection_holder, #endif - run_options.only_execute_path_to_fetches); + run_options.only_execute_path_to_fetches, + nullptr, + run_profiler); } #ifdef ENABLE_TRAINING diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 4b4c483ba1202..32ac09c9bc66d 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -83,7 +83,8 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag DeviceStreamCollectionHolder& device_stream_collection_holder, #endif bool only_execute_path_to_fetches = false, - Stream* parent_stream = nullptr); + Stream* parent_stream = nullptr, + profiling::Profiler* run_profiler = nullptr); common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, @@ -91,7 +92,8 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag #ifdef ORT_ENABLE_STREAM DeviceStreamCollectionHolder& device_stream_collection_holder, #endif - const logging::Logger& logger); + const logging::Logger& logger, + profiling::Profiler* run_profiler = nullptr); #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 2f50fd8051b9c..f48f4124bac61 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -27,6 +27,7 @@ #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/webgpu_profiler.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/program.h" @@ -611,7 +612,7 @@ void WebGpuContext::StartProfiling() { } } -void WebGpuContext::CollectProfilingData(profiling::Events& events) { +void WebGpuContext::CollectProfilingData(const std::vector& profilers) { if (!pending_queries_.empty()) { for (const auto& pending_query : pending_queries_) { const auto& pending_kernels = pending_query.kernels; @@ -660,7 +661,16 @@ void WebGpuContext::CollectProfilingData(profiling::Events& events) { static_cast(std::round(start_time / 1000.0)), static_cast(std::round((end_time - start_time) / 1000.0)), event_args); - events.emplace_back(std::move(event)); + + // Distribute the event to all WebGPU EP profilers. + // To minimize copies, we copy the event to all but the last profiler, + // and move it to the last one. + // Typically, there is only one WebGPU EP profiler. When both session-level and run-level + // profiling are enabled, there are two profilers. + for (size_t p = 0; p < profilers.size() - 1; ++p) { + profilers[p]->Events().emplace_back(event); + } + profilers.back()->Events().emplace_back(std::move(event)); } query_read_buffer.Unmap(); @@ -674,12 +684,21 @@ void WebGpuContext::CollectProfilingData(profiling::Events& events) { } void WebGpuContext::EndProfiling(TimePoint /* tp */, profiling::Events& events, profiling::Events& cached_events) { + // Note: + // With run-level profiling enabled for the WebGPU EP, EndProfiling may be called + // while another thread is executing an inference (e.g., t1 finishes and + // ends profiling while t2 is running). + // + // This concurrent scenario is expected and does not affect the correctness of + // profiling data, but it means we can no longer enforce the assumption that no + // active inference is ongoing at this point. + // This function is called when no active inference is ongoing. - ORT_ENFORCE(!is_profiling_, "Profiling is ongoing in an inference run."); + // ORT_ENFORCE(!is_profiling_, "Profiling is ongoing in an inference run."); if (query_type_ != TimestampQueryType::None) { // No pending kernels or queries should be present at this point. They should have been collected in CollectProfilingData. - ORT_ENFORCE(pending_kernels_.empty() && pending_queries_.empty(), "Pending kernels or queries are not empty."); + // ORT_ENFORCE(pending_kernels_.empty() && pending_queries_.empty(), "Pending kernels or queries are not empty."); events.insert(events.end(), std::make_move_iterator(cached_events.begin()), diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 8cc513680142d..b6baf0b3cdd93 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -24,6 +24,7 @@ namespace webgpu { class WebGpuContext; class ComputeContextBase; class ProgramBase; +class WebGpuProfiler; // Definition for CapturedCommandInfo in the webgpu namespace struct CapturedCommandInfo { @@ -196,7 +197,7 @@ class WebGpuContext final { } void StartProfiling(); - void CollectProfilingData(profiling::Events& events); + void CollectProfilingData(const std::vector& profilers); void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); // @@ -339,6 +340,10 @@ class WebGpuContext final { // External vector to store captured commands, owned by EP std::vector* external_captured_commands_ = nullptr; + +#if defined(ENABLE_PIX_FOR_WEBGPU_EP) + std::unique_ptr pix_frame_generator_ = nullptr; +#endif // ENABLE_PIX_FOR_WEBGPU_EP }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index ee6b7707384e2..9fdc9b724a4db 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -979,7 +979,13 @@ WebGpuExecutionProvider::~WebGpuExecutionProvider() { std::unique_ptr WebGpuExecutionProvider::GetProfiler() { auto profiler = std::make_unique(context_); - profiler_ = profiler.get(); + session_profiler_ = profiler.get(); + return profiler; +} + +std::unique_ptr WebGpuExecutionProvider::GetRunProfiler() { + auto profiler = std::make_unique(context_); + tls_run_profiler_ = profiler.get(); return profiler; } @@ -988,7 +994,8 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op context_.PushErrorScope(); } - if (profiler_->Enabled()) { + // Session-level profiling handling if needed + if (run_options.enable_profiling || (session_profiler_ && session_profiler_->Enabled())) { context_.StartProfiling(); } @@ -1010,7 +1017,7 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op return Status::OK(); } -Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) { +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& run_options) { context_.Flush(BufferManager()); if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { @@ -1023,8 +1030,20 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti } } - if (profiler_->Enabled()) { - context_.CollectProfilingData(profiler_->Events()); + std::vector profilers; + if (session_profiler_ && session_profiler_->Enabled()) { + profilers.push_back(session_profiler_); + } + + if (run_options.enable_profiling && tls_run_profiler_) { + if (tls_run_profiler_->Enabled()) { + profilers.push_back(tls_run_profiler_); + } + tls_run_profiler_ = nullptr; + } + + if (!profilers.empty()) { + context_.CollectProfilingData(profilers); } #if defined(ENABLE_PIX_FOR_WEBGPU_EP) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index bf0963f67cf1e..a44f0ef623cee 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -75,6 +75,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { int GetDeviceId() const override { return context_id_; } std::unique_ptr GetProfiler() override; + std::unique_ptr GetRunProfiler() override; bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; @@ -88,7 +89,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { int context_id_; webgpu::WebGpuContext& context_; - webgpu::WebGpuProfiler* profiler_ = nullptr; + + webgpu::WebGpuProfiler* session_profiler_{nullptr}; + inline static thread_local webgpu::WebGpuProfiler* tls_run_profiler_{nullptr}; + DataLayout preferred_data_layout_; std::vector force_cpu_node_names_; bool enable_graph_capture_ = false; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5523dc78b5d2..fbab8c3aabe2e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -9,7 +9,9 @@ #include #include #include +#include #include +#include #include "core/common/denormal.h" #include "core/common/logging/isink.h" @@ -118,6 +120,8 @@ template inline std::basic_string GetCurrentTimeString() { auto now = std::chrono::system_clock::now(); auto in_time_t = std::chrono::system_clock::to_time_t(now); + auto ms = std::chrono::duration_cast(now.time_since_epoch()) % 1000; + std::tm local_tm; // NOLINT #ifdef _WIN32 @@ -128,7 +132,10 @@ inline std::basic_string GetCurrentTimeString() { T time_str[32]; OrtStrftime(time_str, sizeof(time_str), GetDateFormatString(), &local_tm); - return std::basic_string(time_str); + + std::basic_stringstream ss; + ss << time_str << T('_') << std::setfill(T('0')) << std::setw(3) << ms.count(); + return ss.str(); } #if !defined(ORT_MINIMAL_BUILD) @@ -868,6 +875,8 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } p_exec_provider->SetLogger(session_logger_); + // For session-level profiling, we pass a default RunOptions (or nullptr logic if updated further, currently passing default) + // Since session-level profiler init shouldn't attach to a run-specific pointer, implementations should handle empty options gracefully. session_profiler_.AddEpProfilers(p_exec_provider->GetProfiler()); return execution_providers_.Add(provider_type, p_exec_provider); } @@ -2925,9 +2934,29 @@ Status InferenceSession::Run(const RunOptions& run_options, gsl::span feed_names, gsl::span feeds, gsl::span output_names, std::vector* p_fetches, const std::vector* p_fetches_device_info) { + std::optional run_profiler; + if (run_options.enable_profiling) { + run_profiler.emplace(); + run_profiler->Initialize(session_logger_); + std::basic_ostringstream oss; + oss << ToPathString(run_options.profile_file_prefix) << "_" << GetCurrentTimeString() << ".json"; + run_profiler->StartProfiling(oss.str()); + + for (auto& ep : execution_providers_) { + auto p = ep->GetRunProfiler(); + if (p) { + run_profiler->AddEpProfilers(std::move(p)); + } + } + } + TimePoint tp = std::chrono::high_resolution_clock::now(); + if (session_profiler_.IsEnabled()) { - tp = session_profiler_.Start(); + session_profiler_.Start(tp); + } + if (run_options.enable_profiling) { + run_profiler->Start(tp); } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT @@ -3059,7 +3088,8 @@ Status InferenceSession::Run(const RunOptions& run_options, #ifdef ORT_ENABLE_STREAM device_stream_collection_holder, #endif - run_logger); + run_logger, + run_options.enable_profiling ? &*run_profiler : nullptr); } // info all execution providers InferenceSession:Run ended @@ -3135,8 +3165,14 @@ Status InferenceSession::Run(const RunOptions& run_options, env.GetTelemetryProvider().LogEvaluationStop(session_id_); // send out profiling events (optional) - if (session_profiler_.IsEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp); + if (session_profiler_.IsEnabled() || run_options.enable_profiling) { + auto now = std::chrono::high_resolution_clock::now(); + if (session_profiler_.IsEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp, now); + } + if (run_options.enable_profiling) { + run_profiler->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp, now); + } } #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT TraceLoggingWriteStop(ortrun_activity, "OrtRun"); @@ -3161,7 +3197,14 @@ Status InferenceSession::Run(const RunOptions& run_options, cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) && !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; - ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); + // Disable run-level profiling for internal runs used for memory allocation or graph capture + RunOptions recursive_run_options{run_options}; + recursive_run_options.enable_profiling = false; + ORT_RETURN_IF_ERROR(Run(recursive_run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); + } + + if (run_options.enable_profiling) { + run_profiler->EndProfiling(); } // Log runtime error telemetry if the return value is not OK diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f0d8906d99c14..957fbd4191dfc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2361,6 +2361,10 @@ RunOptions instance. The individual calls will exit gracefully and return an err #endif .def_readwrite("only_execute_path_to_fetches", &RunOptions::only_execute_path_to_fetches, R"pbdoc(Only execute the nodes needed by fetch list)pbdoc") + .def_readwrite("enable_profiling", &RunOptions::enable_profiling, + R"pbdoc(Enable profiling for this run.)pbdoc") + .def_readwrite("profile_file_prefix", &RunOptions::profile_file_prefix, + R"pbdoc(File prefix for profiling result. The actual filename will be: _.json)pbdoc") .def( "add_run_config_entry", [](RunOptions* options, const char* config_key, const char* config_value) -> void {