diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 665bdb73bf..1bf3f31c66 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs): self.image_key = kwargs.get("image_key", "images") self.audio_key = kwargs.get("audio_key", "audios") self.video_key = kwargs.get("video_key", "videos") + self.lidar_key = kwargs.get("lidar_key", "lidar") self.query_key = kwargs.get("query_key", "query") self.response_key = kwargs.get("response_key", "response") diff --git a/data_juicer/ops/mapper/lidar_detection_mapper.py b/data_juicer/ops/mapper/lidar_detection_mapper.py new file mode 100644 index 0000000000..c80c93b069 --- /dev/null +++ b/data_juicer/ops/mapper/lidar_detection_mapper.py @@ -0,0 +1,92 @@ +from typing import Dict + +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +mmdet3d = LazyLoader("mmdet3d") + +OP_NAME = "lidar_detection_mapper" + + +@OPERATORS.register_module(OP_NAME) +class LiDARDetectionMapper(Mapper): + """Mapper to detect ground truth from LiDAR data.""" + + _batched_op = True + _accelerator = "cuda" + + def __init__(self, model_name="centerpoint", *args, **kwargs): + """ + Initialization method. + + :param mode: + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + self.model_name = model_name + + if self.model_name == "centerpoint": + self.deploy_cfg_path = "voxel-detection_onnxruntime_dynamic.py" + self.model_cfg_path = "centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py" + self.backend_files = [ + "centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus_20220811_031844-191a3822.onnx" + ] + else: + raise NotImplementedError(f'Only support "centerpoint" for now, but got {self.model_name}') + + self.model_key = prepare_model( + "mmlab", + model_cfg=self.model_cfg_path, + deploy_cfg=self.deploy_cfg_path, + backend_files=self.backend_files, + ) + + # Maybe should include model name, timestamp, filename, image info etc. + def pred2dict(self, data_sample: mmdet3d.structures.Det3DDataSample) -> Dict: + """Extract elements necessary to represent a prediction into a + dictionary. + + It's better to contain only basic data elements such as strings and + numbers in order to guarantee it's json-serializable. + + Args: + data_sample (:obj:`DetDataSample`): Predictions of the model. + + Returns: + dict: Prediction results. + """ + result = {} + if "pred_instances_3d" in data_sample: + pred_instances_3d = data_sample.pred_instances_3d.numpy() + result = { + "labels_3d": pred_instances_3d.labels_3d.tolist(), + "scores_3d": pred_instances_3d.scores_3d.tolist(), + "bboxes_3d": pred_instances_3d.bboxes_3d.tensor.cpu().tolist(), + } + + if "pred_pts_seg" in data_sample: + pred_pts_seg = data_sample.pred_pts_seg.numpy() + result["pts_semantic_mask"] = pred_pts_seg.pts_semantic_mask.tolist() + + if data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.LIDAR: + result["box_type_3d"] = "LiDAR" + elif data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.CAM: + result["box_type_3d"] = "Camera" + elif data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.DEPTH: + result["box_type_3d"] = "Depth" + + return result + + def process_batched(self, samples, rank=None): + model = get_model(self.model_key, rank, self.use_cuda()) + lidars = samples[self.lidar_key] + + results = [model(lidar) for lidar in lidars] + results = [self.pred2dict(result[0]) for result in results] + samples["lidar_detections"] = results + + return samples diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 138d6711ea..ee440b93a3 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -5,7 +5,7 @@ from contextlib import redirect_stderr from functools import partial from pickle import UnpicklingError -from typing import Optional, Union +from typing import List, Optional, Union import httpx import multiprocess as mp @@ -38,6 +38,7 @@ ultralytics = LazyLoader("ultralytics") tiktoken = LazyLoader("tiktoken") dashscope = LazyLoader("dashscope") +mmdeploy = LazyLoader("mmdeploy") MODEL_ZOO = {} @@ -942,6 +943,57 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl return sampling_params +class MMLabModel(object): + """ + A wrapper for mmdeploy model. + It is used to load a mmdeploy model and run inference on given images. + """ + + def __init__(self, model_cfg_path, deploy_cfg_path, backend_files, device): + self.model_cfg_path = model_cfg_path + self.deploy_cfg_path = deploy_cfg_path + self.backend_files = backend_files + self.device = device + + from mmdeploy.apis.utils import build_task_processor + from mmdeploy.utils import get_input_shape, load_config + + deploy_cfg, model_cfg = load_config(self.deploy_cfg_path, self.model_cfg_path) + self.task_processor = build_task_processor(model_cfg, deploy_cfg, self.device) + + self.model = self.task_processor.build_backend_model( + self.backend_files, data_preprocessor_updater=self.task_processor.update_data_preprocessor + ) + + self.input_shape = get_input_shape(deploy_cfg) + + def __call__(self, image): + model_inputs, _ = self.task_processor.create_input(image, self.input_shape) + + with torch.no_grad(): + result = self.model.test_step(model_inputs) + + return result + + +def prepare_mmlab_model(model_cfg: str, deploy_cfg: str, backend_files: List[str], device: str = "cpu"): + """Prepare and load a model using mmdeploy. + + :param model_cfg: Path to the model config. + :param deploy_cfg: Path to the deployment config. + :param backend_files: Path to the backend model files. + :param device: Device to use. + """ + model = MMLabModel( + check_model(model_cfg), + check_model(deploy_cfg), + [check_model(backend_file) for backend_file in backend_files], + device, + ) + + return model + + MODEL_FUNCTION_MAPPING = { "api": prepare_api_model, "diffusion": prepare_diffusion_model, @@ -960,6 +1012,7 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl "video_blip": prepare_video_blip_model, "vllm": prepare_vllm_model, "embedding": prepare_embedding_model, + "mmlab": prepare_mmlab_model, } _MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"} diff --git a/docs/Operators.md b/docs/Operators.md index e54c07bc02..1bc6ef9361 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -37,7 +37,7 @@ Data-Juicer 中的算子分为以下 7 种类型。 | [filter](#filter) | 49 | Filters out low-quality samples. 过滤低质量样本。 | | [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | -| [mapper](#mapper) | 81 | Edits and transforms samples. 对数据样本进行编辑和转换。 | +| [mapper](#mapper) | 82 | Edits and transforms samples. 对数据样本进行编辑和转换。 | | [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | All the specific operators are listed below, each featured with several capability tags. @@ -200,6 +200,7 @@ All the specific operators are listed below, each featured with several capabili | image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) | | imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) | +| lidar_detection_mapper | 💻CPU 🔴Alpha | Mapper to detect ground truth from LiDAR data. 映射器从激光雷达数据中检测地面真相。 | [code](../data_juicer/ops/mapper/lidar_detection_mapper.py) | - | | mllm_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [code](../data_juicer/ops/mapper/mllm_mapper.py) | [tests](../tests/ops/mapper/test_mllm_mapper.py) | | nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in English based on nlpaug library. 映射器基于nlpaug库简单地增加英语样本。 | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) | | nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in Chinese based on nlpcda library. 基于nlpcda库的映射器可以简单地增加中文样本。 | [code](../data_juicer/ops/mapper/nlpcda_zh_mapper.py) | [tests](../tests/ops/mapper/test_nlpcda_zh_mapper.py) |