diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index ef5cd49334133..804f4557fd321 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1726,7 +1726,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #if !defined(ORT_MINIMAL_BUILD) // Build and verify node connection (edges). // Verify NodeArg name/type/shape matching correctly. - common::Status BuildConnections(std::unordered_set& outer_scope_node_args_consumed); + common::Status BuildConnections(std::unordered_set& outer_scope_node_args_consumed, + bool& removed_node_with_subgraph); common::Status VerifyNoDuplicateName(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e9bc83c25ff4f..dd3eb59b7fafb 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1344,7 +1344,12 @@ Graph::Graph(const Model& owning_model, ORT_THROW("This is an invalid model. Tensor does not have type information."); } - if (utils::HasDataType(tensor) && (tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) { + // all initializers must have a valid data type + if (!utils::HasDataType(tensor) || !tensor.DataType_IsValid(tensor.data_type())) { + ORT_THROW("This is an invalid model. Tensor '", tensor.name(), "' does not have valid data type."); + } + + if ((tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) { weight_data_type_freq_[tensor.data_type()]++; } @@ -1669,13 +1674,14 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s #if !defined(ORT_MINIMAL_BUILD) GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint -Status Graph::BuildConnections(std::unordered_set& outer_scope_node_args_consumed) { +Status Graph::BuildConnections(std::unordered_set& outer_scope_node_args_consumed, + bool& removed_node_with_subgraph) { // recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs if (!resolve_context_.nodes_with_subgraphs.empty()) { for (auto* node : resolve_context_.nodes_with_subgraphs) { for (auto& subgraph : node->MutableSubgraphs()) { std::unordered_set node_args_consumed; - ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed)); + ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed, removed_node_with_subgraph)); for (auto& node_arg_name : node_args_consumed) { auto node_arg = GetNodeArg(node_arg_name); @@ -1805,6 +1811,10 @@ Status Graph::BuildConnections(std::unordered_set& outer_scope_node } else if (node.OutputDefs().empty()) { // This is a useless node. // It has no input/output. + if (node.ContainsSubgraph()) { + removed_node_with_subgraph = true; + } + RemoveNode(node.Index()); } } @@ -3683,10 +3693,17 @@ Status Graph::Resolve(const ResolveOptions& options) { std::unordered_set outer_scope_node_args_consumed; // recursively build connections between nodes in this graph and all subgraphs - ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed)); + bool removed_node_with_subgraph = false; + ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed, removed_node_with_subgraph)); ORT_ENFORCE(outer_scope_node_args_consumed.empty(), "Shouldn't be possible to have NodeArgs that haven't been handled already."); + // if we removed any nodes with subgraphs, we need to refresh the list of subgraphs. + if (removed_node_with_subgraph) { + all_subgraphs.clear(); + FindAllSubgraphs(all_subgraphs); + } + // topological sort of this and any subgraphs is non-recursive auto topo_sort_func = [](Graph& graph) { return graph.PerformTopologicalSortAndCheckIsAcyclic(); }; ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, topo_sort_func)); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 612497ccfd845..500d44074876b 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -1257,12 +1257,8 @@ common::Status CreateCustomRegistry(gsl::span op_domai std::unordered_map> domain_kernels; for (const auto* op : domain->custom_ops_) { // define kernel - auto it = domain_kernels.find(op->GetName(op)); - if (it == domain_kernels.end()) { - domain_kernels[op->GetName(op)] = {op}; - } else { - domain_kernels[op->GetName(op)].push_back(op); - } + const auto* name = op->GetName(op); + domain_kernels[name].push_back(op); } // Creation of the schemas, one per unique name. @@ -1276,7 +1272,8 @@ common::Status CreateCustomRegistry(gsl::span op_domai for (const auto* op : ops) { // define kernel auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op); - kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get()); + const auto* op_name = op->GetName(op); + kernel_def_map[op_name].push_back(kernel_create_info.kernel_def.get()); ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info)); // If IsCompatible returns false, then all custom operators named // 'op->GetName(op)' are not compatible among themselves. diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 8b66009c0c72f..07f2cc8581ed5 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -255,8 +255,7 @@ void RunModel(InferenceSession& session_object, if (is_preallocate_output_vec) { fetches.resize(output_names.size()); for (auto& elem : fetches) { - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, - &elem); + AllocateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x, &elem); } } @@ -3069,5 +3068,26 @@ TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) { } #endif +TEST(InferenceSessionTests, BadDataTypeInInitializerIsHandled) { + // model has an initializer with a bogus data type. Graph ctor should detect and throw. + auto model_uri = ORT_TSTR("testdata/icm-31000000518082.onnx"); + + SessionOptions so; + so.session_logid = "TempTest.LoadModel"; + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Load(model_uri), "does not have valid data type"); +} + +TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) { + // model has a subgraph with output that is not consumed. the node with the subgraph should get removed in + // Graph::BuildConnections and Graph::Resolve should adjust its list of subgraphs to not access the removed subgraph. + auto model_uri = ORT_TSTR("testdata/icm-31000000518483.onnx"); + + SessionOptions so; + so.session_logid = "TempTest.LoadModel"; + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/icm-31000000518082.onnx b/onnxruntime/test/testdata/icm-31000000518082.onnx new file mode 100644 index 0000000000000..6578223ea8b3b Binary files /dev/null and b/onnxruntime/test/testdata/icm-31000000518082.onnx differ diff --git a/onnxruntime/test/testdata/icm-31000000518483.onnx b/onnxruntime/test/testdata/icm-31000000518483.onnx new file mode 100644 index 0000000000000..9c5caa79bdb6a Binary files /dev/null and b/onnxruntime/test/testdata/icm-31000000518483.onnx differ diff --git a/onnxruntime/test/testdata/icm-31000000518483.py b/onnxruntime/test/testdata/icm-31000000518483.py new file mode 100644 index 0000000000000..38ff5926693cd --- /dev/null +++ b/onnxruntime/test/testdata/icm-31000000518483.py @@ -0,0 +1,53 @@ +from onnx import TensorProto, helper, save_model + +# Add node with a subgraph that has no inputs or outputs. +# Graph::BuildConnections should remove and the list of subgraphs in Graph::Resolve should be updated. +# Other details here don't matter. Copied from ort_github_issue_10305.py +if_body = helper.make_graph( + [ + # need to use main_graph_initializer in a way that can't be constant folded + helper.make_node("Constant", inputs=[], outputs=["zero"], name="Constant", value_int=0), + ], + "if_branch_body", + [ + # no explicit inputs + ], + [ + helper.make_tensor_value_info("zero", TensorProto.BOOL, [1]), + ], +) + +# Create the main graph +graph_proto = helper.make_graph( + [ + # add a Transpose that can be moved past the Slice + helper.make_node( + "Transpose", + inputs=["input:0"], + outputs=["transpose:0"], + name="transpose0", + perm=[1, 0, 2], + ), + helper.make_node( + "If", + [], + [], + "If1", + then_branch=if_body, + else_branch=if_body, + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [2, 2, 2]), + helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]), + ], + [ + helper.make_tensor_value_info("transpose:0", TensorProto.FLOAT, [2, 2]), + ], +) + +model = helper.make_model(graph_proto) +# model to repro issue is invalid. can't run checker. +# onnx.checker.check_model(model, True) +save_model(model, "icm-31000000518483.onnx")