Skip to content

Commit

Permalink
Handle missing initializers in allocation planner to fix crashes with…
Browse files Browse the repository at this point in the history
… DML provider (#5244)

* Fix memory planning bug with DML EP

* Address PR comments

* Fix typo
  • Loading branch information
jeffbloo authored and tianleiwu committed Sep 23, 2020
1 parent b648fe5 commit 389cca7
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,13 @@ class PlannerImpl {
struct OrtValueInfo {
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
int usecount = 0; // static reference-count
OrtValueIndex reused_buffer_index; // index of original buffer to reuse

// This is initialized to -1 to ensure that if ProcessDef is somehow not called, planning
// will fail more cleanly. This is also used as a temporary workaround to detect the
// case that the DML provider has removed initilizers from the graph during partitioning.
// Removing initializers is a temporary measure needed to limit the number of copies of
// tensors in GPU memory.
OrtValueIndex reused_buffer_index = -1; // index of original buffer to reuse
};

// ort_value_info_ is indexed by an OrtValueIndex
Expand Down Expand Up @@ -177,6 +183,12 @@ class PlannerImpl {
}
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }

int DecrementUseCount(OrtValueIndex n) {
int& use_count = --UseCount(n);
assert(use_count >= 0);
return use_count;
}

OrtValueIndex& Buffer(OrtValueIndex n) {
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size());
return ort_value_info_[n].reused_buffer_index;
Expand Down Expand Up @@ -643,7 +655,9 @@ class PlannerImpl {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = Buffer(Index(sym));
if (0 == --UseCount(original))
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original)))
freelist_.push_front(FreeBufferInfo(original, program_counter));
}
}
Expand All @@ -652,7 +666,9 @@ class PlannerImpl {
if (node_input->Exists()) {
auto& sym = node_input->Name();
auto original = Buffer(Index(sym));
if (0 == --UseCount(original))
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original)))
freelist_.push_front(FreeBufferInfo(original, program_counter));
}
}
Expand All @@ -662,7 +678,7 @@ class PlannerImpl {
if (node_output->Exists()) {
auto& sym = node_output->Name();
auto original = Buffer(Index(sym));
if (0 == --UseCount(original))
if (0 == DecrementUseCount(original))
freelist_.push_front(FreeBufferInfo(original, program_counter));
}
}
Expand Down

0 comments on commit 389cca7

Please sign in to comment.