diff --git a/paddlex/inference/pipelines/__init__.py b/paddlex/inference/pipelines/__init__.py index 1ae8758244..36d9c7c517 100644 --- a/paddlex/inference/pipelines/__init__.py +++ b/paddlex/inference/pipelines/__init__.py @@ -164,6 +164,14 @@ def create_pipeline( else: config.pop("hpi_config", None) + # 支持自定义pipeline类的加载,根据pipeline_cls + pipeline_cls = config.pop("pipeline_cls", None) + if pipeline_cls and isinstance(pipeline_cls, str): + # 支持"module.submodule:ClassName"格式的字符串导入 + if ":" in pipeline_cls: + module_path, class_name = pipeline_cls.rsplit(":", 1) + __import__(module_path, fromlist=[class_name]) + pipeline = BasePipeline.get(pipeline_name)( config=config, device=device, diff --git a/paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py b/paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py index f940933524..d484ef8b2f 100644 --- a/paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py +++ b/paddlex/inference/serving/basic_serving/_pipeline_apps/__init__.py @@ -36,8 +36,18 @@ def _pipeline_name_to_mod_name(pipeline_name: str) -> str: @function_requires_deps("fastapi") def create_pipeline_app(pipeline: Any, pipeline_config: Dict[str, Any]) -> "FastAPI": pipeline_name = pipeline_config["pipeline_name"] - mod_name = _pipeline_name_to_mod_name(pipeline_name) - mod = importlib.import_module(f".{mod_name}", package=__package__) + # 支持自定义pipeline类的加载,根据pipeline_cls + pipeline_cls = pipeline_config["pipeline_cls"] + if pipeline_cls and isinstance(pipeline_cls, str): + # 支持"module.submodule:ClassName"格式的字符串导入 + if ":" in pipeline_cls: + module_path, class_name = pipeline_cls.rsplit(":", 1) + mod=importlib.import_module(module_path, package=__package__) + else: + mod=importlib.import_module(pipeline_cls, package=__package__) + else: + mod_name = _pipeline_name_to_mod_name(pipeline_name) + mod = importlib.import_module(f".{mod_name}", package=__package__) app_config = create_app_config(pipeline_config) app_creator = getattr(mod, "create_pipeline_app") app = app_creator(pipeline, app_config)