diff --git a/onnx/defs/data_propagators.h b/onnx/defs/data_propagators.h index f87379b44f8..8948953708b 100644 --- a/onnx/defs/data_propagators.h +++ b/onnx/defs/data_propagators.h @@ -33,18 +33,21 @@ inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) { } } int axis = static_cast(axisAttr->i()); - auto input_data_0 = ctx.getInputData(0); - if (input_data_0 == nullptr) { + if (axis >= 0) { + return axis == 0; + } + // For negative axes, we need rank information to determine if it is equivalent to axis 0 + const TypeProto* type = ctx.getInputType(0); + if ((type == nullptr) || (!type->has_tensor_type()) || (!type->tensor_type().has_shape())) { return false; } - int rank = input_data_0->dim_size(); + + int rank = type->tensor_type().shape().dim_size(); if (axis < -rank || axis >= rank) { fail_shape_inference("axis must be in [-rank, rank-1]."); return false; } - if (axis < 0) { - axis += rank; - } + axis += rank; // Only supports axis = 0 since the data comes from Shape return axis == 0; } diff --git a/onnx/test/data_propagation_test.py b/onnx/test/data_propagation_test.py index 83f7fa7396b..3638c12598b 100644 --- a/onnx/test/data_propagation_test.py +++ b/onnx/test/data_propagation_test.py @@ -212,6 +212,40 @@ def test_shape_arithmetic_with_zero_broadcast(self) -> None: data_prop=True, ) # type: ignore + def test_empty_tensor(self) -> None: + """Test that a Concat with an empty tensor as input is handled correctly by data-propagation.""" + model = onnx.parser.parse_model( + """ + + agraph (float[256] y) => (float[N] z) + + { + z = Concat (x, y) + } + """ + ) + inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True) + output = inferred_model.graph.output[0] + self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256) + + def test_empty_tensor_negative_axis(self) -> None: + """Test that a Concat with an empty tensor as input is handled correctly by data-propagation. + This time with a negative axis. + """ + model = onnx.parser.parse_model( + """ + + agraph (float[256] y) => (float[N] z) + + { + z = Concat (x, y) + } + """ + ) + inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True) + output = inferred_model.graph.output[0] + self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256) + if __name__ == "__main__": unittest.main()