Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions docs/video_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The following parameters can be configured in the `dataset_config` section of yo

- **`video_backend`** (Optional[str], default: "qwen_vl_utils")
- Specifies the backend to use for video loading
- Available options: `"decord"`, `"qwen_vl_utils (recommended)"`
- Available options: `"decord"`, `"qwen_vl_utils"`, `"qwen_omni_utils"`
- Note: The `"torchvision"` backend has been removed. See [Migration Guide](#migration-from-torchvision-backend) below.

- **`video_sampling_strategy`** (Optional[str], default: "fps")
Expand Down Expand Up @@ -107,6 +107,7 @@ The `torchvision` video backend has been removed since it was implemented as a f
3. **Verify compatibility:**
- `decord` naive decord video loading, used in load from cloud storage
- `qwen_vl_utils` is optimized for Qwen models and provides additional features
- `qwen_omni_utils` supports audio extraction from videos for Qwen Omni variants

## Training Performance Optimization

Expand Down Expand Up @@ -151,4 +152,29 @@ This periodically clears the CUDA memory cache to prevent fragmentation during l

4. **Optimize sampling strategy:**
- Use `"fps"` for videos with consistent motion
- Use `"frame_num"` when you need exactly N frames regardless of video length
- Use `"frame_num"` when you need exactly N frames regardless of video length

5. **Audio extraction from videos:**
- Use `video_backend: "qwen_omni_utils"` with `use_audio_in_video: true` in processor config to extract audio from video files

## Audio from Video Extraction

When training Qwen Omni models, you can extract audio tracks from video files automatically.

### Configuration

```yaml
dataset_config:
dataset_type: vision_audio
video_backend: "qwen_omni_utils"
video_sampling_strategy: "fps"
fps: 1
video_max_frames: 60

processor_config:
processor_name: "Qwen/Qwen2.5-Omni-7B"
processor_type: "Qwen2_5OmniProcessor"
extra_kwargs:
use_audio_in_video: true
audio_max_length: 60
```
204 changes: 204 additions & 0 deletions examples/qwen2_5_omni.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
trainer_type: fsdp2_trainer

dataset_config:
extra_kwargs:
use_audio_in_video: true
dataset_type: qwen_omni
dataset_format: json
processor_config:
processor_name: Qwen/Qwen2.5-Omni-7B
processor_type: Qwen2_5OmniProcessor
audio_max_length: 60
video_max_pixels: 602112
video_min_pixels: 28800
image_max_pixels: 602112
image_min_pixels: 28800
dataset_path: /path/to/your/dataset.json
datasets: null
shuffle: true
eval_dataset_path: null
object_storage: none
bucket_name: null
packing: false
packing_strategy: first_fit
packing_length: 51200
filter_overlong: true
filter_overlong_workers: 8
max_length: null
video_sampling_strategy: fps
video_max_pixels: 602112
video_max_frames: 512
frame_num: 64
fps: 1
video_backend: qwen_omni_utils

trainer_args:
output_dir: ./output/qwen_omni_debug
overwrite_output_dir: false
do_train: false
do_eval: false
do_predict: false
eval_strategy: steps
prediction_loss_only: false
per_device_train_batch_size: 1
per_device_eval_batch_size: 8
per_gpu_train_batch_size: null
per_gpu_eval_batch_size: null
gradient_accumulation_steps: 1
eval_accumulation_steps: null
eval_delay: 0
torch_empty_cache_steps: 10
learning_rate: 0.00001
weight_decay: 0.0
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1.0e-08
max_grad_norm: 1.0
num_train_epochs: 1
max_steps: 5
lr_scheduler_type: constant
lr_scheduler_kwargs: {}
warmup_ratio: 0.0
warmup_steps: 0
log_level: passive
log_level_replica: warning
log_on_each_node: true
logging_dir: ./output/qwen_omni_debug/runs
logging_strategy: steps
logging_first_step: false
logging_steps: 1
logging_nan_inf_filter: true
save_strategy: steps
save_steps: 1
save_total_limit: 1
save_safetensors: true
save_on_each_node: false
save_only_model: false
restore_callback_states_from_checkpoint: false
no_cuda: false
use_cpu: false
use_mps_device: false
seed: 42
data_seed: null
jit_mode_eval: false
bf16: true
fp16: false
fp16_opt_level: O1
half_precision_backend: auto
bf16_full_eval: false
fp16_full_eval: false
tf32: null
local_rank: 0
ddp_backend: null
tpu_num_cores: null
tpu_metrics_debug: false
debug: []
dataloader_drop_last: false
eval_steps: null
dataloader_num_workers: 0
dataloader_prefetch_factor: null
past_index: -1
run_name: qwen2_5_omni
disable_tqdm: false
remove_unused_columns: true
label_names: null
load_best_model_at_end: false
metric_for_best_model: null
greater_is_better: null
ignore_data_skip: false
fsdp: []
fsdp_min_num_params: 0
fsdp_config:
transformer_layer_cls_to_wrap:
- Qwen2_5OmniDecoderLayer
reshard_after_forward: false
min_num_params: 0
xla: false
xla_fsdp_v2: false
xla_fsdp_grad_ckpt: false
fsdp_transformer_layer_cls_to_wrap: null
accelerator_config:
split_batches: false
dispatch_batches: null
even_batches: true
use_seedable_sampler: true
non_blocking: false
gradient_accumulation_kwargs: null
parallelism_config: null
deepspeed: null
label_smoothing_factor: 0.0
optim: adamw_torch_fused
optim_args: null
adafactor: false
group_by_length: false
length_column_name: length
report_to:
- wandb
project: huggingface
trackio_space_id: trackio
ddp_find_unused_parameters: null
ddp_bucket_cap_mb: null
ddp_broadcast_buffers: null
dataloader_pin_memory: true
dataloader_persistent_workers: false
skip_memory_metrics: true
use_legacy_prediction_loop: false
push_to_hub: false
resume_from_checkpoint: null
hub_model_id: null
hub_strategy: every_save
hub_token: <HUB_TOKEN>
hub_private_repo: null
hub_always_push: false
hub_revision: null
gradient_checkpointing: true
gradient_checkpointing_kwargs: null
include_inputs_for_metrics: false
include_for_metrics: []
eval_do_concat_batches: true
fp16_backend: auto
push_to_hub_model_id: null
push_to_hub_organization: null
mp_parameters: ''
auto_find_batch_size: false
full_determinism: false
torchdynamo: null
ray_scope: last
ddp_timeout: 1800
torch_compile: false
torch_compile_backend: null
torch_compile_mode: null
include_tokens_per_second: false
include_num_input_tokens_seen: 'no'
neftune_noise_alpha: null
optim_target_modules: null
batch_eval_metrics: false
eval_on_start: false
use_liger_kernel: true
liger_kernel_config: null
eval_use_gather_object: false
average_tokens_across_devices: true
use_muon: false
freeze_modules: null
use_rmpad: true
fsdp2: true
sp_ulysses_degree: 1
reduce_dtype: bfloat16
output_dtype: bfloat16
print_batch_input_steps: 5
enable_profiler: false
profiler_config:
start_step: 1
end_step: 3

model_config:
extra_kwargs: {}
load_from_pretrained_path: ngqtrung/Qwen2.5-Omni-Thinker-7B
load_from_config: null
attn_implementation: flash_attention_2
model_type: qwen2_5_omni
torch_dtype: bfloat16
overwrite_config: null
monkey_patch_kwargs: null

extra_kwargs: null
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"torchdata",
"liger-kernel>=0.6.1",
"qwen-vl-utils",
"qwen-omni-utils",
"einops",
"pandas",
"rich>=14.1.0",
Expand Down
11 changes: 9 additions & 2 deletions src/lmms_engine/datasets/collator/vision_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
if "attention_mask" in inputs.keys():
inputs.pop("attention_mask")

attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id)
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
batched_inputs["attention_mask"] = attention_mask

