From ced80cf1cf12b330c20011a041ca4ae9cd3cd624 Mon Sep 17 00:00:00 2001 From: liuhongen1234567 <2998388548@qq.com> Date: Tue, 4 Mar 2025 06:55:38 +0000 Subject: [PATCH 1/2] add batch=1 log --- .../models/formula_recognition/predictor.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/paddlex/inference/models/formula_recognition/predictor.py b/paddlex/inference/models/formula_recognition/predictor.py index 3f41b1a4f..d092fdb1b 100644 --- a/paddlex/inference/models/formula_recognition/predictor.py +++ b/paddlex/inference/models/formula_recognition/predictor.py @@ -38,6 +38,7 @@ class FormulaRecPredictor(BasicPredictor): + """FormulaRecPredictor that inherits from BasicPredictor.""" entities = MODELS @@ -45,7 +46,23 @@ class FormulaRecPredictor(BasicPredictor): register = FuncRegister(_FUNC_MAP) def __init__(self, *args, **kwargs): + """Initializes FormulaRecPredictor. + Args: + *args: Arbitrary positional arguments passed to the superclass. + **kwargs: Arbitrary keyword arguments passed to the superclass. + """ super().__init__(*args, **kwargs) + + self.model_names_only_supports_batchsize_of_one = { + "LaTeX_OCR_rec", + } + if self.model_name in self.model_names_only_supports_batchsize_of_one: + logging.warning( + f"Formula Recognition Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, " + "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, " + f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}." + ) + self.pre_tfs, self.infer, self.post_op = self._build() def _build_batch_sampler(self): From 76cdde4721d4f3b21a9d7223b05966cbef20b8d0 Mon Sep 17 00:00:00 2001 From: liuhongen1234567 <2998388548@qq.com> Date: Wed, 5 Mar 2025 08:58:04 +0000 Subject: [PATCH 2/2] fix batch=1 bug --- .../models/formula_recognition/predictor.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/paddlex/inference/models/formula_recognition/predictor.py b/paddlex/inference/models/formula_recognition/predictor.py index d092fdb1b..58bdc48c5 100644 --- a/paddlex/inference/models/formula_recognition/predictor.py +++ b/paddlex/inference/models/formula_recognition/predictor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np from ....utils import logging from ....utils.func_register import FuncRegister from ....modules.formula_recognition.model_list import MODELS @@ -108,9 +109,25 @@ def process(self, batch_data): batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs) batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs) - x = self.pre_tfs["ToBatch"](imgs=batch_imgs) - batch_preds = self.infer(x=x) - batch_preds = [p.reshape([-1]) for p in batch_preds[0]] + if self.model_name in self.model_names_only_supports_batchsize_of_one: + batch_preds = [] + max_length = 0 + for batch_img in batch_imgs: + batch_pred_ = self.infer([batch_img])[0].reshape([-1]) + max_length = max(max_length, batch_pred_.shape[0]) + batch_preds.append(batch_pred_) + for i in range(len(batch_preds)): + batch_preds[i] = np.pad( + batch_preds[i], + (0, max_length - batch_preds[i].shape[0]), + mode="constant", + constant_values=0, + ) + else: + x = self.pre_tfs["ToBatch"](imgs=batch_imgs) + batch_preds = self.infer(x=x) + batch_preds = [p.reshape([-1]) for p in batch_preds[0]] + rec_formula = self.post_op(batch_preds) return { "input_path": batch_data.input_paths,