Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/vlm/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class VisionEncoderConfig:
class ConnectorConfig:
"""Configuration for the connector/projection layer."""

num_layers: int = 1
num_layers: int = 2
"""Number of MLP layers.

1 = linear projection, 2 = MLP with hidden layer.
"""

hidden_dim: Optional[int] = None
hidden_dim: Optional[int] = 1024
"""Hidden dimension for MLP. Only used if num_layers > 1."""

activation: str = "gelu"
Expand Down
71 changes: 60 additions & 11 deletions src/vlm/train/phase2_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,42 @@ def _train_impl(
):
"""Internal training implementation wrapped in DDP sync context."""

# Map precision to torch dtype for model parameters
# Model dtype should match training precision to avoid dtype mismatches
# Autocast will handle mixed precision conversion during forward/backward
if precision == "fp16":
# Determine model parameter dtype based on precision and device:
# - FP16 on CUDA: requires GradScaler, so parameters must be fp32
# - FP16 on MPS: no GradScaler, can use fp16 parameters
# - BF16: no GradScaler, can use bf16 parameters
# - FP32: use fp32 parameters
# Note: When using GradScaler, parameters must be fp32 so gradients
# are fp32 (GradScaler.unscale_() requires fp32 gradients)
if precision == "bf16" and device.type == "cuda":
if torch.cuda.is_bf16_supported():
# BF16 without GradScaler - can use bf16 parameters
model_dtype = torch.bfloat16
else:
# Fall back to fp32 if bf16 not supported
model_dtype = torch.float32
elif precision == "fp16" and device.type == "cuda":
# FP16 on CUDA requires GradScaler, so parameters must be fp32
model_dtype = torch.float32
elif precision == "fp16" and device.type == "mps":
# FP16 on MPS doesn't use GradScaler, can use fp16 parameters
model_dtype = torch.float16
elif precision == "bf16":
model_dtype = torch.bfloat16
else: # fp32
else:
# FP32 or other cases - use fp32 parameters
model_dtype = torch.float32

if rank == 0:
print(f"Using precision: {precision} (model dtype: {model_dtype})")
print(
f"Using precision: {precision} "
f"(model dtype: {model_dtype}, "
f"autocast will handle {precision} conversion)"
)

# 2. Initialize Model
if rank == 0:
print("Initializing LLaVA model...")
config = LLaVAConfig()
# Set language model dtype to match training precision
# Set language model dtype (fp32 for fp16/fp32, bf16 for bf16)
config.language_model.torch_dtype = model_dtype
model = LLaVAModel(config)

Expand All @@ -220,8 +238,39 @@ def _train_impl(
# Load checkpoint to CPU first, then move to device
checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
if rank == 0:
print("✅ Checkpoint loaded successfully")

# Convert model parameters to match model_dtype if needed
# Key distinction:
# - FP16 on CUDA: requires fp32 parameters (for GradScaler)
# - FP16 on MPS: can use fp16 parameters (no GradScaler)
# - BF16: can use bf16 parameters (no GradScaler)
checkpoint_dtype = None
for param in model.parameters():
if param.is_floating_point():
checkpoint_dtype = param.dtype
break

if checkpoint_dtype != model_dtype:
# Convert to match desired model_dtype
model = model.to(dtype=model_dtype)
if rank == 0:
print(
f"✅ Checkpoint loaded and converted from "
f"{checkpoint_dtype} to {model_dtype}"
)
if (precision == "fp16" and device.type == "cuda" and
model_dtype == torch.float32):
print(
" (fp32 required for GradScaler, "
"autocast uses fp16 during computation)"
)
else:
if rank == 0:
print(
f"✅ Checkpoint loaded "
f"(parameters: {checkpoint_dtype}, "
f"autocast uses {precision})"
)

# Wrap model with DDP if enabled
ddp_model = model # Default to model if DDP not enabled
Expand Down
Loading