-
Notifications
You must be signed in to change notification settings - Fork 237
[5725362] AutoCast Fixes for models with external data #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
galagam
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
galagam:dev-gagam-narrowing-error
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,11 +24,13 @@ | |
| import copy | ||
| import io | ||
| import sys | ||
| import tempfile | ||
| from collections import OrderedDict | ||
|
|
||
| import numpy as np | ||
| import onnx | ||
|
|
||
| from modelopt.onnx import utils as onnx_utils | ||
| from modelopt.onnx.autocast.logging_config import configure_logging, logger | ||
| from modelopt.onnx.quantization.ort_utils import _prepare_ep_list | ||
|
|
||
|
|
@@ -118,13 +120,65 @@ def _load_inputs(self, inputs): | |
|
|
||
| return data_loader | ||
|
|
||
| def _get_ort_runner(self, model): | ||
| import onnxruntime as ort | ||
| from polygraphy.backend.onnx import BytesFromOnnx | ||
| from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx | ||
|
|
||
| # Check if model has external data by checking: | ||
| # 1. If any initializer has data_location set to EXTERNAL (even if data is loaded) | ||
| # 2. If model size would exceed 2GB (indicating need for external data) | ||
| has_external_data = any( | ||
| init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL | ||
| for init in self.model.graph.initializer | ||
| ) | ||
|
|
||
| # Also check if model would be too large (>2GB) for SerializeToString | ||
| # This handles cases where model was loaded with external data already loaded | ||
| if not has_external_data: | ||
| try: | ||
| # Try to estimate size by serializing the model | ||
| # If it fails or exceeds 2GB, we need file-based approach | ||
| model_size = len(self.model.SerializeToString()) | ||
| if model_size > 2 * (1024**3): # 2GB threshold | ||
| has_external_data = True | ||
| logger.debug( | ||
| f"Model size ({model_size / (1024**3):.2f} GB) exceeds 2GB, using file-based approach" | ||
| ) | ||
| except (ValueError, AttributeError) as e: | ||
| # SerializeToString failed (likely >2GB limit), use file-based approach | ||
| if "exceeds maximum protobuf size" in str(e) or "2GB" in str(e): | ||
| has_external_data = True | ||
| logger.debug("Model exceeds protobuf 2GB limit, using file-based approach") | ||
|
|
||
| if has_external_data: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we reuse this logic to save the model depending on wether or not it has external data? |
||
| logger.debug("Model has external data, using file-based approach") | ||
| # Get the actual ONNX ModelProto from ModifyOutputs wrapper | ||
| modified_model = model() | ||
|
|
||
| # Use a persistent temp file to handle external data files properly | ||
| tmp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) | ||
| tmp_file.close() | ||
| tmp_file_path = tmp_file.name | ||
| onnx_utils.save_onnx(modified_model, tmp_file_path, save_as_external_data=True) | ||
| logger.debug(f"Model with all outputs saved to {tmp_file_path}") | ||
| session = ort.InferenceSession(tmp_file_path, providers=self.providers) | ||
| runners = [OnnxrtRunner(lambda: session)] | ||
|
|
||
| else: | ||
| # For models without external data, use the original BytesFromOnnx approach (no tmp files) | ||
| logger.debug("Model has no external data, using BytesFromOnnx approach") | ||
| serialize_onnx = BytesFromOnnx(model) | ||
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | ||
| runners = [OnnxrtRunner(build_onnxrt_session)] | ||
|
|
||
| return runners | ||
|
|
||
| def run(self, inputs=None): | ||
| """Run FP32 inference with provided or random inputs.""" | ||
| import onnxruntime as ort | ||
| from polygraphy import constants | ||
| from polygraphy.backend.onnx import BytesFromOnnx | ||
| from polygraphy.backend.onnx import ModifyOutputs as ModifyOnnxOutputs | ||
| from polygraphy.backend.onnxrt import OnnxrtRunner, SessionFromOnnx | ||
| from polygraphy.comparator import Comparator | ||
|
|
||
| logger.info("Running ONNX Runtime to obtain reference outputs (this may take a while)...") | ||
|
|
@@ -133,9 +187,9 @@ def run(self, inputs=None): | |
|
|
||
| model_copy = copy.deepcopy(self.model) | ||
| modify_outputs = ModifyOnnxOutputs(model_copy, outputs=constants.MARK_ALL) | ||
| serialize_onnx = BytesFromOnnx(modify_outputs) | ||
| build_onnxrt_session = SessionFromOnnx(serialize_onnx, providers=self.providers) | ||
| runners = [OnnxrtRunner(build_onnxrt_session)] | ||
|
|
||
| # Load the modified model and create an inference session | ||
| runners = self._get_ort_runner(modify_outputs) | ||
|
|
||
| # Comparator is used despite the fact that we are using ONNXRuntime | ||
| # because it provides the ability to generate random inputs using DataLoader | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse this function to check for external data?
Model-Optimizer/modelopt/torch/_deploy/utils/torch_onnx.py
Line 629 in 307fe71