From 161eb1386340c37b814672308ae73f6c06678f4d Mon Sep 17 00:00:00 2001 From: amancini-N Date: Tue, 19 Nov 2024 12:26:35 +0000 Subject: [PATCH] #22890 Fix profiling on empty Optional --- .../core/framework/sequential_executor.cc | 4 +- .../test/framework/inference_session_test.cc | 41 ++++++++++++++++++ .../test/testdata/relu_with_optional.onnx | Bin 0 -> 722 bytes 3 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/relu_with_optional.onnx diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 6ea12c7f3336b..2185b8332b9cf 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -68,7 +68,7 @@ static void CalculateTotalOutputSizes(OpKernelContextInternal* op_kernel_context int output_count = op_kernel_context->OutputCount(); for (auto i = 0; i < output_count; i++) { const OrtValue* p_output = op_kernel_context->GetOutputMLValue(i); - if (p_output != nullptr && p_output->IsTensor()) { + if (p_output != nullptr && p_output->IsTensor() && p_output->IsAllocated()) { const auto& tensor = p_output->Get(); size_t tensor_size = tensor.SizeInBytes(); #if defined(TRACE_EXECUTION) @@ -104,7 +104,7 @@ static void CalculateTotalInputSizes(const OpKernelContextInternal* op_kernel_co const int input_count = op_kernel_context->InputCount(); for (auto i = 0; i < input_count; i++) { const OrtValue* p_input = op_kernel_context->GetInputMLValue(i); - if (p_input != nullptr && p_input->IsTensor()) { + if (p_input != nullptr && p_input->IsTensor() && p_input->IsAllocated()) { const OpKernelInfo& op_kernel_info = p_op_kernel->Info(); const Tensor* p_tensor = nullptr; bool is_param = op_kernel_info.TryGetConstantInput(i, &p_tensor); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index c6a81e8a1c1ad..7f4616c964e33 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -818,6 +818,47 @@ TEST(InferenceSessionTests, CheckRunProfilerStartTime) { ASSERT_TRUE(before_start_time <= profiling_start_time && profiling_start_time <= after_start_time); } +TEST(InferenceSessionTests, CheckRunProfilerWithOptionalValues) { + // Test whether the profiler can work on model with optional values + SessionOptions so; + + so.session_logid = "CheckRunProfiler"; + so.enable_profiling = true; + so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); + + InferenceSession session_object(so, GetEnvironment()); + ASSERT_STATUS_OK(session_object.Load(ORT_TSTR("testdata/relu_with_optional.onnx"))); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "RunTag"; + + // prepare inputs + std::vector dims_x = {1}; + std::vector values_x = {-4}; + OrtValue ml_value; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_x, values_x, &ml_value); + NameMLValMap feeds; + feeds.insert(std::make_pair("input", ml_value)); + + // prepare outputs + std::vector output_names; + output_names.push_back("output"); + std::vector fetches; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {1}; + std::vector expected_values_y = {0}; + + // Now run + common::Status st = session_object.Run(run_options, feeds, output_names, &fetches); + if (!st.IsOK()) { + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + } + ASSERT_TRUE(st.IsOK()); + VerifyOutputs(fetches.at(0).Get(), expected_dims_y, expected_values_y); +} + TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { SessionOptions session_options; diff --git a/onnxruntime/test/testdata/relu_with_optional.onnx b/onnxruntime/test/testdata/relu_with_optional.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b52c6927527bda7a5f183120fc3f3baed5e30088 GIT binary patch literal 722 zcmZva-%i3X6vi2JI1hgURWF2_Wj7`hZ@4rjCK!WoV=!JAW3qHupvl%QYlp$7;bnXb zAHjEU8;mjWuKmu}p7ZNh1Ntvd)Qe2!2Ojk)S41(v%}6pnHy?v!L%iN@^+4qP#+PeD zs&?9FW{QA@M2S#d{~)-;=#z%RGVTPIig0D0gT5qQMD2#5wW_vO8l;Z3hTsNt#)C*f zs^B)0PP%sjA(>)J+BuZ0qhUyiq9J^1a9GB5#^a(--&t(3%A=V*g&OrFb=;8f>7qY@ z9zLdipjM6ulx1Yh_8B0xidg^*qlyuf4Q;J8FiH-)UC7Vi=C}?8mu^p>rWTZEO>UvP zWy40?#rX(Hkhj7p@wy$@)&98>E}+1ID+?kmoQ<0Y178{>HCz>uieEE@ek($c;R@^1 zB$B;kI$1K!ESZfZj%G`J0viSYp?xmW4i0CXQNvf3#{nOSO5S3 literal 0 HcmV?d00001