-
Notifications
You must be signed in to change notification settings - Fork 96
[IR] Implement save/load functions in IR and handle external data properly #1801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ba9605d
[IR] Implement efficient save/load
justinchuby 4272a1f
lint
justinchuby 2b23254
Save
justinchuby b973bef
all
justinchuby c1c50f9
proto
justinchuby 10e229e
load
justinchuby 6ff6d2a
fix
justinchuby c121639
Refactor
justinchuby 8624f24
format
justinchuby 62a075c
docs
justinchuby 997384f
docs
justinchuby b69c314
basedir
justinchuby e1cbc00
data test
justinchuby d01bb25
lint
justinchuby fcbec0a
lint
justinchuby 7887b9e
test
justinchuby fb65fa4
lint
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| """Load and save ONNX models.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| __all__ = ["load", "save"] | ||
|
|
||
| import os | ||
| from typing import Iterator | ||
|
|
||
| import onnx | ||
|
|
||
| from onnxscript.ir import _core, _enums, _protocols, serde, traversal | ||
|
|
||
|
|
||
| def _all_tensors( | ||
| graph: _core.Graph | _core.GraphView, include_constants: bool = False | ||
| ) -> Iterator[_protocols.TensorProtocol]: | ||
| """Iterate over all tensors in the graph.""" | ||
|
|
||
| # Yield all tensors in initializers | ||
| for value in graph.initializers.values(): | ||
| if value.const_value is not None: | ||
| yield value.const_value | ||
| if not include_constants: | ||
| return | ||
| # Look at constant attributes in nodes | ||
| for node in traversal.RecursiveGraphIterator(graph): | ||
| for attr in node.attributes.values(): | ||
| if isinstance(attr, _core.RefAttr): | ||
| continue | ||
| if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: | ||
| yield attr.value | ||
| elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: | ||
| for value in attr.value: | ||
| yield value | ||
|
||
|
|
||
|
|
||
| def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: | ||
| """Load an ONNX model from a file. | ||
|
|
||
| Args: | ||
| path: The path to the ONNX file. | ||
| format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
| If None, the format is inferred from the file extension. | ||
|
|
||
| Returns: | ||
| The loaded model. | ||
| """ | ||
| # Do not use ONNX to load external data because the IR handles external data | ||
| # by doing memory mapping directly. | ||
| proto = onnx.load(path, format=format, load_external_data=False) | ||
| model = serde.deserialize_model(proto) | ||
| base_dir = os.path.dirname(path) | ||
| # Set the base directory for external data to the directory of the ONNX file | ||
| # so that relative paths are resolved correctly. | ||
| for tensor in _all_tensors(model.graph, include_constants=True): | ||
| if isinstance(tensor, _core.ExternalTensor): | ||
| tensor.base_dir = base_dir | ||
| return model | ||
|
|
||
|
|
||
| def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: | ||
| """Save an ONNX model to a file. | ||
|
|
||
| Args: | ||
| model: The model to save. | ||
| path: The path to save the model to. | ||
| format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
| If None, the format is inferred from the file extension. | ||
| """ | ||
| proto = serde.serialize_model(model) | ||
| onnx.save(proto, path, format=format) | ||
| # TODO(justinchuby): Handle external data | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.