Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ struct OrtRunOptions {
// So it is possible that only some of the nodes are executed.
bool only_execute_path_to_fetches = false;

// Set to 'true' to enable profiling for this run.
bool enable_profiling = false;

// File prefix for profiling result for this run.
// The actual filename will be: <profile_file_prefix>_<timestamp>.json
// Only used when enable_profiling is true.
std::basic_string<ORTCHAR_T> profile_file_prefix = ORT_TSTR("onnxruntime_run_profile");

#ifdef ENABLE_TRAINING
// Used by onnxruntime::training::TrainingSession. This class is now deprecated.
// Delete training_mode when TrainingSession is deleted.
Expand Down
12 changes: 12 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7013,6 +7013,18 @@ struct OrtApi {
* \since Version 1.24
*/
ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out);

/** \brief Enable profiling for this run
*
* \param[in] options
* \param[in] profile_file_prefix The prefix for the profile file. The actual filename will be:
* <profile_file_prefix>_<timestamp>.json
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
};

/*
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,13 @@ struct RunOptions : detail::Base<OrtRunOptions> {
* \param adapter The LoraAdapter to be used as the active adapter
*/
RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);

/** \brief Enable profiling for this run
*
* Wraps OrtApi::RunOptionsEnableProfiling
* \param profile_file_prefix The prefix for the profile file
*/
RunOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
};

namespace detail {
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter)
return *this;
}

inline RunOptions& RunOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
ThrowOnError(GetApi().RunOptionsEnableProfiling(p_, profile_file_prefix));
return *this;
}

inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) {
ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
}
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/core/common/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,17 @@
template void Profiler::StartProfiling<wchar_t>(const std::basic_string<wchar_t>& file_name);
#endif

void Profiler::EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args,
bool /*sync_gpu*/) {
void Profiler::EndTimeAndRecordEvent(
EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
std::unordered_map<std::string, std::string> event_args,

Check warning on line 78 in onnxruntime/core/common/profiler.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/common/profiler.cc:78: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
bool /*sync_gpu*/) {
long long dur = TimeDiffMicroSeconds(start_time);
long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);

EventRecord event(category, logging::GetProcessId(),
logging::GetThreadId(), event_name, ts, dur, {event_args.begin(), event_args.end()});
logging::GetThreadId(), event_name, ts, dur, std::move(event_args));
if (profile_with_logger_) {
custom_logger_->SendProfileEvent(event);
} else {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/common/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
void EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args = {},
std::unordered_map<std::string, std::string> event_args = {},

Check warning on line 80 in onnxruntime/core/common/profiler.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/common/profiler.h:80: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

td::unordered_map<std::string, std::string> event_args

InlinedHashMap

bool sync_gpu = false);

