From 26c10201b2c14cafbafced931879e0be9d6fe4d8 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 22 Nov 2024 20:13:50 +0000 Subject: [PATCH] refactor checks --- .../version_converter/_version_converter.py | 75 +++++++------------ 1 file changed, 26 insertions(+), 49 deletions(-) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index cccbdaa36..570347a95 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -38,24 +38,11 @@ class Replacement: AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue] -@dataclasses.dataclass -class VersionAdapter: - """A class that represents a version checker for a particular op. - - It is applicable for a specific version upgrade (orignal_version -> original_version + 1) - or downgrade (orignal_version -> original_version - 1)of the op. - """ - - node_version: int - up_conversion: bool - function: AdapterFunction - - class AdapterRegistry: """A class that maintains a registry of adapters for ops.""" def __init__(self): - self.op_adapters: dict[tuple[str, str, int, bool], VersionAdapter] = {} + self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {} def lookup_adapters( self, @@ -64,9 +51,9 @@ def lookup_adapters( original_version: int, up_conversion: bool = True, ) -> AdapterFunction | None: - adapter = self.op_adapters.get((domain, opname, original_version, up_conversion)) - if adapter is not None: - return adapter.function + adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion)) + if adapter_func is not None: + return adapter_func return None def register( @@ -79,9 +66,7 @@ def decorator(function: AdapterFunction) -> AdapterFunction: def wrapped_function(*args, **kwargs): return function(*args, **kwargs) - self.op_adapters[(domain, opname, node_version, up_conversion)] = VersionAdapter( - node_version, up_conversion, function - ) + self.op_adapters[(domain, opname, node_version, up_conversion)] = function return wrapped_function return decorator @@ -229,7 +214,7 @@ def _upgrade_version(self, node: ir.Node, opset_version: int, up_conversion: boo def process_node( self, node: ir.Node, opset_version: int, up_conversion: bool = True ) -> Replacement | None: - if node.domain not in self.opset_imports: + if node.domain not in {"", "ai.onnx"}: return None adapter = registry.lookup_adapters( node.domain, node.op_type, opset_version, up_conversion @@ -277,23 +262,29 @@ def visit_node( return None def visit_graph(self, graph: ir.Graph) -> None: + if self.target_version > CURRENT_MAX_ONNX_OPSET: + logger.warning( + "Conversion to target opset: %s not currently supported.", + self.target_version, + ) + return None for node in graph: + up_conversion = True if node.version is None: node.version = self.model_version # Iterate each node from current node version -> target version # and updating node based on the correct adapter # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1] + # TODO (shubhambhokare1) : Remove once down-conversion adapters are supoorted + if self.target_version < node.version: + up_conversion = False + logger.warning( + "Target opset: %s less than %s, downstream version conversion not currently handled.", + self.target_version, + self.model_version, + ) + return None for opset_version in range(node.version, self.target_version): - up_conversion = True - if self.target_version < opset_version: - up_conversion = False - if up_conversion is True and opset_version == CURRENT_MAX_ONNX_OPSET: - logger.warning( - "Conversion from opset: %s to target opset: %s not currently supported.", - opset_version, - opset_version + 1, - ) - return None try: self.visit_node(node, graph, opset_version, up_conversion) self._upgrade_version(node, opset_version, up_conversion) @@ -307,26 +298,12 @@ def visit_graph(self, graph: ir.Graph) -> None: def visit_model(self, model: ir.Model) -> None: self.opset_imports = model.opset_imports - for opset_import in self.opset_imports: - if opset_import == "": - model_version = model.opset_imports.get("") - elif opset_import == "ai.onnx": - model_version = model.opset_imports.get("ai.onnx") - else: - return None + model_version = self.opset_imports.get("") if model_version is None: - return None + model_version = model.opset_imports.get("ai.onnx") + if model_version is None: + return None self.model_version = model_version - - # TODO (shubhambhokare1) : Remove once down-conversion adapters are supoorted - if self.target_version < model_version: - logger.warning( - "Target opset: %s less than %s, downstream version conversion not currently handled.", - self.target_version, - self.model_version, - ) - return None - self.visit_graph(model.graph) return None