Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upgrade base
Browse files Browse the repository at this point in the history
TingquanGao committed Nov 4, 2024
1 parent 4a0fa12 commit 01509c9
Showing 8 changed files with 25 additions and 73 deletions.
25 changes: 16 additions & 9 deletions paddlex/inference/components/base.py
Original file line number Diff line number Diff line change
@@ -40,10 +40,6 @@ def __init__(self):
self.apply = self.timer.watch_func(self.apply)

def __call__(self, input_list):
# use list type for batched data
if not isinstance(input_list, list):
input_list = [input_list]

output_list = []
for args, input_ in self._check_input(input_list):
output = self.apply(**args)
@@ -107,10 +103,15 @@ def _check_args_key(args):
f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
)

# use list type for batched data
if not isinstance(input_list, list):
input_list = [input_list]

if self.need_batch_input:
args = {}
for input_ in input_list:
input_ = _check_type(input_)
for idx in range(len(input_list)):
input_ = _check_type(input_list[idx])
input_list[idx] = input_
for k, v in self.inputs.items():
if v not in input_:
raise Exception(
@@ -141,7 +142,9 @@ def _check_output(self, output, ori_data):
# when the output data is list type, reassemble each of that
if isinstance(output, list):
if self.need_batch_input:
assert isinstance(ori_data, list) and len(ori_data) == len(output)
assert isinstance(ori_data, list) and len(ori_data) == len(
output
), f"Error in {self.__class__.__name__}"
output_list = []
for ori_item, output_item in zip(ori_data, output):
data = ori_item.copy() if self.keep_input else {}
@@ -179,8 +182,12 @@ def _check_output(self, output, ori_data):
output_list.append(data)
return output_list
else:
assert isinstance(ori_data, dict) and isinstance(output, dict)
data = ori_data.copy() if self.keep_input else {}
assert isinstance(output, dict), f"Error in {self.__class__.__name__}"
if self.keep_input:
assert isinstance(ori_data, dict), f"Error in {self.__class__.__name__}"
data = ori_data.copy()
else:
data = {}
if isinstance(self.outputs, type(None)):
logging.debug(
f"The `output_key` of {self.__class__.__name__} is None, so would not inspect!"
2 changes: 2 additions & 0 deletions paddlex/inference/components/task_related/text_det.py
Original file line number Diff line number Diff line change
@@ -425,6 +425,8 @@ def apply(self, pred, img_shape):
class CropByPolys(BaseComponent):
"""Crop Image by Polys"""

YIELD_BATCH = False

INPUT_KEYS = ["input_path", "dt_polys"]
OUTPUT_KEYS = ["img"]
DEAULT_INPUTS = {"input_path": "input_path", "dt_polys": "dt_polys"}
1 change: 0 additions & 1 deletion paddlex/inference/models/anomaly_detection.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@
from ...modules.anomaly_detection.model_list import MODELS
from ..components import *
from ..results import SegResult
from ..utils.process_hook import batchable_method
from .base import BasicPredictor


2 changes: 1 addition & 1 deletion paddlex/inference/models/base/base_predictor.py
Original file line number Diff line number Diff line change
@@ -18,10 +18,10 @@
from abc import abstractmethod

from ...components.base import BaseComponent
from ...utils.process_hook import generatorable_method


class BasePredictor(BaseComponent):
ENABLE_BATCH = True

KEEP_INPUT = False
YIELD_BATCH = False
9 changes: 3 additions & 6 deletions paddlex/inference/models/base/basic_predictor.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@
from ....utils import logging
from ...components.base import BaseComponent, ComponentsEngine
from ...utils.pp_option import PaddlePredictorOption
from ...utils.process_hook import generatorable_method
from ...utils.benchmark import Benchmark
from .base_predictor import BasePredictor

@@ -68,11 +67,9 @@ def __call__(self, input, **kwargs):

def apply(self, input):
"""predict"""
yield from self._generate_res(self.engine(input))

@generatorable_method
def _generate_res(self, batch_data):
return [{"result": self._pack_res(data)} for data in batch_data]
for batch_data in self.engine(input):
for single_data in batch_data:
yield {"result": self._pack_res(single_data)}

def _add_component(self, cmps):
if not isinstance(cmps, list):
1 change: 0 additions & 1 deletion paddlex/inference/models/multilabel_classification.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@
from ...modules.multilabel_classification.model_list import MODELS
from ..components import *
from ..results import MLClassResult
from ..utils.process_hook import batchable_method
from .image_classification import ClasPredictor


4 changes: 3 additions & 1 deletion paddlex/inference/pipelines/ocr.py
Original file line number Diff line number Diff line change
@@ -73,7 +73,9 @@ def predict(self, input, **kwargs):
single_img_res["rec_text"] = []
single_img_res["rec_score"] = []
if len(single_img_res["dt_polys"]) > 0:
all_subs_of_img = list(self._crop_by_polys(single_img_res))
all_subs_of_img = [
sub["img"] for sub in self._crop_by_polys(single_img_res)
]
for rec_res in self.text_rec_model(all_subs_of_img):
single_img_res["rec_text"].append(rec_res["rec_text"])
single_img_res["rec_score"].append(rec_res["rec_score"])
54 changes: 0 additions & 54 deletions paddlex/inference/utils/process_hook.py

This file was deleted.

0 comments on commit 01509c9

Please sign in to comment.