# for the other keys
for key, values in inputs.items():
batched_inputs[key] = torch.concatenate(values, dim=0)
# Handle scalar/boolean values ( use_audio_in_video)
if isinstance(values[0], bool) or (
isinstance(values[0], (int, float))
and not isinstance(values[0], torch.Tensor)
):
batched_inputs[key] = values[0]
else:
batched_inputs[key] = torch.concatenate(values, dim=0)
return batched_inputs

@property
Expand Down
6 changes: 4 additions & 2 deletions src/lmms_engine/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class DatasetConfig(Args):
video_min_pixels: Optional[int] = 3136
frame_num: Optional[int] = 64
fps: Optional[int] = 1
video_backend: Optional[Literal["decord", "qwen_vl_utils"]] = "qwen_vl_utils"
video_backend: Optional[
Literal["decord", "qwen_vl_utils", "qwen_omni_utils"]
] = "qwen_vl_utils"

@field_validator(
"video_max_pixels",
Expand All @@ -63,7 +65,7 @@ def validate_video_backend_migration(cls, v):
if v == "torchvision":
raise ValueError(
"The 'torchvision' video backend has been removed. "
"Please use 'decord' or 'qwen_vl_utils' instead. "
"Please use 'decord', 'qwen_vl_utils', or 'qwen_omni_utils' instead. "
"Migration guide: If you were using torchvision, 'decord' provides "
"similar functionality with better performance."
)
Expand Down
63 changes: 63 additions & 0 deletions src/lmms_engine/datasets/multimodal_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
except ImportError:
logger.info("qwen_vl_utils not installed. Skipping import.")

try:
from qwen_omni_utils import process_mm_info
except ImportError:
process_mm_info = None
Logging.info("qwen_omni_utils not installed. Skipping import.")


class MultiModalDataLoadingMixin:
"""
Expand Down Expand Up @@ -128,6 +134,8 @@ def load_videos(
return self.load_video_decord(video_path, fps)
elif self.config.video_backend == "qwen_vl_utils":
return self.load_video_qwen_vl_utils(video_path, fps)
elif self.config.video_backend == "qwen_omni_utils":
return self.load_video_qwen_omni_utils(video_path, fps)
else:
raise ValueError(f"Video backend {self.config.video_backend} not supported")

Expand Down Expand Up @@ -216,3 +224,58 @@ def load_video_qwen_vl_utils(
raise ValueError(
f"Invalid video sampling strategy: {self.config.video_sampling_strategy}"
)

def load_video_qwen_omni_utils(
self,
video_path: str,
fps: int,
) -> Tuple[np.ndarray, float]:
"""
Load video using Qwen Omni utils with audio extraction support.

Args:
video_path: Path to video file
fps: Target frames per second

Returns:
Tuple of (video frames, sample fps)

Note:
When use_audio_in_video is True, audio is stored in self.video_extracted_audio
for later processing in the dataset.
"""
messages = [
{
"role": "user",
"content": [{"type": "video", "video": f"file://{video_path}"}],
}
]
use_audio_in_video = self.config.extra_kwargs.get("use_audio_in_video", False)
audios, _, videos = process_mm_info(
messages, use_audio_in_video=use_audio_in_video
)

if use_audio_in_video and audios and len(audios) > 0:
if not hasattr(self, "video_extracted_audio"):
self.video_extracted_audio = {}
self.video_extracted_audio[video_path] = audios[0]

if videos and len(videos) > 0:
video_frames = videos[0]
if isinstance(video_frames, torch.Tensor):
video_frames = video_frames.numpy()
elif not isinstance(video_frames, np.ndarray):
video_frames = np.array(video_frames)
if self.config.video_sampling_strategy == "frame_num":
if len(video_frames) > self.config.frame_num:
indices = np.linspace(
0, len(video_frames) - 1, self.config.frame_num, dtype=int
)
video_frames = video_frames[indices]
sample_fps = fps
else:
sample_fps = fps

return video_frames, sample_fps
else:
raise ValueError("No video frames returned from process_mm_info")
Loading