diff --git a/pyproject.toml b/pyproject.toml index 7094766..a57eb20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ dependencies = [ "torchaudio>=2.9.0", "torchvision>=0.24.0", "transformers>=4.55.0", - "accelerate>=0.20.0", "Pillow>=10.0.0", "datasets>=2.14.0", "huggingface-hub>=0.20.0", diff --git a/src/vlm/inference/inference.py b/src/vlm/inference/inference.py index 3e5f971..1129bdb 100644 --- a/src/vlm/inference/inference.py +++ b/src/vlm/inference/inference.py @@ -6,6 +6,25 @@ from ..models.llava import LLaVAModel +def _get_model_dtype(model: LLaVAModel) -> torch.dtype: + """Get the dtype of the model parameters. + + Args: + model: LLaVA model instance + + Returns: + Model dtype (bfloat16, float16, or float32) + """ + # Check connector dtype first (most likely to be in training dtype) + connector_param = next(model.connector.parameters()) + if connector_param.dtype in (torch.bfloat16, torch.float16, torch.float32): + return connector_param.dtype + + # Fall back to language model dtype + lm_param = next(model.language_model.parameters()) + return lm_param.dtype + + def generate_response( model: LLaVAModel, image_path: Optional[str] = None, @@ -36,6 +55,9 @@ def generate_response( model.eval() tokenizer = model.language_model.tokenizer + # Get model dtype to ensure consistency + model_dtype = _get_model_dtype(model) + # Process image if provided pixel_values = None if image_path: @@ -46,6 +68,9 @@ def generate_response( return_tensors='pt' ) pixel_values = processed['pixel_values'].to(device) + # Convert pixel_values to model dtype to avoid dtype mismatches + if pixel_values.dtype != model_dtype and pixel_values.is_floating_point(): + pixel_values = pixel_values.to(dtype=model_dtype) # Tokenize text text_input = f"Human: {text}\nAssistant:" if text else "Assistant:" @@ -63,6 +88,9 @@ def generate_response( if pixel_values is not None: visual_embeds = model.encode_images(pixel_values) + # Ensure visual_embeds match text_embeds dtype + if visual_embeds.dtype != text_embeds.dtype: + visual_embeds = visual_embeds.to(dtype=text_embeds.dtype) # Extend attention mask for visual tokens visual_mask = torch.ones( visual_embeds.size()[:-1], @@ -111,6 +139,9 @@ def generate_response( # Update for next iteration next_embed = embed_layer(next_token_id) + # Ensure next_embed matches inputs_embeds dtype + if next_embed.dtype != inputs_embeds.dtype: + next_embed = next_embed.to(dtype=inputs_embeds.dtype) inputs_embeds = torch.cat([inputs_embeds, next_embed], dim=1) attention_mask = torch.cat([ attention_mask, @@ -160,6 +191,9 @@ def generate_response_stream( model.eval() tokenizer = model.language_model.tokenizer + # Get model dtype to ensure consistency + model_dtype = _get_model_dtype(model) + # Process image if provided pixel_values = None if image_path: @@ -170,6 +204,9 @@ def generate_response_stream( return_tensors='pt' ) pixel_values = processed['pixel_values'].to(device) + # Convert pixel_values to model dtype to avoid dtype mismatches + if pixel_values.dtype != model_dtype and pixel_values.is_floating_point(): + pixel_values = pixel_values.to(dtype=model_dtype) # Tokenize text text_input = f"Human: {text}\nAssistant:" if text else "Assistant:" @@ -187,6 +224,9 @@ def generate_response_stream( if pixel_values is not None: visual_embeds = model.encode_images(pixel_values) + # Ensure visual_embeds match text_embeds dtype + if visual_embeds.dtype != text_embeds.dtype: + visual_embeds = visual_embeds.to(dtype=text_embeds.dtype) # Extend attention mask for visual tokens visual_mask = torch.ones( visual_embeds.size()[:-1], @@ -247,6 +287,9 @@ def generate_response_stream( # Update for next iteration next_embed = embed_layer(next_token_id) + # Ensure next_embed matches inputs_embeds dtype + if next_embed.dtype != inputs_embeds.dtype: + next_embed = next_embed.to(dtype=inputs_embeds.dtype) inputs_embeds = torch.cat([inputs_embeds, next_embed], dim=1) attention_mask = torch.cat([ attention_mask, diff --git a/src/vlm/inference/model_loader.py b/src/vlm/inference/model_loader.py index cc949c1..6038e33 100644 --- a/src/vlm/inference/model_loader.py +++ b/src/vlm/inference/model_loader.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Optional import torch -from PIL import Image from ..models.llava import LLaVAModel from ..configs.model_config import LLaVAConfig @@ -14,12 +13,12 @@ def load_model_from_checkpoint( device: Optional[torch.device] = None, ) -> LLaVAModel: """Load LLaVA model from checkpoint. - + Args: checkpoint_path: Path to model checkpoint (supports ~ expansion) config: Model configuration. If None, uses default config. device: Device to load model on. If None, auto-detects. - + Returns: Loaded model in eval mode """ @@ -30,16 +29,64 @@ def load_model_from_checkpoint( device = torch.device("mps") else: device = torch.device("cpu") - + config = config or LLaVAConfig() model = LLaVAModel(config) - + # Expand ~ to home directory if present expanded_path = Path(checkpoint_path).expanduser() checkpoint = torch.load(str(expanded_path), map_location=device) model.load_state_dict(checkpoint) model.eval() model.to(device) - - return model + # Ensure consistent dtype across all model components + # Check what dtype the connector was saved in + # (most likely to reflect training dtype) + connector_param = next(model.connector.parameters()) + target_dtype = connector_param.dtype + + # Only convert if it's a mixed precision dtype (bf16 or fp16) + # This ensures all components use the same dtype as trained connector + if target_dtype in (torch.bfloat16, torch.float16): + # Convert vision encoder to match connector dtype + # Use try-except to handle any conversion issues gracefully + if hasattr(model.vision_encoder, 'model'): + try: + # Check current dtype first to avoid unnecessary conversion + vision_param = next( + model.vision_encoder.model.parameters() + ) + if vision_param.dtype != target_dtype: + # Use .to() which is safe for inference + # (converts params and buffers) + # For inference, converting buffers is acceptable + model.vision_encoder.model = ( + model.vision_encoder.model.to(dtype=target_dtype) + ) + except Exception as e: + # If conversion fails, log warning but continue + # Inference code will handle dtype mismatches at runtime + print( + f"Warning: Could not convert vision encoder to " + f"{target_dtype}: {e}. " + "Will handle dtype conversion at inference time." + ) + + # Language model should already match, but ensure consistency + if hasattr(model.language_model, 'model'): + # Only convert if it's not already in the target dtype + lm_param = next(model.language_model.parameters()) + if lm_param.dtype != target_dtype: + try: + model.language_model.model = ( + model.language_model.model.to(dtype=target_dtype) + ) + except Exception as e: + print( + f"Warning: Could not convert language model to " + f"{target_dtype}: {e}. " + "Will handle dtype conversion at inference time." + ) + + return model diff --git a/src/vlm/train/phase1_run.py b/src/vlm/train/phase1_run.py index 1afa0cb..a9458fa 100644 --- a/src/vlm/train/phase1_run.py +++ b/src/vlm/train/phase1_run.py @@ -20,7 +20,6 @@ import argparse import math import os -import sys import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP @@ -292,16 +291,16 @@ def train(args): # 5. Initialize Trainer # Validate precision argument precision = args.precision.lower() - if precision not in ["fp16", "bf16", "fp8", "fp32"]: + if precision not in ["fp16", "bf16", "fp32"]: if rank == 0: print( f"Error: Invalid precision '{precision}'. " - "Must be 'fp16', 'bf16', 'fp8', or 'fp32'." + "Must be 'fp16', 'bf16', or 'fp32'." ) if ddp_enabled: cleanup_ddp() return - + if rank == 0: print(f"Using precision: {precision}") @@ -436,12 +435,11 @@ def train(args): "--precision", type=str, default="fp16", - choices=["fp16", "bf16", "fp8", "fp32"], + choices=["fp16", "bf16", "fp32"], help=( - "Mixed precision mode: 'fp16' (default), 'bf16', 'fp8', or 'fp32'. " + "Mixed precision mode: 'fp16' (default), 'bf16', or 'fp32'. " "fp16: CUDA (with gradient scaling) or MPS. " - "bf16: CUDA (with bf16 support) or MPS. " - "fp8: CUDA only, requires accelerate with Transformer Engine/MS-AMP." + "bf16: CUDA (with bf16 support) or MPS." ) ) diff --git a/src/vlm/train/phase1_trainer.py b/src/vlm/train/phase1_trainer.py index b790c2e..487054f 100644 --- a/src/vlm/train/phase1_trainer.py +++ b/src/vlm/train/phase1_trainer.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist -from accelerate import Accelerator from torch.utils.data import DataLoader from tqdm import tqdm @@ -55,8 +54,8 @@ def __init__( (e.g., learning_rate, batch_size) save_checkpoint_interval: Interval for saving checkpoints (default: 500 steps) - precision: Mixed precision mode: "bf16", "fp8", or "fp32" - (default: "bf16") + precision: Mixed precision mode: "fp16", "bf16", or "fp32" + (default: "fp16") gradient_accumulation_steps: Number of gradient accumulation steps (default: 1, no accumulation) """ @@ -91,46 +90,12 @@ def __init__( type(model).__name__ == 'DistributedDataParallel' ) - # Setup mixed precision training (fp16, bf16, fp8, or fp32) - self.accelerator = None + # Setup mixed precision training (fp16, bf16, or fp32) self.scaler = None self.amp_dtype = None self.device_type = "cuda" if device.type == "cuda" else device.type - - if precision == "fp8": - # Use accelerate for fp8 (requires Transformer Engine or MS-AMP) - if not torch.cuda.is_available(): - if self.rank == 0: - print( - "⚠️ FP8 requires CUDA. Falling back to bf16." - ) - precision = "bf16" - else: - try: - self.accelerator = Accelerator(mixed_precision="fp8") - if self.rank == 0: - print("✅ FP8 training enabled (using accelerate)") - # Prepare model, optimizer, and dataloader with accelerate - # Note: When using accelerate, it handles DDP internally - # so don't pre-wrap with DDP in the run script - self.model = self.accelerator.prepare(self.model) - self.optimizer = self.accelerator.prepare(self.optimizer) - self.train_dataloader = ( - self.accelerator.prepare(self.train_dataloader) - ) - # Update underlying_model after accelerate preparation - if hasattr(self.model, 'module'): - self.underlying_model = self.model.module - else: - self.underlying_model = self.model - except (ImportError, RuntimeError) as e: - if self.rank == 0: - print( - f"⚠️ FP8 not available ({e}). " - "Falling back to bf16." - ) - precision = "bf16" - elif precision == "fp16": + + if precision == "fp16": # FP16 support: CUDA (with gradient scaling) or MPS/CPU (limited) if device.type == "cuda": self.amp_dtype = torch.float16 @@ -199,15 +164,14 @@ def __init__( else: raise ValueError( f"Invalid precision: {precision}. " - "Must be 'fp16', 'bf16', 'fp8', or 'fp32'" + "Must be 'fp16', 'bf16', or 'fp32'" ) - # Get underlying model if wrapped with DDP (only if not using accelerate) - if self.accelerator is None: - if self.ddp_enabled: - self.underlying_model = model.module - else: - self.underlying_model = model + # Get underlying model if wrapped with DDP + if self.ddp_enabled: + self.underlying_model = model.module + else: + self.underlying_model = model # Phase 1: Freeze VLM/LLM, Train Connector # Set training stage on underlying model @@ -315,16 +279,7 @@ def train(self): # 2. pixel_values → CLIP encoder → visual features # 3. visual features → connector → visual embeddings # 4. visual embeddings concatenated with text embeddings - if self.accelerator is not None: - # FP8 training with accelerate - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - images=pixel_values - ) - loss = outputs.loss - elif self.amp_dtype is not None: + if self.amp_dtype is not None: # Mixed precision training with autocast (fp16 or bf16) # Use device-appropriate autocast if self.device_type == "cuda": @@ -384,10 +339,7 @@ def train(self): loss = loss / self.gradient_accumulation_steps # Backward pass - if self.accelerator is not None: - # FP8 training with accelerate - self.accelerator.backward(loss) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA requires gradient scaling self.scaler.scale(loss).backward() else: @@ -397,16 +349,7 @@ def train(self): # Only update optimizer and scheduler after accumulating all steps grad_norm = None if accumulation_step == self.gradient_accumulation_steps: - if self.accelerator is not None: - # FP8 training with accelerate - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - max_norm=self.max_grad_norm - ) - # Optimizer step with accelerate - self.accelerator.step(self.optimizer) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA - unscale before clipping self.scaler.unscale_(self.optimizer) # Gradient clipping @@ -536,14 +479,7 @@ def train(self): # Handle final optimizer step if we have accumulated gradients # but haven't stepped yet if accumulation_step > 0: - if self.accelerator is not None: - # FP8 training with accelerate - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - max_norm=self.max_grad_norm - ) - self.accelerator.step(self.optimizer) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA - unscale before clipping self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( @@ -586,7 +522,7 @@ def save_checkpoint(self, filename: str): output_dir = Path(self.output_dir).expanduser() os.makedirs(output_dir, exist_ok=True) checkpoint_path = output_dir / filename - + # Get state dict and convert to training precision if needed # (autocast doesn't change parameter dtype, so we convert on save) state_dict = self.underlying_model.state_dict() @@ -597,11 +533,10 @@ def save_checkpoint(self, filename: str): else v for k, v in state_dict.items() } - + # Save state dict with correct precision torch.save(state_dict, str(checkpoint_path)) print( f"Saved checkpoint to {checkpoint_path} " f"(precision: {self.precision})" ) - diff --git a/src/vlm/train/phase2_run.py b/src/vlm/train/phase2_run.py index 503eb70..20777d8 100644 --- a/src/vlm/train/phase2_run.py +++ b/src/vlm/train/phase2_run.py @@ -10,10 +10,10 @@ Distributed training (automatically enabled when using torchrun): torchrun --nproc_per_node=2 src/vlm/train/phase2_run.py \ - --checkpoint ~/models/llava/checkpoint_phase1_fp16.pt \ + --checkpoint ~/models/llava/checkpoint_phase1_bf16.pt \ --data_path ~/dataset/llava-instruct-mix/data \ --max_steps 10000 --batch_size 16 --use_cosine_schedule \ - --gradient_accumulation_steps 2 --precision fp16 \ + --gradient_accumulation_steps 4 --precision bf16 \ --output_dir ~/models/llava --learning_rate 2e-5 Note: --data_path should point to a folder containing parquet files. @@ -260,12 +260,9 @@ def train(args): print("=" * 80) try: # Import validation function from scripts - from pathlib import Path scripts_path = Path(__file__).parent.parent.parent / "scripts" sys.path.insert(0, str(scripts_path)) - from inspect_phase2_data import ( - validate_masking_and_prepending - ) + from inspect_phase2_data import validate_masking_and_prepending # Determine device for validation if torch.cuda.is_available(): val_device = torch.device("cuda") @@ -273,7 +270,7 @@ def train(args): val_device = torch.device("mps") else: val_device = torch.device("cpu") - + validation_passed = validate_masking_and_prepending( dataset, model, @@ -281,17 +278,28 @@ def train(args): num_samples=args.validation_samples, device=val_device, ) - + if not validation_passed: - print("\n❌ Data validation failed. Please fix issues before training.") + print( + "\n❌ Data validation failed. " + "Please fix issues before training." + ) if ddp_enabled: cleanup_ddp() return else: - print("\n✅ Data validation passed. Proceeding with training.") + print( + "\n✅ Data validation passed. " + "Proceeding with training." + ) except ImportError as e: - print(f"⚠️ Warning: Could not import validation function: {e}") - print(" Validation skipped. Install required dependencies if needed.") + print( + f"⚠️ Warning: Could not import validation function: {e}" + ) + print( + " Validation skipped. " + "Install required dependencies if needed." + ) except Exception as e: print(f"⚠️ Warning: Validation failed with error: {e}") print(" Proceeding with training anyway.") @@ -363,16 +371,16 @@ def train(args): # 5. Initialize Trainer # Validate precision argument precision = args.precision.lower() - if precision not in ["fp16", "bf16", "fp8", "fp32"]: + if precision not in ["fp16", "bf16", "fp32"]: if rank == 0: print( f"Error: Invalid precision '{precision}'. " - "Must be 'fp16', 'bf16', 'fp8', or 'fp32'." + "Must be 'fp16', 'bf16', or 'fp32'." ) if ddp_enabled: cleanup_ddp() return - + if rank == 0: print(f"Using precision: {precision}") @@ -511,12 +519,11 @@ def train(args): "--precision", type=str, default="fp16", - choices=["fp16", "bf16", "fp8", "fp32"], + choices=["fp16", "bf16", "fp32"], help=( - "Mixed precision mode: 'fp16' (default), 'bf16', 'fp8', or 'fp32'. " + "Mixed precision mode: 'fp16' (default), 'bf16', or 'fp32'. " "fp16: CUDA (with gradient scaling) or MPS. " - "bf16: CUDA (with bf16 support) or MPS. " - "fp8: CUDA only, requires accelerate with Transformer Engine/MS-AMP." + "bf16: CUDA (with bf16 support) or MPS." ) ) diff --git a/src/vlm/train/phase2_trainer.py b/src/vlm/train/phase2_trainer.py index 59802fe..db6eb9d 100644 --- a/src/vlm/train/phase2_trainer.py +++ b/src/vlm/train/phase2_trainer.py @@ -31,7 +31,7 @@ def __init__( wandb_run_name: Optional[str] = None, scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, hyperparams: Optional[dict] = None, - save_checkpoint_interval: int = 500, + save_checkpoint_interval: int = 1000, precision: str = "fp16", gradient_accumulation_steps: int = 1, ): @@ -54,8 +54,8 @@ def __init__( (e.g., learning_rate, batch_size) save_checkpoint_interval: Interval for saving checkpoints (default: 500 steps) - precision: Mixed precision mode: "bf16", "fp8", or "fp32" - (default: "bf16") + precision: Mixed precision mode: "fp16", "bf16", or "fp32" + (default: "fp16") gradient_accumulation_steps: Number of gradient accumulation steps (default: 1, no accumulation) """ @@ -90,47 +90,12 @@ def __init__( type(model).__name__ == 'DistributedDataParallel' ) - # Setup mixed precision training (fp16, bf16, fp8, or fp32) - self.accelerator = None + # Setup mixed precision training (fp16, bf16, or fp32) self.scaler = None self.amp_dtype = None self.device_type = "cuda" if device.type == "cuda" else device.type - - if precision == "fp8": - # Use accelerate for fp8 (requires Transformer Engine or MS-AMP) - if not torch.cuda.is_available(): - if self.rank == 0: - print( - "⚠️ FP8 requires CUDA. Falling back to bf16." - ) - precision = "bf16" - else: - try: - from accelerate import Accelerator - self.accelerator = Accelerator(mixed_precision="fp8") - if self.rank == 0: - print("✅ FP8 training enabled (using accelerate)") - # Prepare model, optimizer, and dataloader with accelerate - # Note: When using accelerate, it handles DDP internally - # so don't pre-wrap with DDP in the run script - self.model = self.accelerator.prepare(self.model) - self.optimizer = self.accelerator.prepare(self.optimizer) - self.train_dataloader = ( - self.accelerator.prepare(self.train_dataloader) - ) - # Update underlying_model after accelerate preparation - if hasattr(self.model, 'module'): - self.underlying_model = self.model.module - else: - self.underlying_model = self.model - except (ImportError, RuntimeError) as e: - if self.rank == 0: - print( - f"⚠️ FP8 not available ({e}). " - "Falling back to bf16." - ) - precision = "bf16" - elif precision == "fp16": + + if precision == "fp16": # FP16 support: CUDA (with gradient scaling) or MPS/CPU (limited) if device.type == "cuda": self.amp_dtype = torch.float16 @@ -199,15 +164,14 @@ def __init__( else: raise ValueError( f"Invalid precision: {precision}. " - "Must be 'fp16', 'bf16', 'fp8', or 'fp32'" + "Must be 'fp16', 'bf16', or 'fp32'" ) - # Get underlying model if wrapped with DDP (only if not using accelerate) - if self.accelerator is None: - if self.ddp_enabled: - self.underlying_model = model.module - else: - self.underlying_model = model + # Get underlying model if wrapped with DDP + if self.ddp_enabled: + self.underlying_model = model.module + else: + self.underlying_model = model # Phase 2: Freeze Vision Encoder, Train Connector + LLM # Set training stage on underlying model @@ -316,16 +280,7 @@ def train(self): # 2. pixel_values → CLIP encoder → visual features # 3. visual features → connector → visual embeddings # 4. visual embeddings concatenated with text embeddings - if self.accelerator is not None: - # FP8 training with accelerate - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - images=pixel_values - ) - loss = outputs.loss - elif self.amp_dtype is not None: + if self.amp_dtype is not None: # Mixed precision training with autocast (fp16 or bf16) # Use device-appropriate autocast if self.device_type == "cuda": @@ -385,10 +340,7 @@ def train(self): loss = loss / self.gradient_accumulation_steps # Backward pass - if self.accelerator is not None: - # FP8 training with accelerate - self.accelerator.backward(loss) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA requires gradient scaling self.scaler.scale(loss).backward() else: @@ -398,16 +350,7 @@ def train(self): # Only update optimizer and scheduler after accumulating all steps grad_norm = None if accumulation_step == self.gradient_accumulation_steps: - if self.accelerator is not None: - # FP8 training with accelerate - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - max_norm=self.max_grad_norm - ) - # Optimizer step with accelerate - self.accelerator.step(self.optimizer) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA - unscale before clipping self.scaler.unscale_(self.optimizer) # Gradient clipping @@ -466,7 +409,7 @@ def train(self): # Only print logs on rank 0 # Only log grad_norm when we actually step (after accumulation) if self.rank == 0: - if not self.use_wandb and step % 100 == 0: + if not self.use_wandb and step % 1000 == 0: grad_norm_str = ( f"{grad_norm:.4f}" if grad_norm is not None else "accumulating" @@ -537,14 +480,7 @@ def train(self): # Handle final optimizer step if we have accumulated gradients # but haven't stepped yet if accumulation_step > 0: - if self.accelerator is not None: - # FP8 training with accelerate - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - max_norm=self.max_grad_norm - ) - self.accelerator.step(self.optimizer) - elif self.scaler is not None: + if self.scaler is not None: # FP16 on CUDA - unscale before clipping self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( @@ -587,7 +523,7 @@ def save_checkpoint(self, filename: str): output_dir = Path(self.output_dir).expanduser() os.makedirs(output_dir, exist_ok=True) checkpoint_path = output_dir / filename - + # Get state dict and convert to training precision if needed # (autocast doesn't change parameter dtype, so we convert on save) state_dict = self.underlying_model.state_dict() @@ -598,11 +534,10 @@ def save_checkpoint(self, filename: str): else v for k, v in state_dict.items() } - + # Save state dict with correct precision torch.save(state_dict, str(checkpoint_path)) print( f"Saved checkpoint to {checkpoint_path} " f"(precision: {self.precision})" ) -