Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ class IExecutionProvider {
return {};
}

virtual std::unique_ptr<profiling::EpProfiler> GetRunProfiler() {
return {};
}

virtual DataLayout GetPreferredLayout() const {
// EPs which prefer a different layout should override to return their preferred layout.
return DataLayout::Default;
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ struct OrtRunOptions {
// So it is possible that only some of the nodes are executed.
bool only_execute_path_to_fetches = false;

// Set to 'true' to enable profiling for this run.
bool enable_profiling = false;

// File prefix for profiling result for this run.
// The actual filename will be: <profile_file_prefix>_<timestamp>.json
// Only used when enable_profiling is true.
std::string profile_file_prefix = "onnxruntime_run_profile";

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need C and C++ API and Python comes after that.

And we need tests

#ifdef ENABLE_TRAINING
// Used by onnxruntime::training::TrainingSession. This class is now deprecated.
// Delete training_mode when TrainingSession is deleted.
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/common/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
#endif

::onnxruntime::TimePoint profiling::Profiler::Start() {
return Start(std::chrono::high_resolution_clock::now());
}

::onnxruntime::TimePoint profiling::Profiler::Start(const TimePoint& start_time) {
ORT_ENFORCE(enabled_);
auto start_time = std::chrono::high_resolution_clock::now();
auto ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);
for (const auto& ep_profiler : ep_profilers_) {
ep_profiler->Start(ts);
Expand Down Expand Up @@ -75,8 +78,17 @@
const std::string& event_name,
const TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args,
bool sync_gpu) {
EndTimeAndRecordEvent(category, event_name, start_time, std::chrono::high_resolution_clock::now(), event_args, sync_gpu);
}

void Profiler::EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const TimePoint& end_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args,
bool /*sync_gpu*/) {
long long dur = TimeDiffMicroSeconds(start_time);
long long dur = TimeDiffMicroSeconds(start_time, end_time);

Check warning on line 91 in onnxruntime/core/common/profiler.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4] Raw Output: onnxruntime/core/common/profiler.cc:91: Use int16_t/int64_t/etc, rather than the C type long [runtime/int] [4]
long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);

