Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ dependencies = [
"torchaudio>=2.9.0",
"torchvision>=0.24.0",
"transformers>=4.55.0",
"accelerate>=0.20.0",
"Pillow>=10.0.0",
"datasets>=2.14.0",
"huggingface-hub>=0.20.0",
Expand Down
43 changes: 43 additions & 0 deletions src/vlm/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@
from ..models.llava import LLaVAModel


def _get_model_dtype(model: LLaVAModel) -> torch.dtype:
"""Get the dtype of the model parameters.

Args:
model: LLaVA model instance

Returns:
Model dtype (bfloat16, float16, or float32)
"""
# Check connector dtype first (most likely to be in training dtype)
connector_param = next(model.connector.parameters())
if connector_param.dtype in (torch.bfloat16, torch.float16, torch.float32):
return connector_param.dtype

# Fall back to language model dtype
lm_param = next(model.language_model.parameters())
return lm_param.dtype


def generate_response(
model: LLaVAModel,
image_path: Optional[str] = None,
Expand Down Expand Up @@ -36,6 +55,9 @@ def generate_response(
model.eval()
tokenizer = model.language_model.tokenizer

# Get model dtype to ensure consistency
model_dtype = _get_model_dtype(model)

# Process image if provided
pixel_values = None
if image_path:
Expand All @@ -46,6 +68,9 @@ def generate_response(
return_tensors='pt'
)
pixel_values = processed['pixel_values'].to(device)
# Convert pixel_values to model dtype to avoid dtype mismatches
if pixel_values.dtype != model_dtype and pixel_values.is_floating_point():
pixel_values = pixel_values.to(dtype=model_dtype)

# Tokenize text
text_input = f"Human: {text}\nAssistant:" if text else "Assistant:"
Expand All @@ -63,6 +88,9 @@ def generate_response(

if pixel_values is not None:
visual_embeds = model.encode_images(pixel_values)
# Ensure visual_embeds match text_embeds dtype
if visual_embeds.dtype != text_embeds.dtype:
visual_embeds = visual_embeds.to(dtype=text_embeds.dtype)
# Extend attention mask for visual tokens
visual_mask = torch.ones(
visual_embeds.size()[:-1],
Expand Down Expand Up @@ -111,6 +139,9 @@ def generate_response(

# Update for next iteration
next_embed = embed_layer(next_token_id)
# Ensure next_embed matches inputs_embeds dtype
if next_embed.dtype != inputs_embeds.dtype:
next_embed = next_embed.to(dtype=inputs_embeds.dtype)
inputs_embeds = torch.cat([inputs_embeds, next_embed], dim=1)
attention_mask = torch.cat([
attention_mask,
Expand Down Expand Up @@ -160,6 +191,9 @@ def generate_response_stream(
model.eval()
tokenizer = model.language_model.tokenizer

# Get model dtype to ensure consistency
model_dtype = _get_model_dtype(model)

# Process image if provided
pixel_values = None
if image_path:
Expand All @@ -170,6 +204,9 @@ def generate_response_stream(
return_tensors='pt'
)
pixel_values = processed['pixel_values'].to(device)
# Convert pixel_values to model dtype to avoid dtype mismatches
if pixel_values.dtype != model_dtype and pixel_values.is_floating_point():
pixel_values = pixel_values.to(dtype=model_dtype)

# Tokenize text
text_input = f"Human: {text}\nAssistant:" if text else "Assistant:"
Expand All @@ -187,6 +224,9 @@ def generate_response_stream(

if pixel_values is not None:
visual_embeds = model.encode_images(pixel_values)
# Ensure visual_embeds match text_embeds dtype
if visual_embeds.dtype != text_embeds.dtype:
visual_embeds = visual_embeds.to(dtype=text_embeds.dtype)
# Extend attention mask for visual tokens
visual_mask = torch.ones(
visual_embeds.size()[:-1],
Expand Down Expand Up @@ -247,6 +287,9 @@ def generate_response_stream(

# Update for next iteration
next_embed = embed_layer(next_token_id)
# Ensure next_embed matches inputs_embeds dtype
if next_embed.dtype != inputs_embeds.dtype:
next_embed = next_embed.to(dtype=inputs_embeds.dtype)
inputs_embeds = torch.cat([inputs_embeds, next_embed], dim=1)
attention_mask = torch.cat([
attention_mask,
Expand Down
61 changes: 54 additions & 7 deletions src/vlm/inference/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
from typing import Optional
import torch
from PIL import Image

from ..models.llava import LLaVAModel
from ..configs.model_config import LLaVAConfig
Expand All @@ -14,12 +13,12 @@ def load_model_from_checkpoint(
device: Optional[torch.device] = None,
) -> LLaVAModel:
"""Load LLaVA model from checkpoint.

Args:
checkpoint_path: Path to model checkpoint (supports ~ expansion)
config: Model configuration. If None, uses default config.
device: Device to load model on. If None, auto-detects.

Returns:
Loaded model in eval mode
"""
Expand All @@ -30,16 +29,64 @@ def load_model_from_checkpoint(
device = torch.device("mps")
else:
device = torch.device("cpu")

config = config or LLaVAConfig()
model = LLaVAModel(config)

# Expand ~ to home directory if present
expanded_path = Path(checkpoint_path).expanduser()
checkpoint = torch.load(str(expanded_path), map_location=device)
model.load_state_dict(checkpoint)
model.eval()
model.to(device)

return model

# Ensure consistent dtype across all model components
# Check what dtype the connector was saved in
# (most likely to reflect training dtype)
connector_param = next(model.connector.parameters())
target_dtype = connector_param.dtype

# Only convert if it's a mixed precision dtype (bf16 or fp16)
# This ensures all components use the same dtype as trained connector
if target_dtype in (torch.bfloat16, torch.float16):
# Convert vision encoder to match connector dtype
# Use try-except to handle any conversion issues gracefully
if hasattr(model.vision_encoder, 'model'):
try:
# Check current dtype first to avoid unnecessary conversion
vision_param = next(
model.vision_encoder.model.parameters()
)
if vision_param.dtype != target_dtype:
# Use .to() which is safe for inference
# (converts params and buffers)
# For inference, converting buffers is acceptable
model.vision_encoder.model = (
model.vision_encoder.model.to(dtype=target_dtype)
)
except Exception as e:
# If conversion fails, log warning but continue
# Inference code will handle dtype mismatches at runtime
print(
f"Warning: Could not convert vision encoder to "
f"{target_dtype}: {e}. "
"Will handle dtype conversion at inference time."
)

# Language model should already match, but ensure consistency
if hasattr(model.language_model, 'model'):
# Only convert if it's not already in the target dtype
lm_param = next(model.language_model.parameters())
if lm_param.dtype != target_dtype:
try:
model.language_model.model = (
model.language_model.model.to(dtype=target_dtype)
)
except Exception as e:
print(
f"Warning: Could not convert language model to "
f"{target_dtype}: {e}. "
"Will handle dtype conversion at inference time."
)

return model
14 changes: 6 additions & 8 deletions src/vlm/train/phase1_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import argparse
import math
import os
import sys
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -292,16 +291,16 @@ def train(args):
# 5. Initialize Trainer
# Validate precision argument
precision = args.precision.lower()
if precision not in ["fp16", "bf16", "fp8", "fp32"]:
if precision not in ["fp16", "bf16", "fp32"]:
if rank == 0:
print(
f"Error: Invalid precision '{precision}'. "
"Must be 'fp16', 'bf16', 'fp8', or 'fp32'."
"Must be 'fp16', 'bf16', or 'fp32'."
)
if ddp_enabled:
cleanup_ddp()
return

if rank == 0:
print(f"Using precision: {precision}")

Expand Down Expand Up @@ -436,12 +435,11 @@ def train(args):
"--precision",
type=str,
default="fp16",
choices=["fp16", "bf16", "fp8", "fp32"],
choices=["fp16", "bf16", "fp32"],
help=(
"Mixed precision mode: 'fp16' (default), 'bf16', 'fp8', or 'fp32'. "
"Mixed precision mode: 'fp16' (default), 'bf16', or 'fp32'. "
"fp16: CUDA (with gradient scaling) or MPS. "
"bf16: CUDA (with bf16 support) or MPS. "
"fp8: CUDA only, requires accelerate with Transformer Engine/MS-AMP."
"bf16: CUDA (with bf16 support) or MPS."
)
)

Expand Down
Loading
Loading