Skip to content

Commit 7a2d91a

Browse files
committed
fix check_model for external data
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 6532522 commit 7a2d91a

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

modelopt/onnx/utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Utility functions related to onnx."""
1717

18+
import copy
1819
import io
1920
import os
2021
import tempfile
@@ -555,14 +556,13 @@ def _get_unique_name(old_name):
555556
def check_model(model: onnx.ModelProto) -> None:
556557
"""Checks if the given model is valid."""
557558
if model.ByteSize() > (2 * (1024**3)): # 2GB limit
558-
logger.warning("Model exceeds 2GB limit, skipping check_model")
559-
# with tempfile.TemporaryDirectory() as temp_dir:
560-
# # ONNX also looks in CWD, so we need to use a unique id
561-
# unique_id = str(uuid.uuid4())[:8]
562-
# onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
563-
# save_onnx(model, onnx_tmp_path, save_as_external_data=True)
564-
# onnx.checker.check_model(onnx_tmp_path)
565-
559+
with tempfile.TemporaryDirectory() as temp_dir:
560+
# ONNX also looks in CWD, so we need to use a unique id
561+
unique_id = str(uuid.uuid4())[:8]
562+
onnx_tmp_path = os.path.join(temp_dir, f"model_{unique_id}.onnx")
563+
model_copy = copy.deepcopy(model)
564+
save_onnx(model_copy, onnx_tmp_path, save_as_external_data=True)
565+
onnx.checker.check_model(onnx_tmp_path)
566566
else:
567567
onnx.checker.check_model(model)
568568

0 commit comments

Comments
 (0)