Skip to content

Commit

Permalink
Release backward inputs per static graph ref count (#20804)
Browse files Browse the repository at this point in the history
### Release backward inputs per static graph ref count

For the output buffer marked as external output:
1. Remove the additional ref count we used for avoiding reusing buffer.
Instead, when we find reuse input/output buffer, we will make sure the
reused buffer not not generated by nodes that has external outputs.
2. Remove the ref count of pybind feed inputs, which exists all the time
until the run_backward completed. Instead, passing a mutuble feeds, and
we clean the feeds vector once that is copied into session states and
not needed any more before run the graph sequencentially.

#### Before the change:

One of the backward inputs is 3.9GB, it lives until the backward ends. 

![image](https://github.com/microsoft/onnxruntime/assets/10530022/e71e2072-eaaa-4be3-a39f-0ca74b507265)

#### With the change:
The 3.9GB is released when the last node depending on that tensor
completed.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/7b27d01f-c675-4faf-9a3e-f886b31b2afe)


Be noted: the peak did not change though, we have more work to do to
reduce on the peak.


#### Others

It is found there are few tests that were updated to use incorrect
expected values in previous code refactoring
a81faee#diff-9e8fbae7d3dff24106cd17564949f320e943cb3048eae07813c7de144f140419L382.

This PR tries to fix them back, and I think now all test cases are back
to normal.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Jun 14, 2024
1 parent fff68c3 commit 87b14ac
Show file tree
Hide file tree
Showing 16 changed files with 274 additions and 94 deletions.
10 changes: 10 additions & 0 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ class KernelDefBuilder {
/**
Specify that this kernel's output buffers are passed from external,
i.e. not created or managed by ORT's memory allocator.
The OrtValue set as external outputs, must be safe to release as long as the OrtValue's reference
count reaches zero in ORT's allocation/deallocation plan. We usually create such an OrtValue
following flows: torch tensors --> to dlpack tensors (destructor will release a view of original torch tensor,
instead of releasing original torch tensor) --> to OrtValue.
When the OrtValue is not needed in the graph, then it will be released after calling the attached
destructor. The destructor will release the view of the original torch tensor, instead of releasing the original
torch tensor. This is to make sure the original torch tensor can still be okay to use externally,
even after OrtValue is released in the graph. (Recalled this OrtValue is also not reused by ORT).
*/
KernelDefBuilder& ExternalOutputs() {
kernel_def_->external_outputs_ = true;
Expand Down
202 changes: 156 additions & 46 deletions onnxruntime/core/framework/allocation_planner.cc

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions onnxruntime/core/framework/partial_graph_execution_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ DeviceStreamCollection* PartialGraphExecutionState::GetDeviceStreamCollection(co
return device_stream_collection_.get();
}

StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, gsl::span<const OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs, std::vector<OrtValue>& fetches,
StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs,
std::vector<OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state,
const logging::Logger& sess_logger,
Expand All @@ -81,15 +83,15 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa
execution_plan->num_barriers,
device_streams,
feed_mlvalue_idxs,
feeds,
std::move(feeds),
fetch_mlvalue_idxs,
fetches,
fetch_allocators,
sess_logger,
// partial executor in training can only be run with single thread
true);
} else {
execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, feeds);
execution_context_->GetExecutionFrame().UpdateFeeds(feed_mlvalue_idxs, std::move(feeds));
execution_context_->GetExecutionFrame().UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
execution_context_->SetLogger(sess_logger);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/partial_graph_execution_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct PartialGraphExecutionState {

ProgramRegion& GetProgramRegions(const SessionState& session_state);

StreamExecutionContext& GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, gsl::span<const OrtValue>& feeds,
// Be noted: feeds will be std::move to ctx, so it will be empty after this function.
StreamExecutionContext& GetExecutionContext(gsl::span<const int>& feed_mlvalue_idxs, std::vector<OrtValue>& feeds,
gsl::span<const int>& fetch_mlvalue_idxs, std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state,
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& feeds,
gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>&
fetch_allocators,
Expand All @@ -626,8 +627,10 @@ onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl
PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
int32_t partial_graph_index) {
// Be noted: feeds will be std::move to ctx, so it will be empty after this function.
auto& ctx = state.GetExecutionContext(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state, logger, device_streams);

auto* plan = session_state.GetExecutionPlan();

ctx.SetCurrentRange(&state.GetProgramRegions(session_state));
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/sequential_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span<

#ifdef ENABLE_TRAINING
onnxruntime::Status PartialExecuteThePlan(const SessionState& session_state, gsl::span<const int> feed_mlvalue_idxs,
gsl::span<const OrtValue> feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& feeds, gsl::span<const int> fetch_mlvalue_idxs,
std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const logging::Logger& logger,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ common::Status ExecuteGraph(const SessionState& session_state,

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
DeviceStreamCollection* device_stream_collection,
Expand Down Expand Up @@ -882,7 +882,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
}

common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache, const bool& terminate_flag,
int32_t partial_graph_index,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManag

#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
const logging::Logger& logger, PartialGraphExecutionState& state,
const OrtValueCachePtr& cache,
const bool& terminate_flag,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/gemm_sum_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Status GemmSumFusion::Apply(Graph& graph, Node& gemm_node, RewriteRuleEffect& mo
std::vector<NodeArg*> new_gemm_output_defs = sum_node.MutableOutputDefs();
ORT_ENFORCE(new_gemm_output_defs.size() == 1);

Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_sum_transformed"),
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmSumFusion/"),
gemm_node.OpType(),
"Fused Gemm with Sum",
new_gemm_input_defs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName("gemm"),
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,7 @@ common::Status InferenceSession::ValidateOutputs(gsl::span<const std::string> ou

#ifdef ENABLE_TRAINING
Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,8 @@ class InferenceSession {
/**
* Partially run a pre-loaded and pre-intialized model.
* @param run_options run options.
* @param feeds inputs owned by client code and should not be changed during
* execution of this function.
* @param mutable_feeds inputs owned by client code and will be released as long as the feeds be set in session states.
* Then the feeds will purely managed in the session states.
* @param fetches outputs produced after the executin of this function.
* @param state State of the graph needed to resume partial graph run.
* @param feeds_fetches_manager Contains feed/fetches name to internal indices mapping and information for device
Expand All @@ -388,7 +388,7 @@ class InferenceSession {
* @param partial_graph_index Index of the partial graph to run.
*/
common::Status PartialRun(onnxruntime::RunOptions& run_options,
const std::vector<OrtValue>& feeds,
std::vector<OrtValue>& mutable_feeds,
std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state,
FeedsFetchesManager& feeds_fetches_manager,
Expand Down
94 changes: 72 additions & 22 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ class PlannerTest : public ::testing::Test {
EXPECT_EQ(plan_->execution_plan.size(), 1U);
int list_size = static_cast<int>(plan_->node_release_list.size());
EXPECT_GT(list_size, step_number);
for (auto freed : plan_->node_release_list[step_number]) {
plan_result.insert(static_cast<int>(freed));
for (auto action_idx : plan_->node_release_list[step_number]) {
const size_t ortvalue_id = plan_->release_actions[action_idx].value_index;
plan_result.insert(static_cast<int>(ortvalue_id));
}
EXPECT_EQ(plan_result, expected) << "Freed items incorrect for step " << step_number;
}
Expand Down Expand Up @@ -433,7 +434,7 @@ TEST_F(PlannerTest, ChainTest) {
CreatePlan();

// Expected plan:
// W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); X is returned output
// W: kAllocateStatically; X: kAllocate; B: kAllocate; Y: kReuse (X); post-node3: free(B); Z is returned output
CheckAllocKind(W, AllocKind::kAllocateStatically);
CheckAllocKind(X, AllocKind::kAllocate);
CheckAllocKind(B, AllocKind::kAllocate);
Expand All @@ -442,8 +443,8 @@ TEST_F(PlannerTest, ChainTest) {

CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {"X"});
CheckFreed(3, {"W"});
CheckFreed(2, {B});
CheckFreed(3, {X});
}

/* InputOutputTest: Test that:
Expand Down Expand Up @@ -510,7 +511,7 @@ TEST_F(PlannerTest, InPlaceTest) {
// check each ml-value is freed at appropriate step
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X1});
CheckFreed(2, {X2});
}

TEST_F(PlannerTest, ExternalOutputsTest) {
Expand All @@ -536,10 +537,42 @@ TEST_F(PlannerTest, ExternalOutputsTest) {
CheckAllocKind(X4, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.
// X2 will not be reused but will be freed (to release the current reference). X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X1});
CheckFreed(1, {X2});
CheckFreed(2, {X3});
}

TEST_F(PlannerTest, ExternalOutputsNoReuseTest) {
// tensor variables:
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4"), X5("X5");

// graph structure:
AddExternalOutputsNode(X1, X2); // external-outputs operator; X1: input; X2: temporary
AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary
AddNormalNode(X3, X4); // normal operator; X4: temporary
AddNormalNode(X4, X5); // normal operator; X5: output

// simulate shape-inference results:
Shape shape1{"M", "N"};
auto shape = &shape1.value;
SetShape({{X1, shape}, {X2, shape}, {X3, shape}, {X4, shape}, {X5, shape}});

CreatePlan();

// check allocation kind:
CheckAllocKind(X1, AllocKind::kPreExisting);
CheckAllocKind(X2, AllocKind::kAllocatedExternally);
CheckAllocKind(X3, AllocKind::kAllocate); // Should not be Reused.
CheckAllocKind(X4, AllocKind::kAllocate);
CheckAllocKind(X5, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused. X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {X2});
CheckFreed(2, {X3});
CheckFreed(3, {X4});
}

#ifdef ENABLE_STRIDED_TENSORS
Expand All @@ -564,19 +597,19 @@ TEST_F(PlannerTest, MayStridedTest1) {
CheckAllocKind(X3, AllocKind::kAllocateOutput);

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.
// X2 will not be reused because X3 is a graph output. X3 will be allocated and will be freed.
CheckFreed(0, {});
CheckFreed(1, {X1});
CheckFreed(1, {X2});
}

TEST_F(PlannerTest, MayStridedTest2) {
// tensor variables:
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");

// graph structure:
AddMayStridedOutputNode(X1, X2);
AddMayStridedInputNode(X2, X3);
AddMayStridedInputNode(X2, X4);
AddMayStridedOutputNode(X1, X2); // X2 can reuse X1, and is a strided output.
AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse.
AddMayStridedInputNode(X2, X4); // X4 is a graph output, cannot reuse.

// simulate shape-inference results:
Shape shape1{"M", "N"};
Expand All @@ -603,8 +636,9 @@ TEST_F(PlannerTest, MayStridedTest3) {
std::string X1("X1"), X2("X2"), X3("X3"), X4("X4");

// graph structure:
AddMayStridedOutputNode(X1, X2);
AddMayStridedInputNode(X2, X3);
AddMayStridedOutputNode(X1, X2); // X2 cannot strided reuse X1 because,
// one of X2's consumers is a node not supporting strided input. So X2 is a allocate.
AddMayStridedInputNode(X2, X3); // X3 is a graph output, cannot reuse.
AddNormalNode(X2, X4);

// simulate shape-inference results:
Expand All @@ -620,11 +654,27 @@ TEST_F(PlannerTest, MayStridedTest3) {
CheckAllocKind(X3, AllocKind::kAllocateOutput);
CheckAllocKind(X4, AllocKind::kAllocateOutput);

// Be noted: the last two nodes added can run in two different orders, we need figure out the exact order
// we planned then we know how to check the free order.
const GraphViewer& graph_viewer = GetState().GetGraphViewer();
// Normal node index is 2.
bool does_normal_node_run_at_last = graph_viewer.GetNodesInTopologicalOrder()[2] == 2;

// check each ml-value is freed at appropriate step
// X2 will not be reused and will not be freed. X3 will be allocated and will be freed.

CheckFreed(0, {});
CheckFreed(1, {X1});
CheckFreed(2, {});

if (does_normal_node_run_at_last) {
// Normal node has node index to be 2, but it is possible that the normal node is executed after the strided node.
// Then X2 will released once the normal node is executed.
CheckFreed(1, {});
CheckFreed(2, {X2});
} else {
// Normal node has node index to be 2, and is executed before the strided node.
// So X2 will be released after the strided node (node index to be 1) is executed.
CheckFreed(2, {});
CheckFreed(1, {X2});
}
}
#endif

Expand All @@ -637,7 +687,7 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) {
// graph structure:
AddNormalNode(X1, X2); // no in-place operator; X1: input; X2: temporary
AddInplaceNode(X2, X3); // may-in-place operator; X3: temporary
AddNormalNode(X3, X4); // no in-place operator; X4: temporary
AddNormalNode(X3, X4); // no in-place operator; X4: temporary (reuse X2)
AddInplaceNode(X4, X5); // may-in-place operator; X5 output

// simulate shape-inference results:
Expand All @@ -659,8 +709,8 @@ TEST_F(PlannerTest, InPlaceSizeMismatchTest) {
// check each ml-value is freed at appropriate step
CheckFreed(0, {});
CheckFreed(1, {});
CheckFreed(2, {X2});
CheckFreed(3, {X1});
CheckFreed(2, {X3});
CheckFreed(3, {X2});
}

// Test operator<< to output details of an allocation & execution plan.
Expand Down
6 changes: 3 additions & 3 deletions orttraining/orttraining/core/agent/training_agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session,

TrainingAgent::~TrainingAgent() = default;

common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunForward(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, const OrtValueCachePtr& cache) {
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
inference_session_.GetMemoryProfiler().GetMemoryInfo().SetIteration(profile_step_);
Expand All @@ -74,15 +74,15 @@ common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std
return RunCore(feeds, fetches, state, *fw_feeds_fetches_manager_, cache, partial_graph_index);
}

common::Status TrainingAgent::RunBackward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunBackward(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) {
state.SetProgramCounterStart(fw_program_counter_end_);
state.SetProgramCounterEnd(bw_program_counter_end_);
constexpr int32_t partial_graph_index = 1;
return RunCore(feeds, fetches, state, *bw_feeds_fetches_manager_, nullptr, partial_graph_index);
}

common::Status TrainingAgent::RunCore(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
common::Status TrainingAgent::RunCore(std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state, FeedsFetchesManager& feeds_fetches_manager,
const OrtValueCachePtr& cache, int32_t partial_graph_index) {
auto fetches_size = feeds_fetches_manager.GetFeedsFetchesInfo().output_names.size();
Expand Down
Loading

0 comments on commit 87b14ac

Please sign in to comment.