diff --git a/paddlex/inference/models/object_detection/modeling/__init__.py b/paddlex/inference/models/object_detection/modeling/__init__.py index acc3858aaf..b68e445b3f 100644 --- a/paddlex/inference/models/object_detection/modeling/__init__.py +++ b/paddlex/inference/models/object_detection/modeling/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .pp_doclayout_v2 import PPDocLayoutV2 from .rt_detr import RTDETR diff --git a/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2.py b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2.py new file mode 100644 index 0000000000..f3d3cb9df0 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2.py @@ -0,0 +1,464 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn.functional as F + +from .pp_doclayout_v2_modules.pp_doclayout_v2_transformer import PPDocLayoutTransformer +from .rt_detr import RTDETR, DETRPostProcess, RTDETRConfig +from .rtdetrl_modules.modules.utils import bbox_cxcywh_to_xyxy + +__all__ = ["PPDocLayoutV2"] + + +def get_order(order_logits): + order_scores = paddle.nn.functional.sigmoid(order_logits) + B, N, _ = order_scores.shape + one = paddle.ones([N, N], dtype=order_scores.dtype) + upper = paddle.triu(one, 1) + lower = paddle.tril(one, -1) + Q = order_scores * upper + (1.0 - paddle.transpose(order_scores, [0, 2, 1])) * lower + order_votes = paddle.sum(Q, axis=1) + order_pointers = paddle.argsort(order_votes, axis=1) + order_seq = paddle.full(order_pointers.shape, -1, dtype=order_pointers.dtype) + batch_indices = paddle.arange(B).reshape([-1, 1]).expand([B, N]) + order_seq[batch_indices, order_pointers] = paddle.arange(N).expand([B, N]) + + return order_seq, order_votes + + +class PPDocLayoutPostProcess(DETRPostProcess): + def __init__(self, **kwargs): + kwargs.setdefault("num_classes", 25) + super(PPDocLayoutPostProcess, self).__init__(**kwargs) + + def __call__(self, head_out, order_logits, im_shape, scale_factor, pad_shape): + """ + Decode the bbox and mask. + + Args: + head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output. + order_logits (Tensor): The result from ReadingOrder. + im_shape (Tensor): The shape of the input image without padding. + scale_factor (Tensor): The scale factor of the input image. + pad_shape (Tensor): The shape of the input image with padding. + Returns: + bbox_pred (Tensor): The output prediction with shape [N, 6], including + labels, scores and bboxes. The size of bboxes are corresponding + to the input image, the bboxes may be used in other branch. + bbox_num (Tensor): The number of prediction boxes of each batch with + shape [bs], and is N. + """ + bboxes, logits, masks = head_out + if self.dual_queries: + num_queries = logits.shape[1] + logits, bboxes = ( + logits[:, : int(num_queries // (self.dual_groups + 1)), :], + bboxes[:, : int(num_queries // (self.dual_groups + 1)), :], + ) + + bbox_pred = bbox_cxcywh_to_xyxy(bboxes) + + # calculate the original shape of the image + origin_shape = paddle.floor(im_shape / scale_factor + 0.5) + img_h, img_w = paddle.split(origin_shape, 2, axis=-1) + if self.bbox_decode_type == "pad": + # calculate the shape of the image with padding + out_shape = pad_shape / im_shape * origin_shape + out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1) + elif self.bbox_decode_type == "origin": + out_shape = origin_shape.flip(1).tile([1, 2]).unsqueeze(1) + else: + raise Exception(f"Wrong `bbox_decode_type`: {self.bbox_decode_type}.") + bbox_pred *= out_shape + + scores = ( + F.sigmoid(logits) if self.use_focal_loss else F.softmax(logits)[:, :, :-1] + ) + + pad_order_seq, pad_order_votes = get_order(order_logits) + + if not self.use_focal_loss: + scores, labels = scores.max(-1), scores.argmax(-1) + if scores.shape[1] > self.num_top_queries: + scores, index = paddle.topk(scores, self.num_top_queries, axis=-1) + batch_ind = ( + paddle.arange(end=scores.shape[0]) + .unsqueeze(-1) + .tile([1, self.num_top_queries]) + ) + index = paddle.stack([batch_ind, index], axis=-1) + labels = paddle.gather_nd(labels, index) + bbox_pred = paddle.gather_nd(bbox_pred, index) + pad_order_seq = paddle.gather_nd(pad_order_seq, index) + pad_order_votes = paddle.gather_nd(pad_order_votes, index) + else: + scores, index = paddle.topk( + scores.flatten(1), self.num_top_queries, axis=-1 + ) + labels = index % self.num_classes + index = index // self.num_classes + batch_ind = ( + paddle.arange(end=scores.shape[0]) + .unsqueeze(-1) + .tile([1, self.num_top_queries]) + ) + index = paddle.stack([batch_ind, index], axis=-1) + bbox_pred = paddle.gather_nd(bbox_pred, index) + pad_order_seq = paddle.gather_nd(pad_order_seq, index) + pad_order_votes = paddle.gather_nd(pad_order_votes, index) + + mask_pred = None + if self.with_mask: + assert masks is not None + assert masks.shape[0] == 1 + masks = paddle.gather_nd(masks, index) + if self.bbox_decode_type == "pad": + masks = F.interpolate( + masks, + scale_factor=self.mask_stride, + mode="bilinear", + align_corners=False, + ) + # TODO: Support prediction with bs>1. + # remove padding for input image + h, w = im_shape.astype("int32")[0] + masks = masks[..., :h, :w] + # get pred_mask in the original resolution. + img_h = img_h[0].astype("int32") + img_w = img_w[0].astype("int32") + masks = F.interpolate( + masks, size=[img_h, img_w], mode="bilinear", align_corners=False + ) + mask_pred, scores = self._mask_postprocess(masks, scores) + + bbox_pred = paddle.concat( + [ + labels.unsqueeze(-1).astype("float32"), + scores.unsqueeze(-1), + bbox_pred, + pad_order_seq.unsqueeze(-1).astype("float32"), + pad_order_votes.unsqueeze(-1).astype("float32"), + ], + axis=-1, + ) + bbox_num = paddle.to_tensor(self.num_top_queries, dtype="int32").tile( + [bbox_pred.shape[0]] + ) + bbox_pred = bbox_pred.reshape([-1, 8]) + return bbox_pred, bbox_num, mask_pred + + +class PPDocLayoutV2Config(RTDETRConfig): + pass + + +class PPDocLayoutV2(RTDETR): + + config_class = PPDocLayoutV2Config + + def __init__(self, config: PPDocLayoutV2Config): + super(PPDocLayoutV2, self).__init__(config) + + self.transformer = PPDocLayoutTransformer( + num_queries=self.config.tf_num_queries, + position_embed_type=self.config.tf_position_embed_type, + feat_strides=self.config.tf_feat_strides, + backbone_feat_channels=self.config.tf_backbone_feat_channels, + num_levels=self.config.tf_num_levels, + nhead=self.config.tf_nhead, + num_decoder_layers=self.config.tf_num_decoder_layers, + dim_feedforward=self.config.tf_dim_feedforward, + dropout=self.config.tf_dropout, + activation=self.config.tf_activation, + num_denoising=self.config.tf_num_denoising, + label_noise_ratio=self.config.tf_label_noise_ratio, + box_noise_scale=self.config.tf_box_noise_scale, + learnt_init_query=self.config.tf_learnt_init_query, + ) + + self.post_process = PPDocLayoutPostProcess( + num_top_queries=self.config.num_top_queries, + use_focal_loss=self.config.use_focal_loss, + ) + + def forward(self, inputs): + x = paddle.to_tensor(inputs[1]) + x = self.backbone(x) + x_neck = self.neck(x) + x = self.transformer(x_neck) + order_logits = x[-1] + preds = self.head(x[:-1], x_neck) + bbox, bbox_num, mask = self.post_process( + preds, + order_logits, + paddle.to_tensor(inputs[0]), + paddle.to_tensor(inputs[2]), + inputs[1][2:].shape, + ) + + output = [bbox, bbox_num] + return output + + def get_transpose_weight_keys(self): + t_layers = [ + "fc", + "out_proj", + "q_proj", + "k_proj", + "v_proj", + "linear_1", + "linear_2", + "enc_bbox_head", + "spatial_proj", + "query", + "key", + "value", + "intermediate", + "attention", + "output", + "relative_head", + "query_pos_head", + "enc_score_head", + "in_proj_weight", + "linear1", + "linear2", + "label_features_projection", + "reading_order_predictor.encoder.layer", + "encoder_attn", + "decoder.bbox_embed", + "decoder.class_embed", + ] + keys = [] + for key, _ in self.get_hf_state_dict().items(): + for t_layer in t_layers: + if ( + t_layer in key + and key.endswith("weight") + and "LayerNorm" not in key + and "layer_norm" not in key + and "enc_output.1" not in key + ): + keys.append(key) + + return keys + + def set_hf_state_dict(self, state_dict, *args, **kwargs): + import re + + mapping = { + # --- Backbone --- + r"model.backbone.model.embedder.stem(\d+)a.normalization": r"backbone.stem.stem\1a.bn", + r"model.backbone.model.embedder.stem(\d+)b.normalization": r"backbone.stem.stem\1b.bn", + r"model.backbone.model.embedder.stem(\d+)a.convolution": r"backbone.stem.stem\1a.conv", + r"model.backbone.model.embedder.stem(\d+)b.convolution": r"backbone.stem.stem\1b.conv", + r"model.backbone.model.embedder.stem(\d+).normalization": r"backbone.stem.stem\1.bn", + r"model.backbone.model.embedder.stem(\d+).convolution": r"backbone.stem.stem\1.conv", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).layers.(\d+).conv(\d+).normalization": r"backbone.stages.\1.blocks.\2.layers.\3.conv\4.bn", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).layers.(\d+).conv(\d+).convolution": r"backbone.stages.\1.blocks.\2.layers.\3.conv\4.conv", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).layers.(\d+).normalization": r"backbone.stages.\1.blocks.\2.layers.\3.bn", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).layers.(\d+).convolution": r"backbone.stages.\1.blocks.\2.layers.\3.conv", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).aggregation.0.normalization": r"backbone.stages.\1.blocks.\2.aggregation_squeeze_conv.bn", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).aggregation.0.convolution": r"backbone.stages.\1.blocks.\2.aggregation_squeeze_conv.conv", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).aggregation.1.convolution": r"backbone.stages.\1.blocks.\2.aggregation_excitation_conv.conv", + r"model.backbone.model.encoder.stages.(\d+).blocks.(\d+).aggregation.1.normalization": r"backbone.stages.\1.blocks.\2.aggregation_excitation_conv.bn", + r"model.backbone.model.encoder.stages.(\d+).downsample.normalization": r"backbone.stages.\1.downsample.bn", + r"model.backbone.model.encoder.stages.(\d+).downsample.convolution": r"backbone.stages.\1.downsample.conv", + # --- Decoder --- + r"model.decoder_input_proj.(\d+).0": r"transformer.input_proj.\1.conv", + r"model.decoder_input_proj.(\d+).1": r"transformer.input_proj.\1.norm", + r"model.decoder.layers.(\d+).self_attn_layer_norm": r"transformer.decoder.layers.\1.norm1", + r"model.decoder.layers.(\d+).encoder_attn_layer_norm": r"transformer.decoder.layers.\1.norm2", + r"model.decoder.layers.(\d+).final_layer_norm": r"transformer.decoder.layers.\1.norm3", + r"model.decoder.layers.(\d+).encoder_attn": r"transformer.decoder.layers.\1.cross_attn", + r"model.decoder.layers.(\d+).fc(\d+)": r"transformer.decoder.layers.\1.linear\2", + # --- Encoder --- + r"model.encoder.encoder.(\d+).layers.(\d+).self_attn_layer_norm": r"neck.encoder.\1.layers.\2.norm1", + r"model.encoder.encoder.(\d+).layers.(\d+).final_layer_norm": r"neck.encoder.\1.layers.\2.norm2", + r"model.encoder.encoder.(\d+).layers.(\d+).fc(\d+)": r"neck.encoder.\1.layers.\2.linear\3", + r"model.encoder.encoder.(\d+).layers.(\d+).fc(\d+).bias": r"neck.encoder.\1.layers.\2.norm\3.bias", + r"model.encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm": r"neck.fpn_blocks.\1.bottlenecks.\2.conv\3.bn", + r"model.encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).norm": r"neck.pan_blocks.\1.bottlenecks.\2.conv\3.bn", + r"model.encoder.fpn_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv": r"neck.fpn_blocks.\1.bottlenecks.\2.conv\3.conv", + r"model.encoder.pan_blocks.(\d+).bottlenecks.(\d+).conv(\d+).conv": r"neck.pan_blocks.\1.bottlenecks.\2.conv\3.conv", + r"model.encoder.fpn_blocks.(\d+).conv(\d+).norm": r"neck.fpn_blocks.\1.conv\2.bn", + r"model.encoder.pan_blocks.(\d+).conv(\d+).norm": r"neck.pan_blocks.\1.conv\2.bn", + r"model.encoder.lateral_convs.(\d+).norm": r"neck.lateral_convs.\1.bn", + r"model.encoder.downsample_convs.(\d+).norm": r"neck.downsample_convs.\1.bn", + # --- General --- + "model.backbone.model.encoder.stages": "backbone.stages", + "model.decoder.layers": "transformer.decoder.layers", + "model.decoder.bbox_embed": "transformer.dec_bbox_head", + "model.decoder.class_embed": "transformer.dec_score_head", + "model.decoder.query_pos_head": "transformer.query_pos_head", + "reading_order": "transformer.reading_order_predictor", + "model.encoder_input_proj": "neck.input_proj", + "model.encoder": "neck", + "model": "transformer", + } + + def _convert_key(key): + for pattern, replacement in mapping.items(): + new_key, n = re.subn(pattern, replacement, key) + if n > 0: + return new_key + return key + + def _convert_state_dict(state_dict): + keys = state_dict.keys() + new_tensors = {} + for key in keys: + tensor = state_dict[key] + new_key = _convert_key(key) + + if "q_proj.weight" in new_key or "q_proj.bias" in new_key: + k_proj = state_dict.get(key.replace("q_proj", "k_proj"), None) + v_proj = state_dict.get(key.replace("q_proj", "v_proj"), None) + if k_proj is not None and v_proj is not None: + merged_tensor = paddle.cat([tensor, k_proj, v_proj], dim=-1) + merged_key = new_key.replace("q_proj.", "in_proj_") + new_tensors[merged_key] = merged_tensor + else: + new_tensors[new_key] = tensor + + return new_tensors + + state_dict = _convert_state_dict(state_dict) + key_mapping = {} + rules = self._get_reverse_key_rules() + for old_key in list(state_dict.keys()): + for match_key, old_sub, new_sub in rules: + if match_key in old_key: + key_mapping[old_key] = old_key.replace(old_sub, new_sub) + break + for old_key, new_key in key_mapping.items(): + state_dict[new_key] = state_dict.pop(old_key) + + return self.set_state_dict(state_dict, *args, **kwargs) + + def get_hf_state_dict(self, *args, **kwargs): + import re + + mapping = { + # --- Backbone --- + r"backbone\.stem\.stem(\d+)a\.bn": r"model.backbone.model.embedder.stem\1a.normalization", + r"backbone\.stem\.stem(\d+)b\.bn": r"model.backbone.model.embedder.stem\1b.normalization", + r"backbone\.stem\.stem(\d+)a\.conv": r"model.backbone.model.embedder.stem\1a.convolution", + r"backbone\.stem\.stem(\d+)b\.conv": r"model.backbone.model.embedder.stem\1b.convolution", + r"backbone\.stem\.stem(\d+)\.bn": r"model.backbone.model.embedder.stem\1.normalization", + r"backbone\.stem\.stem(\d+)\.conv": r"model.backbone.model.embedder.stem\1.convolution", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv(\d+)\.bn": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv\4.normalization", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv(\d+)\.conv": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.conv\4.convolution", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.bn": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.normalization", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv\b": r"model.backbone.model.encoder.stages.\1.blocks.\2.layers.\3.convolution", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation_squeeze_conv\.bn": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.0.normalization", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation_squeeze_conv\.conv": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.0.convolution", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation_excitation_conv\.conv": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.1.convolution", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation_excitation_conv\.bn": r"model.backbone.model.encoder.stages.\1.blocks.\2.aggregation.1.normalization", + r"backbone\.stages\.(\d+)\.downsample\.bn": r"model.backbone.model.encoder.stages.\1.downsample.normalization", + r"backbone\.stages\.(\d+)\.downsample\.conv": r"model.backbone.model.encoder.stages.\1.downsample.convolution", + # --- Decoder --- + r"transformer\.input_proj\.(\d+)\.conv": r"model.decoder_input_proj.\1.0", + r"transformer\.input_proj\.(\d+)\.norm": r"model.decoder_input_proj.\1.1", + r"transformer\.decoder\.layers\.(\d+)\.norm1": r"model.decoder.layers.\1.self_attn_layer_norm", + r"transformer\.decoder\.layers\.(\d+)\.norm2": r"model.decoder.layers.\1.encoder_attn_layer_norm", + r"transformer\.decoder\.layers\.(\d+)\.norm3": r"model.decoder.layers.\1.final_layer_norm", + r"transformer\.decoder\.layers\.(\d+)\.cross_attn": r"model.decoder.layers.\1.encoder_attn", + r"transformer\.decoder\.layers\.(\d+)\.linear(\d+)": r"model.decoder.layers.\1.fc\2", + # --- Encoder --- + r"neck\.encoder\.(\d+)\.layers\.(\d+)\.norm1": r"model.encoder.encoder.\1.layers.\2.self_attn_layer_norm", + r"neck\.encoder\.(\d+)\.layers\.(\d+)\.norm2": r"model.encoder.encoder.\1.layers.\2.final_layer_norm", + r"neck\.encoder\.(\d+)\.layers\.(\d+)\.linear(\d+)": r"model.encoder.encoder.\1.layers.\2.fc\3", + r"neck\.encoder\.(\d+)\.layers\.(\d+)\.norm(\d+)\.bias": r"model.encoder.encoder.\1.layers.\2.fc\3.bias", + r"neck\.fpn_blocks\.(\d+)\.bottlenecks\.(\d+)\.conv(\d+)\.bn": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.norm", + r"neck\.pan_blocks\.(\d+)\.bottlenecks\.(\d+)\.conv(\d+)\.bn": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.norm", + r"neck\.fpn_blocks\.(\d+)\.bottlenecks\.(\d+)\.conv(\d+)\.conv": r"model.encoder.fpn_blocks.\1.bottlenecks.\2.conv\3.conv", + r"neck\.pan_blocks\.(\d+)\.bottlenecks\.(\d+)\.conv(\d+)\.conv": r"model.encoder.pan_blocks.\1.bottlenecks.\2.conv\3.conv", + r"neck\.fpn_blocks\.(\d+)\.conv(\d+)\.bn": r"model.encoder.fpn_blocks.\1.conv\2.norm", + r"neck\.pan_blocks\.(\d+)\.conv(\d+)\.bn": r"model.encoder.pan_blocks.\1.conv\2.norm", + r"neck\.lateral_convs\.(\d+)\.bn": r"model.encoder.lateral_convs.\1.norm", + r"neck\.downsample_convs\.(\d+)\.bn": r"model.encoder.downsample_convs.\1.norm", + # --- General --- + "backbone.stages": "model.backbone.model.encoder.stages", + "transformer.decoder.layers": "model.decoder.layers", + "transformer.dec_bbox_head": "model.decoder.bbox_embed", + "transformer.dec_score_head": "model.decoder.class_embed", + "transformer.query_pos_head": "model.decoder.query_pos_head", + "transformer.reading_order_predictor": "reading_order", + "transformer": "model", + "neck.input_proj": "model.encoder_input_proj", + "neck": "model.encoder", + } + + def _convert_key(key): + for pattern, replacement in mapping.items(): + new_key, n = re.subn(pattern, replacement, key) + if n > 0: + return new_key + return key + + def _split_linear(tensor, key): + encoder_hidden_dim = 256 + if "in_proj_weight" in key: + q = key.replace("in_proj_weight", "q_proj.weight") + q_tensor = tensor[:encoder_hidden_dim, :].clone() + k = key.replace("in_proj_weight", "k_proj.weight") + k_tensor = tensor[ + encoder_hidden_dim : 2 * encoder_hidden_dim, : + ].clone() + v = key.replace("in_proj_weight", "v_proj.weight") + v_tensor = tensor[-encoder_hidden_dim:, :].clone() + elif "in_proj_bias" in key: + q = key.replace("in_proj_bias", "q_proj.bias") + q_tensor = tensor[:encoder_hidden_dim].clone() + k = key.replace("in_proj_bias", "k_proj.bias") + k_tensor = tensor[encoder_hidden_dim : 2 * encoder_hidden_dim].clone() + v = key.replace("in_proj_bias", "v_proj.bias") + v_tensor = tensor[-encoder_hidden_dim:].clone() + + return q, k, v, q_tensor, k_tensor, v_tensor + + def _convert_state_dict(current_state_dict): + keys = current_state_dict.keys() + new_tensors = {} + for key in keys: + tensor = current_state_dict[key] + new_key = _convert_key(key) + + if "in_proj_weight" in new_key or "in_proj_bias" in new_key: + q, k, v, q_tensor, k_tensor, v_tensor = _split_linear( + tensor, new_key + ) + new_tensors[q] = q_tensor + new_tensors[k] = k_tensor + new_tensors[v] = v_tensor + else: + new_tensors[new_key] = tensor + return new_tensors + + model_state_dict = self.state_dict(*args, **kwargs) + hf_state_dict = {} + rules = self._get_forward_key_rules() + for old_key, value in model_state_dict.items(): + new_key = old_key + for match_key, old_sub, new_sub in rules: + if match_key in old_key: + new_key = old_key.replace(old_sub, new_sub) + break + hf_state_dict[new_key] = value + + hf_state_dict = _convert_state_dict(hf_state_dict) + return hf_state_dict diff --git a/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/pp_doclayout_v2_transformer.py b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/pp_doclayout_v2_transformer.py new file mode 100644 index 0000000000..e69c97621b --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/pp_doclayout_v2_transformer.py @@ -0,0 +1,479 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Modified from detrex (https://github.com/IDEA-Research/detrex) +# Copyright 2022 The IDEA Authors. All rights reserved. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ..rtdetrl_modules.modules.detr_ops import _get_clones, inverse_sigmoid +from ..rtdetrl_modules.modules.utils import ( + bbox_cxcywh_to_xyxy, + get_contrastive_denoising_training_group, +) +from ..rtdetrl_modules.rtdetr_transformer import ( + RTDETRTransformer, + TransformerDecoderLayer, +) +from .reading_order_predictor import ReadingOrderPredictor + +__all__ = ["PPDocLayoutTransformer"] + + +DET_TO_ORDER_MAP = { + "paragraph_title": ("paragraph_title", 0), + "image": ("image", 1), + "table": ("image", 1), + "chart": ("image", 1), + "text": ("text", 2), + "reference": ("text", 2), + "algorithm": ("text", 2), + "reference_content": ("text", 2), + "inline_formula": ("text", 2), + "number": ("number", 3), + "abstract": ("abstract", 4), + "content": ("content", 5), + "figure_title": ("figure_title", 6), + "vision_footnote": ("figure_title", 6), + "formula": ("display_formula", 7), + "display_formula": ("display_formula", 7), + "doc_title": ("doc_title", 8), + "footnote": ("footnote", 9), + "header": ("header", 10), + "header_image": ("header", 10), + "footer": ("footer", 11), + "footer_image": ("footer", 11), + "seal": ("seal", 12), + "formula_number": ("formula_number", 13), + "aside_text": ("aside_text", 14), + "vertical_text": ("vertical_text", 15), +} + + +def get_label_map(): + categories = [ + "abstract", + "algorithm", + "aside_text", + "chart", + "content", + "display_formula", + "doc_title", + "figure_title", + "footer", + "footer_image", + "footnote", + "formula_number", + "header", + "header_image", + "image", + "inline_formula", + "number", + "paragraph_title", + "reference", + "reference_content", + "seal", + "table", + "text", + "vertical_text", + "vision_footnote", + ] + + label_map = [] + for det_id, det_name in enumerate(categories): + order_name, order_id = DET_TO_ORDER_MAP[det_name] + label_map.append((det_id, order_id)) + + sorted_label_map = sorted(label_map, key=lambda x: x[0]) + return [i[1] for i in sorted_label_map] + + +def _get_global_visual_feature(memory, spatial_shapes, level_start_index): + """ + 从 encoder 的 memory 提取全局视觉向量。 + 这里用第0层特征图的全局平均池化,形状 [bs, hidden_dim=256] + """ + bs, _, hidden_dim = memory.shape + h0, w0 = spatial_shapes[0] + memory_lvl0 = ( + memory[:, : level_start_index[1], :] + .reshape([bs, h0, w0, hidden_dim]) + .transpose([0, 3, 1, 2]) + ) + # [bs, C, H, W] -> [bs, C] + g = F.adaptive_avg_pool2d(memory_lvl0, output_size=1).flatten(1) + return g # [bs, 256] + + +class TransformerDecoder(nn.Layer): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + # 放在 TransformerDecoder.__init__ 里 + threshold_dict = { + 0: 0.50, # abstract + 1: 0.50, # algorithm + 2: 0.50, # aside_text + 3: 0.50, # chart + 4: 0.50, # content + 5: 0.40, # formula + 6: 0.40, # doc_title + 7: 0.50, # figure_title + 8: 0.50, # footer + 9: 0.50, # footer + 10: 0.50, # footnote + 11: 0.50, # formula_number + 12: 0.50, # header + 13: 0.50, # header + 14: 0.50, # image + 15: 0.40, # formula + 16: 0.50, # number + 17: 0.40, # paragraph_title + 18: 0.50, # reference + 19: 0.50, # reference_content + 20: 0.45, # seal + 21: 0.50, # table + 22: 0.40, # text + 23: 0.40, # text + 24: 0.50, # vision_footnote + } + + # 转成 tensor,shape=[25] + self.class_thresholds = paddle.to_tensor( + [threshold_dict[i] for i in range(25)], dtype="float32" + ) + + self.class_map = get_label_map() + + self.ro_mask_aug_cfg = dict( + enable=True, + prob=0.1, # 60% 的样本做 mask 增强 + mode_weights={"bernoulli": 0.5, "span": 0.4, "headtail": 0.1}, + bernoulli_p=0.12, + span_max_ratio=0.35, + headtail_max_ratio=0.25, + min_keep=2, + ) + + def _ro_semantic_map_anyshape(self, labels: paddle.Tensor) -> paddle.Tensor: + idx = paddle.cast(labels, "int64") + flat = paddle.reshape(idx, [-1]) + _RO_LABEL_MAP = paddle.to_tensor(self.class_map, dtype="int64") + map_t = paddle.to_tensor(_RO_LABEL_MAP, dtype="int64", stop_gradient=True) + out = paddle.gather(map_t, flat) + out = paddle.reshape(out, paddle.shape(idx)) + + return out + + def forward( + self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + score_head, + reading_order_predictor, + query_pos_head, + attn_mask=None, + memory_mask=None, + query_pos_head_inv_sig=False, + gt_meta=None, + ): + + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + + ref_points_detach = F.sigmoid(ref_points_unact) + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + if not query_pos_head_inv_sig: + query_pos_embed = query_pos_head(ref_points_detach) + else: + query_pos_embed = query_pos_head(inverse_sigmoid(ref_points_detach)) + + output = layer( + output, + ref_points_input, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask, + memory_mask, + query_pos_embed, + ) + + inter_ref_bbox = F.sigmoid( + bbox_head[i](output) + inverse_sigmoid(ref_points_detach) + ) + score_logits = score_head[i](output) + + if self.training: + dec_out_logits.append(score_logits) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append( + F.sigmoid( + bbox_head[i](output) + inverse_sigmoid(ref_points_detach) + ) + ) + elif i == self.eval_idx: + dec_out_logits.append(score_logits) + dec_out_bboxes.append(inter_ref_bbox) + break + + ref_points = inter_ref_bbox + ref_points_detach = ( + inter_ref_bbox.detach() if self.training else inter_ref_bbox + ) + + bs = output.shape[0] + + if self.training and gt_meta is not None and "gt_bbox" in gt_meta: + shuffled_gt_bboxes_list = [] + shuffled_gt_labels_list = [] + final_gt_read_order_list = [] + + for i in range(bs): + gt_bboxes = gt_meta["gt_bbox"][i] + + num_gt = gt_bboxes.shape[0] + + gt_read_order = gt_meta["gt_read_order"][i][:num_gt] + valid_gt_labels = gt_meta["gt_class"][i][:num_gt] + + num_gt = gt_bboxes.shape[0] + + if num_gt > 0: + shuffl_indices = paddle.randperm(num_gt) + shuffled_gt_bboxes_list.append( + paddle.gather(gt_bboxes, shuffl_indices, axis=0) + ) + final_gt_read_order_list.append( + paddle.gather(gt_read_order, shuffl_indices, axis=0) + ) + shuffled_gt_labels_list.append( + paddle.gather(valid_gt_labels, shuffl_indices, axis=0) + ) + else: + shuffled_gt_bboxes_list.append(gt_bboxes) + final_gt_read_order_list.append(gt_read_order) + shuffled_gt_labels_list.append(valid_gt_labels) + + gt_bboxes_list = shuffled_gt_bboxes_list + + batch_boxes_list = [] + for boxes_cxcywh in gt_bboxes_list: + if boxes_cxcywh.shape[0] > 0: + boxes_xyxy = ( + bbox_cxcywh_to_xyxy(boxes_cxcywh) * 1000 + ) # 存疑,gt框没有归一化 + boxes_xyxy = boxes_xyxy.clip(min=0, max=1000) + batch_boxes_list.append(boxes_xyxy.astype("int64").numpy().tolist()) + else: + batch_boxes_list.append([]) + + global_visual = _get_global_visual_feature( + memory, memory_spatial_shapes, memory_level_start_index + ) + padded_ro_logits = reading_order_predictor( + boxes_list=batch_boxes_list, + labels_list=shuffled_gt_labels_list, + global_visual=global_visual, # 仍保留 + global_memory=memory, # NEW + global_spatial_shapes=memory_spatial_shapes, # NEW + global_level_start_index=memory_level_start_index, # NEW + ) + + max_gt_len = max(len(b) for b in batch_boxes_list) + num_order_classes = padded_ro_logits.shape[-1] + + final_fg_indices = paddle.zeros([bs, max_gt_len], dtype="int64") + final_fg_masks = paddle.zeros([bs, max_gt_len], dtype="bool") + + for i in range(bs): + num_gt = len(batch_boxes_list[i]) + if num_gt > 0: + final_fg_masks[i, :num_gt] = paddle.arange(num_gt) + final_fg_masks[i, :num_gt] = True + + out_read_orders = ( + padded_ro_logits, + final_fg_indices, + final_fg_masks, + final_gt_read_order_list, + ) + + else: + raw_bboxes = paddle.stack(dec_out_bboxes)[0] # (batch_size, 300, 4) + bboxes = bbox_cxcywh_to_xyxy(raw_bboxes).astype("float32") * 1000 + bboxes = paddle.clip(bboxes, min=0.0, max=1000.0).astype("int64") + logits = paddle.stack(dec_out_logits)[0] # (batch_size, 300, 1) + + # 1. 得到每个框最大logit和对应的类别ID + probs = F.sigmoid(logits) + max_probs = paddle.max(probs, axis=-1) # (batch_size, 300) + class_ids = paddle.argmax(probs, axis=-1) # (batch_size, 300) + + # 2. 有效框mask + inline_formula_id = 15 + thresholds = paddle.index_select( + self.class_thresholds, class_ids.reshape([-1]), 0 + ) + thresholds = thresholds.reshape(class_ids.shape) + mask = ( + max_probs >= thresholds + ) # & (class_ids != inline_formula_id) # (batch_size, 300) + mask = mask.astype("int64") + + # 3. 排序,把有效框排前面(无效的自动补0) + sorted_mask = mask.sort(axis=1, descending=True) + indices = mask.argsort(axis=1, descending=True) # (batch_size, 300) + + # 4. 重排类别和boxes + sorted_class_ids = paddle.take_along_axis( + class_ids, indices, axis=1 + ) # (batch_size, 300) + sorted_boxes = paddle.take_along_axis( + bboxes, indices.unsqueeze(-1).expand(shape=[-1, -1, 4]), axis=1 + ) # (batch_size, 300, 4) + sorted_raw_boxes = paddle.take_along_axis( + raw_bboxes, indices.unsqueeze(-1).expand(shape=[-1, -1, 4]), axis=1 + ) # (batch_size, 300, 4) + sorted_logits = paddle.take_along_axis( + logits, indices.unsqueeze(-1), axis=1 + ) # (batch_size, 300, 1) + + # 5. 补0 + mask_expand = sorted_mask.unsqueeze(-1).expand( + shape=[-1, -1, 4] + ) # (batch_size, 300, 4) + + pad_boxes = sorted_boxes * mask_expand # 无效框box置零 + pad_class_ids = sorted_class_ids * sorted_mask # 无效框类别置零 + + pad_class_ids = self._ro_semantic_map_anyshape(pad_class_ids) + + order_logits = reading_order_predictor( # [B, Nq, C_order] [B, 300, 510] + boxes=pad_boxes, + labels=pad_class_ids, + mask=mask, + ) + order_logits = order_logits[:, :, :300] + return ( + sorted_raw_boxes.unsqueeze(axis=0), + sorted_logits.unsqueeze(axis=0), + order_logits, + ) + + +class PPDocLayoutTransformer(RTDETRTransformer): + def __init__( + self, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + num_decoder_points=4, + eval_idx=-1, + **kwargs + ): + kwargs.setdefault("num_classes", 25) + super(PPDocLayoutTransformer, self).__init__(**kwargs) + + decoder_layer = TransformerDecoderLayer( + self.hidden_dim, + self.nhead, + dim_feedforward, + dropout, + activation, + self.num_levels, + num_decoder_points, + ) + self.decoder = TransformerDecoder( + self.hidden_dim, decoder_layer, self.num_decoder_layers, eval_idx + ) + self.reading_order_predictor = ReadingOrderPredictor() + + def forward(self, feats, pad_mask=None, gt_meta=None, is_teacher=False): + # input projection and embedding + (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) + + # prepare denoising training + if self.training: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = ( + get_contrastive_denoising_training_group( + gt_meta, + self.num_classes, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + ) + ) + else: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = ( + None, + None, + None, + None, + ) + + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = ( + self._get_decoder_input( + memory, + spatial_shapes, + denoising_class, + denoising_bbox_unact, + is_teacher, + ) + ) + + # decoder + out_bboxes, out_logits, out_read_orders = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + self.dec_bbox_head, + self.dec_score_head, + self.reading_order_predictor, + self.query_pos_head, + attn_mask=attn_mask, + memory_mask=None, + query_pos_head_inv_sig=self.query_pos_head_inv_sig, + gt_meta=gt_meta, + ) + return ( + out_bboxes, + out_logits, + enc_topk_bboxes, + enc_topk_logits, + dn_meta, + out_read_orders, + ) diff --git a/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/reading_order_predictor.py b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/reading_order_predictor.py new file mode 100644 index 0000000000..11321183ba --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/reading_order_predictor.py @@ -0,0 +1,747 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import math + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ..rtdetrl_modules.modules.deformable_transformer import MSDeformableAttention +from ..rtdetrl_modules.modules.initializer import linear_init_ +from .roor_head_pd import GlobalPointerPD + + +def box_rel_encoding_pd( + src_boxes: paddle.Tensor, tgt_boxes: paddle.Tensor = None, eps: float = 1e-5 +): + if tgt_boxes is None: + tgt_boxes = src_boxes + assert src_boxes.shape[-1] == 4 and tgt_boxes.shape[-1] == 4 + xy1, wh1 = src_boxes[..., :2], src_boxes[..., 2:] + xy2, wh2 = tgt_boxes[..., :2], tgt_boxes[..., 2:] + delta_xy = paddle.abs(xy1.unsqueeze(-2) - xy2.unsqueeze(-3)) + delta_xy = paddle.log(delta_xy / (wh1.unsqueeze(-2) + eps) + 1.0) + delta_wh = paddle.log((wh1.unsqueeze(-2) + eps) / (wh2.unsqueeze(-3) + eps)) + pos = paddle.concat([delta_xy, delta_wh], axis=-1) + return pos + + +def get_sine_pos_embed_pd( + x: paddle.Tensor, + num_pos_feats: int, + temperature: float = 10000.0, + scale: float = 100.0, + exchange_xy: bool = False, +): + if exchange_xy and x.shape[-1] >= 2: + x = paddle.stack( + [x[..., 1], x[..., 0], *([x[..., i] for i in range(2, x.shape[-1])])], + axis=-1, + ) + + half = num_pos_feats // 2 + dim_t = temperature ** (2 * paddle.arange(half, dtype="float32") / half) + + def _encode(t: paddle.Tensor): + t = t * scale + t = t.unsqueeze(-1) / dim_t + sin = paddle.sin(t) + cos = paddle.cos(t) + return paddle.concat([sin, cos], axis=-1) + + embs = [_encode(x[..., i]) for i in range(x.shape[-1])] + out = paddle.concat(embs, axis=-1) + return out + + +class PositionRelationEmbeddingPD(nn.Layer): + def __init__( + self, + embed_dim: int, + num_heads: int, + temperature: float = 10000.0, + scale: float = 100.0, + ): + super().__init__() + in_ch = embed_dim * 4 + self.pos_proj = nn.Conv2D( + in_channels=in_ch, out_channels=num_heads, kernel_size=1 + ) + self.embed_dim = embed_dim + self.temperature = temperature + self.scale = scale + + def forward(self, src_boxes: paddle.Tensor, tgt_boxes: paddle.Tensor = None): + if tgt_boxes is None: + tgt_boxes = src_boxes + with paddle.no_grad(): + rel = box_rel_encoding_pd(src_boxes, tgt_boxes) + pos = get_sine_pos_embed_pd( + rel, + num_pos_feats=self.embed_dim, + temperature=self.temperature, + scale=self.scale, + ) + pos = pos.transpose([0, 3, 1, 2]) + out = self.pos_proj(pos) + return out + + +class MLP(nn.Layer): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.LayerList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self._reset_parameters() + + def _reset_parameters(self): + for l in self.layers: + linear_init_(l) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class LayoutLMv3SelfAttention(nn.Layer): + def __init__( + self, + hidden_size, + num_attention_heads, + dropout_prob, + has_relative_attention_bias, + has_spatial_attention_bias, + ): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention " + f"heads ({num_attention_heads})" + ) + + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(dropout_prob) + self.has_relative_attention_bias = has_relative_attention_bias + self.has_spatial_attention_bias = has_spatial_attention_bias + + def transpose_for_scores(self, x): + new_x_shape = x.shape[:-1] + [ + self.num_attention_heads, + self.attention_head_size, + ] + x = x.reshape(new_x_shape) + return x.transpose((0, 2, 1, 3)) + + def cogview_attention(self, attention_scores, alpha=32): + scaled_attention_scores = attention_scores / alpha + max_value = paddle.max(scaled_attention_scores, axis=-1, keepdim=True) + new_attention_scores = (scaled_attention_scores - max_value) * alpha + return nn.Softmax(axis=-1)(new_attention_scores) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_geom_bias=None, + ): + mixed_query_layer = self.query(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = paddle.matmul(query_layer, key_layer.transpose((0, 1, 3, 2))) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if self.has_relative_attention_bias and (rel_pos is not None): + attention_scores += rel_pos / math.sqrt(self.attention_head_size) + + if rel_geom_bias is not None: + + if rel_geom_bias.dtype != attention_scores.dtype: + rel_geom_bias = rel_geom_bias.astype(attention_scores.dtype) + attention_scores += rel_geom_bias + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask.astype("float32") + + attention_probs = self.cogview_attention(attention_scores) + attention_probs = self.dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = paddle.matmul(attention_probs, value_layer) + context_layer = context_layer.transpose((0, 2, 1, 3)) + new_context_layer_shape = context_layer.shape[:-2] + [self.all_head_size] + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + return outputs + + +class LayoutLMv3SelfOutput(nn.Layer): + def __init__(self, hidden_size, layer_norm_eps, dropout_prob): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LayoutLMv3Attention(nn.Layer): + def __init__(self, config): + super().__init__() + self.self = LayoutLMv3SelfAttention( + config["hidden_size"], + config["num_attention_heads"], + config["attention_probs_dropout_prob"], + config["has_relative_attention_bias"], + config["has_spatial_attention_bias"], + ) + self.output = LayoutLMv3SelfOutput( + config["hidden_size"], + config["layer_norm_eps"], + config["hidden_dropout_prob"], + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_geom_bias=None, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions, + rel_pos, + rel_geom_bias, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class LayoutLMv3Intermediate(nn.Layer): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = getattr(F, hidden_act) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class LayoutLMv3Output(nn.Layer): + def __init__(self, hidden_size, intermediate_size, layer_norm_eps, dropout_prob): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class LayoutLMv3Layer(nn.Layer): + def __init__(self, config): + super().__init__() + self.attention = LayoutLMv3Attention(config) + self.intermediate = LayoutLMv3Intermediate( + config["hidden_size"], config["intermediate_size"], config["hidden_act"] + ) + self.output = LayoutLMv3Output( + config["hidden_size"], + config["intermediate_size"], + config["layer_norm_eps"], + config["hidden_dropout_prob"], + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + rel_pos=None, + rel_geom_bias=None, + ): + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions, + rel_pos, + rel_geom_bias, + ) + attention_output = attention_outputs[0] + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + outputs = (layer_output,) + attention_outputs[1:] + return outputs + + +class LayoutLMv3Encoder(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.LayerList( + [LayoutLMv3Layer(config) for _ in range(config["num_hidden_layers"])] + ) + self.has_relative_attention_bias = config["has_relative_attention_bias"] + + if self.has_relative_attention_bias: + self.rel_pos_bins = config["rel_pos_bins"] + self.max_rel_pos = config["max_rel_pos"] + self.rel_pos_bias = nn.Linear( + self.rel_pos_bins, config["num_attention_heads"], bias_attr=False + ) + + self.rel_bias_module = PositionRelationEmbeddingPD( + embed_dim=16, + num_heads=config["num_attention_heads"], + temperature=10000.0, + scale=100.0, + ) + + def relative_position_bucket( + self, relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + ret = 0 + if bidirectional: + num_buckets //= 2 + ret += (relative_position > 0).astype("int64") * num_buckets + n = paddle.abs(relative_position) + else: + n = paddle.maximum(-relative_position, paddle.zeros_like(relative_position)) + max_exact = num_buckets // 2 + is_small = n < max_exact + val_if_large = max_exact + ( + paddle.log(n.astype("float32") / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).astype("int64") + val_if_large = paddle.minimum( + val_if_large, paddle.full_like(val_if_large, num_buckets - 1) + ) + ret += paddle.where(is_small, n, val_if_large) + return ret + + def _cal_1d_pos_emb(self, position_ids): + rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) + rel_pos = self.relative_position_bucket( + rel_pos_mat, num_buckets=self.rel_pos_bins, max_distance=self.max_rel_pos + ) + with paddle.no_grad(): + rel_pos = self.rel_pos_bias.weight.reshape([self.rel_pos_bins, -1])[ + rel_pos + ].transpose((0, 3, 1, 2)) + return rel_pos + + def _boxes_xyxy_to_cxcywh(self, bbox_xyxy: paddle.Tensor): + + x1, y1, x2, y2 = ( + bbox_xyxy[..., 0].astype("float32"), + bbox_xyxy[..., 1].astype("float32"), + bbox_xyxy[..., 2].astype("float32"), + bbox_xyxy[..., 3].astype("float32"), + ) + w = (x2 - x1).clip(min=1e-3) + h = (y2 - y1).clip(min=1e-3) + cx = (x1 + x2) * 0.5 + cy = (y1 + y2) * 0.5 + return paddle.stack([cx, cy, w, h], axis=-1) + + def forward( + self, + hidden_states, + bbox, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + B, L = hidden_states.shape[0], hidden_states.shape[1] + position_ids = paddle.arange(L, dtype="int64").expand((B, -1)) + + rel_pos = ( + self._cal_1d_pos_emb(position_ids) + if self.has_relative_attention_bias + else None + ) + + boxes_cxcywh = self._boxes_xyxy_to_cxcywh(bbox) + rel_geom_bias = self.rel_bias_module(boxes_cxcywh) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_head_mask = head_mask[i] if head_mask is not None else None + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos, + rel_geom_bias, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states, all_self_attentions + + +class LayoutLMv3TextEmbeddings(nn.Layer): + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config["vocab_size"], + config["hidden_size"], + padding_idx=config["pad_token_id"], + ) + self.token_type_embeddings = nn.Embedding( + config["type_vocab_size"], config["hidden_size"] + ) + + self.LayerNorm = nn.LayerNorm( + config["hidden_size"], epsilon=config["layer_norm_eps"] + ) + self.dropout = nn.Dropout(p=config["hidden_dropout_prob"]) + + self.position_ids = paddle.arange(config["max_position_embeddings"]).expand( + (1, -1) + ) + self.position_ids.stop_gradient = True + + self.padding_idx = config["pad_token_id"] + self.position_embeddings = nn.Embedding( + config["max_position_embeddings"], + config["hidden_size"], + padding_idx=self.padding_idx, + ) + + self.x_position_embeddings = nn.Embedding( + config["max_2d_position_embeddings"], config["coordinate_size"] + ) + self.y_position_embeddings = nn.Embedding( + config["max_2d_position_embeddings"], config["coordinate_size"] + ) + self.h_position_embeddings = nn.Embedding( + config["max_2d_position_embeddings"], config["shape_size"] + ) + self.w_position_embeddings = nn.Embedding( + config["max_2d_position_embeddings"], config["shape_size"] + ) + + self.spatial_embed_dim = ( + 4 * config["coordinate_size"] + 2 * config["shape_size"] + ) + self.spatial_proj = nn.Linear(self.spatial_embed_dim, config["hidden_size"]) + + def calculate_spatial_position_embeddings(self, bbox): + try: + bbox = paddle.clip(bbox, 0, 1023) + left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) + upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) + right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) + lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) + except IndexError as e: + raise IndexError( + "The `bbox` coordinate values should be within 0-1000 range." + ) from e + + h_position_embeddings = self.h_position_embeddings( + paddle.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023) + ) + w_position_embeddings = self.w_position_embeddings( + paddle.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023) + ) + + spatial_position_embeddings = paddle.concat( + [ + left_position_embeddings, + upper_position_embeddings, + right_position_embeddings, + lower_position_embeddings, + h_position_embeddings, + w_position_embeddings, + ], + axis=-1, + ) + return spatial_position_embeddings + + def create_position_ids_from_input_ids(self, input_ids, padding_idx): + mask = (input_ids != padding_idx).astype("int64") + incremental_indices = paddle.cumsum(mask, axis=1) * mask + return incremental_indices.astype("int64") + padding_idx + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + input_shape = inputs_embeds.shape[:-1] + sequence_length = input_shape[1] + position_ids = paddle.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype="int64" + ) + return position_ids.unsqueeze(0).expand(input_shape) + + def forward( + self, + input_ids=None, + bbox=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + ): + if position_ids is None: + if input_ids is not None: + position_ids = self.create_position_ids_from_input_ids( + input_ids, self.padding_idx + ) + else: + position_ids = self.create_position_ids_from_inputs_embeds( + inputs_embeds + ) + + if input_ids is not None: + input_shape = input_ids.shape + else: + input_shape = inputs_embeds.shape[:-1] + + if token_type_ids is None: + token_type_ids = paddle.zeros(input_shape, dtype="int64") + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox) + spatial_position_embeddings = self.spatial_proj(spatial_position_embeddings) + + ##跟输入序列的绝对位置都不要了 + embeddings += spatial_position_embeddings + + return embeddings + + +class LayoutLMv3ClassificationHead(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config["hidden_size"], config["hidden_size"]) + classifier_dropout = config["hidden_dropout_prob"] + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config["hidden_size"], config["num_labels"]) + + def forward(self, x): + x = self.dropout(x) + x = self.dense(x) + x = paddle.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class BoxGlobalAggregator(nn.Layer): + def __init__(self, d_model, nhead=16, n_levels=3, n_points=4): + super().__init__() + self.proj_in = nn.Linear(d_model, d_model) + self.proj_val = nn.Linear(256, d_model) + self.msda = MSDeformableAttention(d_model, nhead, n_levels, n_points, 1.0) + self.proj_out = nn.Linear(d_model, d_model) + + def forward( + self, + query_tokens, + boxes_xyxy_norm, + global_memory, + spatial_shapes, + level_start_index, + ): + bs, Lq, D = query_tokens.shape + N = boxes_xyxy_norm.shape[1] + x1, y1, x2, y2 = [boxes_xyxy_norm[:, :, i] for i in range(4)] + cx = (x1 + x2) * 0.5 + cy = (y1 + y2) * 0.5 + w = (x2 - x1).clip(min=1e-6) + h = (y2 - y1).clip(min=1e-6) + ref = paddle.stack([cx, cy, w, h], axis=-1) + n_levels = len(spatial_shapes) + ref = ref.unsqueeze(2).tile([1, 1, n_levels, 1]) + + val = self.proj_val(global_memory) + q_boxes = self.proj_in(query_tokens[:, 1 : N + 1, :]) + out_ctx = self.msda(q_boxes, ref, val, spatial_shapes, level_start_index) + out_ctx = self.proj_out(out_ctx) + + enriched = query_tokens.clone() + enriched[:, 1 : N + 1, :] = enriched[:, 1 : N + 1, :] + out_ctx + return enriched + + +class ReadingOrderPredictor(nn.Layer): + def __init__(self): + super(ReadingOrderPredictor, self).__init__() + + self.config = { + "hidden_size": 512, + "num_attention_heads": 8, + "attention_probs_dropout_prob": 0.1, + "has_relative_attention_bias": False, + "has_spatial_attention_bias": True, + "layer_norm_eps": 1e-5, + "hidden_dropout_prob": 0.1, + "intermediate_size": 2048, + "hidden_act": "gelu", + "num_hidden_layers": 6, + "rel_pos_bins": 32, + "max_rel_pos": 128, + "rel_2d_pos_bins": 64, + "max_rel_2d_pos": 256, + "num_labels": 510, + "max_position_embeddings": 514, + "max_2d_position_embeddings": 1024, + "type_vocab_size": 1, + "vocab_size": 4, + "pad_token_id": 1, + "coordinate_size": 171, + "shape_size": 170, + "num_classes": 20, + } + + self.embeddings = LayoutLMv3TextEmbeddings(self.config) + self.label_embeddings = nn.Embedding( + self.config["num_classes"], self.config["hidden_size"] + ) + self.label_features_projection = nn.Linear( + self.config["hidden_size"], self.config["hidden_size"] + ) + + self.encoder = LayoutLMv3Encoder(self.config) + self.dropout = nn.Dropout(self.config["hidden_dropout_prob"]) + + self.relative_head = GlobalPointerPD( + hidden_size=self.config["hidden_size"], + heads=1, + head_size=64, + use_rope=False, + tril_mask=True, + max_length=512, + ) + + def forward(self, boxes, labels=None, mask=None): + START_TOKEN_ID = 0 + PRED_TOKEN_ID = 3 + END_TOKEN_ID = 2 + PAD_TOKEN_ID = 1 + + batch_size, seq_len = mask.shape + + num_pred = mask.sum(axis=1) + + input_ids = paddle.full((batch_size, seq_len + 2), PAD_TOKEN_ID, dtype="int64") + + input_ids[:, 0] = START_TOKEN_ID + + pred_col_idx = paddle.arange(seq_len + 2).unsqueeze(0) + pred_mask = (pred_col_idx >= 1) & (pred_col_idx <= num_pred.unsqueeze(1)) + input_ids[pred_mask] = PRED_TOKEN_ID + + end_col_indices = num_pred + 1 + input_ids[:, end_col_indices] = END_TOKEN_ID + + pad_box = paddle.zeros( + shape=[boxes.shape[0], 1, boxes.shape[-1]], dtype=boxes.dtype + ) + pad_boxes = paddle.concat([pad_box, boxes, pad_box], axis=1).astype("int64") + bbox_embedding = self.embeddings(input_ids=input_ids, bbox=pad_boxes) + + if labels is not None: + + label_embs = self.label_embeddings(labels) + + label_proj = self.label_features_projection(label_embs).squeeze(-1) + pad = paddle.zeros( + shape=[label_proj.shape[0], 1, label_proj.shape[-1]], + dtype=label_proj.dtype, + ) + label_proj = paddle.concat([pad, label_proj, pad], axis=1) + else: + label_proj = paddle.zeros_like(bbox_embedding) + + final_embddings = bbox_embedding + label_proj + final_embddings = self.embeddings.LayerNorm(final_embddings) + final_embddings = self.embeddings.dropout(final_embddings) + + attention_mask = paddle.zeros( + shape=[mask.shape[0], mask.shape[1] + 2], dtype=mask.dtype + ) + set_ones_mask = pred_col_idx < (num_pred + 2).unsqueeze(1) + attention_mask[set_ones_mask] = 1 + attention_mask = attention_mask.astype("int64").unsqueeze(axis=[1, 2]) + attention_mask = (1.0 - attention_mask) * -1e9 + + encoder_output, _, _ = self.encoder( + hidden_states=final_embddings, bbox=pad_boxes, attention_mask=attention_mask + ) + + N_max = 300 + tok = encoder_output[:, 1 : 1 + N_max, :] + attn_1d = ( + paddle.arange(N_max)[None, :].tile([tok.shape[0], 1]) < num_pred[:, None] + ).astype("float32") + logits_bh, mask_b = self.relative_head(tok, attn_1d) + read_order_logits = logits_bh[:, 0] + return read_order_logits diff --git a/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/roor_head_pd.py b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/roor_head_pd.py new file mode 100644 index 0000000000..c52f9ce913 --- /dev/null +++ b/paddlex/inference/models/object_detection/modeling/pp_doclayout_v2_modules/roor_head_pd.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +INF = 1e4 + + +class RotaryPositionEmbeddingPD(nn.Layer): + def __init__(self, dim, max_seq_len=1024): + super().__init__() + self.dim = dim + inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2, dtype="float32") / dim)) + t = paddle.arange(max_seq_len, dtype="float32") + freqs = paddle.einsum("n,d->nd", t, inv_freq) + self.register_buffer("sin", paddle.sin(freqs), persistable=False) + self.register_buffer("cos", paddle.cos(freqs), persistable=False) + + def forward(self, x, seqlen, seq_axis=-2): + nd = x.ndim + if seq_axis < 0: + seq_axis = nd + seq_axis + if seq_axis != nd - 2: + perm = list(range(nd)) + perm[seq_axis], perm[-2] = perm[-2], perm[seq_axis] + x = x.transpose(perm) + + Dh = x.shape[-1] + x1, x2 = x[..., 0::2], x[..., 1::2] + sin = self.sin[:seqlen].reshape([1] * (nd - 2) + [seqlen, Dh // 2]) + cos = self.cos[:seqlen].reshape([1] * (nd - 2) + [seqlen, Dh // 2]) + + y1 = x1 * cos - x2 * sin + y2 = x1 * sin + x2 * cos + y = paddle.stack([y1, y2], axis=-1).reshape(x.shape) + + if seq_axis != nd - 2: + inv = list(range(nd)) + inv[seq_axis], inv[-2] = inv[-2], inv[seq_axis] + y = y.transpose(inv) + return y + + +class GlobalPointerPD(nn.Layer): + def __init__( + self, + hidden_size, + heads=1, + head_size=64, + use_rope=True, + tril_mask=False, + max_length=1024, + ): + super().__init__() + self.heads = heads + self.head_size = head_size + self.use_rope = use_rope + self.tril_mask = tril_mask + self.dense = nn.Linear(hidden_size, heads * 2 * head_size) + self.rotary = ( + RotaryPositionEmbeddingPD(head_size, max_length) if use_rope else None + ) + + def forward(self, inputs, attn_mask_1d): + B, N, _ = inputs.shape + proj = self.dense(inputs).reshape([B, N, self.heads, 2, self.head_size]) + qw, kw = proj[..., 0, :], proj[..., 1, :] + + if self.use_rope: + qw = self.rotary(qw, N, seq_axis=1) + kw = self.rotary(kw, N, seq_axis=1) + + qw_t = qw.transpose([0, 2, 1, 3]) + kw_t = kw.transpose([0, 2, 1, 3]) + logits = paddle.einsum("bhmd,bhnd->bhmn", qw_t, kw_t) / (self.head_size**0.5) + + a = attn_mask_1d.astype("float32") + pair_mask = 1.0 - (a.unsqueeze(1).unsqueeze(2) * a.unsqueeze(1).unsqueeze(3)) + logits = logits - pair_mask * INF + + if self.tril_mask: + lower = paddle.tril(paddle.ones([N, N], dtype="float32")) + lower = lower.astype("bool").unsqueeze(0).unsqueeze(0) + logits = logits - lower.astype(logits.dtype) * INF + pair_mask = paddle.logical_or(pair_mask.astype("bool"), lower) + + return logits, pair_mask.astype("bool") diff --git a/paddlex/inference/models/object_detection/predictor.py b/paddlex/inference/models/object_detection/predictor.py index c3c2125f3d..45f32d49e5 100644 --- a/paddlex/inference/models/object_detection/predictor.py +++ b/paddlex/inference/models/object_detection/predictor.py @@ -156,6 +156,17 @@ def _build(self) -> Tuple: dtype="float32", ) infer.eval() + elif self.model_name == "PP-DocLayoutV2": + from .modeling import PPDocLayoutV2 + + with TemporaryDeviceChanger(self.device): + infer = PPDocLayoutV2.from_pretrained( + self.model_dir, + use_safetensors=True, + convert_from_hf=True, + dtype="float32", + ) + infer.eval() else: raise RuntimeError( f"There is no dynamic graph implementation for model {repr(self.model_name)}." diff --git a/paddlex/inference/pipelines/pp_doctranslation/pipeline.py b/paddlex/inference/pipelines/pp_doctranslation/pipeline.py index 8b81b52095..91774244e3 100644 --- a/paddlex/inference/pipelines/pp_doctranslation/pipeline.py +++ b/paddlex/inference/pipelines/pp_doctranslation/pipeline.py @@ -488,6 +488,7 @@ def translate_func(text): "markdown_texts": target_language_texts, } ) + def concatenate_markdown_pages(self, markdown_list: list) -> tuple: """ Concatenate Markdown content from multiple pages into a single document. @@ -616,4 +617,4 @@ def concatenate_latex_pages(self, latex_info_list: list) -> tuple: "images": merged_images, "input_path": latex_info_list[0]["input_path"], } - ) \ No newline at end of file + )