Skip to content

Commit

Permalink
refactor checks
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Nov 22, 2024
1 parent c38c2cb commit 26c1020
Showing 1 changed file with 26 additions and 49 deletions.
75 changes: 26 additions & 49 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,11 @@ class Replacement:
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]


@dataclasses.dataclass
class VersionAdapter:
"""A class that represents a version checker for a particular op.
It is applicable for a specific version upgrade (orignal_version -> original_version + 1)
or downgrade (orignal_version -> original_version - 1)of the op.
"""

node_version: int
up_conversion: bool
function: AdapterFunction


class AdapterRegistry:
"""A class that maintains a registry of adapters for ops."""

def __init__(self):
self.op_adapters: dict[tuple[str, str, int, bool], VersionAdapter] = {}
self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {}

def lookup_adapters(
self,
Expand All @@ -64,9 +51,9 @@ def lookup_adapters(
original_version: int,
up_conversion: bool = True,
) -> AdapterFunction | None:
adapter = self.op_adapters.get((domain, opname, original_version, up_conversion))
if adapter is not None:
return adapter.function
adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion))
if adapter_func is not None:
return adapter_func
return None

def register(
Expand All @@ -79,9 +66,7 @@ def decorator(function: AdapterFunction) -> AdapterFunction:
def wrapped_function(*args, **kwargs):
return function(*args, **kwargs)

Check warning on line 67 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L67

Added line #L67 was not covered by tests

self.op_adapters[(domain, opname, node_version, up_conversion)] = VersionAdapter(
node_version, up_conversion, function
)
self.op_adapters[(domain, opname, node_version, up_conversion)] = function
return wrapped_function

return decorator
Expand Down Expand Up @@ -229,7 +214,7 @@ def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: boo
def process_node(
self, node: ir.Node, opset_version: int, up_conversion: bool = True
) -> Replacement | None:
if node.domain not in self.opset_imports:
if node.domain not in {"", "ai.onnx"}:
return None

Check warning on line 218 in onnxscript/version_converter/_version_converter.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/version_converter/_version_converter.py#L218

Added line #L218 was not covered by tests
adapter = registry.lookup_adapters(
node.domain, node.op_type, opset_version, up_conversion
Expand Down Expand Up @@ -277,23 +262,29 @@ def visit_node(
return None

def visit_graph(self, graph: ir.Graph) -> None:
if self.target_version > CURRENT_MAX_ONNX_OPSET:
logger.warning(
"Conversion to target opset: %s not currently supported.",
self.target_version,
)
return None
for node in graph:
up_conversion = True
if node.version is None:
node.version = self.model_version
# Iterate each node from current node version -> target version
# and updating node based on the correct adapter
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
# TODO (shubhambhokare1) : Remove once down-conversion adapters are supoorted
if self.target_version < node.version:
up_conversion = False

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable up_conversion is not used.
logger.warning(
"Target opset: %s less than %s, downstream version conversion not currently handled.",
self.target_version,
self.model_version,
)
return None
for opset_version in range(node.version, self.target_version):
up_conversion = True
if self.target_version < opset_version:
up_conversion = False
if up_conversion is True and opset_version == CURRENT_MAX_ONNX_OPSET:
logger.warning(
"Conversion from opset: %s to target opset: %s not currently supported.",
opset_version,
opset_version + 1,
)
return None
try:
self.visit_node(node, graph, opset_version, up_conversion)
self._upgrade_version(node, opset_version, up_conversion)
Expand All @@ -307,26 +298,12 @@ def visit_graph(self, graph: ir.Graph) -> None:

def visit_model(self, model: ir.Model) -> None:
self.opset_imports = model.opset_imports
for opset_import in self.opset_imports:
if opset_import == "":
model_version = model.opset_imports.get("")
elif opset_import == "ai.onnx":
model_version = model.opset_imports.get("ai.onnx")
else:
return None
model_version = self.opset_imports.get("")
if model_version is None:
return None
model_version = model.opset_imports.get("ai.onnx")
if model_version is None:
return None
self.model_version = model_version

# TODO (shubhambhokare1) : Remove once down-conversion adapters are supoorted
if self.target_version < model_version:
logger.warning(
"Target opset: %s less than %s, downstream version conversion not currently handled.",
self.target_version,
self.model_version,
)
return None

self.visit_graph(model.graph)
return None

Expand Down

0 comments on commit 26c1020

Please sign in to comment.