diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index 33f9c560c0..b29417f045 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -973,7 +973,7 @@ def transform(self, results: Dict) -> Optional[dict]: # For single encoding, the encoded items will be directly added # into results. auxiliary_encode_kwargs = { - key: results[key] + key: results.get(key, None) for key in self.encoder.auxiliary_encode_keys } encoded = self.encoder.encode( diff --git a/mmpose/datasets/transforms/topdown_transforms.py b/mmpose/datasets/transforms/topdown_transforms.py index 3480c5b38c..c76d45e46a 100644 --- a/mmpose/datasets/transforms/topdown_transforms.py +++ b/mmpose/datasets/transforms/topdown_transforms.py @@ -126,6 +126,9 @@ def transform(self, results: Dict) -> Optional[dict]: transformed_keypoints[..., :2] = cv2.transform( results['keypoints'][..., :2], warp_mat) results['transformed_keypoints'] = transformed_keypoints + else: + results['transformed_keypoints'] = np.zeros([]) + results['keypoints_visible'] = np.ones((1, 1, 1)) results['input_size'] = (w, h) results['input_center'] = center diff --git a/projects/rtmpose3d/body3d_img2pose_demo.py b/projects/rtmpose3d/body3d_img2pose_demo.py new file mode 100644 index 0000000000..200043d7d4 --- /dev/null +++ b/projects/rtmpose3d/body3d_img2pose_demo.py @@ -0,0 +1,439 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import mimetypes +import os +import time +from argparse import ArgumentParser +from typing import List + +import cv2 +import json_tricks as json +import mmcv +import mmengine +import numpy as np +from mmengine.logging import print_log + +from mmpose.apis import inference_topdown, init_model +from mmpose.registry import VISUALIZERS +from mmpose.structures import (PoseDataSample, merge_data_samples, + split_instances) +from mmpose.utils import adapt_mmdet_pipeline +from mmpose.visualization import Pose3dLocalVisualizer +from rtmpose3d import * + +try: + from mmdet.apis import inference_detector, init_detector + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('det_config', help='Config file for detection') + parser.add_argument('det_checkpoint', help='Checkpoint file for detection') + parser.add_argument( + 'pose3d_estimator_config', + type=str, + default=None, + help='Config file for the 3D pose estimator') + parser.add_argument( + 'pose3d_estimator_checkpoint', + type=str, + default=None, + help='Checkpoint file for the 3D pose estimator') + parser.add_argument('--input', type=str, default='', help='Video path') + parser.add_argument( + '--show', + action='store_true', + default=False, + help='Whether to show visualizations') + parser.add_argument( + '--disable-rebase-keypoint', + action='store_true', + default=False, + help='Whether to disable rebasing the predicted 3D pose so its ' + 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' + 'is useful for visualization when the model do not predict the ' + 'global position of the 3D pose.') + parser.add_argument( + '--disable-norm-pose-2d', + action='store_true', + default=False, + help='Whether to scale the bbox (along with the 2D pose) to the ' + 'average bbox scale of the dataset, and move the bbox (along with the ' + '2D pose) to the average bbox center of the dataset. This is useful ' + 'when bbox is small, especially in multi-person scenarios.') + parser.add_argument( + '--num-instances', + type=int, + default=1, + help='The number of 3D poses to be visualized in every frame. If ' + 'less than 0, it will be set to the number of pose results in the ' + 'first frame.') + parser.add_argument( + '--output-root', + type=str, + default='', + help='Root of the output video file. ' + 'Default not saving the visualization video.') + parser.add_argument( + '--save-predictions', + action='store_true', + default=False, + help='Whether to save predicted results') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--det-cat-id', + type=int, + default=0, + help='Category id for bounding box detection model') + parser.add_argument( + '--bbox-thr', + type=float, + default=0.5, + help='Bounding box score threshold') + parser.add_argument('--kpt-thr', type=float, default=0.3) + parser.add_argument( + '--use-oks-tracking', action='store_true', help='Using OKS tracking') + parser.add_argument( + '--tracking-thr', type=float, default=0.3, help='Tracking threshold') + parser.add_argument( + '--show-interval', type=int, default=0, help='Sleep seconds per frame') + parser.add_argument( + '--thickness', + type=int, + default=1, + help='Link thickness for visualization') + parser.add_argument( + '--radius', + type=int, + default=3, + help='Keypoint radius for visualization') + parser.add_argument( + '--online', + action='store_true', + default=False, + help='Inference mode. If set to True, can not use future frame' + 'information when using multi frames for inference in the 2D pose' + 'detection stage. Default: False.') + + args = parser.parse_args() + return args + + +def process_one_image(args, detector, frame: np.ndarray, frame_idx: int, + pose_estimator: TopdownPoseEstimator3D, + pose_est_results_last: List[PoseDataSample], + pose_est_results_list: List[List[PoseDataSample]], + next_id: int, visualize_frame: np.ndarray, + visualizer: Pose3dLocalVisualizer): + """Visualize detected and predicted keypoints of one image. + + Pipeline of this function: + + frame + | + V + +-----------------+ + | detector | + +-----------------+ + | det_result + V + +-----------------+ + | pose_estimator | + +-----------------+ + | pose_est_results + V + +-----------------+ + | post-processing | + +-----------------+ + | pred_3d_data_samples + V + +------------+ + | visualizer | + +------------+ + + Args: + args (Argument): Custom command-line arguments. + detector (mmdet.BaseDetector): The mmdet detector. + frame (np.ndarray): The image frame read from input image or video. + frame_idx (int): The index of current frame. + pose_estimator (TopdownPoseEstimator): The pose estimator for 2d pose. + pose_est_results_last (list(PoseDataSample)): The results of pose + estimation from the last frame for tracking instances. + pose_est_results_list (list(list(PoseDataSample))): The list of all + pose estimation results converted by + ``convert_keypoint_definition`` from previous frames. In + pose-lifting stage it is used to obtain the 2d estimation sequence. + next_id (int): The next track id to be used. + pose_lifter (PoseLifter): The pose-lifter for estimating 3d pose. + visualize_frame (np.ndarray): The image for drawing the results on. + visualizer (Visualizer): The visualizer for visualizing the 2d and 3d + pose estimation results. + + Returns: + pose_est_results (list(PoseDataSample)): The pose estimation result of + the current frame. + pose_est_results_list (list(list(PoseDataSample))): The list of all + converted pose estimation results until the current frame. + pred_3d_instances (InstanceData): The result of pose-lifting. + Specifically, the predicted keypoints and scores are saved at + ``pred_3d_instances.keypoints`` and + ``pred_3d_instances.keypoint_scores``. + next_id (int): The next track id to be used. + """ + # pose_dataset = pose_estimator.cfg.test_dataloader.dataset + pose_det_dataset_name = pose_estimator.dataset_meta['dataset_name'] + + # First stage: conduct 2D pose detection in a Topdown manner + # use detector to obtain person bounding boxes + det_result = inference_detector(detector, frame) + pred_instance = det_result.pred_instances.cpu().numpy() + + # filter out the person instances with category and bbox threshold + # e.g. 0 for person in COCO + bboxes = pred_instance.bboxes + bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, + pred_instance.scores > args.bbox_thr)] + + # estimate pose results for current image + pose_est_results = inference_topdown(pose_estimator, frame, bboxes) + + # post-processing + for idx, pose_est_result in enumerate(pose_est_results): + pose_est_result.track_id = pose_est_results[idx].get('track_id', 1e4) + + pred_instances = pose_est_result.pred_instances + keypoints = pred_instances.keypoints + keypoint_scores = pred_instances.keypoint_scores + if keypoint_scores.ndim == 3: + keypoint_scores = np.squeeze(keypoint_scores, axis=1) + pose_est_results[ + idx].pred_instances.keypoint_scores = keypoint_scores + if keypoints.ndim == 4: + keypoints = np.squeeze(keypoints, axis=1) + + keypoints = -keypoints[..., [0, 2, 1]] + # keypoints[..., 0] = -keypoints[..., 0] + # keypoints[..., 2] = -keypoints[..., 2] + + # rebase height (z-axis) + if not args.disable_rebase_keypoint: + keypoints[..., 2] -= np.min( + keypoints[..., 2], axis=-1, keepdims=True) + + pose_est_results[idx].pred_instances.keypoints = keypoints + + pose_est_results = sorted( + pose_est_results, key=lambda x: x.get('track_id', 1e4)) + + pred_3d_data_samples = merge_data_samples(pose_est_results) + pred_3d_instances = pred_3d_data_samples.get('pred_instances', None) + + if args.num_instances < 0: + args.num_instances = len(pose_est_results) + + # Visualization + if visualizer is not None: + visualizer.add_datasample( + 'result', + visualize_frame, + data_sample=pred_3d_data_samples, + det_data_sample=pred_3d_data_samples, + draw_gt=False, + draw_2d=True, + dataset_2d=pose_det_dataset_name, + dataset_3d=pose_det_dataset_name, + show=args.show, + draw_bbox=True, + kpt_thr=args.kpt_thr, + convert_keypoint=False, + axis_limit=400, + axis_azimuth=70, + axis_elev=15, + num_instances=args.num_instances, + wait_time=args.show_interval, + root_index=[11, 12]) + + return pose_est_results, pose_est_results_list, pred_3d_instances, next_id + + +def main(): + assert has_mmdet, 'Please install mmdet to run the demo.' + + args = parse_args() + + assert args.show or (args.output_root != '') + assert args.input != '' + assert args.det_config is not None + assert args.det_checkpoint is not None + + detector = init_detector( + args.det_config, args.det_checkpoint, device=args.device.lower()) + detector.cfg = adapt_mmdet_pipeline(detector.cfg) + + pose_estimator = init_model( + args.pose3d_estimator_config, + args.pose3d_estimator_checkpoint, + device=args.device.lower()) + + det_kpt_color = pose_estimator.dataset_meta.get('keypoint_colors', None) + det_dataset_skeleton = pose_estimator.dataset_meta.get( + 'skeleton_links', None) + det_dataset_link_color = pose_estimator.dataset_meta.get( + 'skeleton_link_colors', None) + + pose_estimator.cfg.model.test_cfg.mode = 'simcc' + pose_estimator.cfg.visualizer.radius = args.radius + pose_estimator.cfg.visualizer.line_width = args.thickness + pose_estimator.cfg.visualizer.det_kpt_color = det_kpt_color + pose_estimator.cfg.visualizer.det_dataset_skeleton = det_dataset_skeleton + pose_estimator.cfg.visualizer.det_dataset_link_color = det_dataset_link_color # noqa: E501 + pose_estimator.cfg.visualizer.skeleton = det_dataset_skeleton + pose_estimator.cfg.visualizer.link_color = det_dataset_link_color + pose_estimator.cfg.visualizer.kpt_color = det_kpt_color + visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer) + + if args.input == 'webcam': + input_type = 'webcam' + else: + input_type = mimetypes.guess_type(args.input)[0].split('/')[0] + + if args.output_root == '': + save_output = False + else: + mmengine.mkdir_or_exist(args.output_root) + output_file = os.path.join(args.output_root, + os.path.basename(args.input)) + if args.input == 'webcam': + output_file += '.mp4' + save_output = True + + if args.save_predictions: + assert args.output_root != '' + args.pred_save_path = f'{args.output_root}/results_' \ + f'{os.path.splitext(os.path.basename(args.input))[0]}.json' + + if save_output: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + + pose_est_results_list = [] + pred_instances_list = [] + if input_type == 'image': + frame = mmcv.imread(args.input, channel_order='rgb') + _, _, pred_3d_instances, _ = process_one_image( + args=args, + detector=detector, + frame=args.input, + frame_idx=0, + pose_estimator=pose_estimator, + pose_est_results_last=[], + pose_est_results_list=pose_est_results_list, + next_id=0, + visualize_frame=frame, + visualizer=visualizer) + + if args.save_predictions: + # save prediction results + pred_instances_list = split_instances(pred_3d_instances) + + if save_output: + frame_vis = visualizer.get_image() + mmcv.imwrite(mmcv.rgb2bgr(frame_vis), output_file) + + elif input_type in ['webcam', 'video']: + next_id = 0 + pose_est_results = [] + + if args.input == 'webcam': + video = cv2.VideoCapture(0) + else: + video = cv2.VideoCapture(args.input) + + (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') + if int(major_ver) < 3: + fps = video.get(cv2.cv.CV_CAP_PROP_FPS) + else: + fps = video.get(cv2.CAP_PROP_FPS) + + video_writer = None + frame_idx = 0 + + while video.isOpened(): + success, frame = video.read() + frame_idx += 1 + + if not success: + break + + pose_est_results_last = pose_est_results + + # First stage: 2D pose detection + # make person results for current image + (pose_est_results, pose_est_results_list, pred_3d_instances, + next_id) = process_one_image( + args=args, + detector=detector, + frame=frame, + frame_idx=frame_idx, + pose_estimator=pose_estimator, + pose_est_results_last=pose_est_results_last, + pose_est_results_list=pose_est_results_list, + next_id=next_id, + visualize_frame=mmcv.bgr2rgb(frame), + visualizer=visualizer) + + if args.save_predictions: + # save prediction results + pred_instances_list.append( + dict( + frame_id=frame_idx, + instances=split_instances(pred_3d_instances))) + + if save_output: + frame_vis = visualizer.get_image() + if video_writer is None: + # the size of the image with visualization may vary + # depending on the presence of heatmaps + video_writer = cv2.VideoWriter(output_file, fourcc, fps, + (frame_vis.shape[1], + frame_vis.shape[0])) + video_writer.write(mmcv.rgb2bgr(frame_vis)) + + if args.show: + # press ESC to exit + if cv2.waitKey(5) & 0xFF == 27: + break + time.sleep(args.show_interval) + + video.release() + + if video_writer: + video_writer.release() + else: + args.save_predictions = False + raise ValueError( + f'file {os.path.basename(args.input)} has invalid format.') + + if args.save_predictions: + with open(args.pred_save_path, 'w') as f: + json.dump( + dict( + meta_info=pose_estimator.dataset_meta, + instance_info=pred_instances_list), + f, + indent='\t') + print(f'predictions have been saved at {args.pred_save_path}') + + if save_output: + input_type = input_type.replace('webcam', 'video') + print_log( + f'the output {input_type} has been saved at {output_file}', + logger='current', + level=logging.INFO) + + +if __name__ == '__main__': + main() diff --git a/projects/rtmpose3d/configs/default_runtime.py b/projects/rtmpose3d/configs/default_runtime.py new file mode 100644 index 0000000000..6f27c0345a --- /dev/null +++ b/projects/rtmpose3d/configs/default_runtime.py @@ -0,0 +1,54 @@ +default_scope = 'mmpose' + +# hooks +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=10), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='PoseVisualizationHook', enable=False), + badcase=dict( + type='BadCaseAnalysisHook', + enable=False, + out_dir='badcase', + metric_type='loss', + badcase_thr=5)) + +# custom hooks +custom_hooks = [ + # Synchronize model buffers such as running_mean and running_var in BN + # at the end of each epoch + dict(type='SyncBuffersHook') +] + +# multi-processing backend +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +# visualizer +vis_backends = [ + dict(type='LocalVisBackend'), + # dict(type='TensorboardVisBackend'), + # dict(type='WandbVisBackend'), +] +visualizer = dict( + type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# logger +log_processor = dict( + type='LogProcessor', window_size=50, by_epoch=True, num_digits=6) +log_level = 'INFO' +load_from = None +resume = False + +# file I/O backend +backend_args = dict(backend='local') + +# training/validation/testing progress +train_cfg = dict(by_epoch=True) +val_cfg = dict() +test_cfg = dict() diff --git a/projects/rtmpose3d/configs/rtmdet_m_640-8xb32_coco-person.py b/projects/rtmpose3d/configs/rtmdet_m_640-8xb32_coco-person.py new file mode 100644 index 0000000000..620de8dc8f --- /dev/null +++ b/projects/rtmpose3d/configs/rtmdet_m_640-8xb32_coco-person.py @@ -0,0 +1,20 @@ +_base_ = 'mmdet::rtmdet/rtmdet_m_8xb32-300e_coco.py' + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' # noqa + +model = dict( + backbone=dict( + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + bbox_head=dict(num_classes=1), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + +train_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', )))) + +val_dataloader = dict(dataset=dict(metainfo=dict(classes=('person', )))) +test_dataloader = val_dataloader diff --git a/projects/rtmpose3d/configs/rtmw3d-l_8xb64_cocktail14-384x288.py b/projects/rtmpose3d/configs/rtmw3d-l_8xb64_cocktail14-384x288.py new file mode 100644 index 0000000000..832742788d --- /dev/null +++ b/projects/rtmpose3d/configs/rtmw3d-l_8xb64_cocktail14-384x288.py @@ -0,0 +1,706 @@ +_base_ = ['./default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +max_epochs = 270 +stage2_num_epochs = 10 +base_lr = 5e-4 +num_keypoints = 133 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=2024) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=4096) + +# codec settings +codec = dict( + type='SimCC3DLabel', + input_size=(288, 384, 288), + sigma=(6., 6.93, 6.), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False, + root_index=(11, 12)) + +# model settings +model = dict( + type='TopdownPoseEstimator3D', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + channel_attention=True, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='checkpoints/rtmpose-l_simcc-ucoco_dw-ucoco_270e-256x192-4d6dfc62_20230728.pth' # noqa + )), + neck=dict( + type='CSPNeXtPAFPN', + in_channels=[256, 512, 1024], + out_channels=None, + out_indices=( + 1, + 2, + ), + num_csp_blocks=2, + expand_ratio=0.5, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + head=dict( + type='RTMW3DHead', + in_channels=1024, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=tuple([s // 32 for s in codec['input_size']]), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0.1, + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=[ + dict( + type='KLDiscretLoss2', + use_target_weight=True, + beta=10., + label_softmax=True), + dict( + type='BoneLoss', + joint_parents=[0, 1, 2, 3, 4, 5, 6, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 50, 50, 51, 52, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 91, 92, 93, 94, 91, 96, 97, 98, 91, 100, 101, 102, 91, 104, 105, 106, 91, 108, 109, 110, 8, 112, 113, 114, 113, 112, 117, 118, 117, 112, 121, 122, 123, 112, 125, 126, 127, 112, 129, 130, 131], + use_target_weight=True, + loss_weight=2.0 + ) + ], + decoder=codec), + # test_cfg=dict(flip_test=False, mode='2d') + test_cfg=dict(flip_test=False) +) + +# base dataset settings +data_mode = 'topdown' +dataset_type = 'H36MWholeBodyDataset' +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='RandomBackground', + bg_dir='/mnt/data/oss_beijing/mmseg/obj365v1_images', + bg_prob=0.5, + ), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=(288, 384)), + dict(type='YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=(288, 384)), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.5, 1.5], + rotate_factor=90), + dict(type='TopdownAffine', input_size=(288, 384)), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +# h3wb dataset +h3wb_dataset = dict( + type='H36MWholeBodyDataset', + ann_file='annotation_body3d/h3wb_train_bbox.npz', + seq_len=1, + causal=True, + data_root='data/h36m/', + data_prefix=dict(img='images/'), + test_mode=False, + pipeline=[]) + + +# dna rendering dataset +dna_rendering_dataset = dict( + type='DNARenderingDataset', + data_root='data/dna_rendering_part1', + data_mode='topdown', + ann_file='instances.npz', + subset_frac=0.1, + pipeline=[ + dict(type='LoadMask', backend_args=backend_args) + ], +) + +# mapping + +aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12), + (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)] + +crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11), + (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)] + +mpii_coco133 = [ + (0, 16), + (1, 14), + (2, 12), + (3, 11), + (4, 13), + (5, 15), + (10, 10), + (11, 8), + (12, 6), + (13, 5), + (14, 7), + (15, 9), +] + +jhmdb_coco133 = [ + (3, 6), + (4, 5), + (5, 12), + (6, 11), + (7, 8), + (8, 7), + (9, 14), + (10, 13), + (11, 10), + (12, 9), + (13, 16), + (14, 15), +] + +halpe_coco133 = [(i, i) + for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21), + (24, 19), + (25, 22)] + [(i, i - 3) + for i in range(26, 136)] + +posetrack_coco133 = [ + (0, 0), + (3, 3), + (4, 4), + (5, 5), + (6, 6), + (7, 7), + (8, 8), + (9, 9), + (10, 10), + (11, 11), + (12, 12), + (13, 13), + (14, 14), + (15, 15), + (16, 16), +] + +humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120), + (19, 17), (20, 20)] + +data_mode = 'topdown' +data_root = 'data/' + +# train datasets +dataset_coco = dict( + type='CocoWholeBodyDataset', + data_root='data/coco/', + data_mode='topdown', + ann_file='annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='train2017/'), + pipeline=[], +) + +dataset_aic = dict( + type='AicDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='aic/annotations/aic_train.json', + data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint' + '_train_20170902/keypoint_train_images_20170902/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=aic_coco133) + ], +) + +dataset_crowdpose = dict( + type='CrowdPoseDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json', + data_prefix=dict(img='pose/CrowdPose/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=crowdpose_coco133) + ], +) + +dataset_mpii = dict( + type='MpiiDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='mpii/annotations/mpii_train.json', + data_prefix=dict(img='pose/MPI/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=mpii_coco133) + ], +) + +dataset_jhmdb = dict( + type='JhmdbDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='jhmdb/annotations/Sub1_train.json', + data_prefix=dict(img='pose/JHMDB/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=jhmdb_coco133) + ], +) + +dataset_halpe = dict( + type='HalpeDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='halpe/annotations/halpe_train_v1.json', + data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=halpe_coco133) + ], +) + +dataset_posetrack = dict( + type='PoseTrack18Dataset', + data_root=data_root, + data_mode=data_mode, + ann_file='posetrack18/annotations/posetrack18_train.json', + data_prefix=dict(img='pose/PoseChallenge2018/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=posetrack_coco133) + ], +) + +dataset_humanart = dict( + type='HumanArt21Dataset', + data_root=data_root, + data_mode=data_mode, + ann_file='HumanArt/annotations/training_humanart.json', + filter_cfg=dict(scenes=['real_human']), + data_prefix=dict(img='pose/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=humanart_coco133) + ]) + +face_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale', padding=1.25), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[1.5, 2.0], + rotate_factor=0), +] + +wflw_coco133 = [(i * 2, 23 + i) + for i in range(17)] + [(33 + i, 40 + i) for i in range(5)] + [ + (42 + i, 45 + i) for i in range(5) + ] + [(51 + i, 50 + i) + for i in range(9)] + [(60, 59), (61, 60), (63, 61), + (64, 62), (65, 63), (67, 64), + (68, 65), (69, 66), (71, 67), + (72, 68), (73, 69), + (75, 70)] + [(76 + i, 71 + i) + for i in range(20)] +dataset_wflw = dict( + type='WFLWDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='wflw/annotations/face_landmarks_wflw_train.json', + data_prefix=dict(img='pose/WFLW/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=wflw_coco133), *face_pipeline + ], +) + +mapping_300w_coco133 = [(i, 23 + i) for i in range(68)] +dataset_300w = dict( + type='Face300WDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='300w/annotations/face_landmarks_300w_train.json', + data_prefix=dict(img='pose/300w/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=mapping_300w_coco133), *face_pipeline + ], +) + +cofw_coco133 = [(0, 40), (2, 44), (4, 42), (1, 49), (3, 45), (6, 47), (8, 59), + (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53), + (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89), + (27, 80), (28, 31)] +dataset_cofw = dict( + type='COFWDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='cofw/annotations/cofw_train.json', + data_prefix=dict(img='pose/COFW/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=cofw_coco133), *face_pipeline + ], +) + +lapa_coco133 = [(i * 2, 23 + i) for i in range(17)] + [ + (33 + i, 40 + i) for i in range(5) +] + [(42 + i, 45 + i) for i in range(5)] + [ + (51 + i, 50 + i) for i in range(4) +] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61), + (70, 62), (71, 63), (73, 64), + (75, 65), (76, 66), (78, 67), + (79, 68), (80, 69), + (82, 70)] + [(84 + i, 71 + i) + for i in range(20)] +dataset_lapa = dict( + type='LapaDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='LaPa/annotations/lapa_trainval.json', + data_prefix=dict(img='pose/LaPa/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=lapa_coco133), *face_pipeline + ], +) + +dataset_wb = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_coco, + dataset_halpe + ], + pipeline=[], + test_mode=False, +) + +dataset_body = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_aic, + dataset_crowdpose, + dataset_mpii, + dataset_jhmdb, + dataset_posetrack, + # dataset_humanart, + ], + pipeline=[], + test_mode=False, +) + +dataset_face = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_wflw, + dataset_300w, + dataset_cofw, + dataset_lapa, + ], + pipeline=[], + test_mode=False, +) + +hand_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[1.5, 2.0], + rotate_factor=0), +] + +interhand_left = [(21, 95), (22, 94), (23, 93), (24, 92), (25, 99), (26, 98), + (27, 97), (28, 96), (29, 103), (30, 102), (31, 101), + (32, 100), (33, 107), (34, 106), (35, 105), (36, 104), + (37, 111), (38, 110), (39, 109), (40, 108), (41, 91)] +interhand_right = [(i - 21, j + 21) for i, j in interhand_left] +interhand_coco133 = interhand_right + interhand_left + +dataset_interhand2d = dict( + type='InterHand2DDoubleDataset', + data_root='data/interhand2.6m/', + data_mode='topdown', + ann_file='annotations/all/InterHand2.6M_train_data.json', + camera_param_file='annotations/all/InterHand2.6M_train_camera.json', + joint_file='annotations/all/InterHand2.6M_train_joint_3d.json', + data_prefix=dict(img='images/train/'), + sample_interval=10, + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=interhand_coco133, + ), *hand_pipeline + ], +) + +dataset_interhand3d = dict( + type='InterHand3DDataset', + data_root='data/interhand2.6m/', + data_mode='topdown', + ann_file='annotations/all/InterHand2.6M_train_data.json', + camera_param_file='annotations/all/InterHand2.6M_train_camera.json', + joint_file='annotations/all/InterHand2.6M_train_joint_3d.json', + use_gt_root_depth=True, + rootnet_result_file=None, + data_prefix=dict(img='images/train/'), + sample_interval=10, + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=interhand_coco133, + ), *hand_pipeline + ], +) + +dataset_hand = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[dataset_interhand3d], + pipeline=[], + test_mode=False, +) + + +# ubody dataset +scenes = [ + 'Magic_show', + 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', + 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', + 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference' +] +ubody_datasets = [] +for scene in scenes: + train_ann = f'annotations/{scene}/train_3dkeypoint_annotation.json' + ubody = dict( + type='UBody3dDataset', + data_root='data/UBody/', + ann_file=train_ann, + data_mode='topdown', + causal=True, + seq_len=1, + data_prefix=dict(img='images/'), + subset_frac=0.1, + pipeline=[]) + ubody_datasets.append(ubody) + + +train_datasets = [ + dataset_wb, + dataset_body, + dataset_face, + # dataset_hand, + *ubody_datasets, + h3wb_dataset, + # dna_rendering_dataset +] + + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + datasets=train_datasets, + pipeline=train_pipeline, + metainfo=dict(from_file='configs/_base_/datasets/h3wb.py'), + test_mode=False)) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1)) + +# hooks +# default_hooks = dict( +# checkpoint=dict( +# save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# eval h3wb +# val_dataloader = dict( +# batch_size=64, +# num_workers=10, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), +# dataset=dict( +# type='H36MWholeBodyDataset', +# ann_file='annotation_body3d/h3wb_train_bbox.npz', +# seq_len=1, +# causal=True, +# data_root='data/h36m/', +# data_prefix=dict(img='images/'), +# test_mode=True, +# pipeline=val_pipeline)) +# test_dataloader = val_dataloader + +# # evaluators +# val_evaluator = [ +# dict(type='SimpleMPJPE', mode='mpjpe'), +# dict(type='SimpleMPJPE', mode='p-mpjpe') +# ] +# test_evaluator = val_evaluator + +# eval coco +val_dataloader = dict( + batch_size=64, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type='CocoWholeBodyDataset', + data_root='data/coco/', + data_mode='topdown', + ann_file='annotations/coco_wholebody_val_v1.0.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + bbox_file='data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json', + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file='data/coco/' + 'annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/projects/rtmpose3d/configs/rtmw3d-x_8xb64_cocktail14-384x288.py b/projects/rtmpose3d/configs/rtmw3d-x_8xb64_cocktail14-384x288.py new file mode 100644 index 0000000000..3a822f50b8 --- /dev/null +++ b/projects/rtmpose3d/configs/rtmw3d-x_8xb64_cocktail14-384x288.py @@ -0,0 +1,706 @@ +_base_ = ['../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +max_epochs = 270 +stage2_num_epochs = 10 +base_lr = 5e-4 +num_keypoints = 133 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=2024) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=4096) + +# codec settings +codec = dict( + type='SimCC3DLabel', + input_size=(288, 384, 288), + sigma=(6., 6.93, 6.), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False, + root_index=(11, 12)) + +# model settings +model = dict( + type='TopdownPoseEstimator3D', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1.33, + widen_factor=1.25, + channel_attention=True, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='checkpoints/rtmpose-x_simcc-ucoco_pt-aic-coco_270e-384x288-f5b50679_20230822.pth' # noqa + )), + neck=dict( + type='CSPNeXtPAFPN', + in_channels=[320, 640, 1280], + out_channels=None, + out_indices=( + 1, + 2, + ), + num_csp_blocks=2, + expand_ratio=0.5, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + head=dict( + type='RTMW3DHead', + in_channels=1280, + out_channels=133, + input_size=codec['input_size'], + in_featuremap_size=tuple([s // 32 for s in codec['input_size']]), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=[ + dict( + type='KLDiscretLoss2', + use_target_weight=True, + beta=10., + label_softmax=True), + dict( + type='BoneLoss', + joint_parents=[0, 1, 2, 3, 4, 5, 6, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 50, 50, 51, 52, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 91, 92, 93, 94, 91, 96, 97, 98, 91, 100, 101, 102, 91, 104, 105, 106, 91, 108, 109, 110, 8, 112, 113, 114, 113, 112, 117, 118, 117, 112, 121, 122, 123, 112, 125, 126, 127, 112, 129, 130, 131], + use_target_weight=True, + loss_weight=2.0 + ) + ], + decoder=codec), + test_cfg=dict(flip_test=False, mode='2d') + # test_cfg=dict(flip_test=False) +) + +# base dataset settings +data_mode = 'topdown' + +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='RandomBackground', + bg_dir='/mnt/data/oss_beijing/mmseg/obj365v1_images', + bg_prob=0.5, + ), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80), + dict(type='TopdownAffine', input_size=(288, 384)), + dict(type='YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=1.0), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=(288, 384)), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[0.5, 1.5], + rotate_factor=90), + dict(type='TopdownAffine', input_size=(288, 384)), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + ]), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] + +# h3wb dataset +h3wb_dataset = dict( + type='H36MWholeBodyDataset', + ann_file='annotation_body3d/h3wb_train_bbox.npz', + seq_len=1, + causal=True, + data_root='data/h36m/', + data_prefix=dict(img='images/'), + test_mode=False, + pipeline=[]) + + +# dna rendering dataset +dna_rendering_dataset = dict( + type='DNARenderingDataset', + data_root='data/dna_rendering_part1', + data_mode='topdown', + ann_file='instances.npz', + subset_frac=0.1, + pipeline=[ + dict(type='LoadMask', backend_args=backend_args) + ], +) + +# mapping + +aic_coco133 = [(0, 6), (1, 8), (2, 10), (3, 5), (4, 7), (5, 9), (6, 12), + (7, 14), (8, 16), (9, 11), (10, 13), (11, 15)] + +crowdpose_coco133 = [(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (6, 11), + (7, 12), (8, 13), (9, 14), (10, 15), (11, 16)] + +mpii_coco133 = [ + (0, 16), + (1, 14), + (2, 12), + (3, 11), + (4, 13), + (5, 15), + (10, 10), + (11, 8), + (12, 6), + (13, 5), + (14, 7), + (15, 9), +] + +jhmdb_coco133 = [ + (3, 6), + (4, 5), + (5, 12), + (6, 11), + (7, 8), + (8, 7), + (9, 14), + (10, 13), + (11, 10), + (12, 9), + (13, 16), + (14, 15), +] + +halpe_coco133 = [(i, i) + for i in range(17)] + [(20, 17), (21, 20), (22, 18), (23, 21), + (24, 19), + (25, 22)] + [(i, i - 3) + for i in range(26, 136)] + +posetrack_coco133 = [ + (0, 0), + (3, 3), + (4, 4), + (5, 5), + (6, 6), + (7, 7), + (8, 8), + (9, 9), + (10, 10), + (11, 11), + (12, 12), + (13, 13), + (14, 14), + (15, 15), + (16, 16), +] + +humanart_coco133 = [(i, i) for i in range(17)] + [(17, 99), (18, 120), + (19, 17), (20, 20)] + +data_mode = 'topdown' +data_root = 'data/' + +# train datasets +dataset_coco = dict( + type='CocoWholeBodyDataset', + data_root='data/coco/', + data_mode='topdown', + ann_file='annotations/coco_wholebody_train_v1.0.json', + data_prefix=dict(img='train2017/'), + pipeline=[], +) + +dataset_aic = dict( + type='AicDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='aic/annotations/aic_train.json', + data_prefix=dict(img='pose/ai_challenge/ai_challenger_keypoint' + '_train_20170902/keypoint_train_images_20170902/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=aic_coco133) + ], +) + +dataset_crowdpose = dict( + type='CrowdPoseDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='crowdpose/annotations/mmpose_crowdpose_trainval.json', + data_prefix=dict(img='pose/CrowdPose/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=crowdpose_coco133) + ], +) + +dataset_mpii = dict( + type='MpiiDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='mpii/annotations/mpii_train.json', + data_prefix=dict(img='pose/MPI/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=mpii_coco133) + ], +) + +dataset_jhmdb = dict( + type='JhmdbDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='jhmdb/annotations/Sub1_train.json', + data_prefix=dict(img='pose/JHMDB/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=jhmdb_coco133) + ], +) + +dataset_halpe = dict( + type='HalpeDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='halpe/annotations/halpe_train_v1.json', + data_prefix=dict(img='pose/Halpe/hico_20160224_det/images/train2015'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=halpe_coco133) + ], +) + +dataset_posetrack = dict( + type='PoseTrack18Dataset', + data_root=data_root, + data_mode=data_mode, + ann_file='posetrack18/annotations/posetrack18_train.json', + data_prefix=dict(img='pose/PoseChallenge2018/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=posetrack_coco133) + ], +) + +dataset_humanart = dict( + type='HumanArt21Dataset', + data_root=data_root, + data_mode=data_mode, + ann_file='HumanArt/annotations/training_humanart.json', + filter_cfg=dict(scenes=['real_human']), + data_prefix=dict(img='pose/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=humanart_coco133) + ]) + +face_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale', padding=1.25), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[1.5, 2.0], + rotate_factor=0), +] + +wflw_coco133 = [(i * 2, 23 + i) + for i in range(17)] + [(33 + i, 40 + i) for i in range(5)] + [ + (42 + i, 45 + i) for i in range(5) + ] + [(51 + i, 50 + i) + for i in range(9)] + [(60, 59), (61, 60), (63, 61), + (64, 62), (65, 63), (67, 64), + (68, 65), (69, 66), (71, 67), + (72, 68), (73, 69), + (75, 70)] + [(76 + i, 71 + i) + for i in range(20)] +dataset_wflw = dict( + type='WFLWDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='wflw/annotations/face_landmarks_wflw_train.json', + data_prefix=dict(img='pose/WFLW/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=wflw_coco133), *face_pipeline + ], +) + +mapping_300w_coco133 = [(i, 23 + i) for i in range(68)] +dataset_300w = dict( + type='Face300WDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='300w/annotations/face_landmarks_300w_train.json', + data_prefix=dict(img='pose/300w/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=mapping_300w_coco133), *face_pipeline + ], +) + +cofw_coco133 = [(0, 40), (2, 44), (4, 42), (1, 49), (3, 45), (6, 47), (8, 59), + (10, 62), (9, 68), (11, 65), (18, 54), (19, 58), (20, 53), + (21, 56), (22, 71), (23, 77), (24, 74), (25, 85), (26, 89), + (27, 80), (28, 31)] +dataset_cofw = dict( + type='COFWDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='cofw/annotations/cofw_train.json', + data_prefix=dict(img='pose/COFW/images/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=cofw_coco133), *face_pipeline + ], +) + +lapa_coco133 = [(i * 2, 23 + i) for i in range(17)] + [ + (33 + i, 40 + i) for i in range(5) +] + [(42 + i, 45 + i) for i in range(5)] + [ + (51 + i, 50 + i) for i in range(4) +] + [(58 + i, 54 + i) for i in range(5)] + [(66, 59), (67, 60), (69, 61), + (70, 62), (71, 63), (73, 64), + (75, 65), (76, 66), (78, 67), + (79, 68), (80, 69), + (82, 70)] + [(84 + i, 71 + i) + for i in range(20)] +dataset_lapa = dict( + type='LapaDataset', + data_root=data_root, + data_mode=data_mode, + ann_file='LaPa/annotations/lapa_trainval.json', + data_prefix=dict(img='pose/LaPa/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=lapa_coco133), *face_pipeline + ], +) + +dataset_wb = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_coco, + dataset_halpe + ], + pipeline=[], + test_mode=False, +) + +dataset_body = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_aic, + dataset_crowdpose, + dataset_mpii, + dataset_jhmdb, + dataset_posetrack, + # dataset_humanart, + ], + pipeline=[], + test_mode=False, +) + +dataset_face = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[ + dataset_wflw, + dataset_300w, + dataset_cofw, + dataset_lapa, + ], + pipeline=[], + test_mode=False, +) + +hand_pipeline = [ + dict(type='LoadImage', backend_args=backend_args), + dict(type='GetBBoxCenterScale'), + dict( + type='RandomBBoxTransform', + shift_factor=0., + scale_factor=[1.5, 2.0], + rotate_factor=0), +] + +interhand_left = [(21, 95), (22, 94), (23, 93), (24, 92), (25, 99), (26, 98), + (27, 97), (28, 96), (29, 103), (30, 102), (31, 101), + (32, 100), (33, 107), (34, 106), (35, 105), (36, 104), + (37, 111), (38, 110), (39, 109), (40, 108), (41, 91)] +interhand_right = [(i - 21, j + 21) for i, j in interhand_left] +interhand_coco133 = interhand_right + interhand_left + +dataset_interhand2d = dict( + type='InterHand2DDoubleDataset', + data_root='data/interhand2.6m/', + data_mode='topdown', + ann_file='annotations/all/InterHand2.6M_train_data.json', + camera_param_file='annotations/all/InterHand2.6M_train_camera.json', + joint_file='annotations/all/InterHand2.6M_train_joint_3d.json', + data_prefix=dict(img='images/train/'), + sample_interval=10, + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=interhand_coco133, + ), *hand_pipeline + ], +) + +dataset_interhand3d = dict( + type='InterHand3DDataset', + data_root='data/interhand2.6m/', + data_mode='topdown', + ann_file='annotations/all/InterHand2.6M_train_data.json', + camera_param_file='annotations/all/InterHand2.6M_train_camera.json', + joint_file='annotations/all/InterHand2.6M_train_joint_3d.json', + use_gt_root_depth=True, + rootnet_result_file=None, + data_prefix=dict(img='images/train/'), + sample_interval=10, + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=num_keypoints, + mapping=interhand_coco133, + ), *hand_pipeline + ], +) + +dataset_hand = dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'), + datasets=[dataset_interhand3d], + pipeline=[], + test_mode=False, +) + + +# ubody dataset +scenes = [ + 'Magic_show', + 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow', + 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing', + 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference' +] +ubody_datasets = [] +for scene in scenes: + train_ann = f'annotations/{scene}/train_3dkeypoint_annotation.json' + ubody = dict( + type='UBody3dDataset', + data_root='data/UBody/', + ann_file=train_ann, + data_mode='topdown', + causal=True, + seq_len=1, + data_prefix=dict(img='images/'), + subset_frac=0.1, + pipeline=[]) + ubody_datasets.append(ubody) + + +train_datasets = [ + dataset_wb, + dataset_body, + dataset_face, + dataset_hand, + *ubody_datasets, + h3wb_dataset, + # dna_rendering_dataset +] + + +# data loaders +train_dataloader = dict( + batch_size=32, + num_workers=10, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + datasets=train_datasets, + pipeline=train_pipeline, + metainfo=dict(from_file='configs/_base_/datasets/h3wb.py'), + test_mode=False)) + +# hooks +# default_hooks = dict( +# checkpoint=dict( +# type='CheckpointHook', +# save_best='MPJPE', +# rule='less', +# max_keep_ckpts=1)) + +# hooks +default_hooks = dict( + checkpoint=dict( + save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +# eval h3wb +# val_dataloader = dict( +# batch_size=64, +# num_workers=10, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), +# dataset=dict( +# type='H36MWholeBodyDataset', +# ann_file='annotation_body3d/h3wb_train_bbox.npz', +# seq_len=1, +# causal=True, +# data_root='data/h36m/', +# data_prefix=dict(img='images/'), +# test_mode=True, +# pipeline=val_pipeline)) +# test_dataloader = val_dataloader + +# # evaluators +# val_evaluator = [ +# dict(type='SimpleMPJPE', mode='mpjpe'), +# dict(type='SimpleMPJPE', mode='p-mpjpe') +# ] +# test_evaluator = val_evaluator + +# eval coco +val_dataloader = dict( + batch_size=64, + num_workers=10, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type='CocoWholeBodyDataset', + data_root='data/coco/', + data_mode='topdown', + ann_file='annotations/coco_wholebody_val_v1.0.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + bbox_file='data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json', + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoWholeBodyMetric', + ann_file='data/coco/' + 'annotations/coco_wholebody_val_v1.0.json') +test_evaluator = val_evaluator diff --git a/projects/rtmpose3d/rtmpose3d/__init__.py b/projects/rtmpose3d/rtmpose3d/__init__.py new file mode 100644 index 0000000000..eec926b2c8 --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/__init__.py @@ -0,0 +1,6 @@ +from .pose_estimator import TopdownPoseEstimator3D +from .rtmw3d_head import RTMW3DHead +from .simcc_3d_label import SimCC3DLabel +from .loss import KLDiscretLoss2 + +__all__ = ['TopdownPoseEstimator3D', 'RTMW3DHead', 'SimCC3DLabel', 'KLDiscretLoss2'] diff --git a/projects/rtmpose3d/rtmpose3d/loss.py b/projects/rtmpose3d/rtmpose3d/loss.py new file mode 100644 index 0000000000..499befa5a0 --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/loss.py @@ -0,0 +1,37 @@ +from mmpose.registry import MODELS +from mmpose.models.losses import KLDiscretLoss + +@MODELS.register_module() +class KLDiscretLoss2(KLDiscretLoss): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._loss_name = 'loss_kld' + + def forward(self, pred_simcc, gt_simcc, target_weight): + N, K, _ = pred_simcc[0].shape + loss = 0 + + for pred, target, weight in zip(pred_simcc, gt_simcc, target_weight): + pred = pred.reshape(-1, pred.size(-1)) + target = target.reshape(-1, target.size(-1)) + weight = weight.reshape(-1) + + t_loss = self.criterion(pred, target).mul(weight) + + if self.mask is not None: + t_loss = t_loss.reshape(N, K) + t_loss[:, self.mask] = t_loss[:, self.mask] * self.mask_weight + + loss = loss + t_loss.sum() + + return loss / K + + @property + def loss_name(self): + """Loss Name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name \ No newline at end of file diff --git a/projects/rtmpose3d/rtmpose3d/pose_estimator.py b/projects/rtmpose3d/rtmpose3d/pose_estimator.py new file mode 100644 index 0000000000..6854205b4b --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/pose_estimator.py @@ -0,0 +1,116 @@ +from itertools import zip_longest +from typing import Optional + +import numpy as np + +from mmpose.utils.typing import InstanceList, PixelDataList, SampleList +from mmpose.registry import MODELS +from mmpose.models.pose_estimators import TopdownPoseEstimator + +@MODELS.register_module() +class TopdownPoseEstimator3D(TopdownPoseEstimator): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.camera_param = { + 'c': [512.54150496, 515.45148698], + 'f': [1145.04940459, 1143.78109572], + } + + def add_pred_to_datasample(self, batch_pred_instances: InstanceList, + batch_pred_fields: Optional[PixelDataList], + batch_data_samples: SampleList) -> SampleList: + """Add predictions into data samples. + + Args: + batch_pred_instances (List[InstanceData]): The predicted instances + of the input data batch + batch_pred_fields (List[PixelData], optional): The predicted + fields (e.g. heatmaps) of the input batch + batch_data_samples (List[PoseDataSample]): The input data batch + + Returns: + List[PoseDataSample]: A list of data samples where the predictions + are stored in the ``pred_instances`` field of each data sample. + """ + assert len(batch_pred_instances) == len(batch_data_samples) + if batch_pred_fields is None: + batch_pred_fields = [] + output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', + None) + mode = self.test_cfg.get('mode', '3d') + assert mode in ['2d', '3d', 'vis', 'simcc'] + for pred_instances, pred_fields, data_sample in zip_longest( + batch_pred_instances, batch_pred_fields, batch_data_samples): + + gt_instances = data_sample.gt_instances + + # convert keypoint coordinates from input space to image space + input_center = data_sample.metainfo['input_center'] + input_scale = data_sample.metainfo['input_scale'] + input_size = data_sample.metainfo['input_size'] + keypoints_3d = pred_instances.keypoints + keypoints_2d = pred_instances.keypoints_2d + keypoints_simcc = pred_instances.keypoints_simcc + keypoints_2d = keypoints_2d / input_size * input_scale \ + + input_center - 0.5 * input_scale + + if gt_instances.get('camera_params', None) is not None: + camera_params = gt_instances.camera_params[0] + f = np.array(camera_params['f']) + c = np.array(camera_params['c']) + else: + f = np.array([1145.04940459, 1143.78109572]) + c = np.array(data_sample.ori_shape) + + kpts_pixel = np.concatenate([ + keypoints_2d, + (keypoints_3d[..., 2] + gt_instances.root_z)[..., None] + ], + axis=-1) + kpts_cam = kpts_pixel.copy() + kpts_cam[..., :2] = (kpts_pixel[..., :2] - c) / f * kpts_pixel[..., + 2:] + if mode == '3d': + pred_instances.keypoints = kpts_cam + pred_instances.transformed_keypoints = keypoints_2d + elif mode == 'vis': + pred_instances.keypoints = keypoints_3d + pred_instances.transformed_keypoints = keypoints_2d + elif mode == 'simcc': + pred_instances.keypoints = keypoints_simcc + pred_instances.transformed_keypoints = keypoints_2d + else: + pred_instances.keypoints = keypoints_2d + pred_instances.transformed_keypoints = keypoints_2d + + if 'keypoints_visible' not in pred_instances: + pred_instances.keypoints_visible = \ + pred_instances.keypoint_scores + + if output_keypoint_indices is not None: + # select output keypoints with given indices + num_keypoints = pred_instances.keypoints.shape[1] + for key, value in pred_instances.all_items(): + if key.startswith('keypoint'): + pred_instances.set_field( + value[:, output_keypoint_indices], key) + + # add bbox information into pred_instances + pred_instances.bboxes = gt_instances.bboxes + pred_instances.bbox_scores = gt_instances.bbox_scores + + data_sample.pred_instances = pred_instances + + if pred_fields is not None: + if output_keypoint_indices is not None: + # select output heatmap channels with keypoint indices + # when the number of heatmap channel matches num_keypoints + for key, value in pred_fields.all_items(): + if value.shape[0] != num_keypoints: + continue + pred_fields.set_field(value[output_keypoint_indices], + key) + data_sample.pred_fields = pred_fields + + return batch_data_samples diff --git a/projects/rtmpose3d/rtmpose3d/rtmw3d_head.py b/projects/rtmpose3d/rtmpose3d/rtmw3d_head.py new file mode 100644 index 0000000000..bbf6bd2b48 --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/rtmw3d_head.py @@ -0,0 +1,444 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from mmcv.cnn import ConvModule +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from mmpose.codecs.utils import get_simcc_maximum as get_2d_simcc_maximum +from mmpose.evaluation.functional import keypoint_mpjpe +from mmpose.models.utils.rtmcc_block import RTMCCBlock, ScaleNorm +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, + OptSampleList) +from mmpose.models.heads import BaseHead +from .utils import get_simcc_maximum + +OptIntSeq = Optional[Sequence[int]] + + +@MODELS.register_module() +class RTMW3DHead(BaseHead): + """Top-down head introduced in RTMPose-Wholebody (2023). + + Args: + in_channels (int | sequence[int]): Number of channels in the input + feature map. + out_channels (int): Number of channels in the output heatmap. + input_size (tuple): Size of input image in shape [w, h]. + in_featuremap_size (int | sequence[int]): Size of input feature map. + simcc_split_ratio (float): Split ratio of pixels. + Default: 2.0. + final_layer_kernel_size (int): Kernel size of the convolutional layer. + Default: 1. + gau_cfg (Config): Config dict for the Gated Attention Unit. + Default: dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='ReLU', + use_rel_bias=False, + pos_enc=False). + loss (Config): Config of the keypoint loss. Defaults to use + :class:`KLDiscretLoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + """ + + def __init__( + self, + in_channels: Union[int, Sequence[int]], + out_channels: int, + input_size: Tuple[int, int], + in_featuremap_size: Tuple[int, int], + simcc_split_ratio: float = 2.0, + final_layer_kernel_size: int = 1, + gau_cfg: ConfigType = dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='ReLU', + use_rel_bias=False, + pos_enc=False), + loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None, + ): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.input_size = input_size + self.in_featuremap_size = in_featuremap_size + self.simcc_split_ratio = simcc_split_ratio + + self.loss_module = nn.ModuleList() + if isinstance(loss, dict): + self.loss_module.append(MODELS.build(loss)) + elif isinstance(loss, (list, tuple)): + for cfg in loss: + self.loss_module.append(MODELS.build(cfg)) + else: + raise TypeError(f'loss_decode must be a dict or sequence of dict,\ + but got {type(loss)}') + + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + if isinstance(in_channels, (tuple, list)): + raise ValueError( + f'{self.__class__.__name__} does not support selecting ' + 'multiple input features.') + + # Define SimCC layers + flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1] + + ps = 2 + self.ps = nn.PixelShuffle(ps) + self.conv_dec = ConvModule( + in_channels // ps**2, + in_channels // 4, + kernel_size=final_layer_kernel_size, + stride=1, + padding=final_layer_kernel_size // 2, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU')) + + self.final_layer = ConvModule( + in_channels, + out_channels, + kernel_size=final_layer_kernel_size, + stride=1, + padding=final_layer_kernel_size // 2, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU')) + self.final_layer2 = ConvModule( + in_channels // ps + in_channels // 4, + out_channels, + kernel_size=final_layer_kernel_size, + stride=1, + padding=final_layer_kernel_size // 2, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU')) + + self.mlp = nn.Sequential( + ScaleNorm(flatten_dims), + nn.Linear(flatten_dims, gau_cfg['hidden_dims'] // 2, bias=False)) + + self.mlp2 = nn.Sequential( + ScaleNorm(flatten_dims * ps**2), + nn.Linear( + flatten_dims * ps**2, gau_cfg['hidden_dims'] // 2, bias=False)) + + W = int(self.input_size[0] * self.simcc_split_ratio) + H = int(self.input_size[1] * self.simcc_split_ratio) + D = int(self.input_size[2] * self.simcc_split_ratio) + + self.gau = RTMCCBlock( + self.out_channels, + gau_cfg['hidden_dims'], + gau_cfg['hidden_dims'], + s=gau_cfg['s'], + expansion_factor=gau_cfg['expansion_factor'], + dropout_rate=gau_cfg['dropout_rate'], + drop_path=gau_cfg['drop_path'], + attn_type='self-attn', + act_fn=gau_cfg['act_fn'], + use_rel_bias=gau_cfg['use_rel_bias'], + pos_enc=gau_cfg['pos_enc']) + + self.cls_x = nn.Linear(gau_cfg['hidden_dims'], W, bias=False) + self.cls_y = nn.Linear(gau_cfg['hidden_dims'], H, bias=False) + self.cls_z = nn.Linear(gau_cfg['hidden_dims'], D, bias=False) + + def forward(self, feats: Tuple[Tensor, + Tensor]) -> Tuple[Tensor, Tensor, Tensor]: + """Forward the network. + + The input is the feature map extracted by backbone and the + output is the simcc representation. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + pred_x (Tensor): 1d representation of x. + pred_y (Tensor): 1d representation of y. + """ + # enc_b n / 2, h, w + # enc_t n, h, w + enc_b, enc_t = feats + + feats_t = self.final_layer(enc_t) + feats_t = torch.flatten(feats_t, 2) + feats_t = self.mlp(feats_t) + + dec_t = self.ps(enc_t) + dec_t = self.conv_dec(dec_t) + enc_b = torch.cat([dec_t, enc_b], dim=1) + + feats_b = self.final_layer2(enc_b) + feats_b = torch.flatten(feats_b, 2) + feats_b = self.mlp2(feats_b) + + feats = torch.cat([feats_t, feats_b], dim=2) + + feats = self.gau(feats) + + pred_x = self.cls_x(feats) + pred_y = self.cls_y(feats) + pred_z = self.cls_z(feats) + + return pred_x, pred_y, pred_z + + def decode(self, batch_outputs: Union[Tensor, + Tuple[Tensor]]) -> InstanceList: + """Decode keypoints from outputs. + + Args: + batch_outputs (Tensor | Tuple[Tensor]): The network outputs of + a data batch + + Returns: + List[InstanceData]: A list of InstanceData, each contains the + decoded pose information of the instances of one data sample. + """ + + def _pack_and_call(args, func): + if not isinstance(args, tuple): + args = (args, ) + return func(*args) + + if self.decoder is None: + raise RuntimeError( + f'The decoder has not been set in {self.__class__.__name__}. ' + 'Please set the decoder configs in the init parameters to ' + 'enable head methods `head.predict()` and `head.decode()`') + + batch_output_np = to_numpy(batch_outputs, unzip=True) + batch_keypoints = [] + batch_keypoints2d = [] + batch_keypoints_simcc = [] + batch_scores = [] + for outputs in batch_output_np: + keypoints_2d, keypoints, keypoints_simcc, scores = _pack_and_call( + outputs, self.decoder.decode) + batch_keypoints2d.append(keypoints_2d) + batch_keypoints.append(keypoints) + batch_keypoints_simcc.append(keypoints_simcc) + batch_scores.append(scores) + + preds = [] + for keypoints_2d, keypoints, keypoints_simcc, scores in zip(batch_keypoints2d, + batch_keypoints, + batch_keypoints_simcc, + batch_scores): + pred = InstanceData( + keypoints_2d=keypoints_2d, + keypoints=keypoints, + keypoints_simcc=keypoints_simcc, + keypoint_scores=scores) + preds.append(pred) + + return preds + + def predict( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}, + ) -> InstanceList: + """Predict results from features. + + Args: + feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage + features (or multiple multi-stage features in TTA) + batch_data_samples (List[:obj:`PoseDataSample`]): The batch + data samples + test_cfg (dict): The runtime config for testing process. Defaults + to {} + + Returns: + List[InstanceData]: The pose predictions, each contains + the following fields: + - keypoints (np.ndarray): predicted keypoint coordinates in + shape (num_instances, K, D) where K is the keypoint number + and D is the keypoint dimension + - keypoint_scores (np.ndarray): predicted keypoint scores in + shape (num_instances, K) + - keypoint_x_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the x direction + - keypoint_y_labels (np.ndarray, optional): The predicted 1-D + intensity distribution in the y direction + """ + x, y, z = self.forward(feats) + + preds = self.decode((x, y, z)) + + if test_cfg.get('output_heatmaps', False): + raise NotImplementedError + else: + return preds + + def loss( + self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}, + ) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_x, pred_y, pred_z = self.forward(feats) + + gt_x = torch.cat([ + d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples + ], + dim=0) + gt_y = torch.cat([ + d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples + ], + dim=0) + gt_z = torch.cat([ + d.gt_instance_labels.keypoint_z_labels for d in batch_data_samples + ], + dim=0) + keypoint_weights = torch.cat( + [ + d.gt_instance_labels.keypoint_weights + for d in batch_data_samples + ], + dim=0, + ) + + weight_z = torch.cat( + [d.gt_instance_labels.weight_z for d in batch_data_samples], + dim=0, + ) + + with_z_labels = [ + d.gt_instance_labels.with_z_label[0] for d in batch_data_samples + ] + + N, K, _ = pred_x.shape + keypoint_weights_ = keypoint_weights.clone() + pred_simcc = (pred_x, pred_y, pred_z) + gt_simcc = (gt_x, gt_y, gt_z) + + keypoint_weights = torch.cat([ + keypoint_weights[None, ...], keypoint_weights[None, ...], + weight_z[None, ...] + ]) + + # calculate losses + losses = dict() + for i, loss_ in enumerate(self.loss_module): + if loss_.loss_name == 'loss_bone' or loss_.loss_name == 'loss_mpjpe': + pred_coords = get_3d_coord(pred_x, pred_y, pred_z, + with_z_labels) + gt_coords = get_3d_coord(gt_x, gt_y, gt_z, with_z_labels) + loss = loss_(pred_coords, gt_coords, keypoint_weights_) + else: + loss = loss_(pred_simcc, gt_simcc, keypoint_weights) + losses[loss_.loss_name] = loss + + # calculate accuracy + error = simcc_mpjpe( + output=to_numpy(pred_simcc), + target=to_numpy(gt_simcc), + simcc_split_ratio=self.simcc_split_ratio, + mask=to_numpy(keypoint_weights_) > 0, + ) + + mpjpe = torch.tensor(error, device=gt_x.device) + losses.update(mpjpe=mpjpe) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [ + dict(type='Normal', layer=['Conv2d'], std=0.001), + dict(type='Constant', layer='BatchNorm2d', val=1), + dict(type='Normal', layer=['Linear'], std=0.01, bias=0), + ] + return init_cfg + + +def simcc_mpjpe(output: Tuple[np.ndarray, np.ndarray, np.ndarray], + target: Tuple[np.ndarray, np.ndarray, np.ndarray], + simcc_split_ratio: float, + mask: np.ndarray, + thr: float = 0.05) -> float: + """Calculate the pose accuracy of PCK for each individual keypoint and the + averaged accuracy across all keypoints from 3D SimCC. + + Note: + - PCK metric measures accuracy of the localization of the body joints. + - The distances between predicted positions and the ground-truth ones + are typically normalized by the bounding box size. + + Args: + output (Tuple[np.ndarray, np.ndarray, np.ndarray]): Model predicted + 3D SimCC (x, y, z). + target (Tuple[np.ndarray, np.ndarray, np.ndarray]): Groundtruth + 3D SimCC (x, y, z). + simcc_split_ratio (float): SimCC split ratio for recovering actual + coordinates. + mask (np.ndarray[N, K]): Visibility mask for the target. False for + invisible joints, and True for visible. + thr (float): Threshold for PCK calculation. Default 0.05. + normalize (Optional[np.ndarray[N, 3]]): Normalization factor for + H, W, and Depth. + + Returns: + Tuple[np.ndarray, float, int]: + - np.ndarray[K]: Accuracy of each keypoint. + - float: Averaged accuracy across all keypoints. + - int: Number of valid keypoints. + """ + if len(output) == 3: + pred_x, pred_y, pred_z = output + gt_x, gt_y, gt_z = target + pred_coords, _ = get_simcc_maximum(pred_x, pred_y, pred_z) + gt_coords, _ = get_simcc_maximum(gt_x, gt_y, gt_z) + + else: + pred_x, pred_y = output + gt_x, gt_y = target + pred_coords, _ = get_2d_simcc_maximum(pred_x, pred_y) + gt_coords, _ = get_2d_simcc_maximum(gt_x, gt_y) + + pred_coords /= simcc_split_ratio + gt_coords /= simcc_split_ratio + + return keypoint_mpjpe(pred_coords, gt_coords, mask) + + +def get_3d_coord(simcc_x, simcc_y, simcc_z, with_z_labels): + N, K, W = simcc_x.shape + # 过滤 z 轴 + for i, with_z in enumerate(with_z_labels): + if not with_z: + simcc_z[i] = torch.zeros_like(simcc_z[i]) + x_locs = simcc_x.reshape(N * K, -1).argmax(dim=1) + y_locs = simcc_y.reshape(N * K, -1).argmax(dim=1) + z_locs = simcc_z.reshape(N * K, -1).argmax(dim=1) + + locs = torch.stack((x_locs, y_locs, z_locs), + dim=-1).to(simcc_x).reshape(N, K, 3) + return locs diff --git a/projects/rtmpose3d/rtmpose3d/simcc_3d_label.py b/projects/rtmpose3d/rtmpose3d/simcc_3d_label.py new file mode 100644 index 0000000000..4440caa667 --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/simcc_3d_label.py @@ -0,0 +1,335 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from typing import Any, Optional, Tuple, Union + +import numpy as np +from numpy import ndarray + +from mmpose.registry import KEYPOINT_CODECS +from mmpose.codecs.base import BaseKeypointCodec + +from .utils import get_simcc_maximum + + +@KEYPOINT_CODECS.register_module() +class SimCC3DLabel(BaseKeypointCodec): + r"""Generate keypoint representation via "SimCC" approach. + See the paper: `SimCC: a Simple Coordinate Classification Perspective for + Human Pose Estimation`_ by Li et al (2022) for more details. + Old name: SimDR + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoint_x_labels (np.ndarray): The generated SimCC label for x-axis. + The label shape is (N, K, Wx) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wx=w*simcc_split_ratio` + - keypoint_y_labels (np.ndarray): The generated SimCC label for y-axis. + The label shape is (N, K, Wy) if ``smoothing_type=='gaussian'`` + and (N, K) if `smoothing_type=='standard'``, where + :math:`Wy=h*simcc_split_ratio` + - keypoint_weights (np.ndarray): The target weights in shape (N, K) + + Args: + input_size (tuple): Input image size in [w, h] + smoothing_type (str): The SimCC label smoothing strategy. Options are + ``'gaussian'`` and ``'standard'``. Defaults to ``'gaussian'`` + sigma (float | int | tuple): The sigma value in the Gaussian SimCC + label. Defaults to 6.0 + simcc_split_ratio (float): The ratio of the label size to the input + size. For example, if the input width is ``w``, the x label size + will be :math:`w*simcc_split_ratio`. Defaults to 2.0 + label_smooth_weight (float): Label Smoothing weight. Defaults to 0.0 + normalize (bool): Whether to normalize the heatmaps. Defaults to True. + use_dark (bool): Whether to use the DARK post processing. Defaults to + False. + decode_visibility (bool): Whether to decode the visibility. Defaults + to False. + decode_beta (float): The beta value for decoding visibility. Defaults + to 150.0. + + .. _`SimCC: a Simple Coordinate Classification Perspective for Human Pose + Estimation`: https://arxiv.org/abs/2107.03332 + """ + + auxiliary_encode_keys = {'keypoints_3d'} + + label_mapping_table = dict( + keypoint_x_labels='keypoint_x_labels', + keypoint_y_labels='keypoint_y_labels', + keypoint_z_labels='keypoint_z_labels', + keypoint_weights='keypoint_weights', + weight_z='weight_z', + with_z_label='with_z_label') + + instance_mapping_table = dict( + bbox='bboxes', + bbox_score='bbox_scores', + bbox_scale='bbox_scales', + lifting_target='lifting_target', + lifting_target_visible='lifting_target_visible', + camera_param='camera_params', + root_z='root_z') + + def __init__(self, + input_size: Tuple[int, int, int], + smoothing_type: str = 'gaussian', + sigma: Union[float, int, Tuple[float]] = 6.0, + simcc_split_ratio: float = 2.0, + label_smooth_weight: float = 0.0, + normalize: bool = True, + use_dark: bool = False, + decode_visibility: bool = False, + decode_beta: float = 150.0, + root_index: Union[int, Tuple[int]] = 0, + z_range: Optional[int] = None, + sigmoid_z: bool = False) -> None: + super().__init__() + + self.input_size = input_size + self.smoothing_type = smoothing_type + self.simcc_split_ratio = simcc_split_ratio + self.label_smooth_weight = label_smooth_weight + self.normalize = normalize + self.use_dark = use_dark + self.decode_visibility = decode_visibility + self.decode_beta = decode_beta + + if isinstance(sigma, (float, int)): + self.sigma = np.array([sigma, sigma, sigma]) + else: + self.sigma = np.array(sigma) + + if self.smoothing_type not in {'gaussian', 'standard'}: + raise ValueError( + f'{self.__class__.__name__} got invalid `smoothing_type` value' + f'{self.smoothing_type}. Should be one of ' + '{"gaussian", "standard"}') + + if self.smoothing_type == 'gaussian' and self.label_smooth_weight > 0: + raise ValueError('Attribute `label_smooth_weight` is only ' + 'used for `standard` mode.') + + if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0: + raise ValueError('`label_smooth_weight` should be in range [0, 1]') + + self.root_index = list(root_index) if isinstance( + root_index, tuple) else [root_index] + self.z_range = z_range if z_range is not None else 2.1744869 + self.sigmoid_z = sigmoid_z + self.root_z = [5.14388] + + def encode(self, + keypoints: np.ndarray, + keypoints_3d: Optional[np.ndarray] = None, + keypoints_visible: Optional[np.ndarray] = None) -> dict: + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + lifting_target = [None] + root_z = self.root_z + with_z_label = False + if keypoints_3d is not None: + lifting_target = keypoints_3d.copy() + root_z = keypoints_3d[..., self.root_index, 2].mean(1) + keypoints_3d[..., 2] -= root_z + if self.sigmoid_z: + keypoints_z = (1 / (1 + np.exp(-(3 * keypoints_3d[..., 2]))) + ) * self.input_size[2] + else: + keypoints_z = (keypoints_3d[..., 2] / self.z_range + 1) * ( + self.input_size[2] / 2) + + keypoints_3d = np.concatenate([keypoints, keypoints_z[..., None]], + axis=-1) + x, y, z, keypoint_weights = self._generate_gaussian( + keypoints_3d, keypoints_visible) + weight_z = keypoint_weights + with_z_label = True + else: + if keypoints.shape != np.zeros([]).shape: + keypoints_z = np.ones((keypoints.shape[0], + keypoints.shape[1], 1), dtype=np.float32) + keypoints = np.concatenate([keypoints, keypoints_z], axis=-1) + x, y, z, keypoint_weights = self._generate_gaussian( + keypoints, keypoints_visible) + else: + x, y, z = np.zeros((3, 1), dtype=np.float32) + keypoint_weights = np.ones((1, )) + weight_z = np.zeros_like(keypoint_weights) + with_z_label = False + + encoded = dict( + keypoint_x_labels=x, + keypoint_y_labels=y, + keypoint_z_labels=z, + lifting_target=lifting_target, + root_z=root_z, + keypoint_weights=keypoint_weights, + weight_z=weight_z, + with_z_label=[with_z_label]) + + return encoded + + def decode(self, x: np.ndarray, y: np.ndarray, z: np.ndarray): + """Decode SimCC labels into 3D keypoints. + + Args: + encoded (Tuple[np.ndarray, np.ndarray]): SimCC labels for x-axis, + y-axis and z-axis in shape (N, K, Wx), (N, K, Wy) and (N, K, Wz) + + Returns: + tuple: + - keypoints (np.ndarray): Decoded coordinates in shape (N, K, D) + - scores (np.ndarray): The keypoint scores in shape (N, K). + It usually represents the confidence of the keypoint prediction + """ + + keypoints, scores = get_simcc_maximum(x, y, z) + + # Unsqueeze the instance dimension for single-instance results + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + scores = scores[None, :] + + keypoints /= self.simcc_split_ratio + keypoints_simcc = keypoints.copy() + keypoints_2d = keypoints[..., :2] + keypoints_z = keypoints[..., 2:3] + if self.sigmoid_z: + keypoints_z /= self.input_size[2] + keypoints_z[keypoints_z <= 0] = 1e-8 + scores[(keypoints_z <= 0).squeeze(-1)] = 0 + keypoints[..., 2:3] = np.log(keypoints_z / (1 - keypoints_z)) / 3 + else: + keypoints[..., + 2:3] = (keypoints_z / + (self.input_size[-1] / 2) - 1) * self.z_range + return keypoints_2d, keypoints, keypoints_simcc, scores + + def _map_coordinates( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """Mapping keypoint coordinates into SimCC space.""" + + keypoints_split = keypoints.copy() + keypoints_split = np.around(keypoints_split * self.simcc_split_ratio) + keypoints_split = keypoints_split.astype(np.int64) + keypoint_weights = keypoints_visible.copy() + + return keypoints_split, keypoint_weights + + def _generate_gaussian( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> tuple[ndarray, ndarray, ndarray, ndarray]: + """Encoding keypoints into SimCC labels with Gaussian Label Smoothing + strategy.""" + + N, K, _ = keypoints.shape + w, h, d = self.input_size + W = np.around(w * self.simcc_split_ratio).astype(int) + H = np.around(h * self.simcc_split_ratio).astype(int) + D = np.around(d * self.simcc_split_ratio).astype(int) + + keypoints_split, keypoint_weights = self._map_coordinates( + keypoints, keypoints_visible) + + target_x = np.zeros((N, K, W), dtype=np.float32) + target_y = np.zeros((N, K, H), dtype=np.float32) + target_z = np.zeros((N, K, D), dtype=np.float32) + + # 3-sigma rule + radius = self.sigma * 3 + + # xy grid + x = np.arange(0, W, 1, dtype=np.float32) + y = np.arange(0, H, 1, dtype=np.float32) + z = np.arange(0, D, 1, dtype=np.float32) + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + mu = keypoints_split[n, k] + + # check that the gaussian has in-bounds part + left, top, near = mu - radius + right, bottom, far = mu + radius + 1 + + if left >= W or top >= H or near >= D or right < 0 or bottom < 0 or far < 0: # noqa: E501 + keypoint_weights[n, k] = 0 + continue + + mu_x, mu_y, mu_z = mu + + target_x[n, k] = np.exp(-((x - mu_x)**2) / (2 * self.sigma[0]**2)) + target_y[n, k] = np.exp(-((y - mu_y)**2) / (2 * self.sigma[1]**2)) + target_z[n, k] = np.exp(-((z - mu_z)**2) / (2 * self.sigma[2]**2)) + + if self.normalize: + norm_value = self.sigma * np.sqrt(np.pi * 2) + target_x /= norm_value[0] + target_y /= norm_value[1] + target_z /= norm_value[2] + + return target_x, target_y, target_z, keypoint_weights + + def _generate_standard( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> tuple[ndarray, ndarray, ndarray, Any]: + """Encoding keypoints into SimCC labels with Standard Label Smoothing + strategy. + + Labels will be one-hot vectors if self.label_smooth_weight==0.0 + """ + + N, K, _ = keypoints.shape + w, h, d = self.input_size + w = np.around(w * self.simcc_split_ratio).astype(int) + h = np.around(h * self.simcc_split_ratio).astype(int) + d = np.around(d * self.simcc_split_ratio).astype(int) + + keypoints_split, keypoint_weights = self._map_coordinates( + keypoints, keypoints_visible) + + x = np.zeros((N, K, w), dtype=np.float32) + y = np.zeros((N, K, h), dtype=np.float32) + z = np.zeros((N, K, d), dtype=np.float32) + + for n, k in product(range(N), range(K)): + # skip unlabled keypoints + if keypoints_visible[n, k] < 0.5: + continue + + # get center coordinates + mu_x, mu_y, mu_z = keypoints_split[n, k].astype(np.int64) + + # detect abnormal coords and assign the weight 0 + if mu_x >= w or mu_y >= h or mu_x < 0 or mu_y < 0: + keypoint_weights[n, k] = 0 + continue + + if self.label_smooth_weight > 0: + x[n, k] = self.label_smooth_weight / (w - 1) + y[n, k] = self.label_smooth_weight / (h - 1) + z[n, k] = self.label_smooth_weight / (d - 1) + + x[n, k, mu_x] = 1.0 - self.label_smooth_weight + y[n, k, mu_y] = 1.0 - self.label_smooth_weight + z[n, k, mu_z] = 1.0 - self.label_smooth_weight + + return x, y, z, keypoint_weights diff --git a/projects/rtmpose3d/rtmpose3d/utils.py b/projects/rtmpose3d/rtmpose3d/utils.py new file mode 100644 index 0000000000..8dab90de20 --- /dev/null +++ b/projects/rtmpose3d/rtmpose3d/utils.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import numpy as np + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray, + simcc_z: np.ndarray, + apply_softmax: bool = False + ) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + encoded_keypoints (dict): encoded keypoints with simcc representations. + apply_softmax (bool): whether to apply softmax on the heatmap. + Defaults to False. + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + assert isinstance(simcc_x, np.ndarray), 'simcc_x should be numpy.ndarray' + assert isinstance(simcc_y, np.ndarray), 'simcc_y should be numpy.ndarray' + assert isinstance(simcc_z, np.ndarray), 'simcc_z should be numpy.ndarray' + assert simcc_x.ndim == 2 or simcc_x.ndim == 3, ( + f'Invalid shape {simcc_x.shape}') + assert simcc_y.ndim == 2 or simcc_y.ndim == 3, ( + f'Invalid shape {simcc_y.shape}') + assert simcc_z.ndim == 2 or simcc_z.ndim == 3, ( + f'Invalid shape {simcc_z.shape}') + assert simcc_x.ndim == simcc_y.ndim == simcc_z.ndim, ( + f'{simcc_x.shape} != {simcc_y.shape} or {simcc_z.shape}') + + if simcc_x.ndim == 3: + n, k, _ = simcc_x.shape + simcc_x = simcc_x.reshape(n * k, -1) + simcc_y = simcc_y.reshape(n * k, -1) + simcc_z = simcc_z.reshape(n * k, -1) + else: + n = None + + if apply_softmax: + simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True) + simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True) + simcc_z = simcc_z - np.max(simcc_z, axis=1, keepdims=True) + ex, ey, ez = np.exp(simcc_x), np.exp(simcc_y), np.exp(simcc_z) + simcc_x = ex / np.sum(ex, axis=1, keepdims=True) + simcc_y = ey / np.sum(ey, axis=1, keepdims=True) + simcc_z = ez / np.sum(ez, axis=1, keepdims=True) + + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + z_locs = np.argmax(simcc_z, axis=1) + locs = np.stack((x_locs, y_locs, z_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + if n is not None: + locs = locs.reshape(n, k, 3) + vals = vals.reshape(n, k) + + return locs, vals \ No newline at end of file