-
Notifications
You must be signed in to change notification settings - Fork 243
[5725362] AutoCast Fixes for models with external data #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
d4e15ed
6532522
7a2d91a
581d686
cd96fa5
00ea80c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,11 +24,13 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import copy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import io | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tempfile | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import OrderedDict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import onnx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.onnx import utils as onnx_utils | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.onnx.autocast.logging_config import configure_logging, logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from modelopt.onnx.quantization.ort_utils import _prepare_ep_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -118,13 +120,65 @@ def _load_inputs(self, inputs): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return data_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_ort_runner(self, model): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import onnxruntime as ort | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from polygraphy.backend.onnx import BytesFromOnnx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if model has external data by checking: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 2. If model size would exceed 2GB (indicating need for external data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_external_data = any( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for init in self.model.graph.initializer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def has_external_data(onnx_model_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane I'm not sure I want to introduce dependencies from modelopt.torch here just for this.
Since modelopt/torch/_deploy/utils/torch_onnx.py is already importing quite a few utils from modelopt.onnx.utils, how about I move this function to modelopt.onnx.utils and import it in modelopt.torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane please revisit - since I edited torch utils, I now need a review from modelopt-torch-deploy-codeowners 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model_size = len(self.model.SerializeToString()) | |
| model_size = model.ByteSize() |
galagam marked this conversation as resolved.
Show resolved
Hide resolved
galagam marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI Polygraphy's SaveOnnx can handle models with external data. Also SessionFromOnnx can accept paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @pranavm-nvidia . I'll take a look and see if I can refactor.
If it's an quick fix for you - feel free to push a commit to this PR, and I'll review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a suggestion with the change. One thing I'm not sure about - does your onnx_utils.save_onnx do anything special besides saving the model? On quick inspection, it seems like it's also setting a custom IR version. If that's still required, you'll probably need to add a line like:
modified_model.ir_version = 10prior to calling Polygraphy's save_onnx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if has_external_data: | |
| logger.debug("Model has external data, using file-based approach") | |
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | |
| modified_model = model() | |
| # Use a persistent temp file to handle external data files properly | |
| tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) | |
| tmp_file.close() | |
| tmp_file_path = tmp_file.name | |
| onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) | |
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | |
| session = ort.InferenceSession(tmp_file_path, providers=self.providers) | |
| runners = [OnnxrtRunner(lambda: session)] | |
| else: | |
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | |
| logger.debug("Model has no external data, using BytesFromOnnx approach") | |
| serialize_onnx = BytesFromOnnx(model) | |
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | |
| runners = [OnnxrtRunner(build_onnxrt_session)] | |
| if has_external_data: | |
| logger.debug("Model has external data, using file-based approach") | |
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | |
| modified_model = model() | |
| # Use a persistent temp file to handle external data files properly | |
| outdir = tempfile.TemporaryDirectory() | |
| tmp_file_path = os.path.join(outdir.name, "tmp_model.onnx") | |
| save_onnx(modified_model, tmp_file_path, external_data_path="ext.data") | |
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | |
| build_onnxrt_session = SessionFromOnnx(tmp_file_path, providers=self.providers) | |
| else: | |
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | |
| logger.debug("Model has no external data, using BytesFromOnnx approach") | |
| serialize_onnx = BytesFromOnnx(model) | |
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | |
| runners = [OnnxrtRunner(build_onnxrt_session)] |
Uh oh!
There was an error while loading. Please reload this page.