Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release backward inputs per static graph ref count #20804

Merged
merged 9 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ class KernelDefBuilder {
/**
Specify that this kernel's output buffers are passed from external,
i.e. not created or managed by ORT's memory allocator.

The OrtValue set as external outputs, must be safe to release as long as the OrtValue's reference
count reaches zero in ORT's allocation/deallocation plan. We usually create such an OrtValue
following flows: torch tensors --> to dlpack tensors (destructor will release a view of original torch tensor,
instead of releasing original torch tensor) --> to OrtValue.

When the OrtValue is not needed in the graph, then it will be released after calling the attached
destructor. The destructor will release the view of the original torch tensor, instead of releasing the original
torch tensor. This is to make sure the original torch tensor can still be okay to use externally,
even after OrtValue is released in the graph. (Recalled this OrtValue is also not reused by ORT).
*/
KernelDefBuilder& ExternalOutputs() {
kernel_def_->external_outputs_ = true;
Expand Down
206 changes: 158 additions & 48 deletions onnxruntime/core/framework/allocation_planner.cc

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions onnxruntime/core/framework/partial_graph_execution_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@
return device_stream_collection_.get();
}

StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, gsl::span<const OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs, std::vector<OrtValue>& fetches,
StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs,
std::vector<OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,

Check warning on line 65 in onnxruntime/core/framework/partial_graph_execution_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/partial_graph_execution_state.cc:65: Add #include <vector> for vector<> [build/include_what_you_use] [4]
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state,
const logging::Logger& sess_logger,
Expand All @@ -81,15 +83,15 @@
execution_plan->num_barriers,
device_streams,
feed_mlvalue_idxs,
feeds,
std::move(feeds),
pengwa marked this conversation as resolved.
Show resolved Hide resolved
fetch_mlvalue_idxs,
fetches,
fetch_allocators,
sess_logger,
// partial executor in training can only be run with single thread
true);
} else {
execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, feeds);
execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, std::move(feeds));

