Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions angelslim/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@

from .dataloader import DataLoaderFactory # noqa: F401
from .multimodal_dataset import MultiModalDataset # noqa: F401
from .omni_dataset import OmniDataset # noqa: F401
from .text2image_dataset import Text2ImageDataset # noqa: F401
from .text_dataset import TextDataset # noqa: F401
12 changes: 12 additions & 0 deletions angelslim/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .base_dataset import BaseDataset
from .multimodal_dataset import MultiModalDataset
from .omni_dataset import OmniDataset
from .text2image_dataset import Text2ImageDataset
from .text_dataset import TextDataset

Expand All @@ -39,6 +40,7 @@ def create_data_loader(
data_type: str = "auto",
num_workers: int = 0,
inference_settings: Dict = None,
use_audio_in_video: bool = False,
model_name: str = None,
) -> DataLoader:
"""
Expand Down Expand Up @@ -98,6 +100,16 @@ def create_data_loader(
num_samples=num_samples,
inference_settings=inference_settings,
)
elif data_type == "OmniDataset":
dataset = OmniDataset(
processor=processor,
device=device,
max_length=max_length,
num_samples=num_samples,
data_source=data_source,
is_hf_dataset=not os.path.isfile(data_source),
use_audio_in_video=use_audio_in_video,
)
else:
raise ValueError(f"Unsupported data type: {data_type}")

Expand Down
127 changes: 127 additions & 0 deletions angelslim/data/omni_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from pathlib import Path
from typing import Dict, List, Union

from qwen_omni_utils import process_mm_info
from transformers import ProcessorMixin

from .base_dataset import BaseDataset


class OmniDataset(BaseDataset):
"""Dataset for multimodal (text + image) data"""

def __init__(
self,
processor: ProcessorMixin,
device: str = "cpu",
max_length: int = 4096,
num_samples: int = -1,
data_source: Union[str, Dict] = None,
is_hf_dataset: bool = False,
use_audio_in_video: bool = False,
):
super().__init__(processor, device, max_length)
self.is_hf_dataset = is_hf_dataset
self.use_audio_in_video = use_audio_in_video

self._load_file_based_dataset(data_source, num_samples)

def _load_file_based_dataset(self, data_path: str, num_samples: int):
"""Load dataset from local file system"""
path_obj = Path(data_path)
data_dir = path_obj.parent

line_count = 0
with open(data_path, "r") as f:
for line in f:
if num_samples > 0 and line_count >= num_samples:
break
data = json.loads(line.strip())
video_path = None
audio_path = None
image_path = None

if "video_path" in data:
video_path = os.path.normpath(
os.path.join(data_dir, data["video_path"])
)
if "audio_path" in data:
audio_path = os.path.normpath(
os.path.join(data_dir, data["audio_path"])
)
if "image_path" in data:
image_path = os.path.normpath(
os.path.join(data_dir, data["image_path"])
)

ms = data.get("messages")

conversation = []
for m in ms:
if m["role"] == "system":
conversation.append(
{
"role": "system",
"content": [{"type": "text", "text": m["content"]}],
}
)
elif m["role"] == "user":
content = []
text_content = m["content"]
text_content = (
text_content.replace("<video>", "")
.replace("<audio>", "")
.replace("<image>", "")
)
content.append({"type": "text", "text": text_content})
if video_path:
content.append({"type": "video", "video": video_path})
if audio_path:
content.append({"type": "audio", "audio": audio_path})
if image_path:
content.append({"type": "image", "image": image_path})
conversation.append(
{
"role": "user",
"content": content,
}
)
self._process_and_append(conversation)
line_count += 1

def _process_and_append(self, messages: List[Dict]):
"""Process messages and append to dataset"""
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
audios, images, videos = process_mm_info(
messages, use_audio_in_video=self.use_audio_in_video
)

# Process inputs
inputs = self.processor(
text=text,
images=images,
audios=audios,
videos=videos,
padding=True,
return_tensors="pt",
use_audio_in_video=self.use_audio_in_video,
)
self.data.append(inputs)
17 changes: 15 additions & 2 deletions angelslim/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def prepare_model(
cache_dir=None,
deploy_backend="vllm",
using_multi_nodes=False,
use_audio_in_video=False,
) -> Any:
"""Load pretrained model and tokenizer
Args:
Expand Down Expand Up @@ -116,6 +117,16 @@ def prepare_model(
using_multi_nodes=using_multi_nodes,
)
self.model_path = model_path
elif self.series in ["Omni"]:
if not model:
self.slim_model.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code,
use_audio_in_video=use_audio_in_video,
)
self.model_path = model_path
else:
raise ValueError(f"Unsupported series: {self.series}")

Expand All @@ -131,6 +142,7 @@ def prepare_data(
num_samples=128,
shuffle=True,
inference_settings=None,
use_audio_in_video=False,
model_name=None,
) -> Optional[Any]:
"""Prepare compression dataset"""
Expand All @@ -145,7 +157,7 @@ def prepare_data(
data_type=data_type,
processor=(
self.slim_model.processor
if self.series == "VLM"
if self.series == "VLM" or self.series == "Omni"
else self.slim_model.tokenizer
),
device=self.slim_model.model.device,
Expand All @@ -155,6 +167,7 @@ def prepare_data(
num_samples=num_samples,
data_source=data_path,
inference_settings=inference_settings,
use_audio_in_video=use_audio_in_video,
model_name=model_name,
)
self.max_seq_length = max_length
Expand Down Expand Up @@ -187,7 +200,7 @@ def prepare_compressor(
f"Compression method '{method_name}' not registered. "
f"Available methods: {CompressorFactory.get_available_compressor()}"
)
if self.series in ["LLM", "VLM"]:
if self.series in ["LLM", "VLM", "Omni"]:
global_config.update(self.model_path, self.max_seq_length)

if default_method:
Expand Down
1 change: 1 addition & 0 deletions angelslim/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .diffusion import * # noqa: F401 F403
from .llm import * # noqa: F401 F403
from .model_factory import SlimModelFactory # noqa: F401
from .omni import * # noqa: F401 F403
from .vlm import * # noqa: F401 F403
4 changes: 3 additions & 1 deletion angelslim/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SlimModelFactory:
registry: Dict[str, Type] = {}
series_registry: Dict[str, str] = {}

ALLOWED_SERIES = ("LLM", "VLM", "Diffusion")
ALLOWED_SERIES = ("LLM", "VLM", "Diffusion", "Omni")

@classmethod
def register(cls, model_class: Type) -> Type:
Expand All @@ -39,6 +39,8 @@ def register(cls, model_class: Type) -> Type:
series = "VLM"
elif "diffusion" in module_path:
series = "Diffusion"
elif "omni" in module_path:
series = "Omni"
else:
raise ValueError(
f"model_class '{class_name}' is not in a valid series: {cls.ALLOWED_SERIES}" # noqa: E501
Expand Down
16 changes: 16 additions & 0 deletions angelslim/models/omni/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .qwen_omni import Qwen_Omni # noqa: F401
Loading