Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
#if !defined(ORT_MINIMAL_BUILD)
// Build and verify node connection (edges).
// Verify NodeArg name/type/shape matching correctly.
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed,
bool& removed_node_with_subgraph);

common::Status VerifyNoDuplicateName();

Expand Down
25 changes: 21 additions & 4 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,12 @@ Graph::Graph(const Model& owning_model,
ORT_THROW("This is an invalid model. Tensor does not have type information.");
}

if (utils::HasDataType(tensor) && (tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) {
// all initializers must have a valid data type
if (!utils::HasDataType(tensor) || !tensor.DataType_IsValid(tensor.data_type())) {
ORT_THROW("This is an invalid model. Tensor '", tensor.name(), "' does not have valid data type.");
}

if ((tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) {
weight_data_type_freq_[tensor.data_type()]++;
}

Expand Down Expand Up @@ -1669,13 +1674,14 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s

#if !defined(ORT_MINIMAL_BUILD)
GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed,
bool& removed_node_with_subgraph) {
// recurse into subgraphs first so we can update any nodes in this graph that are used by those subgraphs
if (!resolve_context_.nodes_with_subgraphs.empty()) {
for (auto* node : resolve_context_.nodes_with_subgraphs) {
for (auto& subgraph : node->MutableSubgraphs()) {
std::unordered_set<std::string> node_args_consumed;
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed));
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed, removed_node_with_subgraph));

for (auto& node_arg_name : node_args_consumed) {
auto node_arg = GetNodeArg(node_arg_name);
Expand Down Expand Up @@ -1805,6 +1811,10 @@ Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node
} else if (node.OutputDefs().empty()) {
// This is a useless node.
// It has no input/output.
if (node.ContainsSubgraph()) {
removed_node_with_subgraph = true;
}

RemoveNode(node.Index());
}
}
Expand Down Expand Up @@ -3683,10 +3693,17 @@ Status Graph::Resolve(const ResolveOptions& options) {
std::unordered_set<std::string> outer_scope_node_args_consumed;

// recursively build connections between nodes in this graph and all subgraphs
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed));
bool removed_node_with_subgraph = false;
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed, removed_node_with_subgraph));
ORT_ENFORCE(outer_scope_node_args_consumed.empty(),
"Shouldn't be possible to have NodeArgs that haven't been handled already.");

// if we removed any nodes with subgraphs, we need to refresh the list of subgraphs.
if (removed_node_with_subgraph) {
all_subgraphs.clear();
FindAllSubgraphs(all_subgraphs);
}

// topological sort of this and any subgraphs is non-recursive
auto topo_sort_func = [](Graph& graph) { return graph.PerformTopologicalSortAndCheckIsAcyclic(); };
ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, topo_sort_func));
Expand Down
11 changes: 4 additions & 7 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1257,12 +1257,8 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
std::unordered_map<std::string, std::vector<const OrtCustomOp*>> domain_kernels;
for (const auto* op : domain->custom_ops_) {
// define kernel
auto it = domain_kernels.find(op->GetName(op));
if (it == domain_kernels.end()) {
domain_kernels[op->GetName(op)] = {op};
} else {
domain_kernels[op->GetName(op)].push_back(op);
}
const auto* name = op->GetName(op);
domain_kernels[name].push_back(op);
}

// Creation of the schemas, one per unique name.
Expand All @@ -1276,7 +1272,8 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
for (const auto* op : ops) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
const auto* op_name = op->GetName(op);
kernel_def_map[op_name].push_back(kernel_create_info.kernel_def.get());
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
// If IsCompatible returns false, then all custom operators named
// 'op->GetName(op)' are not compatible among themselves.
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3069,5 +3069,26 @@ TEST(InferenceSessionTests, InterThreadPoolWithDenormalAsZero) {
}
#endif

TEST(InferenceSessionTests, BadDataTypeInInitializerIsHandled) {
// model has an initializer with a bogus data type. Graph ctor should detect and throw.
auto model_uri = ORT_TSTR("testdata/msrc-31000000518082.onnx");

SessionOptions so;
so.session_logid = "TempTest.LoadModel";
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Load(model_uri), "does not have valid data type");
}

TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) {
// model has a subgraph with output that is not consumed. the node with the subgraph should get removed in
// Graph::BuildConnections and Graph::Resolve should adjust its list of subgraphs to not access the removed subgraph.
auto model_uri = ORT_TSTR("testdata/msrc-31000000518483.onnx");

SessionOptions so;
so.session_logid = "TempTest.LoadModel";
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
}

} // namespace test
} // namespace onnxruntime
Binary file added onnxruntime/test/testdata/msrc-31000000518082.onnx
Binary file not shown.
Binary file added onnxruntime/test/testdata/msrc-31000000518483.onnx
Binary file not shown.
53 changes: 53 additions & 0 deletions onnxruntime/test/testdata/msrc-31000000518483.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import onnx
from onnx import TensorProto, helper

# Add node with a subgraph that has no inputs or outputs.
# Graph::BuildConnections should remove and the list of subgraphs in Graph::Resolve should be updated.
# Other details here don't matter. Copied from ort_github_issue_10305.py
if_body = helper.make_graph(
[
# need to use main_graph_initializer in a way that can't be constant folded
helper.make_node("Constant", inputs=[], outputs=["zero"], name="Constant", value_int=0),
],
"if_branch_body",
[
# no explicit inputs
],
[
helper.make_tensor_value_info("zero", TensorProto.BOOL, [1]),
],
)

# Create the main graph
graph_proto = helper.make_graph(
[
# add a Transpose that can be moved past the Slice
helper.make_node(
"Transpose",
inputs=["input:0"],
outputs=["transpose:0"],
name="transpose0",
perm=[1, 0, 2],
),
helper.make_node(
"If",
[],
[],
"If1",
then_branch=if_body,
else_branch=if_body,
),
],
"Main_graph",
[
helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [2, 2, 2]),
helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]),
],
[
helper.make_tensor_value_info("transpose:0", TensorProto.FLOAT, [2, 2]),
],
)

model = helper.make_model(graph_proto)
# onnx.checker.check_model(model, True)
onnx.save(model, "msrc-31000000518483.onnx")
Loading