diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 0d18b639b54ed..4e79618da8e96 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -144,7 +144,13 @@ class PlannerImpl { struct OrtValueInfo { const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue int usecount = 0; // static reference-count - OrtValueIndex reused_buffer_index; // index of original buffer to reuse + + // This is initialized to -1 to ensure that if ProcessDef is somehow not called, planning + // will fail more cleanly. This is also used as a temporary workaround to detect the + // case that the DML provider has removed initilizers from the graph during partitioning. + // Removing initializers is a temporary measure needed to limit the number of copies of + // tensors in GPU memory. + OrtValueIndex reused_buffer_index = -1; // index of original buffer to reuse }; // ort_value_info_ is indexed by an OrtValueIndex @@ -177,6 +183,12 @@ class PlannerImpl { } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } + int DecrementUseCount(OrtValueIndex n) { + int& use_count = --UseCount(n); + assert(use_count >= 0); + return use_count; + } + OrtValueIndex& Buffer(OrtValueIndex n) { ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size()); return ort_value_info_[n].reused_buffer_index; @@ -643,7 +655,9 @@ class PlannerImpl { if (node_input->Exists()) { auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); - if (0 == --UseCount(original)) + // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. + // See comments in the OrtValueInfo definition. + if ((original != -1) && (0 == DecrementUseCount(original))) freelist_.push_front(FreeBufferInfo(original, program_counter)); } } @@ -652,7 +666,9 @@ class PlannerImpl { if (node_input->Exists()) { auto& sym = node_input->Name(); auto original = Buffer(Index(sym)); - if (0 == --UseCount(original)) + // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. + // See comments in the OrtValueInfo definition. + if ((original != -1) && (0 == DecrementUseCount(original))) freelist_.push_front(FreeBufferInfo(original, program_counter)); } } @@ -662,7 +678,7 @@ class PlannerImpl { if (node_output->Exists()) { auto& sym = node_output->Name(); auto original = Buffer(Index(sym)); - if (0 == --UseCount(original)) + if (0 == DecrementUseCount(original)) freelist_.push_front(FreeBufferInfo(original, program_counter)); } }