-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bccee83
commit 0f70bb4
Showing
2 changed files
with
41 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,11 +247,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool: | |
return False | ||
|
||
|
||
class Exporter: | ||
class _Exporter: | ||
"""Class used for recursive traversal of Proto structures.""" | ||
|
||
def __init__( | ||
self, rename: bool, use_operators: bool = False, inline_const: bool = False, skip_initializers: bool = False | ||
self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool | ||
) -> None: | ||
self.use_operators = use_operators | ||
if rename: | ||
|
@@ -267,7 +267,7 @@ def __init__( | |
# We map the multiple SSA-variants back to the same Python variable name. | ||
self._name_remappings: list[dict[str, str]] = [] | ||
self.skip_initializers = skip_initializers | ||
self.skipped_initializers: list[onnx.TensorProto] = [] | ||
self.skipped_initializers: dict[str, onnx.TensorProto] = {} | ||
|
||
def _handle_attrname_conflict(self, renamer): | ||
"""Add ref-attr-name-conflict handling logic to renaming function.""" | ||
|
@@ -341,7 +341,12 @@ def _translate_graph_body(self, graph, opsets, indent=0): | |
if hasattr(graph, "initializer"): | ||
for init in graph.initializer: | ||
if self.skip_initializers: | ||
self.skipped_initializers.append(init) | ||
init_py_name = self._translate_onnx_var(init.name) | ||
if init_py_name in self.skipped_initializers: | ||
raise RuntimeError( | ||
f"Initializer {init.name!r} is already present in skipped_initializers." | ||
) | ||
self.skipped_initializers[init_py_name] = init | ||
continue | ||
node = make_node( | ||
"Constant", | ||
|
@@ -710,41 +715,39 @@ def add(line: str) -> None: | |
return script | ||
|
||
def _substitute_initializers(self, script: str, script_function_name: str) -> str: | ||
init_names = [self._translate_onnx_var(x.name) for x in self.skipped_initializers] | ||
init_names = self.skipped_initializers.keys() | ||
# Formal parameters representing initializers (single level indentation) | ||
initializers_as_params = "\n".join( | ||
f"{_SINGLE_INDENT}{x}," for x in init_names | ||
) | ||
def generate_rand(x: TensorProto) -> str: | ||
name = self._translate_onnx_var(x.name) | ||
shape = ",".join(str(d) for d in x.dims) | ||
if x.data_type != TensorProto.FLOAT: | ||
__ = _SINGLE_INDENT | ||
initializers_as_params = "\n".join(f"{__}{x}," for x in init_names) | ||
|
||
def generate_rand(name: str, value: TensorProto) -> str: | ||
shape = ",".join(str(d) for d in value.dims) | ||
if value.data_type != TensorProto.FLOAT: | ||
raise NotImplementedError( | ||
f"Unable to generate random initializer for data type {x.data_type}." | ||
f"Unable to generate random initializer for data type {value.data_type}." | ||
) | ||
return f"{_SINGLE_INDENT}{name} = numpy.random.rand({shape}).astype(numpy.float32)" | ||
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)" | ||
|
||
random_initializer_values = "\n".join( | ||
generate_rand(x) for x in self.skipped_initializers | ||
generate_rand(key, value) for key, value in self.skipped_initializers.items() | ||
) | ||
# Actual parameter values for initializers (double level indentation) | ||
indented_initializers_as_params = "\n".join( | ||
f"{_SINGLE_INDENT}{_SINGLE_INDENT}{x}," for x in init_names | ||
) | ||
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names) | ||
return f""" | ||
def make_model( | ||
{initializers_as_params} | ||
): | ||
{script} | ||
{_SINGLE_INDENT}model = {script_function_name}.to_model_proto() | ||
{_SINGLE_INDENT}return model | ||
{__}model = {script_function_name}.to_model_proto() | ||
{__}return model | ||
def make_model_with_random_weights(): | ||
{random_initializer_values} | ||
{_SINGLE_INDENT}model = make_model( | ||
{__}model = make_model( | ||
{indented_initializers_as_params} | ||
{_SINGLE_INDENT}) | ||
{_SINGLE_INDENT}return model | ||
{__}) | ||
{__}return model | ||
Check warning Code scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
""" | ||
|
||
def _import_onnx_types( | ||
|
@@ -831,6 +834,7 @@ def visit_graph(graph: onnx.GraphProto) -> None: | |
def export2python( | ||
model_onnx, | ||
function_name: Optional[str] = None, | ||
*, | ||
rename: bool = False, | ||
use_operators: bool = False, | ||
inline_const: bool = False, | ||
|
@@ -844,6 +848,9 @@ def export2python( | |
function_name: main function name | ||
use_operators: use Python operators. | ||
inline_const: replace ONNX constants inline if compact | ||
skip_initializers: generated script will not include initializers. | ||
Instead, a function that generates the model, given initializer values, is generated, | ||
along with one that generates random values for the initializers. | ||
Returns: | ||
python code | ||
|
@@ -869,5 +876,10 @@ def export2python( | |
if not isinstance(model_onnx, (ModelProto, FunctionProto)): | ||
raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.") | ||
|
||
exporter = Exporter(rename, use_operators, inline_const, skip_initializers) | ||
exporter = _Exporter( | ||
rename=rename, | ||
use_operators=use_operators, | ||
inline_const=inline_const, | ||
skip_initializers=skip_initializers, | ||
) | ||
return exporter.export(model_onnx, function_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters