Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions docs/video_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The following parameters can be configured in the `dataset_config` section of yo

- **`video_backend`** (Optional[str], default: "qwen_vl_utils")
- Specifies the backend to use for video loading
- Available options: `"decord"`, `"qwen_vl_utils (recommended)"`
- Available options: `"decord"`, `"qwen_vl_utils"`, `"qwen_omni_utils"`
- Note: The `"torchvision"` backend has been removed. See [Migration Guide](#migration-from-torchvision-backend) below.

- **`video_sampling_strategy`** (Optional[str], default: "fps")
Expand Down Expand Up @@ -107,6 +107,7 @@ The `torchvision` video backend has been removed since it was implemented as a f
3. **Verify compatibility:**
- `decord` naive decord video loading, used in load from cloud storage
- `qwen_vl_utils` is optimized for Qwen models and provides additional features
- `qwen_omni_utils` supports audio extraction from videos for Qwen Omni variants

## Training Performance Optimization

Expand Down Expand Up @@ -151,4 +152,29 @@ This periodically clears the CUDA memory cache to prevent fragmentation during l

4. **Optimize sampling strategy:**
- Use `"fps"` for videos with consistent motion
- Use `"frame_num"` when you need exactly N frames regardless of video length
- Use `"frame_num"` when you need exactly N frames regardless of video length

5. **Audio extraction from videos:**
- Use `video_backend: "qwen_omni_utils"` with `use_audio_in_video: true` in processor config to extract audio from video files

## Audio from Video Extraction

When training Qwen Omni models, you can extract audio tracks from video files automatically.

### Configuration

```yaml
dataset_config:
dataset_type: vision_audio
video_backend: "qwen_omni_utils"
video_sampling_strategy: "fps"
fps: 1
video_max_frames: 60

processor_config:
processor_name: "Qwen/Qwen2.5-Omni-7B"
processor_type: "Qwen2_5OmniProcessor"
extra_kwargs:
use_audio_in_video: true
audio_max_length: 60
```
64 changes: 64 additions & 0 deletions examples/qwen2_5_omni_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
- type: trainer
config:
trainer_type: hf_trainer

dataset_config:
dataset_type: vision_audio
dataset_format: json
dataset_path:
eval_dataset_path:

processor_config:
processor_name: "Qwen/Qwen2.5-Omni-7B"
processor_type: "Qwen2_5OmniProcessor"
extra_kwargs:
audio_max_length: 60
video_max_pixels: 602112
video_min_pixels: 28800
image_max_pixels: 602112
image_min_pixels: 28800

extra_kwargs:
use_audio_in_video: true

video_backend: "qwen_omni_utils"
video_sampling_strategy: "fps"
fps: 1
video_max_frames: 60

model_config:
load_from_pretrained_path: "Qwen/Qwen2.5-Omni-7B"
attn_implementation: "flash_attention_2"
model_type: "qwen2_5_omni"
torch_dtype: "bfloat16"

per_device_train_batch_size: 1
per_device_eval_batch_size: 1
learning_rate: 1.0e-05
weight_decay: 0.0
gradient_accumulation_steps: 16
gradient_checkpointing: true
num_train_epochs: 1
save_steps: 100
save_total_limit: 3
report_to: "wandb"
output_dir: ""
warmup_ratio: 0.0
run_name: ""
eval_strategy: "steps"
eval_steps: 100
logging_steps: 5
group_by_length: false
dataloader_num_workers: 0
bf16: true
lr_scheduler_type: "constant"
torch_empty_cache_steps: 10
use_liger_kernel: true
use_rmpad: true

fsdp: "full_shard"
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniAudioEncoderLayer", "Qwen2_5OmniVisionBlock"]
fsdp_backward_prefetch: "backward_pre"
fsdp_state_dict_type: "sharded_state_dict"
fsdp_use_orig_params: true
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"flash-attn>=2.8.3",
"liger-kernel>=0.6.1",
"qwen-vl-utils",
"qwen-omni-utils",
"einops",
"pandas",
"rich>=14.1.0",
Expand Down
11 changes: 9 additions & 2 deletions src/lmms_engine/datasets/collator/vision_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
if "attention_mask" in inputs.keys():
inputs.pop("attention_mask")

attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
batched_inputs["attention_mask"] = attention_mask

# for the other keys
for key, values in inputs.items():
batched_inputs[key] = torch.concatenate(values, dim=0)
# Handle scalar/boolean values ( use_audio_in_video)
if isinstance(values[0], bool) or (
isinstance(values[0], (int, float))
and not isinstance(values[0], torch.Tensor)
):
batched_inputs[key] = values[0]
else:
batched_inputs[key] = torch.concatenate(values, dim=0)
return batched_inputs

@property
Expand Down
6 changes: 4 additions & 2 deletions src/lmms_engine/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class DatasetConfig(Args):
video_max_frames: Optional[int] = 768
frame_num: Optional[int] = 64
fps: Optional[int] = 1
video_backend: Optional[Literal["decord", "qwen_vl_utils"]] = "qwen_vl_utils"
video_backend: Optional[
Literal["decord", "qwen_vl_utils", "qwen_omni_utils"]
] = "qwen_vl_utils"

@field_validator(
"video_max_pixels",
Expand All @@ -62,7 +64,7 @@ def validate_video_backend_migration(cls, v):
if v == "torchvision":
raise ValueError(
"The 'torchvision' video backend has been removed. "
"Please use 'decord' or 'qwen_vl_utils' instead. "
"Please use 'decord', 'qwen_vl_utils', or 'qwen_omni_utils' instead. "
"Migration guide: If you were using torchvision, 'decord' provides "
"similar functionality with better performance."
)
Expand Down
63 changes: 63 additions & 0 deletions src/lmms_engine/datasets/multimodal_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
except ImportError:
logger.info("qwen_vl_utils not installed. Skipping import.")

try:
from qwen_omni_utils import process_mm_info
except ImportError:
process_mm_info = None
Logging.info("qwen_omni_utils not installed. Skipping import.")


class MultiModalDataLoadingMixin:
"""
Expand Down Expand Up @@ -128,6 +134,8 @@ def load_videos(
return self.load_video_decord(video_path, fps)
elif self.config.video_backend == "qwen_vl_utils":
return self.load_video_qwen_vl_utils(video_path, fps)
elif self.config.video_backend == "qwen_omni_utils":
return self.load_video_qwen_omni_utils(video_path, fps)
else:
raise ValueError(f"Video backend {self.config.video_backend} not supported")

Expand Down Expand Up @@ -215,3 +223,58 @@ def load_video_qwen_vl_utils(
raise ValueError(
f"Invalid video sampling strategy: {self.config.video_sampling_strategy}"
)

def load_video_qwen_omni_utils(
self,
video_path: str,
fps: int,
) -> Tuple[np.ndarray, float]:
"""
Load video using Qwen Omni utils with audio extraction support.

Args:
video_path: Path to video file
fps: Target frames per second

Returns:
Tuple of (video frames, sample fps)

Note:
When use_audio_in_video is True, audio is stored in self.video_extracted_audio
for later processing in the dataset.
"""
messages = [
{
"role": "user",
"content": [{"type": "video", "video": f"file://{video_path}"}],
}
]
use_audio_in_video = self.config.extra_kwargs.get("use_audio_in_video", False)
audios, _, videos = process_mm_info(
messages, use_audio_in_video=use_audio_in_video
)

if use_audio_in_video and audios and len(audios) > 0:
if not hasattr(self, "video_extracted_audio"):
self.video_extracted_audio = {}
self.video_extracted_audio[video_path] = audios[0]

if videos and len(videos) > 0:
video_frames = videos[0]
if isinstance(video_frames, torch.Tensor):
video_frames = video_frames.numpy()
elif not isinstance(video_frames, np.ndarray):
video_frames = np.array(video_frames)
if self.config.video_sampling_strategy == "frame_num":
if len(video_frames) > self.config.frame_num:
indices = np.linspace(
0, len(video_frames) - 1, self.config.frame_num, dtype=int
)
video_frames = video_frames[indices]
sample_fps = fps
else:
sample_fps = fps

return video_frames, sample_fps
else:
raise ValueError("No video frames returned from process_mm_info")
59 changes: 57 additions & 2 deletions src/lmms_engine/datasets/naive/vision_audio_dataset.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably refactor this changes into a new qwen omni dataset that inherit from this and override and keep vision_audio_dataset unchanged

Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict

import torch
Expand Down Expand Up @@ -31,8 +32,13 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
)
new_content.append(content)
elif content["type"] == "audio_url":
audio_url = content["audio_url"]["url"]
# Skip placeholders from video extraction - they're handled by video processing
if audio_url == "from_video":
continue

loaded_audios = self.load_audio(
content["audio_url"]["url"],
audio_url,
sr=self.processor.sampling_rate,
data_folder=data_folder,
)
Expand All @@ -53,14 +59,63 @@ def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]:
audios.extend(audio_splits)
elif content["type"] == "video_url":
# Loading videos with fps
video_url = content["video_url"]["url"]
if data_folder is not None:
video_path = os.path.join(data_folder, video_url)
else:
video_path = video_url
frames, sample_fps = self.load_videos(
content["video_url"]["url"],
video_url,
data_folder=data_folder,
fps=self.config.fps,
)
videos.append(frames)
# Update kwargs
kwargs["fps"] = sample_fps

# check if audio was extracted from video for Qwen2.5-Omni
if (
hasattr(self, "video_extracted_audio")
and video_path in self.video_extracted_audio
):
extracted_audio = self.video_extracted_audio[video_path]
kwargs["use_audio_in_video"] = True

if hasattr(self.processor, "sampling_rate"):
max_audio_samples = (
MAX_AUDIO_LENGTH * self.processor.sampling_rate
)
# minimum audio length (2 seconds) to avoid pooling errors
min_audio_samples = 2 * self.processor.sampling_rate

audio_splits = []
for i in range(0, len(extracted_audio), max_audio_samples):
audio_chunk = extracted_audio[i : i + max_audio_samples]
if len(audio_chunk) >= min_audio_samples:
audio_splits.append(audio_chunk)

audios.extend(audio_splits)

# Add audio placeholders to content if audio was extracted for processor compatibility
for _ in range(len(audio_splits)):
new_content.append(
{
"type": "audio_url",
"audio_url": {"url": "from_video"},
}
)
else:
audios.append(extracted_audio)
new_content.append(
{
"type": "audio_url",
"audio_url": {"url": "from_video"},
}
)
del self.video_extracted_audio[video_path]
else:
kwargs["use_audio_in_video"] = False

new_content.append(content)
else:
new_content.append(content)
Expand Down
2 changes: 2 additions & 0 deletions src/lmms_engine/datasets/processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .config import ProcessorConfig
from .llava_processor import LLaVADataProcessor
from .pure_text_processor import PureTextDataProcessor
from .qwen2_5_omni_processor import Qwen2_5OmniDataProcessor
from .qwen2_5_vl_processor import Qwen2_5_VLDataProcessor
from .qwen2_processor import Qwen2DataProcessor
from .qwen2_vl_processor import Qwen2VLDataProcessor
Expand All @@ -14,6 +15,7 @@
"AeroDataProcessor",
"BaseQwen2_5_DataProcessor",
"LLaVADataProcessor",
"Qwen2_5OmniDataProcessor",
"Qwen2_5_VLDataProcessor",
"Qwen2VLDataProcessor",
"WanVideoDataProcessor",
Expand Down
Loading