Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions docs/video_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The following parameters can be configured in the `dataset_config` section of yo

- **`video_backend`** (Optional[str], default: "qwen_vl_utils")
- Specifies the backend to use for video loading
- Available options: `"decord"`, `"qwen_vl_utils (recommended)"`
- Available options: `"decord"`, `"qwen_vl_utils"`, `"qwen_omni_utils"`
- Note: The `"torchvision"` backend has been removed. See [Migration Guide](#migration-from-torchvision-backend) below.

- **`video_sampling_strategy`** (Optional[str], default: "fps")
Expand Down Expand Up @@ -107,6 +107,7 @@ The `torchvision` video backend has been removed since it was implemented as a f
3. **Verify compatibility:**
- `decord` naive decord video loading, used in load from cloud storage
- `qwen_vl_utils` is optimized for Qwen models and provides additional features
- `qwen_omni_utils` supports audio extraction from videos for Qwen Omni variants

## Training Performance Optimization

Expand Down Expand Up @@ -151,4 +152,29 @@ This periodically clears the CUDA memory cache to prevent fragmentation during l

4. **Optimize sampling strategy:**
- Use `"fps"` for videos with consistent motion
- Use `"frame_num"` when you need exactly N frames regardless of video length
- Use `"frame_num"` when you need exactly N frames regardless of video length

5. **Audio extraction from videos:**
- Use `video_backend: "qwen_omni_utils"` with `use_audio_in_video: true` in processor config to extract audio from video files

## Audio from Video Extraction

When training Qwen Omni models, you can extract audio tracks from video files automatically.

### Configuration

```yaml
dataset_config:
dataset_type: vision_audio
video_backend: "qwen_omni_utils"
video_sampling_strategy: "fps"
fps: 1
video_max_frames: 60

processor_config:
processor_name: "Qwen/Qwen2.5-Omni-7B"
processor_type: "Qwen2_5OmniProcessor"
extra_kwargs:
use_audio_in_video: true
audio_max_length: 60
```
57 changes: 57 additions & 0 deletions examples/qwen2_5_omni.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
- type: trainer
config:
trainer_type: fsdp2_trainer

dataset_config:
dataset_type: qwen_omni
dataset_format: json
dataset_path:

processor_config:
processor_name: "Qwen/Qwen2.5-Omni-7B"
processor_type: "Qwen2_5OmniProcessor"
extra_kwargs:
audio_max_length: 60
use_audio_in_video: true
video_max_pixels: 602112
video_min_pixels: 28800
image_max_pixels: 602112
image_min_pixels: 28800

extra_kwargs:
use_audio_in_video: true

video_backend: "qwen_omni_utils"
video_sampling_strategy: "fps"
fps: 1

model_config:
load_from_pretrained_path: "ngqtrung/Qwen2.5-Omni-Thinker-7B"
attn_implementation: "flash_attention_2"
model_type: "qwen2_5_omni"
torch_dtype: "bfloat16"

per_device_train_batch_size: 1
learning_rate: 1.0e-05
weight_decay: 0.0
gradient_accumulation_steps: 1
gradient_checkpointing: true
num_train_epochs: 1
save_steps: 1
save_total_limit: 1
report_to: "wandb"
output_dir: ""
warmup_ratio: 0.0
run_name: "qwen2_5_omni"
eval_strategy: "steps"
logging_steps: 1
group_by_length: false
dataloader_num_workers: 0
bf16: true
lr_scheduler_type: "constant"
max_steps: 5
torch_empty_cache_steps: 10
use_liger_kernel: true
use_rmpad: true
sp_ulysses_degree: 1
fsdp2: true
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"flash-attn>=2.8.3",
"liger-kernel>=0.6.1",
"qwen-vl-utils",
"qwen-omni-utils",
"einops",
"pandas",
"rich>=14.1.0",
Expand Down
11 changes: 9 additions & 2 deletions src/lmms_engine/datasets/collator/vision_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
if "attention_mask" in inputs.keys():
inputs.pop("attention_mask")

attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
batched_inputs["attention_mask"] = attention_mask

# for the other keys
for key, values in inputs.items():
batched_inputs[key] = torch.concatenate(values, dim=0)
# Handle scalar/boolean values ( use_audio_in_video)
if isinstance(values[0], bool) or (
isinstance(values[0], (int, float))
and not isinstance(values[0], torch.Tensor)
):
batched_inputs[key] = values[0]
else:
batched_inputs[key] = torch.concatenate(values, dim=0)
return batched_inputs

@property
Expand Down
6 changes: 4 additions & 2 deletions src/lmms_engine/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class DatasetConfig(Args):
video_max_frames: Optional[int] = 768
frame_num: Optional[int] = 64
fps: Optional[int] = 1
video_backend: Optional[Literal["decord", "qwen_vl_utils"]] = "qwen_vl_utils"
video_backend: Optional[
Literal["decord", "qwen_vl_utils", "qwen_omni_utils"]
] = "qwen_vl_utils"

@field_validator(
"video_max_pixels",
Expand All @@ -62,7 +64,7 @@ def validate_video_backend_migration(cls, v):
if v == "torchvision":
raise ValueError(
"The 'torchvision' video backend has been removed. "
"Please use 'decord' or 'qwen_vl_utils' instead. "
"Please use 'decord', 'qwen_vl_utils', or 'qwen_omni_utils' instead. "
"Migration guide: If you were using torchvision, 'decord' provides "
"similar functionality with better performance."
)
Expand Down
63 changes: 63 additions & 0 deletions src/lmms_engine/datasets/multimodal_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
except ImportError:
logger.info("qwen_vl_utils not installed. Skipping import.")

try:
from qwen_omni_utils import process_mm_info
except ImportError:
process_mm_info = None
Logging.info("qwen_omni_utils not installed. Skipping import.")


class MultiModalDataLoadingMixin:
"""
Expand Down Expand Up @@ -128,6 +134,8 @@ def load_videos(
return self.load_video_decord(video_path, fps)
elif self.config.video_backend == "qwen_vl_utils":
return self.load_video_qwen_vl_utils(video_path, fps)
elif self.config.video_backend == "qwen_omni_utils":
return self.load_video_qwen_omni_utils(video_path, fps)
else:
raise ValueError(f"Video backend {self.config.video_backend} not supported")

Expand Down Expand Up @@ -215,3 +223,58 @@ def load_video_qwen_vl_utils(
raise ValueError(
f"Invalid video sampling strategy: {self.config.video_sampling_strategy}"
)

def load_video_qwen_omni_utils(
self,
video_path: str,
fps: int,
) -> Tuple[np.ndarray, float]:
"""
Load video using Qwen Omni utils with audio extraction support.

Args:
video_path: Path to video file
fps: Target frames per second

Returns:
Tuple of (video frames, sample fps)

Note:
When use_audio_in_video is True, audio is stored in self.video_extracted_audio
for later processing in the dataset.
"""
messages = [
{
"role": "user",
"content": [{"type": "video", "video": f"file://{video_path}"}],
}
]
use_audio_in_video = self.config.extra_kwargs.get("use_audio_in_video", False)
audios, _, videos = process_mm_info(
messages, use_audio_in_video=use_audio_in_video
)

if use_audio_in_video and audios and len(audios) > 0:
if not hasattr(self, "video_extracted_audio"):
self.video_extracted_audio = {}
self.video_extracted_audio[video_path] = audios[0]

if videos and len(videos) > 0:
video_frames = videos[0]
if isinstance(video_frames, torch.Tensor):
video_frames = video_frames.numpy()
elif not isinstance(video_frames, np.ndarray):
video_frames = np.array(video_frames)
if self.config.video_sampling_strategy == "frame_num":
if len(video_frames) > self.config.frame_num:
indices = np.linspace(
0, len(video_frames) - 1, self.config.frame_num, dtype=int
)
video_frames = video_frames[indices]
sample_fps = fps
else:
sample_fps = fps

return video_frames, sample_fps
else:
raise ValueError("No video frames returned from process_mm_info")
2 changes: 2 additions & 0 deletions src/lmms_engine/datasets/naive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_dataset import BaseDataset
from .multimodal_dataset import MultiModalDataset
from .qwen_omni_dataset import QwenOmniSFTDataset
from .vision_audio_dataset import VisionAudioSFTDataset
from .vision_dataset import VisionSFTDataset

Expand All @@ -8,4 +9,5 @@
"MultiModalDataset",
"VisionSFTDataset",
"VisionAudioSFTDataset",
"QwenOmniSFTDataset",
]
136 changes: 136 additions & 0 deletions src/lmms_engine/datasets/naive/qwen_omni_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from typing import Dict

import torch

from lmms_engine.mapping_func import register_dataset
from lmms_engine.utils.train_utils import TrainUtilities

from .vision_audio_dataset import MAX_AUDIO_LENGTH, VisionAudioSFTDataset


@register_dataset("qwen_omni")
class QwenOmniSFTDataset(VisionAudioSFTDataset):
def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
images = []
audios = []
videos = []
messages = data["messages"]
new_messages = []
kwargs = {}
for message in messages:
new_content = []
for idx, content in enumerate(message["content"]):
if content["type"] == "image_url":
images.append(
self.load_image(
content["image_url"]["url"], data_folder=data_folder
)
)
new_content.append(content)
elif content["type"] == "audio_url":
audio_url = content["audio_url"]["url"]
# Skip placeholders from video extraction, they handled by video processing
if audio_url == "from_video":
continue

loaded_audios = self.load_audio(
audio_url,
sr=self.processor.sampling_rate,
data_folder=data_folder,
)
audio_splits = []
# Split the loaded audio to 30s chunks and extend the messages content
for i in range(
0,
len(loaded_audios),
MAX_AUDIO_LENGTH * self.processor.sampling_rate,
):
audio_splits.append(
loaded_audios[
i : i + MAX_AUDIO_LENGTH * self.processor.sampling_rate
]
)
for _ in range(len(audio_splits)):
new_content.append(content)
audios.extend(audio_splits)
elif content["type"] == "video_url":
video_url = content["video_url"]["url"]
if data_folder is not None:
video_path = os.path.join(data_folder, video_url)
else:
video_path = video_url
frames, sample_fps = self.load_videos(
video_url,
data_folder=data_folder,
fps=self.config.fps,
)
videos.append(frames)
kwargs["fps"] = sample_fps

# check if audio was extracted from video
if (
hasattr(self, "video_extracted_audio")
and video_path in self.video_extracted_audio
):
extracted_audio = self.video_extracted_audio[video_path]
kwargs["use_audio_in_video"] = True

if hasattr(self.processor, "sampling_rate"):
max_audio_samples = (
MAX_AUDIO_LENGTH * self.processor.sampling_rate
)
# minimum audio length (2 seconds) to avoid pooling errors
min_audio_samples = 2 * self.processor.sampling_rate

audio_splits = []
for i in range(0, len(extracted_audio), max_audio_samples):
audio_chunk = extracted_audio[i : i + max_audio_samples]
if len(audio_chunk) >= min_audio_samples:
audio_splits.append(audio_chunk)

audios.extend(audio_splits)

# audio placeholders to content if audio was extracted for processor compatibility
for _ in range(len(audio_splits)):
new_content.append(
{
"type": "audio_url",
"audio_url": {"url": "from_video"},
}
)
else:
audios.append(extracted_audio)
new_content.append(
{
"type": "audio_url",
"audio_url": {"url": "from_video"},
}
)
del self.video_extracted_audio[video_path]
else:
kwargs["use_audio_in_video"] = False

new_content.append(content)
else:
new_content.append(content)
message["content"] = new_content
new_messages.append(message)
messages = new_messages

hf_messages = TrainUtilities.convert_open_to_hf(messages)
if len(images) == 0:
images = None
if len(audios) == 0:
audios = None
if len(videos) == 0:
videos = None
inputs = self.processor.process(
images=images,
hf_messages=hf_messages,
audios=audios,
videos=videos,
sampling_rate=self.processor.sampling_rate,
**kwargs,
)
return inputs
Loading