EventRecord event(category, logging::GetProcessId(),
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/common/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class Profiler {
*/
TimePoint Start();

/*
Start profiling with a specific start time.
*/
TimePoint Start(const TimePoint& start_time);

/*
Whether data collection and output from this profiler is enabled.
*/
Expand All @@ -80,6 +85,13 @@ class Profiler {
const std::initializer_list<std::pair<std::string, std::string>>& event_args = {},
bool sync_gpu = false);

void EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const TimePoint& end_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args = {},
bool sync_gpu = false);

/*
Write profile data to the given stream in chrome format defined below.
https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#
Expand Down
108 changes: 74 additions & 34 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ std::string ComposeSeriesName(const GraphViewer& graph_viewer) {
class SessionScope {
public:
friend class KernelScope;
SessionScope(const SessionState& session_state, const ExecutionFrame& frame)
: session_state_(session_state)
SessionScope(const SessionState& session_state, const ExecutionFrame& frame, profiling::Profiler* run_profiler = nullptr)
: session_state_(session_state), run_profiler_(run_profiler)
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
,
frame_(frame)
Expand All @@ -173,12 +173,18 @@ class SessionScope {
#endif
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
,
dump_context_{
session_state_.GetGraphExecutionCounter(), 0}
dump_context_{session_state_.GetGraphExecutionCounter(), 0}
#endif
{
if (session_state_.Profiler().IsEnabled()) {
session_start_ = session_state.Profiler().Start();
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boo

const

Copy link
Member

@yuslepukhin yuslepukhin Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

session_state_.Profiler().IsEnabled();

I am still not convinced that we should allow both profilers to run in parallel.

Do you have a use case for that? What would be the purpose to collect the same data?

If someone wants continuous profiling, would it not be the same thing as running it with RunOptons?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on how we want to handle the case when both run-level and session-level profiling are enabled.

For example, when a user calls Session::Run with both run-level and session-level profiling enabled, there will be two profilers active: a local run_profiler and the session_profiler_ owned by InferenceSession. The current implementation guarantees that two JSON files are generated, and that the events recorded in the run-level profiling output are a strict subset of those in the session-level profiling output.

In this scenario, each operator execution generates two identical profiling events: one is recorded by the session-level profiler, and the other is recorded by the run-level profiler.

bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boo

const


session_start_ = std::chrono::high_resolution_clock::now();
if (session_profiling_enabled) {
session_state_.Profiler().Start(session_start_);
}
if (run_profiling_enabled) {
run_profiler_->Start(session_start_);
}

auto& logger = session_state_.Logger();
Expand Down Expand Up @@ -225,9 +231,17 @@ class SessionScope {
}
#endif

if (session_state_.Profiler().IsEnabled()) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_);
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();

auto now = std::chrono::high_resolution_clock::now();
if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now);
}
if (run_profiling_enabled) {
run_profiler_->EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_, now);
}

#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
auto& logger = session_state_.Logger();
for (auto i : frame_.GetStaticMemorySizeInfo()) {
Expand All @@ -254,6 +268,7 @@ class SessionScope {

private:
const SessionState& session_state_;
profiling::Profiler* run_profiler_;
TimePoint session_start_;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
const ExecutionFrame& frame_;
Expand Down Expand Up @@ -340,16 +355,23 @@ class KernelScope {
utils::DumpNodeInputs(dump_context_, kernel_context_, kernel_.Node(), session_state_, session_scope_.dump_analysis_);
#endif

#ifdef ENABLE_NVTX_PROFILE
node_compute_range_.Begin();
#endif
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled();

if (session_state_.Profiler().IsEnabled()) {
if (session_profiling_enabled || run_profiling_enabled) {
auto& node = kernel.Node();
node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name();
concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool());
VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_;
kernel_begin_time_ = session_state_.Profiler().Start();

kernel_begin_time_ = std::chrono::high_resolution_clock::now();
if (session_profiling_enabled) {
session_state_.Profiler().Start(kernel_begin_time_);
}
if (run_profiling_enabled) {
session_scope_.run_profiler_->Start(kernel_begin_time_);
}

CalculateTotalInputSizes(&kernel_context, &kernel_,
input_activation_sizes_, input_parameter_sizes_,
node_name_, input_type_shape_);
Expand All @@ -363,26 +385,42 @@ class KernelScope {
node_compute_range_.End();
#endif

if (session_state_.Profiler().IsEnabled()) {
auto& profiler = session_state_.Profiler();
bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
bool run_profiling_enabled = session_scope_.run_profiler_ && session_scope_.run_profiler_->IsEnabled();

if (session_profiling_enabled || run_profiling_enabled) {
std::string output_type_shape_;
CalculateTotalOutputSizes(&kernel_context_, total_output_sizes_, node_name_, output_type_shape_);
profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
// Log additional operation args / info.
{
{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
});

std::initializer_list<std::pair<std::string, std::string>> event_args = {
Copy link
Member

@yuslepukhin yuslepukhin Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::

Should this be constexpr? Should this be outside the destructor?

{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
};

auto now = std::chrono::high_resolution_clock::now();
if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
now,
event_args);
}

if (run_profiling_enabled) {
session_scope_.run_profiler_->EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
now,
event_args);
}
}

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
Expand All @@ -405,7 +443,6 @@ class KernelScope {
} //~KernelScope

private:
TimePoint kernel_begin_time_;
SessionScope& session_scope_;
const SessionState& session_state_;
std::string node_name_;
Expand All @@ -417,6 +454,8 @@ class KernelScope {
size_t total_output_sizes_{};
std::string input_type_shape_;

TimePoint kernel_begin_time_;

#ifdef CONCURRENCY_VISUALIZER
diagnostic::span span_;
#endif
Expand Down Expand Up @@ -588,7 +627,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode) {
bool single_thread_mode,
profiling::Profiler* run_profiler) {
auto* execution_plan = session_state.GetExecutionPlan();
VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size();
int32_t valid_streams = 0;
Expand Down Expand Up @@ -631,7 +671,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
ORT_UNUSED_PARAMETER(only_execute_path_to_fetches);
#endif

SessionScope session_scope(session_state, ctx.GetExecutionFrame());
SessionScope session_scope(session_state, ctx.GetExecutionFrame(), run_profiler);

auto* tp = single_thread_mode ? nullptr : session_state.GetInterOpThreadPool();

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/sequential_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode);
bool single_thread_mode,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
Expand Down
33 changes: 19 additions & 14 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ ExecuteGraphImpl(const SessionState& session_state,
DeviceStreamCollection* device_stream_collection,
#endif
const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr) {
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -631,7 +632,8 @@ ExecuteGraphImpl(const SessionState& session_state,
terminate_flag,
only_execute_path_to_fetches,
// single thread mode
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
} else {
auto feeds_to_use = feeds;
Expand Down Expand Up @@ -679,7 +681,8 @@ ExecuteGraphImpl(const SessionState& session_state,
#endif
terminate_flag,
only_execute_path_to_fetches,
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
InlinedVector<Stream*> fetches_streams;
fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size());
Expand Down Expand Up @@ -717,7 +720,8 @@ common::Status ExecuteGraph(const SessionState& session_state,
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
Stream* parent_stream,
profiling::Profiler* run_profiler) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));

// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
Expand All @@ -728,13 +732,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
return retval;
#else
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
#endif
}

Expand All @@ -745,17 +751,16 @@ common::Status ExecuteGraph(const SessionState& session_state,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
const logging::Logger& logger,
profiling::Profiler* run_profiler) {
return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches,
execution_mode, run_options.terminate, logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
run_options.only_execute_path_to_fetches,
nullptr,
run_profiler);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a number of things being passed from RunOptions here. Can we modify the signature in a way that a reference to RunOptions is being passed?

Then we can instantiate the profiler higher in the stack, inside ExecuteGraph?

I can see that RunOptions are being passed in one of the overloads, that seems sensible.

}

#ifdef ENABLE_TRAINING
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr);

common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);
const logging::Logger& logger,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
Expand Down
Loading
Loading