/*
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/framework/run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptio
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options,
_In_ const ORTCHAR_T* profile_file_prefix) {
options->enable_profiling = true;
options->profile_file_prefix = profile_file_prefix;
return nullptr;
}
107 changes: 75 additions & 32 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ std::string ComposeSeriesName(const GraphViewer& graph_viewer) {
class SessionScope {
public:
friend class KernelScope;
SessionScope(const SessionState& session_state, const ExecutionFrame& frame)
: session_state_(session_state)
SessionScope(const SessionState& session_state, const ExecutionFrame& frame, profiling::Profiler* run_profiler = nullptr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

profiling::Profiler* run_profiler = nullpt

Let; get rid of the default value and pass it explicitly.

: session_state_(session_state), run_profiler_(run_profiler)
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
,
frame_(frame)
Expand All @@ -173,12 +173,16 @@ class SessionScope {
#endif
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
,
dump_context_{
session_state_.GetGraphExecutionCounter(), 0}
dump_context_{session_state_.GetGraphExecutionCounter(), 0}
#endif
{
if (session_state_.Profiler().IsEnabled()) {
session_start_ = session_state.Profiler().Start();
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = IsRunProfilingEnabled();

if (session_profiling_enabled) {
session_start_ = session_state_.Profiler().Start();
} else if (run_profiling_enabled) {
session_start_ = run_profiler_->Start();
}

auto& logger = session_state_.Logger();
Expand Down Expand Up @@ -225,9 +229,15 @@ class SessionScope {
}
#endif

if (session_state_.Profiler().IsEnabled()) {
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();

if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_);
} else if (run_profiling_enabled) {
StopEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this into a function StopProfilingIfEnabled()


#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
auto& logger = session_state_.Logger();
for (auto i : frame_.GetStaticMemorySizeInfo()) {
Expand All @@ -252,8 +262,24 @@ class SessionScope {
}
#endif

bool IsRunProfilingEnabled() const {
return run_profiler_ && run_profiler_->IsEnabled();
}

void StopEvent(profiling::EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
std::unordered_map<std::string, std::string> event_args = {}) {
if (!run_profiler_) return;
run_profiler_->EndTimeAndRecordEvent(category,
event_name,
start_time,
std::move(event_args));
}

private:
const SessionState& session_state_;
profiling::Profiler* run_profiler_;
TimePoint session_start_;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
const ExecutionFrame& frame_;
Expand Down Expand Up @@ -340,16 +366,21 @@ class KernelScope {
utils::DumpNodeInputs(dump_context_, kernel_context_, kernel_.Node(), session_state_, session_scope_.dump_analysis_);
#endif

#ifdef ENABLE_NVTX_PROFILE
node_compute_range_.Begin();
#endif
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = session_scope_.IsRunProfilingEnabled();

if (session_state_.Profiler().IsEnabled()) {
if (session_profiling_enabled || run_profiling_enabled) {
auto& node = kernel.Node();
node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name();
concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool());
VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_;
kernel_begin_time_ = session_state_.Profiler().Start();

if (session_profiling_enabled) {
kernel_begin_time_ = session_state_.Profiler().Start();
} else {
kernel_begin_time_ = session_scope_.run_profiler_->Start();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this into a SessionScopt method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Added StopEvent in SessionScope.


CalculateTotalInputSizes(&kernel_context, &kernel_,
input_activation_sizes_, input_parameter_sizes_,
node_name_, input_type_shape_);
Expand All @@ -363,26 +394,37 @@ class KernelScope {
node_compute_range_.End();
#endif

if (session_state_.Profiler().IsEnabled()) {
auto& profiler = session_state_.Profiler();
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = session_scope_.IsRunProfilingEnabled();

if (session_profiling_enabled || run_profiling_enabled) {
std::string output_type_shape_;
CalculateTotalOutputSizes(&kernel_context_, total_output_sizes_, node_name_, output_type_shape_);
profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
// Log additional operation args / info.
{
{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
});

std::unordered_map<std::string, std::string> event_args = {
{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
};

if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
std::move(event_args));
} else if (run_profiling_enabled) {
session_scope_.StopEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
std::move(event_args));
}
}

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
Expand Down Expand Up @@ -588,7 +630,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode) {
bool single_thread_mode,
profiling::Profiler* run_profiler) {
auto* execution_plan = session_state.GetExecutionPlan();
VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size();
int32_t valid_streams = 0;
Expand Down Expand Up @@ -631,7 +674,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
ORT_UNUSED_PARAMETER(only_execute_path_to_fetches);
#endif

SessionScope session_scope(session_state, ctx.GetExecutionFrame());
SessionScope session_scope(session_state, ctx.GetExecutionFrame(), run_profiler);

auto* tp = single_thread_mode ? nullptr : session_state.GetInterOpThreadPool();

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/sequential_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode);
bool single_thread_mode,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
Expand Down
33 changes: 19 additions & 14 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ ExecuteGraphImpl(const SessionState& session_state,
DeviceStreamCollection* device_stream_collection,
#endif
const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr) {
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -631,7 +632,8 @@ ExecuteGraphImpl(const SessionState& session_state,
terminate_flag,
only_execute_path_to_fetches,
// single thread mode
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
} else {
auto feeds_to_use = feeds;
Expand Down Expand Up @@ -679,7 +681,8 @@ ExecuteGraphImpl(const SessionState& session_state,
#endif
terminate_flag,
only_execute_path_to_fetches,
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
InlinedVector<Stream*> fetches_streams;
fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size());
Expand Down Expand Up @@ -717,7 +720,8 @@ common::Status ExecuteGraph(const SessionState& session_state,
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
Stream* parent_stream,
profiling::Profiler* run_profiler) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));

// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
Expand All @@ -728,13 +732,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
return retval;
#else
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
#endif
}

Expand All @@ -745,17 +751,16 @@ common::Status ExecuteGraph(const SessionState& session_state,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
const logging::Logger& logger,
profiling::Profiler* run_profiler) {
return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches,
execution_mode, run_options.terminate, logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
run_options.only_execute_path_to_fetches,
nullptr,
run_profiler);
}

#ifdef ENABLE_TRAINING
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr);

common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);
const logging::Logger& logger,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
Expand Down
Loading
Loading