Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions sam3/model/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def box_iou(boxes1, boxes2):

union = area1[..., None] + area2[..., None, :] - inter

iou = inter / union
# Add epsilon to prevent division by zero for degenerate boxes (e.g., empty masks)
iou = inter / (union + 1e-7)
return iou, union


Expand All @@ -139,7 +140,8 @@ def generalized_box_iou(boxes1, boxes2):
wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
area = wh[..., 0] * wh[..., 1] # (..., N, M)

return iou - (area - union) / area
# Add epsilon to prevent division by zero for degenerate boxes (e.g., empty masks)
return iou - (area - union) / (area + 1e-7)


@torch.jit.script
Expand All @@ -164,9 +166,10 @@ def fast_diag_generalized_box_iou(boxes1, boxes2):

union = area1 + area2 - inter

iou = inter / union
# Add epsilon to prevent division by zero for degenerate boxes
iou = inter / (union + 1e-7)

return iou - (tot_area - union) / tot_area
return iou - (tot_area - union) / (tot_area + 1e-7)


@torch.jit.script
Expand All @@ -188,7 +191,8 @@ def fast_diag_box_iou(boxes1, boxes2):

union = area1 + area2 - inter

iou = inter / union
# Add epsilon to prevent division by zero for degenerate boxes
iou = inter / (union + 1e-7)

return iou

Expand Down
26 changes: 26 additions & 0 deletions sam3/train/loss/loss_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def _dice_loss(inputs, targets, num_boxes, loss_on_multimask=False, reduce=True)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
# Replace NaN/Inf in loss (dice loss should be in [0, 1])
loss = torch.nan_to_num(loss, nan=1.0, posinf=1.0, neginf=0.0)
if loss_on_multimask:
return loss / num_boxes
if not reduce:
Expand Down Expand Up @@ -166,6 +168,9 @@ def sigmoid_focal_loss(
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

# Replace NaN/Inf in loss (can occur with extreme logits)
loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0)

if not reduce:
return loss

Expand Down Expand Up @@ -364,6 +369,8 @@ def get_loss(self, outputs, targets, indices, num_boxes):
)

iou = box_ops.fast_diag_box_iou(src_boxes_xyxy, target_boxes_giou)
# Replace NaN/Inf in IoU (IoU should be in [0, 1])
iou = torch.nan_to_num(iou, nan=0.0, posinf=1.0, neginf=0.0)
t = prob[(indices[0], indices[1])] ** self.alpha * iou ** (1 - self.alpha)
t = torch.clamp(t, 0.01).detach()
positive_target_classes = target_classes.clone()
Expand Down Expand Up @@ -506,6 +513,10 @@ def get_loss(self, outputs, targets, indices, num_boxes):
task="binary",
)

# Replace NaN/Inf in losses
loss_bce = torch.nan_to_num(loss_bce, nan=0.0, posinf=100.0, neginf=0.0)
presence_loss = torch.nan_to_num(presence_loss, nan=0.0, posinf=100.0, neginf=0.0)

losses = {
"loss_ce": loss_bce,
"ce_f1": bce_f1,
Expand Down Expand Up @@ -554,13 +565,17 @@ def get_loss(self, outputs, targets, indices, num_boxes):
)

loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
# Replace NaN/Inf in loss_bbox (can occur with extreme predicted boxes)
loss_bbox = torch.nan_to_num(loss_bbox, nan=1.0, posinf=1.0, neginf=0.0)

losses = {}
losses["loss_bbox"] = loss_bbox.sum() / num_boxes

loss_giou = 1 - box_ops.fast_diag_generalized_box_iou(
src_boxes_xyxy, target_boxes_giou
)
# Replace NaN/Inf in loss_giou (GIoU loss should be in [0, 2])
loss_giou = torch.nan_to_num(loss_giou, nan=2.0, posinf=2.0, neginf=0.0)
losses["loss_giou"] = loss_giou.sum() / num_boxes
return losses

Expand Down Expand Up @@ -1011,6 +1026,7 @@ def __init__(
# we could still set presence_head to True so that
# losses are not propogated to masks when there is no GT mask
presence_loss: bool = True,
fallback_resolution: int = 1008,
):
super().__init__(weight_dict, False)
self.focal = focal
Expand All @@ -1019,6 +1035,7 @@ def __init__(
self.downsample = downsample
self.presence_head = presence_head
self.presence_loss = presence_loss
self.fallback_resolution = fallback_resolution

def get_loss(self, out_dict, targets):
outputs = out_dict["semantic_seg"]
Expand Down Expand Up @@ -1066,6 +1083,9 @@ def get_loss(self, out_dict, targets):
segments, targets["num_boxes"]
)

if sum(targets["num_boxes"]) == 0:
semantic_targets = torch.zeros((outputs.shape[0], self.fallback_resolution, self.fallback_resolution), dtype=torch.bool, device=outputs.device)

if not self.downsample:
# upsample predictions to the target size
size = semantic_targets.shape[-2:]
Expand Down Expand Up @@ -1126,6 +1146,8 @@ def get_loss(self, out_dict, targets):
# should also track presence_acc
presence_acc = torch.tensor(0.0, device=loss.device)

# Replace NaN/Inf in presence loss
loss_presence = torch.nan_to_num(loss_presence, nan=0.0, posinf=100.0, neginf=0.0)
loss_dict["loss_semantic_presence"] = loss_presence
loss_dict["presence_acc"] = presence_acc

Expand All @@ -1139,6 +1161,10 @@ def get_loss(self, out_dict, targets):
loss = (loss * mask.float()).sum() / (nb_valid + 1e-6)
loss_dice = (loss_dice * mask.float()).sum() / (nb_valid + 1e-6)

