diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 9a57e2ecef..0f4aaa9d48 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -164,6 +164,7 @@ def __init__(self, *args, **kwargs): self.image_key = kwargs.get("image_key", "images") self.audio_key = kwargs.get("audio_key", "audios") self.video_key = kwargs.get("video_key", "videos") + self.lidar_key = kwargs.get("lidar_key", "lidar") # extra mm bytes keys self.image_bytes_key = kwargs.get("image_bytes_key", "image_bytes") diff --git a/data_juicer/ops/mapper/lidar_segmentation_mapper.py b/data_juicer/ops/mapper/lidar_segmentation_mapper.py new file mode 100644 index 0000000000..ff376504eb --- /dev/null +++ b/data_juicer/ops/mapper/lidar_segmentation_mapper.py @@ -0,0 +1,68 @@ +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +mmdet3d = LazyLoader("mmdet3d") + +OP_NAME = "lidar_segmentation_mapper" + + +@OPERATORS.register_module(OP_NAME) +class LiDARSegmentationMapper(Mapper): + """Mapper to do segmentation from LiDAR data.""" + + _batched_op = True + _accelerator = "cuda" + + def __init__( + self, + model_name="cylinder3d", + model_cfg_name="", + model_path="", + tag_field_name: str = MetaKeys.lidar_segmentation_tags, + *args, + **kwargs, + ): + """ + Initialization method. + :param model_name: Name of the model used. + :param model_cfg_name: The config name of the model used. + :param model_path: Path of the model weight. + :param tag_field_name: The field name to store the tags. It's + "lidar_segmentation_tags" in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + self.model_name = model_name + + if self.model_name == "cylinder3d": + self.model_cfg_name = model_cfg_name + self.model_path = model_path + else: + raise NotImplementedError(f'Only support "cylinder3d" for now, but got {self.model_name}') + + self.model_key = prepare_model( + "mmlab", + model_cfg=self.model_cfg_name, + model_path=self.model_path, + task="LiDARSegmentation", + model_name=self.model_name, + ) + self.tag_field_name = tag_field_name + + def process_single(self, sample, rank=None): + + # check if it's generated already + if self.tag_field_name in sample[Fields.meta]: + return sample + + model = get_model(self.model_key, rank, self.use_cuda()) + + results = model(dict(points=sample[self.lidar_key])) + sample[Fields.meta][self.tag_field_name] = results[0] + + return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index b32e82784d..f0f4027e63 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -80,6 +80,8 @@ class MetaKeys(object): class_label_tag = DEFAULT_PREFIX + "class_label__" # # 2D whole-body pose estimation pose_estimation_tags = "pose_estimation_tags" + # # lidar segmentation tags + lidar_segmentation_tags = "lidar_segmentation_tags" # === info extraction related tags === # # for event extraction diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 3a9369a423..a18aa9e246 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -5,7 +5,7 @@ from contextlib import redirect_stderr from functools import partial from pickle import UnpicklingError -from typing import Optional, Union +from typing import List, Optional, Union import httpx import multiprocess as mp @@ -40,6 +40,8 @@ ultralytics = LazyLoader("ultralytics") tiktoken = LazyLoader("tiktoken") dashscope = LazyLoader("dashscope") +mmdeploy = LazyLoader("mmdeploy") +mmdet3d = LazyLoader("mmdet3d") qwen_vl_utils = LazyLoader("qwen_vl_utils", "qwen-vl-utils") transformers_stream_generator = LazyLoader( "transformers_stream_generator", "git+https://github.com/HYLcool/transformers-stream-generator.git" @@ -76,6 +78,8 @@ # DWPose "dwpose_onnx_det_model": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx", "dwpose_onnx_pose_model": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx", + # Cylinder3d + "cylinder3d": "https://download.openmmlab.com/mmdetection3d/v1.1.0_models/cylinder3d/cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth", } @@ -1237,6 +1241,136 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl return sampling_params +class MMLabModel(object): + """ + A wrapper for mmdeploy model. + It is used to load a mmdeploy model and run inference on given images. + """ + + def __init__(self, model_cfg_path, deploy_cfg_path, backend_files, device): + self.model_cfg_path = model_cfg_path + self.deploy_cfg_path = deploy_cfg_path + self.backend_files = backend_files + self.device = device + + from mmdeploy.apis.utils import build_task_processor + from mmdeploy.utils import get_input_shape, load_config + + deploy_cfg, model_cfg = load_config(self.deploy_cfg_path, self.model_cfg_path) + self.task_processor = build_task_processor(model_cfg, deploy_cfg, self.device) + + self.model = self.task_processor.build_backend_model( + self.backend_files, data_preprocessor_updater=self.task_processor.update_data_preprocessor + ) + + self.input_shape = get_input_shape(deploy_cfg) + + def __call__(self, image): + model_inputs, _ = self.task_processor.create_input(image, self.input_shape) + + with torch.no_grad(): + result = self.model.test_step(model_inputs) + + return result + + +class MMLabInferencer(object): + """ + A wrapper for mmdet3d Inferencer. + It is used to load a mmdet3d Inferencer and run inference on given LiDAR data. + """ + + def __init__(self, model_cfg_path, model_path, device): + self.model_cfg_path = model_cfg_path + self.model_path = model_path + self.device = device + + from mmdet3d.apis import LidarSeg3DInferencer + + self.model = LidarSeg3DInferencer(model=self.model_cfg_path, weights=self.model_path, device=self.device) + + def __call__(self, lidar_bin_files): + result = self.model(lidar_bin_files, show=False)["predictions"] + + return result + + +def prepare_mmlab_model( + model_cfg: str = "", + deploy_cfg: str = "", + backend_files: List[str] = [], + device: str = "cpu", + task: str = "LiDARDetection", + model_name: str = "", + model_path: str = "", +): + """Prepare and load a model using mmdeploy. + :param model_cfg: Path to the model config. + :param deploy_cfg: Path to the deployment config. + :param backend_files: Path to the backend model files. + :param device: Device to use. + :param task: Current task. Only support ["LiDARDetection", "LiDARSegmentation"] for now. + :param model_name: Name of the model used. + :param model_path: Path of the model weight. + """ + + if task == "LiDARDetection": + model = MMLabModel( + check_model(model_cfg), + check_model(deploy_cfg), + [check_model(backend_file) for backend_file in backend_files], + device, + ) + elif task == "LiDARSegmentation": + + import subprocess + + from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE + + mmdetection3d_repo_path = os.path.join(DATA_JUICER_ASSETS_CACHE, "mmdetection3d") + if not os.path.exists(mmdetection3d_repo_path): + subprocess.run( + ["git", "clone", "https://github.com/open-mmlab/mmdetection3d.git", mmdetection3d_repo_path], check=True + ) + + original_model_cfg = model_cfg + model_cfg = os.path.splitext(os.path.basename(model_cfg))[0] + model_cfg = os.path.join(mmdetection3d_repo_path, "configs", model_name, model_cfg + ".py") + + if not os.path.exists(model_cfg): + raise ValueError(f"{model_cfg} does not exist.") + + if not os.path.exists(model_path): + if "cylinder3d_8xb2-laser-polar-mix-3x_semantickitti" in original_model_cfg: + logger.info( + f'The model corresponding to "{original_model_cfg}" does not exist. Model weight is not found at {model_path}. Starting automatic download...' + ) + if not os.path.exists(DJMC): + os.makedirs(DJMC) + model_path = os.path.join( + DJMC, "cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth" + ) + if not os.path.exists(model_path): + wget.download(BACKUP_MODEL_LINKS["cylinder3d"], DJMC) + else: + raise ValueError( + f'The model corresponding to "{original_model_cfg}" does not exist. Model weight is not found at {model_path}.' + ) + + model = MMLabInferencer( + model_cfg, + model_path, + device, + ) + + else: + raise NotImplementedError( + f'Only support task name ["LiDARDetection", "LiDARSegmentation"] for now, but got {task}' + ) + + return model + + def prepare_qwen_vl_inputs_for_vllm(messages, processor): text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # qwen_vl_utils 0.0.14+ required @@ -1279,6 +1413,7 @@ def prepare_qwen_vl_inputs_for_vllm(messages, processor): "wilor": prepare_wilor_model, "yolo": prepare_yolo_model, "embedding": prepare_embedding_model, + "mmlab": prepare_mmlab_model, } _MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"} diff --git a/docs/Operators.md b/docs/Operators.md index 706978f5dd..234233678a 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -40,14 +40,14 @@ The operators in Data-Juicer are categorized into 7 types. Data-Juicer 中的算子分为以下 7 种类型。 | Type 类型 | Number 数量 | Description 描述 | -|------|:---------:|-------------| -| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | -| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | -| [filter](#filter) | 54 | Filters out low-quality samples. 过滤低质量样本。 | -| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | -| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | -| [mapper](#mapper) | 92 | Edits and transforms samples. 对数据样本进行编辑和转换。 | -| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | +|------|:------:|-------------| +| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | +| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | +| [filter](#filter) | 54 | Filters out low-quality samples. 过滤低质量样本。 | +| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | +| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | +| [mapper](#mapper) | 94 | Edits and transforms samples. 对数据样本进行编辑和转换。 | +| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | All the specific operators are listed below, each featured with several capability tags. 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -218,6 +218,7 @@ All the specific operators are listed below, each featured with several capabili | image_tagging_mapper | 🏞Image 🚀GPU 🟢Stable | Generates image tags for each image in the sample. 为样本中的每个图像生成图像标记。 | [info](operators/mapper/image_tagging_mapper.md) | - | | imgdiff_difference_area_generator_mapper | 🚀GPU 🟡Beta | Generates and filters bounding boxes for image pairs based on similarity, segmentation, and text matching. 根据相似性、分割和文本匹配生成和过滤图像对的边界框。 | [info](operators/mapper/imgdiff_difference_area_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) | | imgdiff_difference_caption_generator_mapper | 🚀GPU 🟡Beta | Generates difference captions for bounding box regions in two images. 为两个图像中的边界框区域生成差异字幕。 | [info](operators/mapper/imgdiff_difference_caption_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) | +| lidar_segmentation_mapper | 🚀GPU 🟡Beta | Mapper to do segmentation from LiDAR data. 映射器从激光雷达数据中进行分割。 | - | [tests](../tests/ops/mapper/test_lidar_segmentation_mapper.py) | | mllm_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [info](operators/mapper/mllm_mapper.md) | - | | nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Augments English text samples using various methods from the nlpaug library. 使用nlpaug库中的各种方法增强英语文本样本。 | [info](operators/mapper/nlpaug_en_mapper.md) | - | | nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Augments Chinese text samples using the nlpcda library. 使用nlpcda库扩充中文文本样本。 | [info](operators/mapper/nlpcda_zh_mapper.md) | - | diff --git a/tests/ops/data/lidar_test1.bin b/tests/ops/data/lidar_test1.bin new file mode 100644 index 0000000000..24cefd327f Binary files /dev/null and b/tests/ops/data/lidar_test1.bin differ diff --git a/tests/ops/data/lidar_test2.bin b/tests/ops/data/lidar_test2.bin new file mode 100644 index 0000000000..3d3694929d Binary files /dev/null and b/tests/ops/data/lidar_test2.bin differ diff --git a/tests/ops/data/lidar_test3.bin b/tests/ops/data/lidar_test3.bin new file mode 100644 index 0000000000..b6388d5416 Binary files /dev/null and b/tests/ops/data/lidar_test3.bin differ diff --git a/tests/ops/mapper/test_lidar_segmentation_mapper.py b/tests/ops/mapper/test_lidar_segmentation_mapper.py new file mode 100644 index 0000000000..ef3737c3ad --- /dev/null +++ b/tests/ops/mapper/test_lidar_segmentation_mapper.py @@ -0,0 +1,61 @@ +import unittest +import os + +from data_juicer.core import NestedDataset as Dataset +from data_juicer.ops.mapper.lidar_segmentation_mapper import LiDARSegmentationMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.constant import Fields, MetaKeys + + +class LiDARSegmentationMapperTest(DataJuicerTestCaseBase): + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + lidar_test1 = os.path.join(data_path, 'lidar_test1.bin') + lidar_test2 = os.path.join(data_path, 'lidar_test2.bin') + lidar_test3 = os.path.join(data_path, 'lidar_test3.bin') + + model_cfg_name = "cylinder3d_8xb2-laser-polar-mix-3x_semantickitti" + model_path = "cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth" + + + def setUp(self): + super().setUp() + self.source = [ + {'lidar': self.lidar_test1}, + {'lidar': self.lidar_test2}, + {'lidar': self.lidar_test3} + ] + self.op = LiDARSegmentationMapper( + model_name="cylinder3d", + model_cfg_name=self.model_cfg_name, + model_path=self.model_path, + ) + + def _run_and_assert(self, num_proc, with_rank): + dataset = Dataset.from_list(self.source) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(self.op.process, num_proc=num_proc, with_rank=with_rank) + res_list = dataset.to_list() + + self.assertEqual(len(res_list), 3) + self.assertEqual(list(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags].keys()), ['box_type_3d', 'pts_semantic_mask']) + self.assertEqual(len(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags]["pts_semantic_mask"]), 17238) + self.assertEqual(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags]["box_type_3d"], "LiDAR") + + def test_cpu(self): + self._run_and_assert(num_proc=1, with_rank=False) + + def test_cuda(self): + self._run_and_assert(num_proc=1, with_rank=True) + + def test_cpu_mul_proc(self): + self._run_and_assert(num_proc=2, with_rank=False) + + def test_cuda_mul_proc(self): + self._run_and_assert(num_proc=2, with_rank=True) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file