-
Notifications
You must be signed in to change notification settings - Fork 96
[IR] Implement save/load functions in IR and handle external data properly #1801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ba9605d
[IR] Implement efficient save/load
justinchuby 4272a1f
lint
justinchuby 2b23254
Save
justinchuby b973bef
all
justinchuby c1c50f9
proto
justinchuby 10e229e
load
justinchuby 6ff6d2a
fix
justinchuby c121639
Refactor
justinchuby 8624f24
format
justinchuby 62a075c
docs
justinchuby 997384f
docs
justinchuby b69c314
basedir
justinchuby e1cbc00
data test
justinchuby d01bb25
lint
justinchuby fcbec0a
lint
justinchuby 7887b9e
test
justinchuby fb65fa4
lint
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| """External data related utilities.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| __all__ = ["set_base_dir"] | ||
|
|
||
| import os | ||
| from typing import Iterator | ||
|
|
||
| from onnxscript.ir import _core, _enums, _protocols, traversal | ||
|
|
||
|
|
||
| def _all_tensors( | ||
| graph: _core.Graph | _core.GraphView, include_attributes: bool = False | ||
| ) -> Iterator[_protocols.TensorProtocol]: | ||
| """Iterate over all tensors in the graph. | ||
|
|
||
| Args: | ||
| graph: The graph to traverse tensors on. | ||
| include_attributes: Whether to include tensors in attributes. | ||
|
|
||
| Yields: | ||
| Tensors in the graph. | ||
| """ | ||
| # Yield all tensors in initializers | ||
| for value in graph.initializers.values(): | ||
| if value.const_value is not None: | ||
| yield value.const_value | ||
| if not include_attributes: | ||
| return | ||
| # Look at constant attributes in nodes | ||
| for node in traversal.RecursiveGraphIterator(graph): | ||
| for attr in node.attributes.values(): | ||
| if isinstance(attr, _core.RefAttr): | ||
| continue | ||
| if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: | ||
| yield attr.value | ||
| elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: | ||
| yield from attr.value | ||
|
|
||
|
|
||
| def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: | ||
| """Set the base directory for external data in a graph. | ||
|
|
||
| Args: | ||
| graph: The graph to traverse tensors on. | ||
| base_dir: The base directory. This is the directory where the ONNX file is. | ||
| """ | ||
| for tensor in _all_tensors(graph, include_attributes=True): | ||
| if isinstance(tensor, _core.ExternalTensor): | ||
| tensor.base_dir = base_dir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| import unittest | ||
|
|
||
| import onnx | ||
| import onnx.external_data_helper | ||
|
|
||
| from onnxscript import ir | ||
| from onnxscript.ir import _external_data | ||
|
|
||
|
|
||
| class ExternalDataTest(unittest.TestCase): | ||
| def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): | ||
| attr_tensor = onnx.helper.make_tensor( | ||
| name="test_constant", | ||
| data_type=onnx.TensorProto.FLOAT, | ||
| dims=[1], | ||
| vals=b"\x01\x00\x00\x00", | ||
| raw=True, | ||
| ) | ||
| graph = onnx.helper.make_graph( | ||
| nodes=[ | ||
| onnx.helper.make_node( | ||
| "Constant", | ||
| [], | ||
| ["test"], | ||
| value=attr_tensor, | ||
| ) | ||
| ], | ||
| name="test", | ||
| inputs=[], | ||
| outputs=[], | ||
| initializer=[ | ||
| onnx.helper.make_tensor( | ||
| name="test_tensor", | ||
| data_type=onnx.TensorProto.FLOAT, | ||
| dims=[1], | ||
| vals=b"\x01\x00\x00\x00", | ||
| raw=True, | ||
| ), | ||
| ], | ||
| ) | ||
| model_proto = onnx.helper.make_model(graph) | ||
| onnx.external_data_helper.convert_model_to_external_data( | ||
| model_proto, location="tempdir", size_threshold=0, convert_attribute=True | ||
| ) | ||
| model = ir.serde.deserialize_model(model_proto) | ||
| expected_dir = "something_else" | ||
| _external_data.set_base_dir(model.graph, expected_dir) | ||
|
|
||
| initializer_tensor = model.graph.initializers["test_tensor"].const_value | ||
| assert isinstance(initializer_tensor, ir.ExternalTensor) | ||
| self.assertEqual(initializer_tensor.base_dir, expected_dir) | ||
| attr_tensor = model.graph.node(0).attributes["value"].value | ||
| self.assertEqual(attr_tensor.base_dir, expected_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
| """Load and save ONNX models.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| __all__ = ["load", "save"] | ||
|
|
||
| import os | ||
|
|
||
| import onnx | ||
|
|
||
| from onnxscript.ir import _core, _external_data, serde | ||
|
|
||
|
|
||
| def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: | ||
| """Load an ONNX model from a file. | ||
|
|
||
| Args: | ||
| path: The path to the ONNX file. | ||
| format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
| If None, the format is inferred from the file extension. | ||
|
|
||
| Returns: | ||
| The loaded model. | ||
| """ | ||
| # Do not use ONNX to load external data because the IR handles external data | ||
| # by doing memory mapping directly. | ||
| proto = onnx.load(path, format=format, load_external_data=False) | ||
| model = serde.deserialize_model(proto) | ||
| base_dir = os.path.dirname(path) | ||
| # Set the base directory for external data to the directory of the ONNX file | ||
| # so that relative paths are resolved correctly. | ||
| _external_data.set_base_dir(model.graph, base_dir) | ||
| return model | ||
|
|
||
|
|
||
| def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None: | ||
| """Save an ONNX model to a file. | ||
|
|
||
| Args: | ||
| model: The model to save. | ||
| path: The path to save the model to. | ||
| format: The format of the file (e.g. protobuf, textproto, json, etc.). | ||
| If None, the format is inferred from the file extension. | ||
| """ | ||
| proto = serde.serialize_model(model) | ||
| onnx.save(proto, path, format=format) | ||
| # TODO(justinchuby): Handle external data when the relative path has changed | ||
| # TODO(justinchuby): Handle off loading external data to disk when saving |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.