# Replace NaN/Inf in semantic segmentation losses
loss = torch.nan_to_num(loss, nan=0.0, posinf=100.0, neginf=0.0)
loss_dice = torch.nan_to_num(loss_dice, nan=1.0, posinf=1.0, neginf=0.0)

loss_dict.update(
{
"loss_semantic_seg": loss,
Expand Down
7 changes: 7 additions & 0 deletions sam3/train/loss/sam3_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,11 @@ def forward(self, find_stages: SAM3Output, find_targets):
else:
total_losses[k] += v

# Final safety check: replace any NaN/Inf in the core loss
# This catches any NaN that slipped through individual loss guards
if isinstance(total_losses[CORE_LOSS_KEY], torch.Tensor):
total_losses[CORE_LOSS_KEY] = torch.nan_to_num(
total_losses[CORE_LOSS_KEY], nan=0.0, posinf=1e6, neginf=0.0
)

return total_losses
9 changes: 9 additions & 0 deletions sam3/train/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,15 @@ def forward(

# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Replace NaN/Inf with large values to prevent linear_sum_assignment failure
cost_bbox = torch.nan_to_num(cost_bbox, nan=1e6, posinf=1e6, neginf=1e6)

# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
)
# Replace NaN/Inf (GIoU should be in [-1, 1], so cost_giou in [-1, 1])
cost_giou = torch.nan_to_num(cost_giou, nan=0.0, posinf=2.0, neginf=-2.0)

out_prob = self.norm(out_score)
if not self.focal:
Expand All @@ -596,6 +600,9 @@ def forward(
if not self.stable:
cost_class = cost_class.unsqueeze(-1).expand_as(cost_bbox)

# Replace NaN/Inf in cost_class (can occur with extreme logits)
cost_class = torch.nan_to_num(cost_class, nan=1e6, posinf=1e6, neginf=-1e6)

assert cost_class.shape == cost_bbox.shape

# Final cost matrix
Expand All @@ -604,6 +611,8 @@ def forward(
+ self.cost_class * cost_class
+ self.cost_giou * cost_giou
)
# Final safety check: replace any remaining NaN/Inf in the cost matrix
C = torch.nan_to_num(C, nan=1e9, posinf=1e9, neginf=-1e9)
# assign a very high cost (1e9) to invalid outputs and targets, so that we can
# filter them out (in `_do_matching`) from bipartite matching results
do_filtering = out_is_valid is not None or target_is_valid_padded is not None
Expand Down
2 changes: 2 additions & 0 deletions sam3/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def save_checkpoint(self, epoch, checkpoint_names=None):
checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))

state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
# Add 'detector.' prefix to match checkpoint format expected by model loading code
state_dict = {"detector." + k: v for k, v in state_dict.items()}
state_dict = exclude_params_matching_unix_pattern(
patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
)
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Meta, Inc. and its affiliates. All Rights Reserved
98 changes: 98 additions & 0 deletions tests/test_instance_segmentation_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Minimal test for instance segmentation finetuning.
This test reproduces the issue from the external project.
"""
import os
import tempfile
from pathlib import Path

# Path to the test dataset
DATASET_DIR = Path(__file__).parent / "testdata" / "dataset-instance-seg"


def test_instance_segmentation_finetune_minimal():
"""Test instance segmentation finetuning with the provided dataset."""
from sam3.train.utils.train_utils import register_omegaconf_resolvers, makedir
from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
import random

# Create temporary cache directory
with tempfile.TemporaryDirectory() as cache_dir:
print(f"Cache directory: {cache_dir}")

try:
# BPE vocab is in the assets directory
bpe_vocab = str(Path(__file__).parent.parent /
"assets" / "bpe_simple_vocab_16e6.txt.gz")
assert os.path.exists(
bpe_vocab), f"BPE vocab not found at {bpe_vocab}"

# Use default checkpoint path (should exist in the devcontainer)
checkpoint_path = "sam3_checkpoint.pt"

experiment_log_dir = os.path.join(cache_dir, "sam3_logs")

# Initialize Hydra
GlobalHydra.instance().clear()
register_omegaconf_resolvers()

initialize_config_dir(config_dir=str(
Path(__file__).parent / "testdata"), version_base="1.2")
# read num_images from train dir
num_images = len(list(DATASET_DIR.glob("train/*.jpg")))
cfg = compose(
config_name="sam3_template-seg",
overrides=[
f"paths.experiment_log_dir={experiment_log_dir}",
f"paths.checkpoint_path={checkpoint_path}",
f"paths.bpe_path={bpe_vocab}",
f"paths.dataset_path={DATASET_DIR}",
f"roboflow_train.num_images={num_images}",
"roboflow_train.supercategory=cars-In1I",
"roboflow_train.max_epochs=3",
]
)

makedir(cfg.launcher.experiment_log_dir)

# Configure for single GPU test
cfg.launcher.num_nodes = 1
cfg.launcher.gpus_per_node = 1

# Set environment variables for distributed training
main_port = random.randint(10000, 65000)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(main_port)
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

# Instantiate trainer
trainer = instantiate(cfg.trainer, _recursive_=False)

# Run training for 1 epoch
try:
trainer.run()
print("Training completed successfully!")
except Exception as e:
# Print the full error for debugging
import traceback
print(f"Training failed with error: {e}")
print(traceback.format_exc())
raise
finally:
# Copy cache directory to testdata/last_run for debugging
import shutil
last_run_dir = Path(__file__).parent / "testdata" / "last_run"
if last_run_dir.exists():
shutil.rmtree(last_run_dir)
shutil.copytree(cache_dir, last_run_dir)
print(f"Cache directory saved to: {last_run_dir}")


if __name__ == "__main__":
# For manual testing
os.environ["HYDRA_FULL_ERROR"] = "1"
test_instance_segmentation_finetune_minimal()
Loading