From a9451590c2fbfcb2ced206f4531684c1915af360 Mon Sep 17 00:00:00 2001 From: luoyueyuguang Date: Tue, 17 Mar 2026 00:07:11 +0800 Subject: [PATCH] feat: add multi-format frontend importer, deps and e2e tests --- python/pyproject.toml | 30 +- python/src/infinitensor/__init__.py | 14 +- python/src/infinitensor/model_frontend.py | 300 ++++++++++++++++++ .../src/infinitensor/torch_fx_translator.py | 22 +- python/tests/stubs/pyinfinitensor.py | 39 +++ python/tests/test_frontend_e2e.py | 139 ++++++++ python/tests/test_model_frontend.py | 79 +++++ requirements.txt | 11 +- 8 files changed, 630 insertions(+), 4 deletions(-) create mode 100644 python/src/infinitensor/model_frontend.py create mode 100644 python/tests/stubs/pyinfinitensor.py create mode 100644 python/tests/test_frontend_e2e.py create mode 100644 python/tests/test_model_frontend.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 8e1b5bf..35969e6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -15,5 +15,33 @@ classifiers = [ "Operating System :: OS Independent",] dependencies = [] +[project.optional-dependencies] +test = [ + "pytest", + "numpy", + "torch", +] +frontends = [ + "onnx", + "onnxscript", + "onnx2torch", + "tensorflow-cpu", + "tf2onnx", + "paddlepaddle", + "paddle2onnx", +] +all = [ + "pytest", + "numpy", + "torch", + "onnx", + "onnxscript", + "onnx2torch", + "tensorflow-cpu", + "tf2onnx", + "paddlepaddle", + "paddle2onnx", +] + [tool.setuptools.package-data] -pyinfinitensor = ["*.so"] \ No newline at end of file +pyinfinitensor = ["*.so"] diff --git a/python/src/infinitensor/__init__.py b/python/src/infinitensor/__init__.py index 3f1fbf2..112a0d0 100644 --- a/python/src/infinitensor/__init__.py +++ b/python/src/infinitensor/__init__.py @@ -5,6 +5,18 @@ import pyinfinitensor from pyinfinitensor import Runtime, DeviceType +from .model_frontend import ( + FrontendModelImporter, + ModelFormat, + detect_model_format, + load_model_as_torch, +) from .torch_fx_translator import TorchFXTranslator -__all__ = ["TorchFXTranslator"] +__all__ = [ + "TorchFXTranslator", + "FrontendModelImporter", + "ModelFormat", + "detect_model_format", + "load_model_as_torch", +] diff --git a/python/src/infinitensor/model_frontend.py b/python/src/infinitensor/model_frontend.py new file mode 100644 index 0000000..46b86bb --- /dev/null +++ b/python/src/infinitensor/model_frontend.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import tempfile +import importlib +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union, cast + +import torch + + +class ModelFormat(str, Enum): + PYTORCH = "pytorch" + ONNX = "onnx" + TENSORFLOW = "tensorflow" + PADDLE = "paddle" + + +_SUFFIX_TO_FORMAT: Dict[str, ModelFormat] = { + ".pt": ModelFormat.PYTORCH, + ".pth": ModelFormat.PYTORCH, + ".torchscript": ModelFormat.PYTORCH, + ".ts": ModelFormat.PYTORCH, + ".jit": ModelFormat.PYTORCH, + ".onnx": ModelFormat.ONNX, + ".pb": ModelFormat.TENSORFLOW, + ".pdmodel": ModelFormat.PADDLE, +} + + +def _normalize_format( + format_hint: Optional[Union[str, ModelFormat]], +) -> Optional[ModelFormat]: + if format_hint is None: + return None + if isinstance(format_hint, ModelFormat): + return format_hint + + normalized = format_hint.strip().lower() + aliases: Dict[str, ModelFormat] = { + "pt": ModelFormat.PYTORCH, + "torch": ModelFormat.PYTORCH, + "pytorch": ModelFormat.PYTORCH, + "onnx": ModelFormat.ONNX, + "tf": ModelFormat.TENSORFLOW, + "tensorflow": ModelFormat.TENSORFLOW, + "paddle": ModelFormat.PADDLE, + "paddlepaddle": ModelFormat.PADDLE, + } + if normalized not in aliases: + raise ValueError(f"Unsupported model format hint: {format_hint}") + return aliases[normalized] + + +def detect_model_format( + model_path: Union[str, Path], + format_hint: Optional[Union[str, ModelFormat]] = None, +) -> ModelFormat: + hinted_format = _normalize_format(format_hint) + if hinted_format is not None: + return hinted_format + + path = Path(model_path) + if not path.exists(): + raise FileNotFoundError(f"Model path does not exist: {path}") + + if path.is_dir(): + if (path / "saved_model.pb").exists(): + return ModelFormat.TENSORFLOW + + if ( + (path / "__model__").exists() + or any(path.glob("*.pdmodel")) + or (path / "inference.json").exists() + ): + return ModelFormat.PADDLE + + raise ValueError( + f"Unable to infer model format from directory: {path}. " + "Use format_hint to specify one explicitly." + ) + + suffix = path.suffix.lower() + if suffix in _SUFFIX_TO_FORMAT: + return _SUFFIX_TO_FORMAT[suffix] + + if path.name in {"__model__", "inference.json"}: + return ModelFormat.PADDLE + + raise ValueError( + f"Unable to infer model format from file: {path}. " + "Use format_hint to specify one explicitly." + ) + + +def _load_onnx_as_torch(model_path: Path) -> torch.nn.Module: + try: + onnx = importlib.import_module("onnx") + except ImportError as exc: + raise ImportError("Loading ONNX models requires dependency 'onnx'.") from exc + + try: + onnx2torch = importlib.import_module("onnx2torch") + except ImportError as exc: + raise ImportError( + "Loading ONNX models as torch modules requires dependency 'onnx2torch'." + ) from exc + + onnx_model = onnx.load(model_path.as_posix()) + torch_model = onnx2torch.convert(onnx_model) + torch_model.eval() + return torch_model + + +def _tensorflow_to_onnx_model_path(model_path: Path, opset: int) -> Path: + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp_file: + onnx_path = Path(tmp_file.name) + + try: + if model_path.is_dir(): + result = subprocess.run( + [ + sys.executable, + "-m", + "tf2onnx.convert", + "--saved-model", + model_path.as_posix(), + "--output", + onnx_path.as_posix(), + "--opset", + str(opset), + ], + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError( + "Failed to convert TensorFlow SavedModel to ONNX: " + f"{result.stderr.strip() or result.stdout.strip()}" + ) + else: + try: + tf2onnx = importlib.import_module("tf2onnx") + except ImportError as exc: + raise ImportError( + "Loading TensorFlow models requires dependency 'tf2onnx'." + ) from exc + + try: + tf = importlib.import_module("tensorflow") + except ImportError as exc: + raise ImportError( + "Converting TensorFlow .pb graph requires dependency 'tensorflow'." + ) from exc + + graph_def = cast(Any, tf).compat.v1.GraphDef() + graph_def.ParseFromString(model_path.read_bytes()) + cast(Any, tf2onnx).convert.from_graph_def( + graph_def, + name=model_path.stem, + output_path=onnx_path.as_posix(), + opset=opset, + ) + except Exception: + onnx_path.unlink(missing_ok=True) + raise + + return onnx_path + + +def _paddle_to_onnx_model_path(model_path: Path, opset: int) -> Path: + try: + paddle2onnx = importlib.import_module("paddle2onnx") + except ImportError as exc: + raise ImportError( + "Loading Paddle models requires dependency 'paddle2onnx'." + ) from exc + + if model_path.is_dir(): + model_candidates = list(model_path.glob("*.pdmodel")) + if not model_candidates and (model_path / "inference.json").exists(): + model_candidates = [model_path / "inference.json"] + params_candidates = list(model_path.glob("*.pdiparams")) + if not model_candidates or not params_candidates: + raise ValueError( + "Paddle directory must contain model file (*.pdmodel or inference.json) " + "and *.pdiparams file." + ) + model_file = model_candidates[0] + params_file = params_candidates[0] + elif model_path.suffix.lower() in {".pdmodel", ".json"}: + model_file = model_path + params_file = model_path.with_suffix(".pdiparams") + if not params_file.exists(): + raise FileNotFoundError( + f"Expected Paddle parameter file next to model: {params_file}" + ) + else: + raise ValueError( + "Paddle model path must be a directory with model/params files " + "or a model file (*.pdmodel or *.json)." + ) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp_file: + onnx_path = Path(tmp_file.name) + + try: + cast(Any, paddle2onnx).export( + model_filename=model_file.as_posix(), + params_filename=params_file.as_posix(), + save_file=onnx_path.as_posix(), + opset_version=opset, + enable_onnx_checker=True, + ) + except Exception: + onnx_path.unlink(missing_ok=True) + raise + + return onnx_path + + +def _load_tensorflow_as_torch(model_path: Path, opset: int) -> torch.nn.Module: + onnx_path = _tensorflow_to_onnx_model_path(model_path, opset=opset) + try: + return _load_onnx_as_torch(onnx_path) + finally: + onnx_path.unlink(missing_ok=True) + + +def _load_paddle_as_torch(model_path: Path, opset: int) -> torch.nn.Module: + onnx_path = _paddle_to_onnx_model_path(model_path, opset=opset) + try: + return _load_onnx_as_torch(onnx_path) + finally: + onnx_path.unlink(missing_ok=True) + + +FrontendLoader = Callable[[Path], torch.nn.Module] + + +class FrontendModelImporter: + def __init__(self): + self._loaders: Dict[ModelFormat, FrontendLoader] = { + ModelFormat.ONNX: _load_onnx_as_torch, + ModelFormat.TENSORFLOW: self._tensorflow_loader, + ModelFormat.PADDLE: self._paddle_loader, + } + self._default_opset = 17 + + def _tensorflow_loader(self, model_path: Path) -> torch.nn.Module: + return _load_tensorflow_as_torch(model_path, opset=self._default_opset) + + def _paddle_loader(self, model_path: Path) -> torch.nn.Module: + return _load_paddle_as_torch(model_path, opset=self._default_opset) + + def set_default_opset(self, opset: int) -> None: + if opset < 7: + raise ValueError("ONNX opset must be >= 7") + self._default_opset = opset + + def register_loader( + self, model_format: ModelFormat, loader: FrontendLoader + ) -> None: + self._loaders[model_format] = loader + + def list_supported_formats(self): + return sorted(fmt.value for fmt in self._loaders) + + def load( + self, + model_path: Union[str, Path], + format_hint: Optional[Union[str, ModelFormat]] = None, + ) -> torch.nn.Module: + path = Path(model_path) + model_format = detect_model_format(path, format_hint=format_hint) + if model_format == ModelFormat.PYTORCH: + raise ValueError( + "PyTorch format should be handled directly by passing a torch.nn.Module " + "into TorchFXTranslator.import_from_fx()." + ) + + if model_format not in self._loaders: + raise ValueError( + f"No loader registered for model format: {model_format.value}" + ) + + return self._loaders[model_format](path) + + +default_frontend_importer = FrontendModelImporter() + + +def load_model_as_torch( + model_path: Union[str, Path], + format_hint: Optional[Union[str, ModelFormat]] = None, +) -> torch.nn.Module: + return default_frontend_importer.load(model_path, format_hint=format_hint) diff --git a/python/src/infinitensor/torch_fx_translator.py b/python/src/infinitensor/torch_fx_translator.py index e4320ed..bb9ec32 100644 --- a/python/src/infinitensor/torch_fx_translator.py +++ b/python/src/infinitensor/torch_fx_translator.py @@ -11,9 +11,11 @@ import torch from torch import fx from torch.export import export, Dim -from typing import Callable, Dict, List, Tuple, Optional, Union +from typing import Any, Callable, Dict, List, Tuple, Optional, Union from .converter import registry +from .model_frontend import ModelFormat, load_model_as_torch import inspect +import re class TorchFXTranslator: @@ -146,16 +148,19 @@ def _process_dynamic_shapes(self, fake_inputs): def _process_call_function(self, node): """Handle function call nodes""" target = node.target + func_name = str(target) if hasattr(target, "_overloadpacket"): op_name = str(target._overloadpacket).split(".")[-1] overload = target._overloadname function = registry.get_method_converter(op_name, overload) + func_name = f"{op_name}.{overload}" if overload else op_name else: if hasattr(target, "__name__"): op_base_name = target.__name__ else: op_base_name = str(target) function = registry.get_method_converter(op_base_name) + func_name = op_base_name if function: try: self.nodes_map[node] = function @@ -301,6 +306,21 @@ def import_from_fx( # print(self.builder.to_string()) + def import_from_model_path( + self, + model_path: str, + input_list: List[torch.Tensor], + format_hint: Optional[Union[str, ModelFormat]] = None, + is_real_tensor: bool = False, + ): + """Import model from ONNX / TensorFlow / Paddle path via frontend loader.""" + torch_model = load_model_as_torch(model_path, format_hint=format_hint) + return self.import_from_fx( + torch_model, + input_list, + is_real_tensor=is_real_tensor, + ) + def run(self, input_list: List[torch.Tensor]): """ Run computation graph diff --git a/python/tests/stubs/pyinfinitensor.py b/python/tests/stubs/pyinfinitensor.py new file mode 100644 index 0000000..aca1640 --- /dev/null +++ b/python/tests/stubs/pyinfinitensor.py @@ -0,0 +1,39 @@ +class DeviceType: + CPU = "cpu" + CUDA = "cuda" + MLU = "mlu" + ASCEND = "ascend" + METAX = "metax" + MOORE = "moore" + ILUVATAR = "iluvatar" + KUNLUN = "kunlun" + HYGON = "hygon" + + +class Runtime: + @staticmethod + def setup(*args, **kwargs): + return Runtime() + + +class Tensor: + pass + + +class GraphBuilder: + def __init__(self, runtime): + self.runtime = runtime + + +class ShapeExpr: + def __init__(self, shape): + self.shape = shape + + +class StrideExpr: + def __init__(self, stride): + self.stride = stride + + +def dtype_from_string(dtype): + return dtype diff --git a/python/tests/test_frontend_e2e.py b/python/tests/test_frontend_e2e.py new file mode 100644 index 0000000..1ae86e9 --- /dev/null +++ b/python/tests/test_frontend_e2e.py @@ -0,0 +1,139 @@ +import importlib.util +import subprocess +import sys +import tempfile +from pathlib import Path + +import numpy as np +import torch + + +def _load_model_frontend_module(): + module_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "infinitensor" + / "model_frontend.py" + ) + spec = importlib.util.spec_from_file_location( + "infinitensor_model_frontend", module_path + ) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +frontend = _load_model_frontend_module() + + +def test_onnx_frontend_e2e(): + class Toy(torch.nn.Module): + def forward(self, x): + return x * 2.0 + 1.0 + + model = Toy().eval() + x = torch.randn(2, 3, dtype=torch.float32) + + with tempfile.TemporaryDirectory() as td: + onnx_path = Path(td) / "toy.onnx" + torch.onnx.export( + model, + (x,), + onnx_path.as_posix(), + input_names=["x"], + output_names=["y"], + opset_version=17, + ) + + loaded = frontend.load_model_as_torch(onnx_path) + loaded.eval() + + y_ref = model(x) + y_loaded = loaded(x) + max_diff = (y_ref - y_loaded).abs().max().item() + assert max_diff < 1e-5 + + +def test_tensorflow_savedmodel_frontend_e2e(): + tf = __import__("tensorflow") + + class ToyTF(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=[None, 3], dtype=tf.float32)]) + def __call__(self, x): + return {"y": x * 3.0 - 2.0} + + x_np = np.random.randn(4, 3).astype(np.float32) + x_torch = torch.from_numpy(x_np) + + with tempfile.TemporaryDirectory() as td: + saved_dir = Path(td) / "saved_model" + tf.saved_model.save(ToyTF(), saved_dir.as_posix()) + + loaded = frontend.load_model_as_torch(saved_dir) + loaded.eval() + + y_ref = x_torch * 3.0 - 2.0 + y_loaded = loaded(x_torch) + max_diff = (y_ref - y_loaded).abs().max().item() + assert max_diff < 1e-4 + + +def test_paddle_frontend_e2e(): + script = """ +import importlib.util +import tempfile +from pathlib import Path +import numpy as np +import torch +import paddle + +module_path = Path('/home/luoyue/compiler/InfiniTensor_v2.0/python/src/infinitensor/model_frontend.py') +spec = importlib.util.spec_from_file_location('model_frontend', module_path) +module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(module) + +paddle.enable_static() +place = paddle.CPUPlace() +exe = paddle.static.Executor(place) +startup = paddle.static.Program() +main = paddle.static.Program() + +with paddle.static.program_guard(main, startup): + x = paddle.static.data(name='x', shape=[-1, 3], dtype='float32') + w = paddle.create_parameter( + shape=[1], + dtype='float32', + default_initializer=paddle.nn.initializer.Constant(0.5), + ) + y = x * 4.0 + w + +exe.run(startup) +x_np = np.random.randn(5, 3).astype('float32') +x_torch = torch.from_numpy(x_np) + +with tempfile.TemporaryDirectory() as td: + save_dir = Path(td) / 'paddle_model' + save_dir.mkdir(parents=True, exist_ok=True) + paddle.static.save_inference_model( + path_prefix=(save_dir / 'inference').as_posix(), + feed_vars=[x], + fetch_vars=[y], + executor=exe, + program=main, + ) + loaded = module.load_model_as_torch(save_dir) + loaded.eval() + + y_ref = x_torch * 4.0 + 0.5 + y_loaded = loaded(x_torch) + max_diff = (y_ref - y_loaded).abs().max().item() + assert max_diff < 1e-4 +""" + result = subprocess.run( + [sys.executable, "-c", script], + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stderr or result.stdout diff --git a/python/tests/test_model_frontend.py b/python/tests/test_model_frontend.py new file mode 100644 index 0000000..5e65a5f --- /dev/null +++ b/python/tests/test_model_frontend.py @@ -0,0 +1,79 @@ +import importlib.util +from pathlib import Path + +import pytest +import torch + + +def _load_model_frontend_module(): + module_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "infinitensor" + / "model_frontend.py" + ) + spec = importlib.util.spec_from_file_location( + "infinitensor_model_frontend", module_path + ) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +frontend = _load_model_frontend_module() + + +def test_detect_model_format_by_file_extension(tmp_path): + onnx_file = tmp_path / "model.onnx" + onnx_file.write_bytes(b"onnx") + assert frontend.detect_model_format(onnx_file) == frontend.ModelFormat.ONNX + + tf_file = tmp_path / "graph.pb" + tf_file.write_bytes(b"pb") + assert frontend.detect_model_format(tf_file) == frontend.ModelFormat.TENSORFLOW + + paddle_file = tmp_path / "inference.pdmodel" + paddle_file.write_bytes(b"pdmodel") + assert frontend.detect_model_format(paddle_file) == frontend.ModelFormat.PADDLE + + +def test_detect_model_format_for_directories(tmp_path): + tf_dir = tmp_path / "saved_model" + tf_dir.mkdir() + (tf_dir / "saved_model.pb").write_bytes(b"pb") + assert frontend.detect_model_format(tf_dir) == frontend.ModelFormat.TENSORFLOW + + paddle_dir = tmp_path / "paddle_model" + paddle_dir.mkdir() + (paddle_dir / "inference.pdmodel").write_bytes(b"pdmodel") + (paddle_dir / "inference.pdiparams").write_bytes(b"params") + assert frontend.detect_model_format(paddle_dir) == frontend.ModelFormat.PADDLE + + +def test_detect_model_format_hint_overrides_path(tmp_path): + unknown_file = tmp_path / "weights.bin" + unknown_file.write_bytes(b"data") + fmt = frontend.detect_model_format(unknown_file, format_hint="onnx") + assert fmt == frontend.ModelFormat.ONNX + + +def test_detect_model_format_unknown_raises(tmp_path): + unknown_file = tmp_path / "weights.bin" + unknown_file.write_bytes(b"data") + with pytest.raises(ValueError): + frontend.detect_model_format(unknown_file) + + +def test_frontend_importer_custom_loader_dispatch(tmp_path): + class DummyModule(torch.nn.Module): + def forward(self, x): + return x + + onnx_file = tmp_path / "model.onnx" + onnx_file.write_bytes(b"onnx") + + importer = frontend.FrontendModelImporter() + importer.register_loader(frontend.ModelFormat.ONNX, lambda _: DummyModule()) + module = importer.load(onnx_file) + assert isinstance(module, DummyModule) diff --git a/requirements.txt b/requirements.txt index 2f7b71e..af5b170 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,13 @@ pytest numpy -torch \ No newline at end of file +torch + +# Frontend format support (ONNX / TensorFlow / Paddle) +onnx +onnxscript +onnx2torch +tensorflow-cpu +tf2onnx +paddlepaddle +paddle2onnx