diff --git a/paddlex/inference/pipelines/components/common/cal_ocr_word_box.py b/paddlex/inference/pipelines/components/common/cal_ocr_word_box.py index f2f142bad5..4d4d254546 100644 --- a/paddlex/inference/pipelines/components/common/cal_ocr_word_box.py +++ b/paddlex/inference/pipelines/components/common/cal_ocr_word_box.py @@ -18,10 +18,19 @@ # from .convert_points_and_boxes import convert_points_to_boxes +def is_vertical_text(box: np.ndarray): + """Determine if the text box is vertical based on its aspect ratio.""" + + width = box[:,0].max() - box[:,0].min() + height = box[:,1].max() - box[:,1].min() + if (height / width) > 1.5: + return True + else: + return False def cal_ocr_word_box(rec_str, box, rec_word_info): """Calculate the detection frame for each word based on the results of recognition and detection of ocr""" - + is_vertical = is_vertical_text(box) col_num, word_list, word_col_list, state_list = rec_word_info box = box.tolist() bbox_x_start = box[0][0] @@ -29,50 +38,53 @@ def cal_ocr_word_box(rec_str, box, rec_word_info): bbox_y_start = box[0][1] bbox_y_end = box[2][1] - cell_width = (bbox_x_end - bbox_x_start) / col_num + if is_vertical: + cell_size = (bbox_y_end - bbox_y_start) / col_num + bbox_size = bbox_y_end - bbox_y_start + bbox_start = bbox_y_start + def create_box(start, end): + return ( + (bbox_x_start, start), (bbox_x_start, end), + (bbox_x_end, end), (bbox_x_end, start) + ) + else: + cell_size = (bbox_x_end - bbox_x_start) / col_num + bbox_size = bbox_x_end - bbox_x_start + bbox_start = bbox_x_start + def create_box(start, end): + return ( + (start, bbox_y_start), (end, bbox_y_start), + (end, bbox_y_end), (start, bbox_y_end) + ) word_box_list = [] word_box_content_list = [] - cn_width_list = [] + cn_size_list = [] cn_col_list = [] + + # Process words for word, word_col, state in zip(word_list, word_col_list, state_list): if state == "cn": if len(word_col) != 1: - char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width - char_width = char_seq_length / (len(word_col) - 1) - cn_width_list.append(char_width) + char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_size + char_size = char_seq_length / (len(word_col) - 1) + cn_size_list.append(char_size) cn_col_list += word_col word_box_content_list += word else: - cell_x_start = bbox_x_start + int(word_col[0] * cell_width) - cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width) - cell = ( - (cell_x_start, bbox_y_start), - (cell_x_end, bbox_y_start), - (cell_x_end, bbox_y_end), - (cell_x_start, bbox_y_end), - ) - word_box_list.append(cell) + cell_start = bbox_start + int(word_col[0] * cell_size) + cell_end = bbox_start + int((word_col[-1] + 1) * cell_size) + word_box_list.append(create_box(cell_start, cell_end)) word_box_content_list.append("".join(word)) + if len(cn_col_list) != 0: - if len(cn_width_list) != 0: - avg_char_width = np.mean(cn_width_list) - else: - avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str) + avg_char_size = np.mean(cn_size_list) if cn_size_list else bbox_size / len(rec_str) for center_idx in cn_col_list: - center_x = (center_idx + 0.5) * cell_width - cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start - cell_x_end = ( - min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) - + bbox_x_start - ) - cell = ( - (cell_x_start, bbox_y_start), - (cell_x_end, bbox_y_start), - (cell_x_end, bbox_y_end), - (cell_x_start, bbox_y_end), - ) - word_box_list.append(cell) + center = (center_idx + 0.5) * cell_size + cell_start = max(int(center - avg_char_size / 2), 0) + bbox_start + cell_end = min(int(center + avg_char_size / 2), bbox_size) + bbox_start + word_box_list.append(create_box(cell_start, cell_end)) + word_box_list = sort_boxes(word_box_list, y_thresh=12) return word_box_content_list, word_box_list