Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 27, 2024
1 parent bccee83 commit 0f70bb4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 26 deletions.
60 changes: 36 additions & 24 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool:
return False


class Exporter:
class _Exporter:
"""Class used for recursive traversal of Proto structures."""

def __init__(
self, rename: bool, use_operators: bool = False, inline_const: bool = False, skip_initializers: bool = False
self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool
) -> None:
self.use_operators = use_operators
if rename:
Expand All @@ -267,7 +267,7 @@ def __init__(
# We map the multiple SSA-variants back to the same Python variable name.
self._name_remappings: list[dict[str, str]] = []
self.skip_initializers = skip_initializers
self.skipped_initializers: list[onnx.TensorProto] = []
self.skipped_initializers: dict[str, onnx.TensorProto] = {}

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""
Expand Down Expand Up @@ -341,7 +341,12 @@ def _translate_graph_body(self, graph, opsets, indent=0):
if hasattr(graph, "initializer"):
for init in graph.initializer:
if self.skip_initializers:
self.skipped_initializers.append(init)
init_py_name = self._translate_onnx_var(init.name)
if init_py_name in self.skipped_initializers:
raise RuntimeError(
f"Initializer {init.name!r} is already present in skipped_initializers."
)
self.skipped_initializers[init_py_name] = init
continue
node = make_node(
"Constant",
Expand Down Expand Up @@ -710,41 +715,39 @@ def add(line: str) -> None:
return script

def _substitute_initializers(self, script: str, script_function_name: str) -> str:
init_names = [self._translate_onnx_var(x.name) for x in self.skipped_initializers]
init_names = self.skipped_initializers.keys()
# Formal parameters representing initializers (single level indentation)
initializers_as_params = "\n".join(
f"{_SINGLE_INDENT}{x}," for x in init_names
)
def generate_rand(x: TensorProto) -> str:
name = self._translate_onnx_var(x.name)
shape = ",".join(str(d) for d in x.dims)
if x.data_type != TensorProto.FLOAT:
__ = _SINGLE_INDENT
initializers_as_params = "\n".join(f"{__}{x}," for x in init_names)

def generate_rand(name: str, value: TensorProto) -> str:
shape = ",".join(str(d) for d in value.dims)
if value.data_type != TensorProto.FLOAT:
raise NotImplementedError(
f"Unable to generate random initializer for data type {x.data_type}."
f"Unable to generate random initializer for data type {value.data_type}."
)
return f"{_SINGLE_INDENT}{name} = numpy.random.rand({shape}).astype(numpy.float32)"
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)"

random_initializer_values = "\n".join(
generate_rand(x) for x in self.skipped_initializers
generate_rand(key, value) for key, value in self.skipped_initializers.items()
)
# Actual parameter values for initializers (double level indentation)
indented_initializers_as_params = "\n".join(
f"{_SINGLE_INDENT}{_SINGLE_INDENT}{x}," for x in init_names
)
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
return f"""
def make_model(
{initializers_as_params}
):
{script}
{_SINGLE_INDENT}model = {script_function_name}.to_model_proto()
{_SINGLE_INDENT}return model
{__}model = {script_function_name}.to_model_proto()
{__}return model
def make_model_with_random_weights():
{random_initializer_values}
{_SINGLE_INDENT}model = make_model(
{__}model = make_model(
{indented_initializers_as_params}
{_SINGLE_INDENT})
{_SINGLE_INDENT}return model
{__})
{__}return model

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

"""

def _import_onnx_types(
Expand Down Expand Up @@ -831,6 +834,7 @@ def visit_graph(graph: onnx.GraphProto) -> None:
def export2python(
model_onnx,
function_name: Optional[str] = None,
*,
rename: bool = False,
use_operators: bool = False,
inline_const: bool = False,
Expand All @@ -844,6 +848,9 @@ def export2python(
function_name: main function name
use_operators: use Python operators.
inline_const: replace ONNX constants inline if compact
skip_initializers: generated script will not include initializers.
Instead, a function that generates the model, given initializer values, is generated,
along with one that generates random values for the initializers.
Returns:
python code
Expand All @@ -869,5 +876,10 @@ def export2python(
if not isinstance(model_onnx, (ModelProto, FunctionProto)):
raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.")

exporter = Exporter(rename, use_operators, inline_const, skip_initializers)
exporter = _Exporter(
rename=rename,
use_operators=use_operators,
inline_const=inline_const,
skip_initializers=skip_initializers,
)
return exporter.export(model_onnx, function_name)
7 changes: 5 additions & 2 deletions tools/onnx2script.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def convert2script(
) -> None:
model = onnx.load(input_file_name, load_external_data=False)
python_code = onnxscript.proto2python(
model, use_operators=not verbose, inline_const=not verbose, skip_initializers=not initializers
model,
use_operators=not verbose,
inline_const=not verbose,
skip_initializers=not initializers,
)

# If output file name is not provided, use the input file name with .py extension
Expand Down Expand Up @@ -60,7 +63,7 @@ def convert2script(
"--initializers",
action="store_true",
help="Include initializers in the generated script",
default=False
default=False,
)

args = parser.parse_args()
Expand Down

0 comments on commit 0f70bb4

Please sign in to comment.