Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add batch=1 log #3516

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions paddlex/inference/models/formula_recognition/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from ....utils import logging
from ....utils.func_register import FuncRegister
from ....modules.formula_recognition.model_list import MODELS
Expand All @@ -38,14 +39,31 @@


class FormulaRecPredictor(BasicPredictor):
"""FormulaRecPredictor that inherits from BasicPredictor."""

entities = MODELS

_FUNC_MAP = {}
register = FuncRegister(_FUNC_MAP)

def __init__(self, *args, **kwargs):
"""Initializes FormulaRecPredictor.
Args:
*args: Arbitrary positional arguments passed to the superclass.
**kwargs: Arbitrary keyword arguments passed to the superclass.
"""
super().__init__(*args, **kwargs)

self.model_names_only_supports_batchsize_of_one = {
"LaTeX_OCR_rec",
}
if self.model_name in self.model_names_only_supports_batchsize_of_one:
logging.warning(
f"Formula Recognition Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
"if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
)

self.pre_tfs, self.infer, self.post_op = self._build()

def _build_batch_sampler(self):
Expand Down Expand Up @@ -91,9 +109,25 @@ def process(self, batch_data):
batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)

x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
batch_preds = self.infer(x=x)
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
if self.model_name in self.model_names_only_supports_batchsize_of_one:
batch_preds = []
max_length = 0
for batch_img in batch_imgs:
batch_pred_ = self.infer([batch_img])[0].reshape([-1])
max_length = max(max_length, batch_pred_.shape[0])
batch_preds.append(batch_pred_)
for i in range(len(batch_preds)):
batch_preds[i] = np.pad(
batch_preds[i],
(0, max_length - batch_preds[i].shape[0]),
mode="constant",
constant_values=0,
)
else:
x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
batch_preds = self.infer(x=x)
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]

rec_formula = self.post_op(batch_preds)
return {
"input_path": batch_data.input_paths,
Expand Down