Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ struct OrtRunOptions {
// So it is possible that only some of the nodes are executed.
bool only_execute_path_to_fetches = false;

// Set to 'true' to enable profiling for this run.
bool enable_profiling = false;

// File prefix for profiling result for this run.
// The actual filename will be: <profile_file_prefix>_<timestamp>.json
// Only used when enable_profiling is true.
std::basic_string<ORTCHAR_T> profile_file_prefix = ORT_TSTR("onnxruntime_run_profile");

#ifdef ENABLE_TRAINING
// Used by onnxruntime::training::TrainingSession. This class is now deprecated.
// Delete training_mode when TrainingSession is deleted.
Expand Down
12 changes: 12 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7013,6 +7013,18 @@ struct OrtApi {
* \since Version 1.24
*/
ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out);

/** \brief Enable profiling for this run
*
* \param[in] options
* \param[in] profile_file_prefix The prefix for the profile file. The actual filename will be:
* <profile_file_prefix>_<timestamp>.json
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
};

/*
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,13 @@ struct RunOptions : detail::Base<OrtRunOptions> {
* \param adapter The LoraAdapter to be used as the active adapter
*/
RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);

/** \brief Enable profiling for this run
*
* Wraps OrtApi::RunOptionsEnableProfiling
* \param profile_file_prefix The prefix for the profile file
*/
RunOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
};

namespace detail {
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter)
return *this;
}

inline RunOptions& RunOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
ThrowOnError(GetApi().RunOptionsEnableProfiling(p_, profile_file_prefix));
return *this;
}

inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) {
ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
}
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/core/common/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,17 @@
template void Profiler::StartProfiling<wchar_t>(const std::basic_string<wchar_t>& file_name);
#endif

