Skip to content
4 changes: 4 additions & 0 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
# Pass infrastructure
"passes",
"traversal",
# IO
"load",
"save",
]

from onnxscript.ir import passes, serde, traversal
Expand Down Expand Up @@ -114,6 +117,7 @@
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
ArrayCompatible,
AttributeProtocol,
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=

Attributes:
path: The path to the data file. This can be a relative path or an absolute path.
base_dir: The base directory for the external data. It is used to resolve relative paths.
At serialization, only the ``path`` is serialized into the "location" field of the TensorProto.
offset: The offset in bytes from the start of the file.
length: The length of the data in bytes.
dtype: The data type of the tensor.
Expand Down Expand Up @@ -509,8 +511,15 @@ def __init__(
name: str,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
base_dir: os.PathLike | str = "",
) -> None:
self._path = path
if os.path.isabs(path):
self._base_dir = os.path.dirname(path)
self._path = os.path.basename(path)
else:
self._base_dir = base_dir
self._path = path

self._offset: int | None = offset
self._length: int | None = length
self._dtype: _enums.DataType = dtype
Expand All @@ -528,6 +537,15 @@ def path(self) -> str | os.PathLike:
# Immutable
return self._path

@property
def base_dir(self) -> str | os.PathLike:
# Mutable
return self._base_dir

@base_dir.setter
def base_dir(self, value: str | os.PathLike) -> None:
self._base_dir = value

@property
def offset(self) -> int | None:
# Immutable
Expand Down Expand Up @@ -556,7 +574,8 @@ def _load(self):
return
# Map the whole file into the memory
# TODO(justinchuby): Verify if this would exhaust the memory address space
with open(self._path, "rb") as f:
file_path = os.path.join(self._base_dir, self._path)
with open(file_path, "rb") as f:
self.raw = mmap.mmap(
f.fileno(),
0,
Expand Down Expand Up @@ -599,7 +618,10 @@ def __dlpack_device__(self) -> tuple[int, int]:
)

def __repr__(self) -> str:
return f"{self._repr_base()}(path='{self._path}', name={self.name!r}, offset={self._offset!r}), length={self._length!r})"
return (
f"{self._repr_base()}(path='{self._path}', name={self.name!r}, "
f"offset={self._offset!r}, length={self._length!r}, base_dir={self._base_dir!r})"
)

def numpy(self) -> np.ndarray:
"""Return the tensor as a numpy array.
Expand Down Expand Up @@ -2069,7 +2091,7 @@ def __init__(
outputs: Sequence[Value],
*,
nodes: Iterable[Node],
initializers: Sequence[_protocols.TensorProtocol] = (),
initializers: Sequence[_protocols.ValueProtocol] = (),
doc_string: str | None = None,
opset_imports: dict[str, int] | None = None,
name: str | None = None,
Expand Down
75 changes: 75 additions & 0 deletions onnxscript/ir/_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Load and save ONNX models."""

from __future__ import annotations

__all__ = ["load", "save"]

import os
from typing import Iterator

import onnx

from onnxscript.ir import _core, _enums, _protocols, serde, traversal


def _all_tensors(
graph: _core.Graph | _core.GraphView, include_constants: bool = False
) -> Iterator[_protocols.TensorProtocol]:
"""Iterate over all tensors in the graph."""

# Yield all tensors in initializers
for value in graph.initializers.values():
if value.const_value is not None:
yield value.const_value
if not include_constants:
return
# Look at constant attributes in nodes
for node in traversal.RecursiveGraphIterator(graph):
for attr in node.attributes.values():
if isinstance(attr, _core.RefAttr):
continue
if attr.type == _enums.AttributeType.TENSOR and attr.value is not None:
yield attr.value
elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None:
for value in attr.value:
yield value


def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
"""Load an ONNX model from a file.

Args:
path: The path to the ONNX file.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.

Returns:
The loaded model.
"""
# Do not use ONNX to load external data because the IR handles external data
# by doing memory mapping directly.
proto = onnx.load(path, format=format, load_external_data=False)
model = serde.deserialize_model(proto)
base_dir = os.path.dirname(path)
# Set the base directory for external data to the directory of the ONNX file
# so that relative paths are resolved correctly.
for tensor in _all_tensors(model.graph, include_constants=True):
if isinstance(tensor, _core.ExternalTensor):
tensor.base_dir = base_dir
return model


def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
"""Save an ONNX model to a file.

Args:
model: The model to save.
path: The path to save the model to.
format: The format of the file (e.g. protobuf, textproto, json, etc.).
If None, the format is inferred from the file extension.
"""
proto = serde.serialize_model(model)
onnx.save(proto, path, format=format)
# TODO(justinchuby): Handle external data