Skip to content
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import passes, serde, traversal
Expand Down Expand Up @@ -114,6 +117,7 @@
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
ArrayCompatible,
AttributeProtocol,
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=

Attributes:
path: The path to the data file. This can be a relative path or an absolute path.
base_dir: The base directory for the external data. It is used to resolve relative paths.
At serialization, only the ``path`` is serialized into the "location" field of the TensorProto.
offset: The offset in bytes from the start of the file.
length: The length of the data in bytes.
dtype: The data type of the tensor.
Expand Down Expand Up @@ -509,8 +511,15 @@ def __init__(
name: str,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
base_dir: os.PathLike | str = "",
) -> None:
self._path = path
if os.path.isabs(path):
self._base_dir = os.path.dirname(path)
self._path = os.path.basename(path)
else:
self._base_dir = base_dir
self._path = path

self._offset: int | None = offset
self._length: int | None = length
self._dtype: _enums.DataType = dtype
Expand All @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike:
# Immutable
return self._path

@property
def base_dir(self) -> str | os.PathLike:
# Mutable
return self._base_dir

@base_dir.setter
def base_dir(self, value: str | os.PathLike) -> None:
self._base_dir = value

@property
def offset(self) -> int | None:
# Immutable
Expand Down Expand Up @@ -556,7 +574,8 @@ def _load(self):
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
file_path = os.path.join(self._base_dir, self._path)
with open(file_path, "rb") as f:
self.raw = mmap.mmap(
f.fileno(),
0,
Expand Down Expand Up @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]:
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
return (
f"{self._repr_base()}(path='{self._path}', name={self.name!r}, "
f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})"
)

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
Expand Down Expand Up @@ -2069,7 +2091,7 @@ def __init__(
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.TensorProtocol] = (),
initializers: Sequence[_protocols.ValueProtocol] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand Down
17 changes: 17 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def test_initialize(self):
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_initialize_with_relative_path(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
tensor = _core.ExternalTensor(
path=external_info.location,
offset=external_info.offset,
length=external_info.length,
dtype=ir.DataType.FLOAT,
name="input",
shape=_core.Shape(external_tensor.dims),
base_dir=pathlib.Path(self.base_path),
)
self.assertEqual(tensor.dtype, ir.DataType.FLOAT)
np.testing.assert_equal(tensor, self.data)
# Ensure repeated reads are consistent
np.testing.assert_equal(tensor, self.data)

def test_totypes_returns_correct_data_in(self):
external_tensor = self.model.graph.initializer[0]
external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor)
Expand Down
53 changes: 53 additions & 0 deletions onnxscript/ir/_external_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""External data related utilities."""

from __future__ import annotations

__all__ = ["set_base_dir"]

import os
from typing import Iterator

from onnxscript.ir import _core, _enums, _protocols, traversal


def _all_tensors(
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
) -> Iterator[_protocols.TensorProtocol]:
"""Iterate over all tensors in the graph.

Args:
graph: The graph to traverse tensors on.
include_attributes: Whether to include tensors in attributes.

Yields:
Tensors in the graph.
"""
# Yield all tensors in initializers
for value in graph.initializers.values():
if value.const_value is not None:
yield value.const_value
if not include_attributes:
return
# Look at constant attributes in nodes
for node in traversal.RecursiveGraphIterator(graph):
for attr in node.attributes.values():
if isinstance(attr, _core.RefAttr):
continue
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
yield attr.value
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
yield from attr.value


def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None:
"""Set the base directory for external data in a graph.

Args:
graph: The graph to traverse tensors on.
base_dir: The base directory. This is the directory where the ONNX file is.
"""
for tensor in _all_tensors(graph, include_attributes=True):
if isinstance(tensor, _core.ExternalTensor):
tensor.base_dir = base_dir
59 changes: 59 additions & 0 deletions onnxscript/ir/_external_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest

import onnx
import onnx.external_data_helper

from onnxscript import ir
from onnxscript.ir import _external_data


class ExternalDataTest(unittest.TestCase):
def test_set_base_dir_sets_base_dir_for_all_external_tensors(self):
attr_tensor = onnx.helper.make_tensor(
name="test_constant",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
)
graph = onnx.helper.make_graph(
nodes=[
onnx.helper.make_node(
"Constant",
[],
["test"],
value=attr_tensor,
)
],
name="test",
inputs=[],
outputs=[],
initializer=[
onnx.helper.make_tensor(
name="test_tensor",
data_type=onnx.TensorProto.FLOAT,
dims=[1],
vals=b"\x01\x00\x00\x00",
raw=True,
),
],
)
model_proto = onnx.helper.make_model(graph)
onnx.external_data_helper.convert_model_to_external_data(
model_proto, location="tempdir", size_threshold=0, convert_attribute=True
)
model = ir.serde.deserialize_model(model_proto)
expected_dir = "something_else"
_external_data.set_base_dir(model.graph, expected_dir)

initializer_tensor = model.graph.initializers["test_tensor"].const_value
assert isinstance(initializer_tensor, ir.ExternalTensor)
self.assertEqual(initializer_tensor.base_dir, expected_dir)
attr_tensor = model.graph.node(0).attributes["value"].value
self.assertEqual(attr_tensor.base_dir, expected_dir)


if __name__ == "__main__":
unittest.main()
50 changes: 50 additions & 0 deletions onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load and save ONNX models."""

from __future__ import annotations

__all__ = ["load", "save"]

import os

import onnx

from onnxscript.ir import _core, _external_data, serde


def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
"""Load an ONNX model from a file.

Args:
path: The path to the ONNX file.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.

Returns:
The loaded model.
"""
# Do not use ONNX to load external data because the IR handles external data
# by doing memory mapping directly.
proto = onnx.load(path, format=format, load_external_data=False)
model = serde.deserialize_model(proto)
base_dir = os.path.dirname(path)
# Set the base directory for external data to the directory of the ONNX file
# so that relative paths are resolved correctly.
_external_data.set_base_dir(model.graph, base_dir)
return model


def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
"""Save an ONNX model to a file.

Args:
model: The model to save.
path: The path to save the model to.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
"""
proto = serde.serialize_model(model)
onnx.save(proto, path, format=format)
# TODO(justinchuby): Handle external data when the relative path has changed
# TODO(justinchuby): Handle off loading external data to disk when saving
12 changes: 7 additions & 5 deletions onnxscript/ir/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,35 @@
"RecursiveGraphIterator",
]

from typing import Callable, Iterator, Reversible
from typing import Callable, Iterator, Reversible, Union

from typing_extensions import Self

from onnxscript.ir import _core, _enums

GraphLike = Union[_core.Graph, _core.Function, _core.GraphView]


class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
def __init__(
self,
graph: _core.Graph | _core.Function | _core.GraphView,
graph_like: GraphLike,
*,
recursive: Callable[[_core.Node], bool] | None = None,
reverse: bool = False,
):
"""Iterate over the nodes in the graph, recursively visiting subgraphs.

Args:
graph: The graph to traverse.
graph_like: The graph to traverse.
recursive: A callback that determines whether to recursively visit the subgraphs
contained in a node. If not provided, all nodes in subgraphs are visited.
reverse: Whether to iterate in reverse order.
"""
self._graph = graph
self._graph = graph_like
self._recursive = recursive
self._reverse = reverse
self._iterator = self._recursive_node_iter(graph)
self._iterator = self._recursive_node_iter(graph_like)

def __iter__(self) -> Self:
self._iterator = self._recursive_node_iter(self._graph)
Expand Down