Skip to content

ConstantFolder fails shape inference on second iteration #1471

@thiagocrepaldi

Description

@thiagocrepaldi

During ONNX export process, onnxscript.optimizer.optimize is called, which runs 2 times applying all optimizations.
When the optimization modified = fold_constants(model, external_data_folder, onnx_shape_inference=onnx_shape_inference) is executed TWICE, the following shape inference error is raised

Traceback (most recent call last):
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
    ).export()
      ^^^^^^^^
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1280, in export
    onnx_model = optimizer.optimize(onnx_model)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/onnxscript/onnxscript/optimizer/__init__.py", line 61, in optimize
    model = onnx.shape_inference.infer_shapes(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/onnx/shape_inference.py", line 46, in infer_shapes
    inferred_model_str = C.infer_shapes(
                         ^^^^^^^^^^^^^^^
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:MatMul, node name: torch_nn_modules_linear_Linear_fc1_1_0_aten_mm_1_n0): B has inconsistent type tensor(float)

Repro option 1:

class MyLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5, bias=False, dtype=torch.bfloat16)

    def forward(self, tensor_x: torch.Tensor):
        return self.fc1(tensor_x)

model = MyLinear()
tensor_x = torch.rand((3, 10), dtype=torch.bfloat16)
onnx_program = torch.onnx.dynamo_export(model, tensor_x)

Repro option 2:

use the onnx model attached to the issue and run the pass manually:

onnx_model.zip

import onnx
import onnxscript
model = onnx.load("/path/to/model.onnx")
_ = onnxscript.optimizer.fold_constants(model,"/path/to
/", onnx_shape_inference=True)
model = onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True, data_prop=True)
_ = onnxscript.optimizer.fold_constants(model,"/opt/dev/no_op/", onnx_shape_inference=True)
model = onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True, data_prop=True)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions