-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Labels
Description
During ONNX export process, onnxscript.optimizer.optimize
is called, which runs 2 times applying all optimizations.
When the optimization modified = fold_constants(model, external_data_folder, onnx_shape_inference=onnx_shape_inference)
is executed TWICE, the following shape inference error is raised
Traceback (most recent call last):
File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
).export()
^^^^^^^^
File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1280, in export
onnx_model = optimizer.optimize(onnx_model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/onnxscript/onnxscript/optimizer/__init__.py", line 61, in optimize
model = onnx.shape_inference.infer_shapes(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/ptca/lib/python3.11/site-packages/onnx/shape_inference.py", line 46, in infer_shapes
inferred_model_str = C.infer_shapes(
^^^^^^^^^^^^^^^
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:MatMul, node name: torch_nn_modules_linear_Linear_fc1_1_0_aten_mm_1_n0): B has inconsistent type tensor(float)
Repro option 1:
class MyLinear(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5, bias=False, dtype=torch.bfloat16)
def forward(self, tensor_x: torch.Tensor):
return self.fc1(tensor_x)
model = MyLinear()
tensor_x = torch.rand((3, 10), dtype=torch.bfloat16)
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
Repro option 2:
use the onnx model attached to the issue and run the pass manually:
import onnx
import onnxscript
model = onnx.load("/path/to/model.onnx")
_ = onnxscript.optimizer.fold_constants(model,"/path/to
/", onnx_shape_inference=True)
model = onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True, data_prop=True)
_ = onnxscript.optimizer.fold_constants(model,"/opt/dev/no_op/", onnx_shape_inference=True)
model = onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True, data_prop=True)
justinchuby