Skip to content

Commit

Permalink
Fix handling of tensor rank in concat data propagation (onnx#6570)
Browse files Browse the repository at this point in the history
### Description
Fix issue onnx#6276 (data propagation
fails on Concat when first tensor is initialized empty).

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam authored Dec 11, 2024
1 parent 96a0ca4 commit 25a134a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
15 changes: 9 additions & 6 deletions onnx/defs/data_propagators.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
}
}
int axis = static_cast<int>(axisAttr->i());
auto input_data_0 = ctx.getInputData(0);
if (input_data_0 == nullptr) {
if (axis >= 0) {
return axis == 0;
}
// For negative axes, we need rank information to determine if it is equivalent to axis 0
const TypeProto* type = ctx.getInputType(0);
if ((type == nullptr) || (!type->has_tensor_type()) || (!type->tensor_type().has_shape())) {
return false;
}
int rank = input_data_0->dim_size();

int rank = type->tensor_type().shape().dim_size();
if (axis < -rank || axis >= rank) {
fail_shape_inference("axis must be in [-rank, rank-1].");
return false;
}
if (axis < 0) {
axis += rank;
}
axis += rank;
// Only supports axis = 0 since the data comes from Shape
return axis == 0;
}
Expand Down
34 changes: 34 additions & 0 deletions onnx/test/data_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,40 @@ def test_shape_arithmetic_with_zero_broadcast(self) -> None:
data_prop=True,
) # type: ignore

def test_empty_tensor(self) -> None:
"""Test that a Concat with an empty tensor as input is handled correctly by data-propagation."""
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[256] y) => (float[N] z)
<float[0] x = {}>
{
z = Concat <axis=0> (x, y)
}
"""
)
inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True)
output = inferred_model.graph.output[0]
self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256)

def test_empty_tensor_negative_axis(self) -> None:
"""Test that a Concat with an empty tensor as input is handled correctly by data-propagation.
This time with a negative axis.
"""
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[256] y) => (float[N] z)
<float[0] x = {}>
{
z = Concat <axis=-1> (x, y)
}
"""
)
inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True)
output = inferred_model.graph.output[0]
self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256)


if __name__ == "__main__":
unittest.main()

0 comments on commit 25a134a

Please sign in to comment.