Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
#if !defined(ORT_MINIMAL_BUILD)
// Build and verify node connection (edges).
// Verify NodeArg name/type/shape matching correctly.
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed,
bool& removed_node_with_subgraph);

common::Status VerifyNoDuplicateName();

Expand Down
25 changes: 21 additions & 4 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,12 @@ Graph::Graph(const Model& owning_model,
ORT_THROW("This is an invalid model. Tensor does not have type information.");
}

if (utils::HasDataType(tensor) && (tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) {
// all initializers must have a valid data type
if (!utils::HasDataType(tensor) || !tensor.DataType_IsValid(tensor.data_type())) {
ORT_THROW("This is an invalid model. Tensor '", tensor.name(), "' does not have valid data type.");
}

if ((tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) {
weight_data_type_freq_[tensor.data_type()]++;
}

Expand Down Expand Up @@ -1669,13 +1674,14 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s

#if !defined(ORT_MINIMAL_BUILD)
GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed,
bool& removed_node_with_subgraph) {
// recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs
if (!resolve_context_.nodes_with_subgraphs.empty()) {
for (auto* node : resolve_context_.nodes_with_subgraphs) {
for (auto& subgraph : node->MutableSubgraphs()) {
std::unordered_set<std::string> node_args_consumed;
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed));
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed, removed_node_with_subgraph));

for (auto& node_arg_name : node_args_consumed) {
auto node_arg = GetNodeArg(node_arg_name);
Expand Down Expand Up @@ -1805,6 +1811,10 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
} else if (node.OutputDefs().empty()) {
// This is a useless node.
// It has no input/output.
if (node.ContainsSubgraph()) {
removed_node_with_subgraph = true;
}

RemoveNode(node.Index());
}
}
Expand Down Expand Up @@ -3683,10 +3693,17 @@ Status Graph::Resolve(const ResolveOptions& options) {
std::unordered_set<std::string> outer_scope_node_args_consumed;

// recursively build connections between nodes in this graph and all subgraphs
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed));
bool removed_node_with_subgraph = false;
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed, removed_node_with_subgraph));
ORT_ENFORCE(outer_scope_node_args_consumed.empty(),
"Shouldn't be possible to have NodeArgs that haven't been handled already.");

// if we removed any nodes with subgraphs, we need to refresh the list of subgraphs.
if (removed_node_with_subgraph) {
all_subgraphs.clear();
FindAllSubgraphs(all_subgraphs);
}

// topological sort of this and any subgraphs is non-recursive
auto topo_sort_func = [](Graph& graph) { return graph.PerformTopologicalSortAndCheckIsAcyclic(); };
ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, topo_sort_func));
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1257,12 +1257,8 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
std::unordered_map<std::string, std::vector<const OrtCustomOp*>> domain_kernels;
for (const auto* op : domain->custom_ops_) {
// define kernel
auto it = domain_kernels.find(op->GetName(op));
if (it == domain_kernels.end()) {
domain_kernels[op->GetName(op)] = {op};
} else {
domain_kernels[op->GetName(op)].push_back(op);
}
const auto* name = op->GetName(op);
domain_kernels[name].push_back(op);
}

// Creation of the schemas, one per unique name.
Expand All @@ -1276,7 +1272,8 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
for (const auto* op : ops) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
const auto* op_name = op->GetName(op);
kernel_def_map[op_name].push_back(kernel_create_info.kernel_def.get());
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
// If IsCompatible returns false, then all custom operators named
// 'op->GetName(op)' are not compatible among themselves.
Expand Down
24 changes: 22 additions & 2 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ void RunModel(InferenceSession& session_object,
if (is_preallocate_output_vec) {
fetches.resize(output_names.size());
for (auto& elem : fetches) {
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x,
&elem);
AllocateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x, &elem);
}
}

Expand Down Expand Up @@ -3069,5 +3068,26 @@ TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) {
}
#endif

TEST(InferenceSessionTests, BadDataTypeInInitializerIsHandled) {
// model has an initializer with a bogus data type. Graph ctor should detect and throw.
auto model_uri = ORT_TSTR("testdata/icm-31000000518082.onnx");

SessionOptions so;
so.session_logid = "TempTest.LoadModel";
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Load(model_uri), "does not have valid data type");
}

TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) {
// model has a subgraph with output that is not consumed. the node with the subgraph should get removed in
// Graph::BuildConnections and Graph::Resolve should adjust its list of subgraphs to not access the removed subgraph.
auto model_uri = ORT_TSTR("testdata/icm-31000000518483.onnx");

SessionOptions so;
so.session_logid = "TempTest.LoadModel";
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
}

} // namespace test
} // namespace onnxruntime
Binary file added onnxruntime/test/testdata/icm-31000000518082.onnx
Binary file not shown.
Binary file added onnxruntime/test/testdata/icm-31000000518483.onnx
Binary file not shown.
53 changes: 53 additions & 0 deletions onnxruntime/test/testdata/icm-31000000518483.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from onnx import TensorProto, helper, save_model

# Add node with a subgraph that has no inputs or outputs.
# Graph::BuildConnections should remove and the list of subgraphs in Graph::Resolve should be updated.
# Other details here don't matter. Copied from ort_github_issue_10305.py
if_body = helper.make_graph(
[
# need to use main_graph_initializer in a way that can't be constant folded
helper.make_node("Constant", inputs=[], outputs=["zero"], name="Constant", value_int=0),
],
"if_branch_body",
[
# no explicit inputs
],
[
helper.make_tensor_value_info("zero", TensorProto.BOOL, [1]),
],
)

# Create the main graph
graph_proto = helper.make_graph(
[
# add a Transpose that can be moved past the Slice
helper.make_node(
"Transpose",
inputs=["input:0"],
outputs=["transpose:0"],
name="transpose0",
perm=[1, 0, 2],
),
helper.make_node(
"If",
[],
[],
"If1",
then_branch=if_body,
else_branch=if_body,
),
],
"Main_graph",
[
helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [2, 2, 2]),
helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]),
],
[
helper.make_tensor_value_info("transpose:0", TensorProto.FLOAT, [2, 2]),
],
)

model = helper.make_model(graph_proto)
# model to repro issue is invalid. can't run checker.
# onnx.checker.check_model(model, True)
save_model(model, "icm-31000000518483.onnx")
Loading