diff --git a/docs/video_configuration.md b/docs/video_configuration.md index 45d474d6..429009a3 100644 --- a/docs/video_configuration.md +++ b/docs/video_configuration.md @@ -10,7 +10,7 @@ The following parameters can be configured in the `dataset_config` section of yo - **`video_backend`** (Optional[str], default: "qwen_vl_utils") - Specifies the backend to use for video loading - - Available options: `"decord"`, `"qwen_vl_utils (recommended)"` + - Available options: `"decord"`, `"qwen_vl_utils"`, `"qwen_omni_utils"` - Note: The `"torchvision"` backend has been removed. See [Migration Guide](#migration-from-torchvision-backend) below. - **`video_sampling_strategy`** (Optional[str], default: "fps") @@ -107,6 +107,7 @@ The `torchvision` video backend has been removed since it was implemented as a f 3. **Verify compatibility:** - `decord` naive decord video loading, used in load from cloud storage - `qwen_vl_utils` is optimized for Qwen models and provides additional features + - `qwen_omni_utils` supports audio extraction from videos for Qwen Omni variants ## Training Performance Optimization @@ -151,4 +152,29 @@ This periodically clears the CUDA memory cache to prevent fragmentation during l 4. **Optimize sampling strategy:** - Use `"fps"` for videos with consistent motion - - Use `"frame_num"` when you need exactly N frames regardless of video length \ No newline at end of file + - Use `"frame_num"` when you need exactly N frames regardless of video length + +5. **Audio extraction from videos:** + - Use `video_backend: "qwen_omni_utils"` with `use_audio_in_video: true` in processor config to extract audio from video files + +## Audio from Video Extraction + +When training Qwen Omni models, you can extract audio tracks from video files automatically. + +### Configuration + +```yaml +dataset_config: + dataset_type: vision_audio + video_backend: "qwen_omni_utils" + video_sampling_strategy: "fps" + fps: 1 + video_max_frames: 60 + + processor_config: + processor_name: "Qwen/Qwen2.5-Omni-7B" + processor_type: "Qwen2_5OmniProcessor" + extra_kwargs: + use_audio_in_video: true + audio_max_length: 60 +``` \ No newline at end of file diff --git a/examples/qwen2_5_omni.yaml b/examples/qwen2_5_omni.yaml new file mode 100644 index 00000000..7d11c64a --- /dev/null +++ b/examples/qwen2_5_omni.yaml @@ -0,0 +1,204 @@ +trainer_type: fsdp2_trainer + +dataset_config: + extra_kwargs: + use_audio_in_video: true + dataset_type: qwen_omni + dataset_format: json + processor_config: + processor_name: Qwen/Qwen2.5-Omni-7B + processor_type: Qwen2_5OmniProcessor + audio_max_length: 60 + video_max_pixels: 602112 + video_min_pixels: 28800 + image_max_pixels: 602112 + image_min_pixels: 28800 + dataset_path: /path/to/your/dataset.json + datasets: null + shuffle: true + eval_dataset_path: null + object_storage: none + bucket_name: null + packing: false + packing_strategy: first_fit + packing_length: 51200 + filter_overlong: true + filter_overlong_workers: 8 + max_length: null + video_sampling_strategy: fps + video_max_pixels: 602112 + video_max_frames: 512 + frame_num: 64 + fps: 1 + video_backend: qwen_omni_utils + +trainer_args: + output_dir: ./output/qwen_omni_debug + overwrite_output_dir: false + do_train: false + do_eval: false + do_predict: false + eval_strategy: steps + prediction_loss_only: false + per_device_train_batch_size: 1 + per_device_eval_batch_size: 8 + per_gpu_train_batch_size: null + per_gpu_eval_batch_size: null + gradient_accumulation_steps: 1 + eval_accumulation_steps: null + eval_delay: 0 + torch_empty_cache_steps: 10 + learning_rate: 0.00001 + weight_decay: 0.0 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_epsilon: 1.0e-08 + max_grad_norm: 1.0 + num_train_epochs: 1 + max_steps: 5 + lr_scheduler_type: constant + lr_scheduler_kwargs: {} + warmup_ratio: 0.0 + warmup_steps: 0 + log_level: passive + log_level_replica: warning + log_on_each_node: true + logging_dir: ./output/qwen_omni_debug/runs + logging_strategy: steps + logging_first_step: false + logging_steps: 1 + logging_nan_inf_filter: true + save_strategy: steps + save_steps: 1 + save_total_limit: 1 + save_safetensors: true + save_on_each_node: false + save_only_model: false + restore_callback_states_from_checkpoint: false + no_cuda: false + use_cpu: false + use_mps_device: false + seed: 42 + data_seed: null + jit_mode_eval: false + bf16: true + fp16: false + fp16_opt_level: O1 + half_precision_backend: auto + bf16_full_eval: false + fp16_full_eval: false + tf32: null + local_rank: 0 + ddp_backend: null + tpu_num_cores: null + tpu_metrics_debug: false + debug: [] + dataloader_drop_last: false + eval_steps: null + dataloader_num_workers: 0 + dataloader_prefetch_factor: null + past_index: -1 + run_name: qwen2_5_omni + disable_tqdm: false + remove_unused_columns: true + label_names: null + load_best_model_at_end: false + metric_for_best_model: null + greater_is_better: null + ignore_data_skip: false + fsdp: [] + fsdp_min_num_params: 0 + fsdp_config: + transformer_layer_cls_to_wrap: + - Qwen2_5OmniDecoderLayer + reshard_after_forward: false + min_num_params: 0 + xla: false + xla_fsdp_v2: false + xla_fsdp_grad_ckpt: false + fsdp_transformer_layer_cls_to_wrap: null + accelerator_config: + split_batches: false + dispatch_batches: null + even_batches: true + use_seedable_sampler: true + non_blocking: false + gradient_accumulation_kwargs: null + parallelism_config: null + deepspeed: null + label_smoothing_factor: 0.0 + optim: adamw_torch_fused + optim_args: null + adafactor: false + group_by_length: false + length_column_name: length + report_to: + - wandb + project: huggingface + trackio_space_id: trackio + ddp_find_unused_parameters: null + ddp_bucket_cap_mb: null + ddp_broadcast_buffers: null + dataloader_pin_memory: true + dataloader_persistent_workers: false + skip_memory_metrics: true + use_legacy_prediction_loop: false + push_to_hub: false + resume_from_checkpoint: null + hub_model_id: null + hub_strategy: every_save + hub_token: + hub_private_repo: null + hub_always_push: false + hub_revision: null + gradient_checkpointing: true + gradient_checkpointing_kwargs: null + include_inputs_for_metrics: false + include_for_metrics: [] + eval_do_concat_batches: true + fp16_backend: auto + push_to_hub_model_id: null + push_to_hub_organization: null + mp_parameters: '' + auto_find_batch_size: false + full_determinism: false + torchdynamo: null + ray_scope: last + ddp_timeout: 1800 + torch_compile: false + torch_compile_backend: null + torch_compile_mode: null + include_tokens_per_second: false + include_num_input_tokens_seen: 'no' + neftune_noise_alpha: null + optim_target_modules: null + batch_eval_metrics: false + eval_on_start: false + use_liger_kernel: true + liger_kernel_config: null + eval_use_gather_object: false + average_tokens_across_devices: true + use_muon: false + freeze_modules: null + use_rmpad: true + fsdp2: true + sp_ulysses_degree: 1 + reduce_dtype: bfloat16 + output_dtype: bfloat16 + print_batch_input_steps: 5 + enable_profiler: false + profiler_config: + start_step: 1 + end_step: 3 + +model_config: + extra_kwargs: {} + load_from_pretrained_path: ngqtrung/Qwen2.5-Omni-Thinker-7B + load_from_config: null + attn_implementation: flash_attention_2 + model_type: qwen2_5_omni + torch_dtype: bfloat16 + overwrite_config: null + monkey_patch_kwargs: null + +extra_kwargs: null diff --git a/pyproject.toml b/pyproject.toml index e04b176f..ca5e14dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "torchdata", "liger-kernel>=0.6.1", "qwen-vl-utils", + "qwen-omni-utils", "einops", "pandas", "rich>=14.1.0", diff --git a/src/lmms_engine/datasets/collator/vision_collator.py b/src/lmms_engine/datasets/collator/vision_collator.py index c0fbde25..0ed75a76 100644 --- a/src/lmms_engine/datasets/collator/vision_collator.py +++ b/src/lmms_engine/datasets/collator/vision_collator.py @@ -52,12 +52,19 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: if "attention_mask" in inputs.keys(): inputs.pop("attention_mask") - attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) + attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long() batched_inputs["attention_mask"] = attention_mask # for the other keys for key, values in inputs.items(): - batched_inputs[key] = torch.concatenate(values, dim=0) + # Handle scalar/boolean values ( use_audio_in_video) + if isinstance(values[0], bool) or ( + isinstance(values[0], (int, float)) + and not isinstance(values[0], torch.Tensor) + ): + batched_inputs[key] = values[0] + else: + batched_inputs[key] = torch.concatenate(values, dim=0) return batched_inputs @property diff --git a/src/lmms_engine/datasets/config.py b/src/lmms_engine/datasets/config.py index 3fa63e73..d9bb59a9 100644 --- a/src/lmms_engine/datasets/config.py +++ b/src/lmms_engine/datasets/config.py @@ -37,7 +37,9 @@ class DatasetConfig(Args): video_min_pixels: Optional[int] = 3136 frame_num: Optional[int] = 64 fps: Optional[int] = 1 - video_backend: Optional[Literal["decord", "qwen_vl_utils"]] = "qwen_vl_utils" + video_backend: Optional[ + Literal["decord", "qwen_vl_utils", "qwen_omni_utils"] + ] = "qwen_vl_utils" @field_validator( "video_max_pixels", @@ -63,7 +65,7 @@ def validate_video_backend_migration(cls, v): if v == "torchvision": raise ValueError( "The 'torchvision' video backend has been removed. " - "Please use 'decord' or 'qwen_vl_utils' instead. " + "Please use 'decord', 'qwen_vl_utils', or 'qwen_omni_utils' instead. " "Migration guide: If you were using torchvision, 'decord' provides " "similar functionality with better performance." ) diff --git a/src/lmms_engine/datasets/multimodal_mixin.py b/src/lmms_engine/datasets/multimodal_mixin.py index 08446001..14a476f0 100644 --- a/src/lmms_engine/datasets/multimodal_mixin.py +++ b/src/lmms_engine/datasets/multimodal_mixin.py @@ -25,6 +25,12 @@ except ImportError: logger.info("qwen_vl_utils not installed. Skipping import.") +try: + from qwen_omni_utils import process_mm_info +except ImportError: + process_mm_info = None + logger.info("qwen_omni_utils not installed. Skipping import.") + class MultiModalDataLoadingMixin: """ @@ -128,6 +134,8 @@ def load_videos( return self.load_video_decord(video_path, fps) elif self.config.video_backend == "qwen_vl_utils": return self.load_video_qwen_vl_utils(video_path, fps) + elif self.config.video_backend == "qwen_omni_utils": + return self.load_video_qwen_omni_utils(video_path, fps) else: raise ValueError(f"Video backend {self.config.video_backend} not supported") @@ -216,3 +224,58 @@ def load_video_qwen_vl_utils( raise ValueError( f"Invalid video sampling strategy: {self.config.video_sampling_strategy}" ) + + def load_video_qwen_omni_utils( + self, + video_path: str, + fps: int, + ) -> Tuple[np.ndarray, float]: + """ + Load video using Qwen Omni utils with audio extraction support. + + Args: + video_path: Path to video file + fps: Target frames per second + + Returns: + Tuple of (video frames, sample fps) + + Note: + When use_audio_in_video is True, audio is stored in self.video_extracted_audio + for later processing in the dataset. + """ + messages = [ + { + "role": "user", + "content": [{"type": "video", "video": f"file://{video_path}"}], + } + ] + use_audio_in_video = self.config.extra_kwargs.get("use_audio_in_video", False) + audios, _, videos = process_mm_info( + messages, use_audio_in_video=use_audio_in_video + ) + + if use_audio_in_video and audios and len(audios) > 0: + if not hasattr(self, "video_extracted_audio"): + self.video_extracted_audio = {} + self.video_extracted_audio[video_path] = audios[0] + + if videos and len(videos) > 0: + video_frames = videos[0] + if isinstance(video_frames, torch.Tensor): + video_frames = video_frames.numpy() + elif not isinstance(video_frames, np.ndarray): + video_frames = np.array(video_frames) + if self.config.video_sampling_strategy == "frame_num": + if len(video_frames) > self.config.frame_num: + indices = np.linspace( + 0, len(video_frames) - 1, self.config.frame_num, dtype=int + ) + video_frames = video_frames[indices] + sample_fps = fps + else: + sample_fps = fps + + return video_frames, sample_fps + else: + raise ValueError("No video frames returned from process_mm_info") diff --git a/src/lmms_engine/datasets/naive/__init__.py b/src/lmms_engine/datasets/naive/__init__.py index daa4aa6e..0033842d 100644 --- a/src/lmms_engine/datasets/naive/__init__.py +++ b/src/lmms_engine/datasets/naive/__init__.py @@ -1,5 +1,6 @@ from .base_dataset import BaseDataset from .multimodal_dataset import MultiModalDataset +from .qwen_omni_dataset import QwenOmniSFTDataset from .rae_dataset import RaeDataset from .sit_dataset import SitDataset from .vision_audio_dataset import VisionAudioSFTDataset @@ -10,6 +11,7 @@ "MultiModalDataset", "VisionSFTDataset", "VisionAudioSFTDataset", + "QwenOmniSFTDataset", "RaeDataset", "SitDataset", ] diff --git a/src/lmms_engine/datasets/naive/qwen_omni_dataset.py b/src/lmms_engine/datasets/naive/qwen_omni_dataset.py new file mode 100644 index 00000000..d0dad0e9 --- /dev/null +++ b/src/lmms_engine/datasets/naive/qwen_omni_dataset.py @@ -0,0 +1,136 @@ +import os +from typing import Dict + +import torch + +from lmms_engine.mapping_func import register_dataset +from lmms_engine.utils.train_utils import TrainUtilities + +from .vision_audio_dataset import MAX_AUDIO_LENGTH, VisionAudioSFTDataset + + +@register_dataset("qwen_omni") +class QwenOmniSFTDataset(VisionAudioSFTDataset): + def load_from_json(self, data, data_folder=None) -> Dict[str, torch.Tensor]: + images = [] + audios = [] + videos = [] + messages = data["messages"] + new_messages = [] + kwargs = {} + for message in messages: + new_content = [] + for idx, content in enumerate(message["content"]): + if content["type"] == "image_url": + images.append( + self.load_image( + content["image_url"]["url"], data_folder=data_folder + ) + ) + new_content.append(content) + elif content["type"] == "audio_url": + audio_url = content["audio_url"]["url"] + # Skip placeholders from video extraction, they handled by video processing + if audio_url == "from_video": + continue + + loaded_audios = self.load_audio( + audio_url, + sr=self.processor.sampling_rate, + data_folder=data_folder, + ) + audio_splits = [] + # Split the loaded audio to 30s chunks and extend the messages content + for i in range( + 0, + len(loaded_audios), + MAX_AUDIO_LENGTH * self.processor.sampling_rate, + ): + audio_splits.append( + loaded_audios[ + i : i + MAX_AUDIO_LENGTH * self.processor.sampling_rate + ] + ) + for _ in range(len(audio_splits)): + new_content.append(content) + audios.extend(audio_splits) + elif content["type"] == "video_url": + video_url = content["video_url"]["url"] + if data_folder is not None: + video_path = os.path.join(data_folder, video_url) + else: + video_path = video_url + frames, sample_fps = self.load_videos( + video_url, + data_folder=data_folder, + fps=self.config.fps, + ) + videos.append(frames) + kwargs["fps"] = sample_fps + + # check if audio was extracted from video + if ( + hasattr(self, "video_extracted_audio") + and video_path in self.video_extracted_audio + ): + extracted_audio = self.video_extracted_audio[video_path] + kwargs["use_audio_in_video"] = True + + if hasattr(self.processor, "sampling_rate"): + max_audio_samples = ( + MAX_AUDIO_LENGTH * self.processor.sampling_rate + ) + # minimum audio length (2 seconds) to avoid pooling errors + min_audio_samples = 2 * self.processor.sampling_rate + + audio_splits = [] + for i in range(0, len(extracted_audio), max_audio_samples): + audio_chunk = extracted_audio[i : i + max_audio_samples] + if len(audio_chunk) >= min_audio_samples: + audio_splits.append(audio_chunk) + + audios.extend(audio_splits) + + # audio placeholders to content if audio was extracted for processor compatibility + for _ in range(len(audio_splits)): + new_content.append( + { + "type": "audio_url", + "audio_url": {"url": "from_video"}, + } + ) + else: + audios.append(extracted_audio) + new_content.append( + { + "type": "audio_url", + "audio_url": {"url": "from_video"}, + } + ) + del self.video_extracted_audio[video_path] + else: + kwargs["use_audio_in_video"] = False + + new_content.append(content) + else: + new_content.append(content) + message["content"] = new_content + new_messages.append(message) + messages = new_messages + + hf_messages = TrainUtilities.convert_open_to_hf(messages) + if len(images) == 0: + images = None + if len(audios) == 0: + audios = None + if len(videos) == 0: + videos = None + inputs = self.processor.process( + images=images, + hf_messages=hf_messages, + audios=audios, + videos=videos, + sampling_rate=self.processor.sampling_rate, + **kwargs, + ) + return inputs diff --git a/src/lmms_engine/datasets/processor/__init__.py b/src/lmms_engine/datasets/processor/__init__.py index 08101d43..84f20368 100644 --- a/src/lmms_engine/datasets/processor/__init__.py +++ b/src/lmms_engine/datasets/processor/__init__.py @@ -4,6 +4,7 @@ from .config import ProcessorConfig from .llava_processor import LLaVADataProcessor from .pure_text_processor import PureTextDataProcessor +from .qwen2_5_omni_processor import Qwen2_5OmniDataProcessor from .qwen2_5_vl_processor import Qwen2_5_VLDataProcessor from .qwen2_processor import Qwen2DataProcessor from .qwen2_vl_processor import Qwen2VLDataProcessor @@ -17,6 +18,7 @@ "AeroDataProcessor", "BaseQwen2_5_DataProcessor", "LLaVADataProcessor", + "Qwen2_5OmniDataProcessor", "Qwen2_5_VLDataProcessor", "Qwen2VLDataProcessor", "WanVideoDataProcessor", diff --git a/src/lmms_engine/datasets/processor/qwen2_5_omni_processor.py b/src/lmms_engine/datasets/processor/qwen2_5_omni_processor.py new file mode 100644 index 00000000..79a383e3 --- /dev/null +++ b/src/lmms_engine/datasets/processor/qwen2_5_omni_processor.py @@ -0,0 +1,200 @@ +from typing import List, Optional + +import numpy as np +import torch +from PIL.Image import Image +from transformers import Qwen2_5OmniProcessor +from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( + Qwen2_5OmniProcessorKwargs, +) + +from lmms_engine.mapping_func import register_processor + +from .base_qwen2_5_processor import BaseQwen2_5_DataProcessor + + +@register_processor("Qwen2_5OmniProcessor") +class Qwen2_5OmniDataProcessor(BaseQwen2_5_DataProcessor): + def _build_processor(self): + model_path = getattr(self.config, "processor_path", self.config.processor_name) + processor = Qwen2_5OmniProcessor.from_pretrained( + model_path, trust_remote_code=True, local_files_only=False + ) + + # Set image processor parameters + image_max_pixels = self.config.extra_kwargs.get("image_max_pixels", None) + image_min_pixels = self.config.extra_kwargs.get("image_min_pixels", None) + if image_max_pixels: + processor.image_processor.max_pixels = image_max_pixels + if image_min_pixels: + processor.image_processor.min_pixels = image_min_pixels + + # Set video processor parameters + video_max_pixels = self.config.extra_kwargs.get("video_max_pixels", None) + video_min_pixels = self.config.extra_kwargs.get("video_min_pixels", None) + if video_max_pixels: + processor.video_processor.max_pixels = video_max_pixels + if video_min_pixels: + processor.video_processor.min_pixels = video_min_pixels + + # Set audio processor parameters + audio_max_length = self.config.extra_kwargs.get("audio_max_length", None) + if audio_max_length and hasattr(processor, "audio_processor"): + processor.audio_processor.max_length = audio_max_length + + return processor + + def build(self): + self.processor = self._build_processor() + + @property + def audio_processor(self): + # For Qwen2.5Omni, audio processing is done via feature_extractor + # Create a wrapper to make it compatible with parent's expectations + return self.processor.feature_extractor + + @property + def audio_token_id(self): + # Return the audio token ID if processor has one + if hasattr(self.processor, "audio_token_id"): + return self.processor.audio_token_id + # Fallback: try to get from tokenizer + if hasattr(self.tokenizer, "audio_token_id"): + return self.tokenizer.audio_token_id + # Try to convert the audio token string to ID + if hasattr(self.processor, "audio_token") and self.processor.audio_token: + return self.tokenizer.convert_tokens_to_ids(self.processor.audio_token) + return None + + @property + def tokenizer(self): + return self.processor.tokenizer + + @property + def sampling_rate(self): + # Qwen2.5Omni uses feature_extractor instead of audio_processor + return self.processor.feature_extractor.sampling_rate + + def process( + self, + images: List[Image], + hf_messages, + audios: Optional[List[np.ndarray]] = None, + sampling_rate: Optional[int] = None, + videos=None, + system_message: str = "You are a helpful assistant", + add_system_prompt=True, + add_generation_prompt=False, + **kwargs, + ): + if hasattr(self.processor, "_merge_kwargs"): + output_kwargs = self.processor._merge_kwargs( + Qwen2_5OmniProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + else: + output_kwargs = kwargs + + image_inputs = {} + videos_inputs = {} + audio_inputs = {} + + if images is not None: + new_images = [] + for image in images: + height = image.size[0] + width = image.size[1] + if width < 28 and height < 28: + image = image.resize((28, 28)) + elif height < 28: + image = image.resize((28, width)) + elif width < 28: + image = image.resize((height, 28)) + new_images.append(image) + images = new_images + image_inputs = self.processor.image_processor( + images, return_tensors="pt", **output_kwargs["images_kwargs"] + ) + image_inputs["image_sizes"] = image_inputs.pop("image_grid_thw") + merge_size = self.processor.image_processor.merge_size + num_image_tokens = [ + (image_size[-2] * image_size[-1]).item() // (merge_size**2) + for image_size in image_inputs["image_sizes"] + ] + else: + num_image_tokens = None + + if videos is not None: + videos_inputs = self.processor.video_processor( + videos=videos, + **output_kwargs["videos_kwargs"], + return_tensors="pt", + ) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [ + self.processor.video_processor.temporal_patch_size / fps + ] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [ + self.processor.video_processor.temporal_patch_size / tmp + for tmp in fps + ] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update( + {"video_second_per_grid": torch.tensor(second_per_grid_ts)} + ) + merge_length = self.processor.video_processor.merge_size**2 + num_video_tokens = [ + (video_grid_thw[index].prod() // merge_length) + for index in range(len(video_grid_thw)) + ] + else: + num_video_tokens = None + + if audios is not None: + audio_inputs = self.audio_processor( + audios, + sampling_rate=sampling_rate, + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + **kwargs, + ) + audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask") + audio_inputs["audio_feature_lengths"] = ( + audio_inputs["feature_attention_mask"].sum(-1) - 1 + ) // 2 + 1 + num_audio_tokens = (audio_inputs["audio_feature_lengths"] - 2) // 2 + 1 + else: + num_audio_tokens = None + + inputs = self.get_qwen_template_labels( + hf_messages, + num_image_tokens, + num_audio_tokens, + num_video_tokens, + system_message=system_message, + add_system_prompt=add_system_prompt, + add_generation_prompt=add_generation_prompt, + ) + if images is not None: + inputs["pixel_values"] = image_inputs["pixel_values"] + inputs["image_grid_thw"] = image_inputs["image_sizes"] + if audios is not None: + inputs["input_features"] = audio_inputs["input_features"] + inputs["feature_attention_mask"] = audio_inputs["feature_attention_mask"] + inputs["audio_feature_lengths"] = audio_inputs["audio_feature_lengths"] + if videos is not None: + for key, value in videos_inputs.items(): + inputs[key] = value + if "use_audio_in_video" in kwargs: + inputs["use_audio_in_video"] = kwargs["use_audio_in_video"] + + return inputs diff --git a/src/lmms_engine/mapping_func.py b/src/lmms_engine/mapping_func.py index 1f95b216..895894c7 100644 --- a/src/lmms_engine/mapping_func.py +++ b/src/lmms_engine/mapping_func.py @@ -61,13 +61,14 @@ def register_model( "causal_lm", "masked_lm", "image_text_to_text", "general" ] = "causal_lm", ): - AutoConfig.register(model_type, model_config) + AutoConfig.register(model_type, model_config, exist_ok=True) AUTO_REGISTER_MODEL_MAPPING[model_general_type].register(model_config, model_class) def create_model_from_pretrained(load_from_pretrained_path): # Handle both config object and model name/path config = AutoConfig.from_pretrained(load_from_pretrained_path) + if type(config) in AutoModelForCausalLM._model_mapping.keys(): model_class = AutoModelForCausalLM elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): diff --git a/src/lmms_engine/models/__init__.py b/src/lmms_engine/models/__init__.py index 29db94c6..7eba1668 100644 --- a/src/lmms_engine/models/__init__.py +++ b/src/lmms_engine/models/__init__.py @@ -4,6 +4,11 @@ from .llava_onevision import apply_liger_kernel_to_llava_onevision from .monkey_patch import MONKEY_PATCHER from .qwen2 import apply_liger_kernel_to_qwen2 +from .qwen2_5_omni import ( + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniThinkerForConditionalGeneration, + apply_liger_kernel_to_qwen2_5_omni, +) from .qwen2_5_vl import apply_liger_kernel_to_qwen2_5_vl from .qwen2_audio import apply_liger_kernel_to_qwen2_audio from .qwen3_dllm import Qwen3DLLMConfig, Qwen3DLLMForMaskedLM @@ -25,6 +30,9 @@ "AeroProcessor", "apply_liger_kernel_to_llava_onevision", "apply_liger_kernel_to_qwen2", + "Qwen2_5OmniThinkerConfig", + "Qwen2_5OmniThinkerForConditionalGeneration", + "apply_liger_kernel_to_qwen2_5_omni", "apply_liger_kernel_to_qwen2_5_vl", "apply_liger_kernel_to_qwen2_audio", "apply_liger_kernel_to_qwen3_vl", diff --git a/src/lmms_engine/models/aero/modeling_aero.py b/src/lmms_engine/models/aero/modeling_aero.py index 67476ef4..dba02626 100644 --- a/src/lmms_engine/models/aero/modeling_aero.py +++ b/src/lmms_engine/models/aero/modeling_aero.py @@ -44,7 +44,9 @@ logger = logging.get_logger(__name__) -AutoConfig.register("qwen2_5_omni_audio_encoder", Qwen2_5OmniAudioEncoderConfig) +AutoConfig.register( + "qwen2_5_omni_audio_encoder", Qwen2_5OmniAudioEncoderConfig, exist_ok=True +) AutoModel.register(Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoder) diff --git a/src/lmms_engine/models/qwen2_5_omni/__init__.py b/src/lmms_engine/models/qwen2_5_omni/__init__.py new file mode 100644 index 00000000..fa1bf05a --- /dev/null +++ b/src/lmms_engine/models/qwen2_5_omni/__init__.py @@ -0,0 +1,23 @@ +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniThinkerConfig, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniThinkerForConditionalGeneration, +) + +from lmms_engine.mapping_func import register_model + +from .monkey_patch import apply_liger_kernel_to_qwen2_5_omni + +register_model( + "qwen2_5_omni_thinker", + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniThinkerForConditionalGeneration, + model_general_type="causal_lm", +) + +__all__ = [ + "apply_liger_kernel_to_qwen2_5_omni", + "Qwen2_5OmniThinkerConfig", + "Qwen2_5OmniThinkerForConditionalGeneration", +] diff --git a/src/lmms_engine/models/qwen2_5_omni/monkey_patch.py b/src/lmms_engine/models/qwen2_5_omni/monkey_patch.py new file mode 100644 index 00000000..912b8c0f --- /dev/null +++ b/src/lmms_engine/models/qwen2_5_omni/monkey_patch.py @@ -0,0 +1,162 @@ +import inspect +from functools import partial, wraps +from typing import Callable + +from packaging import version + +try: + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + from liger_kernel.transformers.functional import liger_cross_entropy + from liger_kernel.transformers.geglu import LigerGEGLUMLP + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.model.qwen2 import ( + lce_forward_deprecated as qwen2_lce_forward_deprecated, + ) + from liger_kernel.transformers.monkey_patch import ( + _patch_layer_norm_module, + _patch_rms_norm_module, + _patch_swiglu_module, + ) + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +except: + print( + "liger kernel not installed, please install it with `pip install liger-kernel`" + ) + +import transformers +from transformers import PreTrainedModel +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder, + Qwen2_5OmniThinkerForConditionalGeneration, + Qwen2_5OmniThinkerTextModel, + Qwen2_5OmniVisionEncoder, +) + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + get_ulysses_sequence_parallel_world_size, + patch_vlm_for_ulysses_input_slicing, +) + +transformer_version = version.parse(transformers.__version__) +SUPPORTED_TRANSFORMER_VERSION = "4.46.1" +TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191" + +from lmms_engine.models.monkey_patch import MONKEY_PATCHER +from lmms_engine.utils.logging_utils import Logging + + +@MONKEY_PATCHER.register("qwen2_5_omni", "liger") +@MONKEY_PATCHER.register("qwen2_5_omni_thinker", "liger") +def apply_liger_kernel_to_qwen2_5_omni( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + layer_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, + use_rmpad: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-Omni models. + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + use_rmpad (bool): Whether to use remove padding optimization. Default is False. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni + + from .qwen2_5_omni_liger import lce_forward as qwen2_5_omni_lce_forward + + def wrap_forward(func): + @wraps(func) + def wrapper(*args, **kwargs): + kwargs.setdefault("use_rmpad", use_rmpad) + return func(*args, **kwargs) + + return wrapper + + qwen2_5_omni_lce_forward = wrap_forward(qwen2_5_omni_lce_forward) + if rope: + Logging.warning("RoPE optimization not supported for Qwen2.5-Omni, skipping") + if rms_norm: + modeling_qwen2_5_omni.Qwen2RMSNorm = LigerRMSNorm + if cross_entropy: + modeling_qwen2_5_omni.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + modeling_qwen2_5_omni.Qwen2_5OmniThinkerForConditionalGeneration.forward = ( + qwen2_5_omni_lce_forward + ) + if swiglu: + modeling_qwen2_5_omni.Qwen2MLP = LigerSwiGLUMLP + if use_rmpad: + from .qwen2_5_omni_ops import attn_forward as qwen2_5_omni_attn_forward + from .qwen2_5_omni_ops import ( + decoder_layer_forward as qwen2_5_omni_decoder_layer_forward, + ) + from .qwen2_5_omni_ops import ( + text_model_forward as qwen2_5_omni_text_model_forward, + ) + + modeling_qwen2_5_omni.Qwen2_5OmniThinkerTextModel.forward = ( + qwen2_5_omni_text_model_forward + ) + modeling_qwen2_5_omni.Qwen2_5OmniDecoderLayer.forward = ( + qwen2_5_omni_decoder_layer_forward + ) + modeling_qwen2_5_omni.Qwen2_5OmniAttention.forward = qwen2_5_omni_attn_forward + + if get_ulysses_sequence_parallel_world_size() > 1: + patch_vlm_for_ulysses_input_slicing( + modeling_qwen2_5_omni.Qwen2_5OmniThinkerTextModel + ) + + if model is not None: + if isinstance(model, Qwen2_5OmniThinkerForConditionalGeneration): + text_model: Qwen2_5OmniThinkerTextModel = model.model + vision_model: Qwen2_5OmniVisionEncoder = model.visual + audio_model: Qwen2_5OmniAudioEncoder = model.audio_tower + elif isinstance(model, Qwen2_5OmniThinkerTextModel): + text_model: Qwen2_5OmniThinkerTextModel = model + vision_model = None + audio_model = None + else: + raise TypeError( + f"Unsupported Qwen2.5-Omni model type. `model` must be " + f"`Qwen2_5OmniThinkerForConditionalGeneration` or `Qwen2_5OmniThinkerTextModel`. " + f"Got: {type(model)}. " + f"If you have the full model, extract the thinker using scripts/extract_qwen_omni_thinker.py" + ) + + if vision_model is not None and rms_norm: + for vision_block in vision_model.blocks: + _patch_rms_norm_module(vision_block.norm1) + _patch_rms_norm_module(vision_block.norm2) + if audio_model is not None and layer_norm: + if hasattr(audio_model, "layers"): + for audio_layer in audio_model.layers: + _patch_layer_norm_module(audio_layer.self_attn_layer_norm) + _patch_layer_norm_module(audio_layer.final_layer_norm) + if text_model is not None: + if rms_norm: + _patch_rms_norm_module(text_model.norm) + for decoder_layer in text_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) diff --git a/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_liger.py b/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_liger.py new file mode 100644 index 00000000..8f5e1f3e --- /dev/null +++ b/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_liger.py @@ -0,0 +1,317 @@ +from typing import List, Optional, Tuple, Union + +import torch +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniThinkerCausalLMOutputWithPast, + Qwen2_5OmniThinkerForConditionalGeneration, +) +from transformers.utils import is_flash_attn_2_available + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + calculate_seq_len_per_rank, + get_ulysses_sequence_parallel_world_size, + slice_input_tensor, + ulysses_pad, +) +from lmms_engine.utils import Logging + +from ..sequence_packing_utils import _unpad_input + +if is_flash_attn_2_available(): + try: + from einops import rearrange + from flash_attn.bert_padding import index_first_axis + except: + raise ModuleNotFoundError( + "flash_attn is not available. Please install it via `pip install flash_attn`." + ) + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except: + print("Liger Kernel is not installed, pip install liger-kernel to use this patch") + + +def lce_forward( + self: Qwen2_5OmniThinkerForConditionalGeneration, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.Tensor] = None, + use_audio_in_video: Optional[bool] = None, + use_rmpad: Optional[bool] = False, + **kwargs, +) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + tokens_count = attention_mask.sum().item() + n_image_tokens = ( + (input_ids == self.config.image_token_id).sum().item() + if hasattr(self.config, "image_token_id") + else 0 + ) + n_video_tokens = ( + (input_ids == self.config.video_token_id).sum().item() + if hasattr(self.config, "video_token_id") + else 0 + ) + n_audio_tokens = ( + (input_ids == self.config.audio_token_id).sum().item() + if hasattr(self.config, "audio_token_id") + else 0 + ) + visual_tokens = n_image_tokens + n_video_tokens + + cu_seq_lens = None + indices = None + original_input_ids = None + if use_rmpad and attention_mask is not None: + # input_ids is 2D [batch, seq_len] + original_input_ids = input_ids + # unpad input_ids: 2D [batch, seq_len] -> 1D [total_non_pad_tokens] + input_ids, indices, cu_seq_lens, _ = _unpad_input( + input_ids, attention_mask=attention_mask + ) + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + batch_size, seq_length = original_input_ids.shape + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + # get_rope_index expects RAW audio feature lengths before any downsampling. + # Processor provides audio_feature_lengths after first downsampling. + # Reconstruct raw length: raw = (audio_feature_lengths - 1) * 2 + 1 + if audio_feature_lengths is not None: + audio_raw_lengths = (audio_feature_lengths - 1) * 2 + 1 + else: + audio_raw_lengths = None + + position_ids, rope_deltas = self.get_rope_index( + original_input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_raw_lengths, + video_second_per_grid, + ) + + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = original_input_ids.shape + delta = ( + cache_position[0] + self.rope_deltas + if cache_position is not None + else 0 + ) + position_ids = torch.arange( + seq_length, device=original_input_ids.device + ) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + position_ids = ( + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) + .transpose(0, 1) + .unsqueeze(1) + ) + + if get_ulysses_sequence_parallel_world_size() > 1: + sp_size = get_ulysses_sequence_parallel_world_size() + input_ids, position_ids, pad_size = ulysses_pad( + input_ids.unsqueeze(0), + position_ids, + sp_size=sp_size, + ) + input_ids = input_ids.squeeze(0) + actual_tokens = input_ids.shape[0] + # update the actual seg_len if pad is used + if cu_seq_lens is not None and len(cu_seq_lens) > 0: + cu_seq_lens = torch.tensor( + [0] + [actual_tokens] * (len(cu_seq_lens) - 1), + dtype=cu_seq_lens.dtype, + device=cu_seq_lens.device, + ) + if inputs_embeds is None and input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + n_audio_tokens_check = (input_ids == self.config.audio_token_id).sum().item() + n_audio_features = audio_features.shape[0] + if n_audio_tokens_check != n_audio_features: + raise ValueError( + f"Audio features and audio tokens do not match: " + f"tokens: {n_audio_tokens_check}, features {n_audio_features}. " + f"This indicates a mismatch between the audio encoder output and placeholder tokens." + ) + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + n_image_tokens_check = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens_check != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_check}, features {n_image_features}" + ) + + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + + n_video_tokens_check = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens_check != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens_check}, features {n_video_features}" + ) + + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + outputs = self.model( + input_ids=None, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + rope_deltas=rope_deltas, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + cu_seq_lens=cu_seq_lens, + indices=indices, + ) + + seq_lens = outputs.get("seq_lens", None) + word_idx = outputs.get("word_idx", None) + hidden_states = outputs[0] + loss = None + logits = None + labels_unpad = labels.view(-1)[word_idx.long()] + if get_ulysses_sequence_parallel_world_size() > 1: + seq_lens = ( + calculate_seq_len_per_rank(seq_lens.tolist()) + if seq_lens is not None + else None + ) + labels_unpad = slice_input_tensor(labels_unpad, dim=0, padding=True) + labels = labels_unpad + if labels is not None: + if use_rmpad and seq_lens is not None: + shift_hidden_states = [] + shift_labels = [] + for i in range(len(seq_lens) - 1): + cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :] + cur_labels = labels[seq_lens[i] : seq_lens[i + 1]] + cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous() + cur_shift_labels = cur_labels[1:].contiguous() + shift_hidden_states.append(cur_shift_hidden_states) + shift_labels.append(cur_shift_labels) + shift_hidden_states = torch.cat(shift_hidden_states, dim=0) + shift_labels = torch.cat(shift_labels, dim=0) + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + hidden_size = ( + self.config.text_config.hidden_size + if hasattr(self.config, "text_config") + else self.config.hidden_size + ) + shift_hidden_states = shift_hidden_states.view(-1, hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + + if reduction == "sum": + loss /= kwargs["num_items_in_batch"] + + else: + logits = self.lm_head(hidden_states) + + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) diff --git a/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_ops.py b/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_ops.py new file mode 100644 index 00000000..abc29664 --- /dev/null +++ b/src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_ops.py @@ -0,0 +1,351 @@ +import inspect +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAttention, + Qwen2_5OmniAudioEncoder, + Qwen2_5OmniAudioEncoderLayer, + Qwen2_5OmniDecoderLayer, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniThinkerCausalLMOutputWithPast as HFQwen2_5OmniModelOutputWithPast, +) +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniThinkerForConditionalGeneration, + Qwen2_5OmniThinkerTextModel, + Qwen2_5OmniVisionEncoder, + apply_multimodal_rotary_pos_emb, + rotate_half, +) +from transformers.utils import is_flash_attn_2_available + +from lmms_engine.parallel.sequence_parallel.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + get_ulysses_sequence_parallel_rank, + get_ulysses_sequence_parallel_world_size, + repeat_kv, + ulysses_pad, +) +from lmms_engine.utils import Logging + +from ..sequence_packing_utils import ( + BaseModelOutputWithPastAndRmpad, + _get_unpad_data, + _unpad_input, +) + +if is_flash_attn_2_available(): + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, + ) + + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) + except: + raise ModuleNotFoundError( + "flash_attn is not available. Please install it via `pip install flash_attn`." + ) + + +@dataclass +class Qwen2_5OmniModelOutputWithPast(HFQwen2_5OmniModelOutputWithPast): + seq_lens: Optional[torch.IntTensor] = None + word_idx: Optional[torch.IntTensor] = None + + +def text_model_forward( + self: Qwen2_5OmniThinkerTextModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + **kwargs, +) -> Union[Tuple, BaseModelOutputWithPastAndRmpad]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + Logging.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + if cu_seq_lens is not None and indices is not None: + seq_len_for_cache = inputs_embeds.shape[0] # 1D case, total unpadded tokens + else: + seq_len_for_cache = inputs_embeds.shape[ + 1 + ] # 2D case, sequence length dimension + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + seq_len_for_cache, + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + if cu_seq_lens is not None and indices is not None: + # if use rmpad, position ids is [3, 1, total_non_pad_tokens] + # but lce_forward already provides position_ids + position_ids = cache_position.view(1, 1, -1).expand(3, 1, -1) + else: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.dim() == 2: + # if position_ids is provided but only 2D [batch, seq_len], expand to 3D [3, batch, seq_len] + # by adding the TMRoPE dimension at the front + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cu_seq_lens, + indices, + position_embeddings, + use_reentrant=False, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + cache_position=cache_position, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, past_key_values, all_hidden_states, all_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndRmpad( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_attentions, + seq_lens=cu_seq_lens, + word_idx=indices, + ) + + +def decoder_layer_forward( + self: Qwen2_5OmniDecoderLayer, + hidden_states: torch.Tensor, # should be 2D with rmpad + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + position_embeddings=position_embeddings, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def attn_forward( + self: Qwen2_5OmniAttention, + hidden_states: torch.Tensor, # should be 2D with rmpad + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, + **kwargs, +): + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + bsz = hidden_states.shape[0] + if cu_seq_lens is not None: + q_len = (cu_seq_lens[1:] - cu_seq_lens[:-1]).max().item() + else: + q_len = ( + hidden_states.shape[0] + if hidden_states.ndim == 2 + else hidden_states.shape[1] + ) + kv_seq_len = q_len + query_states = self.q_proj(hidden_states).view(-1, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view( + -1, self.num_key_value_heads, self.head_dim + ) + value_states = self.v_proj(hidden_states).view( + -1, self.num_key_value_heads, self.head_dim + ) + ########## AlltoAll for Ulysses ########## + if ulysses_sp_size > 1: + assert ( + position_ids is not None + ), "position_ids is required for Ulysses sequence parallelism" + repeats = max(ulysses_sp_size // key_states.size(1), 1) + key_states = repeat_kv(key_states, repeats) + value_states = repeat_kv(value_states, repeats) + # Testing + # before all to all Q: torch.Size([22541, 28, 128]), K: torch.Size([22541, 4, 128]), V: torch.Size([22541, 4, 128]) + # after all to all Q: torch.Size([45082, 14, 128]), K: torch.Size([45082, 2, 128]), V: torch.Size([45082, 2, 128]) + query_states = gather_seq_scatter_heads(query_states, seq_dim=0, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=0, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=0, head_dim=1) + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + self.rope_scaling["mrope_section"], + ) + + max_seqlen = ( + torch.diff(cu_seq_lens).max().item() if cu_seq_lens is not None else None + ) + + query_states = query_states.transpose(1, 2).squeeze(0) + key_states = key_states.transpose(1, 2).squeeze(0) + + window_size = (-1, -1) + + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seq_lens, + cu_seqlens_k=cu_seq_lens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=window_size, + softmax_scale=self.head_dim**-0.5, + dropout_p=0.0, + ) + + if ulysses_sp_size > 1: + # [45082, 14, 128] -> [22541, 28, 128] + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=0, head_dim=1) + + attn_output = attn_output.reshape(-1, self.hidden_size).contiguous() + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/src/lmms_engine/models/utils.py b/src/lmms_engine/models/utils.py index 3677383a..7e1f18f6 100644 --- a/src/lmms_engine/models/utils.py +++ b/src/lmms_engine/models/utils.py @@ -26,6 +26,8 @@ "qwen2", "qwen2_vl", "qwen2_5_vl", + "qwen2_5_omni", + "qwen2_5_omni_thinker", "qwen3", "qwen3_dllm", "qwen3_moe", @@ -59,6 +61,8 @@ def __init__(self, config: PretrainedConfig): "qwen2_moe": self._estimate_qwen2_moe_flops, "qwen2_vl": self._estimate_qwen2_flops, "qwen2_5_vl": self._estimate_qwen2_flops, + "qwen2_5_omni": self._estimate_qwen2_flops, + "qwen2_5_omni_thinker": self._estimate_qwen2_flops, "qwen3": self._estimate_qwen2_flops, "qwen3_dllm": self._estimate_qwen2_flops, "qwen3_moe": self._estimate_qwen2_moe_flops, @@ -71,6 +75,8 @@ def __init__(self, config: PretrainedConfig): } if config.model_type in ["llava_onevision", "qwen3_vl"]: self.config = config.text_config + elif config.model_type in ("qwen2_5_omni", "qwen2_5_omni_thinker"): + self.config = config.text_config self.config.model_type = config.model_type elif config.model_type == "bagel": self.config = config.llm_config @@ -81,17 +87,18 @@ def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time): return 0 def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): - hidden_size = self.config.hidden_size - vocab_size = self.config.vocab_size - num_hidden_layers = self.config.num_hidden_layers - num_key_value_heads = self.config.num_key_value_heads - num_attention_heads = self.config.num_attention_heads - intermediate_size = self.config.intermediate_size + config = self.config + hidden_size = config.hidden_size + vocab_size = config.vocab_size + num_hidden_layers = config.num_hidden_layers + num_key_value_heads = config.num_key_value_heads + num_attention_heads = config.num_attention_heads + intermediate_size = config.intermediate_size head_dim = getattr( - self.config, + config, "head_dim", - self.config.hidden_size // self.config.num_attention_heads, + config.hidden_size // config.num_attention_heads, ) q_size = num_attention_heads * head_dim k_size = num_key_value_heads * head_dim