Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get("image_key", "images")
self.audio_key = kwargs.get("audio_key", "audios")
self.video_key = kwargs.get("video_key", "videos")
self.lidar_key = kwargs.get("lidar_key", "lidar")

# extra mm bytes keys
self.image_bytes_key = kwargs.get("image_bytes_key", "image_bytes")
Expand Down
68 changes: 68 additions & 0 deletions data_juicer/ops/mapper/lidar_segmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

mmdet3d = LazyLoader("mmdet3d")

OP_NAME = "lidar_segmentation_mapper"


@OPERATORS.register_module(OP_NAME)
class LiDARSegmentationMapper(Mapper):
"""Mapper to do segmentation from LiDAR data."""

_batched_op = True
_accelerator = "cuda"

def __init__(
self,
model_name="cylinder3d",
model_cfg_name="",
model_path="",
tag_field_name: str = MetaKeys.lidar_segmentation_tags,
*args,
**kwargs,
):
"""
Initialization method.
:param model_name: Name of the model used.
:param model_cfg_name: The config name of the model used.
:param model_path: Path of the model weight.
:param tag_field_name: The field name to store the tags. It's
"lidar_segmentation_tags" in default.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)

self.model_name = model_name

if self.model_name == "cylinder3d":
self.model_cfg_name = model_cfg_name
self.model_path = model_path
else:
raise NotImplementedError(f'Only support "cylinder3d" for now, but got {self.model_name}')

self.model_key = prepare_model(
"mmlab",
model_cfg=self.model_cfg_name,
model_path=self.model_path,
task="LiDARSegmentation",
model_name=self.model_name,
)
self.tag_field_name = tag_field_name

def process_single(self, sample, rank=None):

# check if it's generated already
if self.tag_field_name in sample[Fields.meta]:
return sample

model = get_model(self.model_key, rank, self.use_cuda())

results = model(dict(points=sample[self.lidar_key]))
sample[Fields.meta][self.tag_field_name] = results[0]

return sample
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class MetaKeys(object):
class_label_tag = DEFAULT_PREFIX + "class_label__"
# # 2D whole-body pose estimation
pose_estimation_tags = "pose_estimation_tags"
# # lidar segmentation tags
lidar_segmentation_tags = "lidar_segmentation_tags"

# === info extraction related tags ===
# # for event extraction
Expand Down
137 changes: 136 additions & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import redirect_stderr
from functools import partial
from pickle import UnpicklingError
from typing import Optional, Union
from typing import List, Optional, Union

import httpx
import multiprocess as mp
Expand Down Expand Up @@ -40,6 +40,8 @@
ultralytics = LazyLoader("ultralytics")
tiktoken = LazyLoader("tiktoken")
dashscope = LazyLoader("dashscope")
mmdeploy = LazyLoader("mmdeploy")
mmdet3d = LazyLoader("mmdet3d")
qwen_vl_utils = LazyLoader("qwen_vl_utils", "qwen-vl-utils")
transformers_stream_generator = LazyLoader(
"transformers_stream_generator", "git+https://github.com/HYLcool/transformers-stream-generator.git"
Expand Down Expand Up @@ -76,6 +78,8 @@
# DWPose
"dwpose_onnx_det_model": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx",
"dwpose_onnx_pose_model": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx",
# Cylinder3d
"cylinder3d": "https://download.openmmlab.com/mmdetection3d/v1.1.0_models/cylinder3d/cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth",
}


Expand Down Expand Up @@ -1237,6 +1241,136 @@ def update_sampling_params(sampling_params, pretrained_model_name_or_path, enabl
return sampling_params


class MMLabModel(object):
"""
A wrapper for mmdeploy model.
It is used to load a mmdeploy model and run inference on given images.
"""

def __init__(self, model_cfg_path, deploy_cfg_path, backend_files, device):
self.model_cfg_path = model_cfg_path
self.deploy_cfg_path = deploy_cfg_path
self.backend_files = backend_files
self.device = device

from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_input_shape, load_config

deploy_cfg, model_cfg = load_config(self.deploy_cfg_path, self.model_cfg_path)
self.task_processor = build_task_processor(model_cfg, deploy_cfg, self.device)

self.model = self.task_processor.build_backend_model(
self.backend_files, data_preprocessor_updater=self.task_processor.update_data_preprocessor
)

self.input_shape = get_input_shape(deploy_cfg)

def __call__(self, image):
model_inputs, _ = self.task_processor.create_input(image, self.input_shape)

with torch.no_grad():
result = self.model.test_step(model_inputs)

return result


class MMLabInferencer(object):
"""
A wrapper for mmdet3d Inferencer.
It is used to load a mmdet3d Inferencer and run inference on given LiDAR data.
"""

def __init__(self, model_cfg_path, model_path, device):
self.model_cfg_path = model_cfg_path
self.model_path = model_path
self.device = device

from mmdet3d.apis import LidarSeg3DInferencer

self.model = LidarSeg3DInferencer(model=self.model_cfg_path, weights=self.model_path, device=self.device)

def __call__(self, lidar_bin_files):
result = self.model(lidar_bin_files, show=False)["predictions"]

return result


def prepare_mmlab_model(
model_cfg: str = "",
deploy_cfg: str = "",
backend_files: List[str] = [],
device: str = "cpu",
task: str = "LiDARDetection",
model_name: str = "",
model_path: str = "",
):
"""Prepare and load a model using mmdeploy.
:param model_cfg: Path to the model config.
:param deploy_cfg: Path to the deployment config.
:param backend_files: Path to the backend model files.
:param device: Device to use.
:param task: Current task. Only support ["LiDARDetection", "LiDARSegmentation"] for now.
:param model_name: Name of the model used.
:param model_path: Path of the model weight.
"""

if task == "LiDARDetection":
model = MMLabModel(
check_model(model_cfg),
check_model(deploy_cfg),
[check_model(backend_file) for backend_file in backend_files],
device,
)
elif task == "LiDARSegmentation":

import subprocess

from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE

mmdetection3d_repo_path = os.path.join(DATA_JUICER_ASSETS_CACHE, "mmdetection3d")
if not os.path.exists(mmdetection3d_repo_path):
subprocess.run(
["git", "clone", "https://github.com/open-mmlab/mmdetection3d.git", mmdetection3d_repo_path], check=True
)

original_model_cfg = model_cfg
model_cfg = os.path.splitext(os.path.basename(model_cfg))[0]
model_cfg = os.path.join(mmdetection3d_repo_path, "configs", model_name, model_cfg + ".py")

if not os.path.exists(model_cfg):
raise ValueError(f"{model_cfg} does not exist.")

if not os.path.exists(model_path):
if "cylinder3d_8xb2-laser-polar-mix-3x_semantickitti" in original_model_cfg:
logger.info(
f'The model corresponding to "{original_model_cfg}" does not exist. Model weight is not found at {model_path}. Starting automatic download...'
)
if not os.path.exists(DJMC):
os.makedirs(DJMC)
model_path = os.path.join(
DJMC, "cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth"
)
if not os.path.exists(model_path):
wget.download(BACKUP_MODEL_LINKS["cylinder3d"], DJMC)
else:
raise ValueError(
f'The model corresponding to "{original_model_cfg}" does not exist. Model weight is not found at {model_path}.'
)

model = MMLabInferencer(
model_cfg,
model_path,
device,
)

else:
raise NotImplementedError(
f'Only support task name ["LiDARDetection", "LiDARSegmentation"] for now, but got {task}'
)

return model


def prepare_qwen_vl_inputs_for_vllm(messages, processor):
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# qwen_vl_utils 0.0.14+ required
Expand Down Expand Up @@ -1279,6 +1413,7 @@ def prepare_qwen_vl_inputs_for_vllm(messages, processor):
"wilor": prepare_wilor_model,
"yolo": prepare_yolo_model,
"embedding": prepare_embedding_model,
"mmlab": prepare_mmlab_model,
}

_MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"}
Expand Down
17 changes: 9 additions & 8 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ The operators in Data-Juicer are categorized into 7 types.
Data-Juicer 中的算子分为以下 7 种类型。

| Type 类型 | Number 数量 | Description 描述 |
|------|:---------:|-------------|
| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 |
| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 |
| [filter](#filter) | 54 | Filters out low-quality samples. 过滤低质量样本。 |
| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 |
| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 |
| [mapper](#mapper) | 92 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 |
|------|:------:|-------------|
| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 |
| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 |
| [filter](#filter) | 54 | Filters out low-quality samples. 过滤低质量样本。 |
| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 |
| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 |
| [mapper](#mapper) | 94 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 |

All the specific operators are listed below, each featured with several capability tags.
下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。
Expand Down Expand Up @@ -218,6 +218,7 @@ All the specific operators are listed below, each featured with several capabili
| image_tagging_mapper | 🏞Image 🚀GPU 🟢Stable | Generates image tags for each image in the sample. 为样本中的每个图像生成图像标记。 | [info](operators/mapper/image_tagging_mapper.md) | - |
| imgdiff_difference_area_generator_mapper | 🚀GPU 🟡Beta | Generates and filters bounding boxes for image pairs based on similarity, segmentation, and text matching. 根据相似性、分割和文本匹配生成和过滤图像对的边界框。 | [info](operators/mapper/imgdiff_difference_area_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) |
| imgdiff_difference_caption_generator_mapper | 🚀GPU 🟡Beta | Generates difference captions for bounding box regions in two images. 为两个图像中的边界框区域生成差异字幕。 | [info](operators/mapper/imgdiff_difference_caption_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) |
| lidar_segmentation_mapper | 🚀GPU 🟡Beta | Mapper to do segmentation from LiDAR data. 映射器从激光雷达数据中进行分割。 | - | [tests](../tests/ops/mapper/test_lidar_segmentation_mapper.py) |
| mllm_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [info](operators/mapper/mllm_mapper.md) | - |
| nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Augments English text samples using various methods from the nlpaug library. 使用nlpaug库中的各种方法增强英语文本样本。 | [info](operators/mapper/nlpaug_en_mapper.md) | - |
| nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Augments Chinese text samples using the nlpcda library. 使用nlpcda库扩充中文文本样本。 | [info](operators/mapper/nlpcda_zh_mapper.md) | - |
Expand Down
Binary file added tests/ops/data/lidar_test1.bin
Binary file not shown.
Binary file added tests/ops/data/lidar_test2.bin
Binary file not shown.
Binary file added tests/ops/data/lidar_test3.bin
Binary file not shown.
61 changes: 61 additions & 0 deletions tests/ops/mapper/test_lidar_segmentation_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
import os

from data_juicer.core import NestedDataset as Dataset
from data_juicer.ops.mapper.lidar_segmentation_mapper import LiDARSegmentationMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
from data_juicer.utils.constant import Fields, MetaKeys


class LiDARSegmentationMapperTest(DataJuicerTestCaseBase):
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')
lidar_test1 = os.path.join(data_path, 'lidar_test1.bin')
lidar_test2 = os.path.join(data_path, 'lidar_test2.bin')
lidar_test3 = os.path.join(data_path, 'lidar_test3.bin')

model_cfg_name = "cylinder3d_8xb2-laser-polar-mix-3x_semantickitti"
model_path = "cylinder3d_8xb2-amp-laser-polar-mix-3x_semantickitti_20230425_144950-372cdf69.pth"


def setUp(self):
super().setUp()
self.source = [
{'lidar': self.lidar_test1},
{'lidar': self.lidar_test2},
{'lidar': self.lidar_test3}
]
self.op = LiDARSegmentationMapper(
model_name="cylinder3d",
model_cfg_name=self.model_cfg_name,
model_path=self.model_path,
)

def _run_and_assert(self, num_proc, with_rank):
dataset = Dataset.from_list(self.source)
if Fields.meta not in dataset.features:
dataset = dataset.add_column(name=Fields.meta,
column=[{}] * dataset.num_rows)
dataset = dataset.map(self.op.process, num_proc=num_proc, with_rank=with_rank)
res_list = dataset.to_list()

self.assertEqual(len(res_list), 3)
self.assertEqual(list(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags].keys()), ['box_type_3d', 'pts_semantic_mask'])
self.assertEqual(len(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags]["pts_semantic_mask"]), 17238)
self.assertEqual(res_list[0][Fields.meta][MetaKeys.lidar_segmentation_tags]["box_type_3d"], "LiDAR")

def test_cpu(self):
self._run_and_assert(num_proc=1, with_rank=False)

def test_cuda(self):
self._run_and_assert(num_proc=1, with_rank=True)

def test_cpu_mul_proc(self):
self._run_and_assert(num_proc=2, with_rank=False)

def test_cuda_mul_proc(self):
self._run_and_assert(num_proc=2, with_rank=True)


if __name__ == "__main__":
unittest.main()
Loading