Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get("image_key", "images")
self.audio_key = kwargs.get("audio_key", "audios")
self.video_key = kwargs.get("video_key", "videos")
self.lidar_key = kwargs.get("lidar_key", "lidar")

self.query_key = kwargs.get("query_key", "query")
self.response_key = kwargs.get("response_key", "response")
Expand Down
93 changes: 93 additions & 0 deletions data_juicer/ops/mapper/lidar_detection_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Dict

from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

mmdet3d = LazyLoader("mmdet3d")

OP_NAME = "lidar_detection_mapper"


@OPERATORS.register_module(OP_NAME)
class LiDARDetectionMapper(Mapper):
"""Mapper to detect ground truth from LiDAR data."""

_batched_op = True

def __init__(self, model_name="centerpoint", *args, **kwargs):
"""
Initialization method.

:param mode:
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)

self.model_name = model_name
self.device = "cpu"

if self.model_name == "centerpoint":
self.deploy_cfg_path = "voxel-detection_onnxruntime_dynamic.py"
self.model_cfg_path = "centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py"
self.backend_files = [
"centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus_20220811_031844-191a3822.onnx"
]
else:
raise NotImplementedError(f'Only support "centerpoint" for now, but got {self.model_name}')

self.model_key = prepare_model(
"mmlab",
model_cfg=self.model_cfg_path,
deploy_cfg=self.deploy_cfg_path,
backend_files=self.backend_files,
device=self.device,
)

# Maybe should include model name, timestamp, filename, image info etc.
def pred2dict(self, data_sample: mmdet3d.structures.Det3DDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary.

It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.

Args:
data_sample (:obj:`DetDataSample`): Predictions of the model.

Returns:
dict: Prediction results.
"""
result = {}
if "pred_instances_3d" in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = {
"labels_3d": pred_instances_3d.labels_3d.tolist(),
"scores_3d": pred_instances_3d.scores_3d.tolist(),
"bboxes_3d": pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
}

if "pred_pts_seg" in data_sample:
pred_pts_seg = data_sample.pred_pts_seg.numpy()
result["pts_semantic_mask"] = pred_pts_seg.pts_semantic_mask.tolist()

if data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.LIDAR:
result["box_type_3d"] = "LiDAR"
elif data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.CAM:
result["box_type_3d"] = "Camera"
elif data_sample.box_mode_3d == mmdet3d.structures.Box3DMode.DEPTH:
result["box_type_3d"] = "Depth"

return result

def process_batched(self, samples):
model = get_model(self.model_key)
lidars = samples[self.lidar_key]

results = [model(lidar) for lidar in lidars]
results = [self.pred2dict(result) for result in results]
samples["lidar_detections"] = results

return samples
55 changes: 54 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import redirect_stderr
from functools import partial
from pickle import UnpicklingError
from typing import Optional, Union
from typing import List, Optional, Union

import httpx
import multiprocess as mp
Expand Down Expand Up @@ -38,6 +38,7 @@
ultralytics = LazyLoader("ultralytics")
tiktoken = LazyLoader("tiktoken")
dashscope = LazyLoader("dashscope")
mmdeploy = LazyLoader("mmdeploy")

MODEL_ZOO = {}

Expand Down Expand Up @@ -942,6 +943,57 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl
return sampling_params


class MMLabModel(object):
Copy link
Collaborator

@yxdyc yxdyc Jun 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, @Qirui-jiao : based on onnx models from mmdeploy

"""
A wrapper for mmdeploy model.
It is used to load a mmdeploy model and run inference on given images.
"""

def __init__(self, model_cfg, deploy_cfg, backend_files, device):
self.model_cfg = model_cfg
self.deploy_cfg = deploy_cfg
self.backend_files = backend_files
self.device = device

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config

deploy_cfg, model_cfg = load_config(self.deploy_cfg, self.model_cfg)
self.task_processor = build_task_processor(self.model_cfg, self.deploy_cfg, self.device)

self.model = self.task_processor.build_backend_model(
self.backend_files, data_preprocessor_updater=self.task_processor.update_data_preprocessor
)

self.input_shape = get_input_shape(deploy_cfg)

def __call__(self, image):
model_inputs, _ = self.task_processor.create_input(image, self.input_shape)

with torch.no_grad():
result = self.model.test_step(model_inputs)

return result


def prepare_mmlab_model(model_cfg: str, deploy_cfg: str, backend_files: List[str], device: str = "cpu"):
"""Prepare and load a model using mmdeploy.

:param model_cfg: Path to the model config.
:param deploy_cfg: Path to the deployment config.
:param backend_files: Path to the backend model files.
:param device: Device to use.
"""
model = MMLabModel(
check_model(model_cfg),
check_model(deploy_cfg),
[check_model(backend_file) for backend_file in backend_files],
device,
)

return model


MODEL_FUNCTION_MAPPING = {
"api": prepare_api_model,
"diffusion": prepare_diffusion_model,
Expand All @@ -960,6 +1012,7 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl
"video_blip": prepare_video_blip_model,
"vllm": prepare_vllm_model,
"embedding": prepare_embedding_model,
"mmlab": prepare_mmlab_model,
}

_MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"}
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Data-Juicer 中的算子分为以下 7 种类型。
| [filter](#filter) | 49 | Filters out low-quality samples. 过滤低质量样本。 |
| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 |
| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 |
| [mapper](#mapper) | 81 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [mapper](#mapper) | 82 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 |

All the specific operators are listed below, each featured with several capability tags.
Expand Down Expand Up @@ -200,6 +200,7 @@ All the specific operators are listed below, each featured with several capabili
| image_tagging_mapper | 🏞Image 💻CPU 🟢Stable | Mapper to generate image tags. 映射器生成图像标签。 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) |
| imgdiff_difference_area_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_area_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_area_generator_mapper.py) |
| imgdiff_difference_caption_generator_mapper | 💻CPU 🟡Beta | A fused operator for OPs that is used to run sequential OPs on the same batch to allow fine-grained control on data processing. OPs的融合操作符,用于在同一批次上运行顺序OPs,以实现对数据处理的细粒度控制。 | [code](../data_juicer/ops/mapper/imgdiff_difference_caption_generator_mapper.py) | [tests](../tests/ops/mapper/test_imgdiff_difference_caption_generator_mapper.py) |
| lidar_detection_mapper | 💻CPU 🔴Alpha | Mapper to detect ground truth from LiDAR data. 映射器从激光雷达数据中检测地面真相。 | [code](../data_juicer/ops/mapper/lidar_detection_mapper.py) | - |
| mllm_mapper | 🔮Multimodal 💻CPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [code](../data_juicer/ops/mapper/mllm_mapper.py) | [tests](../tests/ops/mapper/test_mllm_mapper.py) |
| nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in English based on nlpaug library. 映射器基于nlpaug库简单地增加英语样本。 | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) |
| nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Mapper to simply augment samples in Chinese based on nlpcda library. 基于nlpcda库的映射器可以简单地增加中文样本。 | [code](../data_juicer/ops/mapper/nlpcda_zh_mapper.py) | [tests](../tests/ops/mapper/test_nlpcda_zh_mapper.py) |
Expand Down
Loading