Check warning on line 94 in onnxruntime/core/framework/partial_graph_execution_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/partial_graph_execution_state.cc:94: Add #include <utility> for move [build/include_what_you_use] [4]
execution_context_->GetExecutionFrame().UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
execution_context_->SetLogger(sess_logger);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/partial_graph_execution_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct PartialGraphExecutionState {

ProgramRegion& GetProgramRegions(const SessionState& session_state);

StreamExecutionContext& GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, gsl::span<const OrtValue>& feeds,
// Be noted: feeds will be std::move to ctx, so it will be empty after this function.
StreamExecutionContext& GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, std::vector<OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state,
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& feeds,
gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>&
fetch_allocators,
Expand All @@ -626,8 +627,10 @@ onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl
PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
int32_t partial_graph_index) {
// Be noted: feeds will be std::move to ctx, so it will be empty after this function.
auto& ctx = state.GetExecutionContext(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state, logger, device_streams);

auto* plan = session_state.GetExecutionPlan();

ctx.SetCurrentRange(&state.GetProgramRegions(session_state));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/sequential_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const logging::Logger& logger,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ common::Status ExecuteGraph(const SessionState& session_state,

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
DeviceStreamCollection* device_stream_collection,
Expand Down Expand Up @@ -882,7 +882,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
}

common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
int32_t partial_graph_index,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
const bool& terminate_flag,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/gemm_sum_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Status GemmSumFusion::Apply(Graph& graph, Node& gemm_node, RewriteRuleEffect& mo
std::vector<NodeArg*> new_gemm_output_defs = sum_node.MutableOutputDefs();
ORT_ENFORCE(new_gemm_output_defs.size() == 1);

Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_sum_transformed"),
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmSumFusion/"),
gemm_node.OpType(),
"Fused Gemm with Sum",
new_gemm_input_defs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName("gemm"),
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,7 @@ common::Status InferenceSession::ValidateOutputs(gsl::span<const std::string> ou

#ifdef ENABLE_TRAINING
Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@
/**
* Partially run a pre-loaded and pre-intialized model.
* @param run_options run options.
* @param feeds inputs owned by client code and should not be changed during
* execution of this function.
* @param mutable_feeds inputs owned by client code and will be released as long as the feeds be set in session states.
* Then the feeds will purely managed in the session states.
* @param fetches outputs produced after the executin of this function.

Check warning on line 382 in onnxruntime/core/session/inference_session.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "executin" is a misspelling of "execution" Raw Output: ./onnxruntime/core/session/inference_session.h:382:47: "executin" is a misspelling of "execution"
* @param state State of the graph needed to resume partial graph run.
* @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device
* copy/checks.
Expand All @@ -388,7 +388,7 @@
* @param partial_graph_index Index of the partial graph to run.
*/
common::Status PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& mutable_feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager,
Expand Down
94 changes: 72 additions & 22 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ class PlannerTest : public ::testing::Test {
EXPECT_EQ(plan_->execution_plan.size(), 1U);
int list_size = static_cast<int>(plan_->node_release_list.size());
EXPECT_GT(list_size, step_number);
for (auto freed : plan_->node_release_list[step_number]) {
plan_result.insert(static_cast<int>(freed));
for (auto action_idx : plan_->node_release_list[step_number]) {
const size_t ortvalue_id = plan_->release_actions[action_idx].value_index;
plan_result.insert(static_cast<int>(ortvalue_id));
}
EXPECT_EQ(plan_result, expected) << "Freed items incorrect for step " << step_number;
}
Expand Down Expand Up @@ -433,7 +434,7 @@ TEST_F(PlannerTest, ChainTest) {
CreatePlan();

// Expected plan:
// W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); X is returned output
// W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); Z is returned output
CheckAllocKind(W, AllocKind::kAllocateStatically);
CheckAllocKind(X, AllocKind::kAllocate);
CheckAllocKind(B, AllocKind::kAllocate);
Expand All @@ -442,8 +443,8 @@ TEST_F(PlannerTest, ChainTest) {

CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {"X"});
CheckFreed(3, {"W"});
CheckFreed(2, {B});
CheckFreed(3, {X});
}

/* InputOutputTest: Test that:
Expand Down Expand Up @@ -510,7 +511,7 @@ TEST_F(PlannerTest, InPlaceTest) {
// check each ml-value is freed at appropriate step
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X1});
CheckFreed(2, {X2});
}

TEST_F(PlannerTest, ExternalOutputsTest) {
Expand All @@ -536,10 +537,42 @@ TEST_F(PlannerTest, ExternalOutputsTest) {
CheckAllocKind(X4, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.
// X2 will not be reused but will be freed (to release the current reference). X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X1});
CheckFreed(1, {X2});
CheckFreed(2, {X3});
}

TEST_F(PlannerTest, ExternalOutputsNoReuseTest) {
// tensor variables:
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"), X5("X5");

// graph structure:
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary
AddNormalNode(X3, X4); // normal operator; X4: temporary
AddNormalNode(X4, X5); // normal operator; X5: output

// simulate shape-inference results:
Shape shape1{"M", "N"};
auto shape = &shape1.value;
SetShape({{X1, shape}, {X2, shape}, {X3, shape}, {X4, shape}, {X5, shape}});

CreatePlan();

// check allocation kind:
CheckAllocKind(X1, AllocKind::kPreExisting);
CheckAllocKind(X2, AllocKind::kAllocatedExternally);
CheckAllocKind(X3, AllocKind::kAllocate); // Should not be Reused.
CheckAllocKind(X4, AllocKind::kAllocate);
CheckAllocKind(X5, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused. X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {X2});
CheckFreed(2, {X3});
CheckFreed(3, {X4});
}

#ifdef ENABLE_STRIDED_TENSORS
Expand All @@ -564,19 +597,19 @@ TEST_F(PlannerTest, MayStridedTest1) {
CheckAllocKind(X3, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.
// X2 will not be reused because X3 is a graph output. X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {X1});
CheckFreed(1, {X2});
}

TEST_F(PlannerTest, MayStridedTest2) {
// tensor variables:
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");

// graph structure:
AddMayStridedOutputNode(X1, X2);
AddMayStridedInputNode(X2, X3);
AddMayStridedInputNode(X2, X4);
AddMayStridedOutputNode(X1, X2); // X2 can reuse X1, and is a strided output.
AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse.
AddMayStridedInputNode(X2, X4); // X4 is a graph output, cannot reuse.

// simulate shape-inference results:
Shape shape1{"M", "N"};
Expand All @@ -603,8 +636,9 @@ TEST_F(PlannerTest, MayStridedTest3) {
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");

// graph structure:
AddMayStridedOutputNode(X1, X2);
AddMayStridedInputNode(X2, X3);
AddMayStridedOutputNode(X1, X2); // X2 cannot strided reuse X1 because,
// one of X2's consumers is a node not supporting strided input. So X2 is a allocate.
AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse.
AddNormalNode(X2, X4);

// simulate shape-inference results:
Expand All @@ -620,11 +654,27 @@ TEST_F(PlannerTest, MayStridedTest3) {
CheckAllocKind(X3, AllocKind::kAllocateOutput);
CheckAllocKind(X4, AllocKind::kAllocateOutput);

// Be noted: the last two nodes added can run in two different orders, we need figure out the exact order
// we planned then we know how to check the free order.
const GraphViewer& graph_viewer = GetState().GetGraphViewer();
// Normal node index is 2.
bool does_normal_node_run_at_last = graph_viewer.GetNodesInTopologicalOrder()[2] == 2;

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.

CheckFreed(0, {});
CheckFreed(1, {X1});
CheckFreed(2, {});

if (does_normal_node_run_at_last) {
// Normal node has node index to be 2, but it is possible that the normal node is executed after the strided node.
// Then X2 will released once the normal node is executed.
CheckFreed(1, {});
CheckFreed(2, {X2});
} else {
// Normal node has node index to be 2, and is executed before the strided node.
// So X2 will be released after the strided node (node index to be 1) is executed.
CheckFreed(2, {});
CheckFreed(1, {X2});
}
}
#endif

Expand All @@ -637,7 +687,7 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) {
// graph structure:
AddNormalNode(X1, X2); // no in-place operator; X1: input; X2: temporary
AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary
AddNormalNode(X3, X4); // no in-place operator; X4: temporary
AddNormalNode(X3, X4); // no in-place operator; X4: temporary (reuse X2)
AddInplaceNode(X4, X5); // may-in-place operator; X5 output

// simulate shape-inference results:
Expand All @@ -659,8 +709,8 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) {
// check each ml-value is freed at appropriate step
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X2});
CheckFreed(3, {X1});
CheckFreed(2, {X3});
CheckFreed(3, {X2});
}

// Test operator<< to output details of an allocation & execution plan.
Expand Down
6 changes: 3 additions & 3 deletions orttraining/orttraining/core/agent/training_agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session,

TrainingAgent::~TrainingAgent() = default;

common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunForward(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) {
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
inference_session_.GetMemoryProfiler().GetMemoryInfo().SetIteration(profile_step_);
Expand All @@ -74,15 +74,15 @@ common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache, partial_graph_index);
}

common::Status TrainingAgent::RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunBackward(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) {
state.SetProgramCounterStart(fw_program_counter_end_);
state.SetProgramCounterEnd(bw_program_counter_end_);
constexpr int32_t partial_graph_index = 1;
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr, partial_graph_index);
}

common::Status TrainingAgent::RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunCore(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache, int32_t partial_graph_index) {
auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size();
Expand Down
Loading
Loading