void Profiler::EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args,
bool /*sync_gpu*/) {
void Profiler::EndTimeAndRecordEvent(
EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
std::unordered_map<std::string, std::string> event_args,

Check warning on line 78 in onnxruntime/core/common/profiler.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/common/profiler.cc:78: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
bool /*sync_gpu*/) {
long long dur = TimeDiffMicroSeconds(start_time);
long long ts = TimeDiffMicroSeconds(profiling_start_time_, start_time);

EventRecord event(category, logging::GetProcessId(),
logging::GetThreadId(), event_name, ts, dur, {event_args.begin(), event_args.end()});
logging::GetThreadId(), event_name, ts, dur, std::move(event_args));
if (profile_with_logger_) {
custom_logger_->SendProfileEvent(event);
} else {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/common/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
void EndTimeAndRecordEvent(EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
const std::initializer_list<std::pair<std::string, std::string>>& event_args = {},
std::unordered_map<std::string, std::string> event_args = {},

Check warning on line 80 in onnxruntime/core/common/profiler.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/common/profiler.h:80: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
bool sync_gpu = false);

/*
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/framework/run_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptio
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::RunOptionsEnableProfiling, _Inout_ OrtRunOptions* options,
_In_ const ORTCHAR_T* profile_file_prefix) {
options->enable_profiling = true;
options->profile_file_prefix = profile_file_prefix;
return nullptr;
}
107 changes: 75 additions & 32 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ std::string ComposeSeriesName(const GraphViewer& graph_viewer) {
class SessionScope {
public:
friend class KernelScope;
SessionScope(const SessionState& session_state, const ExecutionFrame& frame)
: session_state_(session_state)
SessionScope(const SessionState& session_state, const ExecutionFrame& frame, profiling::Profiler* run_profiler = nullptr)
: session_state_(session_state), run_profiler_(run_profiler)
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
,
frame_(frame)
Expand All @@ -173,12 +173,16 @@ class SessionScope {
#endif
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
,
dump_context_{
session_state_.GetGraphExecutionCounter(), 0}
dump_context_{session_state_.GetGraphExecutionCounter(), 0}
#endif
{
if (session_state_.Profiler().IsEnabled()) {
session_start_ = session_state.Profiler().Start();
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = IsRunProfilingEnabled();

if (session_profiling_enabled) {
session_start_ = session_state_.Profiler().Start();
} else if (run_profiling_enabled) {
session_start_ = run_profiler_->Start();
}

auto& logger = session_state_.Logger();
Expand Down Expand Up @@ -225,9 +229,15 @@ class SessionScope {
}
#endif

if (session_state_.Profiler().IsEnabled()) {
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = run_profiler_ && run_profiler_->IsEnabled();

if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_);
} else if (run_profiling_enabled) {
StopEvent(profiling::SESSION_EVENT, "SequentialExecutor::Execute", session_start_);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this into a function StopProfilingIfEnabled()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added StopProfilingIfEnabled and StartProfilingIfEnabled as suggested. Done!


#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
auto& logger = session_state_.Logger();
for (auto i : frame_.GetStaticMemorySizeInfo()) {
Expand All @@ -252,8 +262,24 @@ class SessionScope {
}
#endif

bool IsRunProfilingEnabled() const {
return run_profiler_ && run_profiler_->IsEnabled();
}

void StopEvent(profiling::EventCategory category,
const std::string& event_name,
const TimePoint& start_time,
std::unordered_map<std::string, std::string> event_args = {}) {
if (!run_profiler_) return;
run_profiler_->EndTimeAndRecordEvent(category,
event_name,
start_time,
std::move(event_args));
}

private:
const SessionState& session_state_;
profiling::Profiler* run_profiler_;
TimePoint session_start_;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
const ExecutionFrame& frame_;
Expand Down Expand Up @@ -340,16 +366,21 @@ class KernelScope {
utils::DumpNodeInputs(dump_context_, kernel_context_, kernel_.Node(), session_state_, session_scope_.dump_analysis_);
#endif

#ifdef ENABLE_NVTX_PROFILE
node_compute_range_.Begin();
#endif
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = session_scope_.IsRunProfilingEnabled();

if (session_state_.Profiler().IsEnabled()) {
if (session_profiling_enabled || run_profiling_enabled) {
auto& node = kernel.Node();
node_name_ = node.Name().empty() ? MakeString(node.OpType(), "_", node.Index()) : node.Name();
concurrency::ThreadPool::StartProfiling(session_state_.GetThreadPool());
VLOGS(session_state_.Logger(), 1) << "Computing kernel: " << node_name_;
kernel_begin_time_ = session_state_.Profiler().Start();

if (session_profiling_enabled) {
kernel_begin_time_ = session_state_.Profiler().Start();
} else {
kernel_begin_time_ = session_scope_.run_profiler_->Start();
}

CalculateTotalInputSizes(&kernel_context, &kernel_,
input_activation_sizes_, input_parameter_sizes_,
node_name_, input_type_shape_);
Expand All @@ -363,26 +394,37 @@ class KernelScope {
node_compute_range_.End();
#endif

if (session_state_.Profiler().IsEnabled()) {
auto& profiler = session_state_.Profiler();
const bool session_profiling_enabled = session_state_.Profiler().IsEnabled();
const bool run_profiling_enabled = session_scope_.IsRunProfilingEnabled();

if (session_profiling_enabled || run_profiling_enabled) {
std::string output_type_shape_;
CalculateTotalOutputSizes(&kernel_context_, total_output_sizes_, node_name_, output_type_shape_);
profiler.EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
// Log additional operation args / info.
{
{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
});

std::unordered_map<std::string, std::string> event_args = {
{"op_name", kernel_.KernelDef().OpName()},
{"provider", kernel_.KernelDef().Provider()},
{"node_index", std::to_string(kernel_.Node().Index())},
{"activation_size", std::to_string(input_activation_sizes_)},
{"parameter_size", std::to_string(input_parameter_sizes_)},
{"output_size", std::to_string(total_output_sizes_)},
{"input_type_shape", input_type_shape_},
{"output_type_shape", output_type_shape_},
{"thread_scheduling_stats",
concurrency::ThreadPool::StopProfiling(session_state_.GetThreadPool())},
};

if (session_profiling_enabled) {
session_state_.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
std::move(event_args));
} else if (run_profiling_enabled) {
session_scope_.StopEvent(profiling::NODE_EVENT,
node_name_ + "_kernel_time",
kernel_begin_time_,
std::move(event_args));
}
}

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
Expand Down Expand Up @@ -588,7 +630,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode) {
bool single_thread_mode,
profiling::Profiler* run_profiler) {
auto* execution_plan = session_state.GetExecutionPlan();
VLOGS(logger, 0) << "Number of streams: " << execution_plan->execution_plan.size();
int32_t valid_streams = 0;
Expand Down Expand Up @@ -631,7 +674,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
ORT_UNUSED_PARAMETER(only_execute_path_to_fetches);
#endif

SessionScope session_scope(session_state, ctx.GetExecutionFrame());
SessionScope session_scope(session_state, ctx.GetExecutionFrame(), run_profiler);

auto* tp = single_thread_mode ? nullptr : session_state.GetInterOpThreadPool();

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/sequential_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<
#endif
const bool& terminate_flag,
const bool only_execute_path_to_fetches,
bool single_thread_mode);
bool single_thread_mode,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
Expand Down
33 changes: 19 additions & 14 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ ExecuteGraphImpl(const SessionState& session_state,
DeviceStreamCollection* device_stream_collection,
#endif
const bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr) {
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr) {
const auto& feeds_fetches_info = feeds_fetches_manager.GetFeedsFetchesInfo();
const auto& device_copy_checks = feeds_fetches_manager.GetDeviceCopyChecks();
#ifdef ORT_ENABLE_STREAM
Expand Down Expand Up @@ -631,7 +632,8 @@ ExecuteGraphImpl(const SessionState& session_state,
terminate_flag,
only_execute_path_to_fetches,
// single thread mode
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
} else {
auto feeds_to_use = feeds;
Expand Down Expand Up @@ -679,7 +681,8 @@ ExecuteGraphImpl(const SessionState& session_state,
#endif
terminate_flag,
only_execute_path_to_fetches,
single_thread_mode));
single_thread_mode,
run_profiler));
ORT_RETURN_IF_ERROR(status);
InlinedVector<Stream*> fetches_streams;
fetches_streams.reserve(feeds_fetches_info.fetches_mlvalue_idxs.size());
Expand Down Expand Up @@ -717,7 +720,8 @@ common::Status ExecuteGraph(const SessionState& session_state,
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
Stream* parent_stream,
profiling::Profiler* run_profiler) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));

// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
Expand All @@ -728,13 +732,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
return retval;
#else
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
parent_stream);
parent_stream,
run_profiler);
#endif
}

Expand All @@ -745,17 +751,16 @@ common::Status ExecuteGraph(const SessionState& session_state,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
const logging::Logger& logger,
profiling::Profiler* run_profiler) {
return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches,
execution_mode, run_options.terminate, logger,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
run_options.only_execute_path_to_fetches,
nullptr,
run_profiler);
}

#ifdef ENABLE_TRAINING
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);
Stream* parent_stream = nullptr,
profiling::Profiler* run_profiler = nullptr);

common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);
const logging::Logger& logger,
profiling::Profiler* run_profiler = nullptr);

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
Expand Down
Loading
Loading