diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 3468e2e55c7b6..dd0ed75f4782f 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1016,37 +1016,29 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_[i] = input_names_str_[i].c_str(); } - auto transform_fcn = std::function(); - auto new_value = std::function&, Ort::ConstTensorTypeAndShapeInfo&)>(); - if (device_memory_name_.empty()) { - transform_fcn = [](int64_t input) { return input; }; - new_value = [](OrtAllocator*, const std::vector&, Ort::ConstTensorTypeAndShapeInfo&) { - return Ort::Value(nullptr); - }; - } else { + if (!device_memory_name_.empty()) { + Ort::MemoryInfo memory_info(nullptr); // Default initialize, will be overwritten if (device_memory_name_ == CUDA) { - memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeDefault); + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeDefault); } else { - memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); } custom_allocator_ = Ort::Allocator(session_, memory_info); - // Switch to custom + // Switch to custom allocator allocator_ = Ort::UnownedAllocator(custom_allocator_); - - // free dimensions are treated as 1 if not overridden - transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; - new_value = [](OrtAllocator* allocator, const std::vector& output_shape, Ort::ConstTensorTypeAndShapeInfo& tensor_info) { - return Ort::Value::CreateTensor(allocator, output_shape.data(), output_shape.size(), tensor_info.GetElementType()); - }; } - for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { - Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); - auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - std::vector output_shape = tensor_info.GetShape(); - std::transform(output_shape.begin(), output_shape.end(), output_shape.begin(), transform_fcn); - outputs_.emplace_back(new_value(allocator_, output_shape, tensor_info)); + Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + auto is_dynamic = std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end(); + if (is_dynamic || device_memory_name_.empty()) { + outputs_.emplace_back(Ort::Value(nullptr)); + } else { + auto new_value = Ort::Value::CreateTensor(allocator_, output_shape.data(), output_shape.size(), tensor_info.GetElementType()); + outputs_.emplace_back(std::move(new_value)); + } } }