From 77319cfc6119184813b601e4b595f52bf2450e35 Mon Sep 17 00:00:00 2001 From: Dojo2024 Date: Tue, 28 Oct 2025 23:32:11 +0530 Subject: [PATCH 1/7] feat: Add initial SimCLR implementation --- main_pretrain.py | 110 +++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/simclr/data.py | 24 ++++++++++ src/simclr/loss.py | 45 ++++++++++++++++++ src/simclr/models.py | 84 +++++++++++++++++++++++++++++++++ 5 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 main_pretrain.py create mode 100644 src/simclr/data.py create mode 100644 src/simclr/loss.py create mode 100644 src/simclr/models.py diff --git a/main_pretrain.py b/main_pretrain.py new file mode 100644 index 0000000..87e8d15 --- /dev/null +++ b/main_pretrain.py @@ -0,0 +1,110 @@ +# src/simclr/main_pretrain.py + +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +from pathlib import Path +from tqdm import tqdm + +# --- Import components using the new modular structure --- +import src.data.flat_data as flat_data +import src.simclr.data as simclr_data +import src.flat_mae.models_mae as models_mae +import src.simclr.models as simclr_models +import src.simclr.loss as simclr_loss + +def main(): + """ + Main function to run a simplified SimCLR pre-training loop. + """ + print("--- Starting SimCLR Pre-training Verification Script ---") + + # --- 1. Configuration --- + # File Paths + # IMPORTANT: Paths are now relative to the project root, not the script location. + data_folder_path = Path("nsd-train-task-clips-16t") # Assuming this folder is in the project root + + # Model Hyperparameters + backbone_embed_dim = 384 # For ViT-Small + projection_hidden_dim = 512 + projection_output_dim = 128 # As used in the SimCLR paper + + # Training Hyperparameters + batch_size = 4 + learning_rate = 1e-4 + temperature = 0.5 + num_epochs = 3 # Run for a few epochs to see it work + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # --- 2. Data Pipeline Setup --- + print("\n--- Setting up Data Pipeline ---") + base_transform = flat_data.make_flat_transform( + img_size=(224, 560), + normalize='global', + random_crop=True, + crop_kwargs={'scale': (0.8, 1.0), 'ratio': (2.4, 2.6)} + ) + transform = simclr_data.SimCLRTransform(base_transform) + dataset = flat_data.FlatClipsDataset(root=data_folder_path, transform=transform) + data_loader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=simclr_data.simclr_collate, + shuffle=True, + num_workers=0 + ) + print("Data pipeline setup complete.") + + # --- 3. Model and Optimizer Setup --- + print("\n--- Setting up Model and Optimizer ---") + # Base Encoder (Backbone) + backbone = models_mae.mae_vit_small(img_size=(224, 560), in_chans=1) + + # Projection Head + projection_head = simclr_models.ProjectionHead( + input_dim=backbone_embed_dim, + hidden_dim=projection_hidden_dim, + output_dim=projection_output_dim + ) + + # Full SimCLR Model + model = simclr_models.SimCLRModel(backbone, projection_head).to(device) + + # Optimizer + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # Loss Function + criterion = simclr_loss.NTXentLoss(temperature=temperature).to(device) + + print("Model, optimizer, and loss function setup complete.") + + # --- 4. The Training Loop --- + print(f"\n--- Starting Training for {num_epochs} Epochs ---") + + for epoch in range(num_epochs): + loop = tqdm(data_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False) + total_loss = 0.0 + + for batch_view_1, batch_view_2 in loop: + view_1_images = batch_view_1['image'].to(device) + view_2_images = batch_view_2['image'].to(device) + + optimizer.zero_grad() + z1, z2 = model(view_1_images, view_2_images) + loss = criterion(z1, z2) + loss.backward() + optimizer.step() + + total_loss += loss.item() + loop.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(data_loader) + print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}") + + print("\n--- Training Verification Complete ---") + print("Refactoring successful. The script ran correctly from its new location.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ffcccfc..fe472b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "numpy", "omegaconf", "pre-commit", - "pydantic<2.7", # silence pydantic v2 warning + "pydantic", # silence pydantic v2 warning "pytest", "pycortex", "scikit-learn", diff --git a/src/simclr/data.py b/src/simclr/data.py new file mode 100644 index 0000000..d84860a --- /dev/null +++ b/src/simclr/data.py @@ -0,0 +1,24 @@ +import torch +from torch.utils.data import default_collate + +class SimCLRTransform: + + def __init__(self, base_transform): + + self.base_transform = base_transform + + def __call__(self, raw_sample): + + view_1 = self.base_transform(raw_sample) + view_2 = self.base_transform(raw_sample) + return view_1, view_2 + +def simclr_collate(batch): + + views_1 = [item[0] for item in batch] + views_2 = [item[1] for item in batch] + + collated_view_1 = default_collate(views_1) + collated_view_2 = default_collate(views_2) + + return collated_view_1, collated_view_2 diff --git a/src/simclr/loss.py b/src/simclr/loss.py new file mode 100644 index 0000000..ff0d81c --- /dev/null +++ b/src/simclr/loss.py @@ -0,0 +1,45 @@ +# src/flat_mae/simclr_loss.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class NTXentLoss(nn.Module): + + def __init__(self, temperature: float = 0.5): + + super().__init__() + self.temperature = temperature + self.criterion = nn.CrossEntropyLoss(reduction="sum") + self.similarity_f = nn.CosineSimilarity(dim=2) + + def forward(self, z1: torch.Tensor, z2: torch.Tensor): + + batch_size = z1.shape[0] + + representations = torch.cat([z1, z2], dim=0) + + similarity_matrix = self.similarity_f(representations.unsqueeze(1), representations.unsqueeze(0)) + + labels = torch.cat([ + torch.arange(batch_size) + batch_size, + torch.arange(batch_size) + ]).to(similarity_matrix.device) + + + mask = torch.eye(batch_size * 2, dtype=torch.bool).to(similarity_matrix.device) + + similarity_matrix = similarity_matrix[~mask].view(batch_size * 2, -1) + + labels_adjusted = labels.clone() + for i in range(batch_size, batch_size * 2): + if labels[i] > i: + labels_adjusted[i] -= 1 + + logits = similarity_matrix / self.temperature + + loss = self.criterion(logits, labels_adjusted) + + loss = loss / (2 * batch_size) + + return loss \ No newline at end of file diff --git a/src/simclr/models.py b/src/simclr/models.py new file mode 100644 index 0000000..83cec34 --- /dev/null +++ b/src/simclr/models.py @@ -0,0 +1,84 @@ +# src/flat_mae/simclr_models.py + +import torch +import torch.nn as nn + +class ProjectionHead(nn.Module): + """ + The Projection Head (g(·)) for the SimCLR framework. + As described in the paper, this is a small MLP that maps the representation (h) + to the latent space where the contrastive loss is applied. + """ + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + """ + Args: + input_dim (int): The feature dimension of the output from the base encoder (h). + hidden_dim (int): The dimension of the hidden layer in the MLP. + output_dim (int): The final output dimension of the projection (z). + """ + super().__init__() + + # The MLP consists of a linear layer, a non-linearity (ReLU), and another linear layer. + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + """ + Passes the representation vector through the MLP. + Args: + x (torch.Tensor): The representation vector (h) from the base encoder. + Returns: + torch.Tensor: The projected vector (z). + """ + return self.mlp(x) + + +class SimCLRModel(nn.Module): + """ + The complete SimCLR model, combining the base encoder and the projection head. + """ + def __init__(self, backbone: nn.Module, projection_head: nn.Module): + """ + Args: + backbone (nn.Module): The base encoder network (f(·)), e.g., a ViT. + It is expected to have a `forward_embedding` method. + projection_head (nn.Module): The MLP projection head (g(·)). + """ + super().__init__() + self.backbone = backbone + self.projection_head = projection_head + + def forward(self, view_1: torch.Tensor, view_2: torch.Tensor): + """ + Performs the forward pass for both augmented views as shown in Figure 2 of the paper. + + Args: + view_1 (torch.Tensor): The first batch of augmented images. + view_2 (torch.Tensor): The second batch of augmented images. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the two projected vectors (z1, z2). + """ + + # --- Process View 1 --- + # 1. Get the representation vector (h1) from the backbone encoder. + # We use `forward_embedding` to get the feature representation. The output is a tuple, + # and we'll typically use the CLS token as the representation. + _, _, h1_patches = self.backbone.forward_embedding(view_1) + + # For a ViT, a common representation is the average of all patch tokens. + h1 = h1_patches.mean(dim=1) + + # 2. Get the projection (z1) by passing h1 through the projection head. + z1 = self.projection_head(h1) + + # --- Process View 2 --- + # Repeat the exact same process for the second view. + _, _, h2_patches = self.backbone.forward_embedding(view_2) + h2 = h2_patches.mean(dim=1) + z2 = self.projection_head(h2) + + return z1, z2 \ No newline at end of file From 78a5cba2765843faabb3f126addb3109579ba13f Mon Sep 17 00:00:00 2001 From: Dojo2024 Date: Wed, 29 Oct 2025 00:24:56 +0530 Subject: [PATCH 2/7] feat: Add modular SimCLR and SimSiam implementation --- src/simclr/loss.py | 139 ++++++++++++++++++++++++++---------- src/simclr/main_pretrain.py | 119 ++++++++++++++++++++++++++++++ src/simclr/models.py | 121 ++++++++++++++++--------------- 3 files changed, 282 insertions(+), 97 deletions(-) create mode 100644 src/simclr/main_pretrain.py diff --git a/src/simclr/loss.py b/src/simclr/loss.py index ff0d81c..1f466e2 100644 --- a/src/simclr/loss.py +++ b/src/simclr/loss.py @@ -1,45 +1,106 @@ -# src/flat_mae/simclr_loss.py - import torch import torch.nn as nn import torch.nn.functional as F +import os -class NTXentLoss(nn.Module): - - def __init__(self, temperature: float = 0.5): - +class SimCLRProjectionHead(nn.Module): + """ The g(·) projection head for SimCLR. """ + def __init__(self, in_dim, hidden_dim=4096, out_dim=1024): super().__init__() - self.temperature = temperature - self.criterion = nn.CrossEntropyLoss(reduction="sum") - self.similarity_f = nn.CosineSimilarity(dim=2) - - def forward(self, z1: torch.Tensor, z2: torch.Tensor): - - batch_size = z1.shape[0] - - representations = torch.cat([z1, z2], dim=0) - - similarity_matrix = self.similarity_f(representations.unsqueeze(1), representations.unsqueeze(0)) - - labels = torch.cat([ - torch.arange(batch_size) + batch_size, - torch.arange(batch_size) - ]).to(similarity_matrix.device) - - - mask = torch.eye(batch_size * 2, dtype=torch.bool).to(similarity_matrix.device) - - similarity_matrix = similarity_matrix[~mask].view(batch_size * 2, -1) - - labels_adjusted = labels.clone() - for i in range(batch_size, batch_size * 2): - if labels[i] > i: - labels_adjusted[i] -= 1 - - logits = similarity_matrix / self.temperature - - loss = self.criterion(logits, labels_adjusted) - - loss = loss / (2 * batch_size) + self.head = nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, out_dim), + ) + + def forward(self, x): + return self.head(x) + +class SimSiamProjectionHead(nn.Module): + """ The g(·) projection head for SimSiam. """ + def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): + super().__init__() + self.head = nn.Sequential( + nn.Linear(in_dim, hidden_dim, bias=False), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim, bias=False), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, out_dim, bias=False), + nn.BatchNorm1d(out_dim, affine=False), + ) + + def forward(self, x): + return self.head(x) + +class SimSiamPredictionHead(nn.Module): + """ The h(·) prediction head for SimSiam. """ + def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): + super().__init__() + self.head = nn.Sequential( + nn.Linear(in_dim, hidden_dim, bias=False), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, out_dim), + ) + + def forward(self, x): + return self.head(x) + +def nt_xent_loss(z1, z2, temperature=0.5, distributed=False): + """ The NT-Xent loss for SimCLR. """ + z1 = F.normalize(z1, dim=-1) + z2 = F.normalize(z2, dim=-1) + + # Concatenate all representations + all_z = torch.cat([z1, z2], dim=0) + + if distributed: + # Gather representations from all GPUs in a distributed setting + all_z_dist = torch.cat(torch.distributed.nn.all_gather(all_z), dim=0) + world_size = int(os.getenv("WORLD_SIZE", 1)) + rank = int(os.getenv("LOCAL_RANK", 0)) + else: + all_z_dist = all_z + world_size = 1 + rank = 0 + + # Calculate pairwise similarity + logits = torch.matmul(all_z, all_z_dist.T) / temperature + + # Create labels + batch_size = z1.shape[0] + labels = torch.arange(batch_size, device=z1.device) + labels = labels + (rank * batch_size) + + # The positive pair for z1[i] is z2[i], which is at index (i + batch_size) in all_z + # And the positive pair for z2[i] is z1[i], which is at index i in all_z + labels = torch.cat([labels + batch_size, labels], dim=0) + + # We need to mask out the similarity of an embedding with itself + mask = ~torch.eye(2 * batch_size, device=z1.device, dtype=torch.bool) + logits = logits[mask].view(2 * batch_size, -1) + + # Adjust labels because of the mask + labels_adjusted = labels.clone() + for i in range(2 * batch_size): + if labels[i] > i: + labels_adjusted[i] -= 1 + + loss = F.cross_entropy(logits, labels_adjusted, reduction="sum") + return loss / (2 * batch_size) + + +def simsiam_loss(p1, z2, p2, z1): + """ The loss for SimSiam. """ + # Stop-gradient: we don't want gradients to flow from z to the encoder + z1 = z1.detach() + z2 = z2.detach() + + loss1 = -F.cosine_similarity(p1, z2, dim=-1).mean() + loss2 = -F.cosine_similarity(p2, z1, dim=-1).mean() + + return (loss1 + loss2) / 2 - return loss \ No newline at end of file diff --git a/src/simclr/main_pretrain.py b/src/simclr/main_pretrain.py new file mode 100644 index 0000000..80b5b23 --- /dev/null +++ b/src/simclr/main_pretrain.py @@ -0,0 +1,119 @@ +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +from pathlib import Path +from tqdm import tqdm + +# --- Import our custom and project-level components --- +# Data-related imports +import src.data.flat_data as flat_data +from simclr.data import SimCLRTransform, simclr_collate # Our new data components + +# Model-related imports +from flat_mae.models_mae import MaskedViT # Reusing the backbone +from simclr.models import ContrastiveModel # Our new main model + +def main(): + """ + Main function to run a simplified contrastive pre-training loop. + Supports both 'simclr' and 'simsiam' modes. + """ + print("--- Starting Contrastive Pre-training Script ---") + + # --- 1. Configuration --- + # File Paths + data_folder_path = Path("nsd-train-task-clips-16t") # Assumes this is in the project root + + # Model Hyperparameters + CONTRASTIVE_MODE = "simclr" # <-- CHANGE THIS TO "simsiam" TO TEST THE OTHER MODE + BACKBONE_EMBED_DIM = 384 # For ViT-Small + + # Training Hyperparameters + BATCH_SIZE = 4 + LEARNING_RATE = 1e-4 + NUM_EPOCHS = 5 # Run for a few epochs for verification + MASK_RATIO = 0.75 # Ratio of patches to mask in the encoder + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + print(f"Running in mode: '{CONTRASTIVE_MODE}'") + + # --- 2. Data Pipeline Setup --- + print("\n--- Setting up Data Pipeline ---") + # Base transform with stochastic augmentations (e.g., random crop) + base_transform = flat_data.make_flat_transform( + img_size=(224, 560), + normalize='global', + random_crop=True, + crop_kwargs={'scale': (0.8, 1.0), 'ratio': (2.4, 2.6)} + ) + # Wrap it to produce two views + simclr_transform = SimCLRTransform(base_transform) + + dataset = flat_data.FlatClipsDataset(root=data_folder_path, transform=simclr_transform) + + data_loader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + collate_fn=simclr_collate, # Use our custom collate function + shuffle=True, + num_workers=0 + ) + print("Data pipeline setup complete.") + + # --- 3. Model and Optimizer Setup --- + print("\n--- Setting up Model and Optimizer ---") + # Initialize the backbone encoder + backbone = MaskedViT( + img_size=(224, 560), + in_chans=1, + embed_dim=BACKBONE_EMBED_DIM, + depth=12, + num_heads=6 + ) + + # Initialize our main ContrastiveModel + model = ContrastiveModel( + backbone=backbone, + mode=CONTRASTIVE_MODE, + embed_dim=BACKBONE_EMBED_DIM + ).to(device) + + # Set up the optimizer + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + + print("Model and optimizer setup complete.") + print(f"Total model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M") + + # --- 4. The Training Loop --- + print(f"\n--- Starting Training for {NUM_EPOCHS} Epochs ---") + + model.train() # Set the model to training mode + for epoch in range(NUM_EPOCHS): + loop = tqdm(data_loader, desc=f"Epoch [{epoch+1}/{NUM_EPOCHS}]", leave=False) + total_loss = 0.0 + + for batch_view_1, batch_view_2 in loop: + # Move data to the configured device + view_1_images = batch_view_1['image'].to(device) + view_2_images = batch_view_2['image'].to(device) + + # --- Core Training Steps --- + optimizer.zero_grad() + loss = model(view_1_images, view_2_images, mask_ratio=MASK_RATIO) + loss.backward() + optimizer.step() + # ------------------------- + + total_loss += loss.item() + loop.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(data_loader) + print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] completed. Average Loss: {avg_loss:.4f}") + + print("\n--- Training Verification Complete ---") + print("The script successfully ran a few training epochs without crashing.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/simclr/models.py b/src/simclr/models.py index 83cec34..a8dd12f 100644 --- a/src/simclr/models.py +++ b/src/simclr/models.py @@ -1,84 +1,89 @@ -# src/flat_mae/simclr_models.py - import torch import torch.nn as nn +from src.flat_mae.models_mae import MaskedViT # We reuse the encoder from the MAE implementation +from simclr.loss import ( + SimCLRProjectionHead, + SimSiamProjectionHead, + SimSiamPredictionHead, + nt_xent_loss, + simsiam_loss, +) -class ProjectionHead(nn.Module): +class ContrastiveModel(nn.Module): """ - The Projection Head (g(·)) for the SimCLR framework. - As described in the paper, this is a small MLP that maps the representation (h) - to the latent space where the contrastive loss is applied. + A unified model for contrastive learning, supporting both SimCLR and SimSiam. """ - def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + def __init__(self, backbone: MaskedViT, mode: str = "simclr", embed_dim: int = 384): """ Args: - input_dim (int): The feature dimension of the output from the base encoder (h). - hidden_dim (int): The dimension of the hidden layer in the MLP. - output_dim (int): The final output dimension of the projection (z). + backbone (MaskedViT): The pre-trained or randomly initialized backbone encoder. + mode (str): The contrastive learning mode. Can be "simclr" or "simsiam". + embed_dim (int): The output dimension of the backbone encoder. """ super().__init__() + if mode not in ["simclr", "simsiam"]: + raise ValueError(f"Invalid contrastive mode: {mode}") - # The MLP consists of a linear layer, a non-linearity (ReLU), and another linear layer. - self.mlp = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.ReLU(inplace=True), - nn.Linear(hidden_dim, output_dim) - ) - - def forward(self, x): - """ - Passes the representation vector through the MLP. - Args: - x (torch.Tensor): The representation vector (h) from the base encoder. - Returns: - torch.Tensor: The projected vector (z). - """ - return self.mlp(x) + self.mode = mode + self.backbone = backbone + if self.mode == "simclr": + # For SimCLR, we only need a projection head. + self.projection_head = SimCLRProjectionHead(in_dim=embed_dim) + + elif self.mode == "simsiam": + # For SimSiam, we need both a projection head and a prediction head. + self.projection_head = SimSiamProjectionHead(in_dim=embed_dim) + self.prediction_head = SimSiamPredictionHead() -class SimCLRModel(nn.Module): - """ - The complete SimCLR model, combining the base encoder and the projection head. - """ - def __init__(self, backbone: nn.Module, projection_head: nn.Module): + def get_representation(self, x: torch.Tensor, mask_ratio: float): """ - Args: - backbone (nn.Module): The base encoder network (f(·)), e.g., a ViT. - It is expected to have a `forward_embedding` method. - projection_head (nn.Module): The MLP projection head (g(·)). + A helper function to pass an input through the backbone and get the CLS token. """ - super().__init__() - self.backbone = backbone - self.projection_head = projection_head + # The MAE backbone returns (cls_token, reg_tokens, patch_tokens, mask, ids_keep) + # We only need the cls_token for contrastive learning. + cls_embeds, _, _, _, _ = self.backbone(x, mask_ratio=mask_ratio) + # The cls_embeds has a shape of [Batch, 1, Dim], so we squeeze it. + return cls_embeds.squeeze(1) - def forward(self, view_1: torch.Tensor, view_2: torch.Tensor): + def forward(self, view_1: torch.Tensor, view_2: torch.Tensor, mask_ratio: float): """ - Performs the forward pass for both augmented views as shown in Figure 2 of the paper. + The main forward pass. It takes two augmented views and computes the final loss. Args: view_1 (torch.Tensor): The first batch of augmented images. view_2 (torch.Tensor): The second batch of augmented images. + mask_ratio (float): The ratio of patches to mask in the encoder. Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing the two projected vectors (z1, z2). + torch.Tensor: The final calculated loss for the batch. """ - # --- Process View 1 --- - # 1. Get the representation vector (h1) from the backbone encoder. - # We use `forward_embedding` to get the feature representation. The output is a tuple, - # and we'll typically use the CLS token as the representation. - _, _, h1_patches = self.backbone.forward_embedding(view_1) - - # For a ViT, a common representation is the average of all patch tokens. - h1 = h1_patches.mean(dim=1) - - # 2. Get the projection (z1) by passing h1 through the projection head. - z1 = self.projection_head(h1) + # Get the representations (h1, h2) from the backbone for each view + h1 = self.get_representation(view_1, mask_ratio) + h2 = self.get_representation(view_2, mask_ratio) + + if self.mode == "simclr": + # --- SimCLR Forward Pass --- + # 1. Get the projections (z1, z2) + z1 = self.projection_head(h1) + z2 = self.projection_head(h2) + + # 2. Calculate the loss + loss = nt_xent_loss(z1, z2) + return loss + + elif self.mode == "simsiam": + # --- SimSiam Forward Pass --- + # 1. Get the projections (z1, z2) + z1 = self.projection_head(h1) + z2 = self.projection_head(h2) + + # 2. Get the predictions (p1, p2) + p1 = self.prediction_head(z1) + p2 = self.prediction_head(z2) - # --- Process View 2 --- - # Repeat the exact same process for the second view. - _, _, h2_patches = self.backbone.forward_embedding(view_2) - h2 = h2_patches.mean(dim=1) - z2 = self.projection_head(h2) + # 3. Calculate the loss + loss = simsiam_loss(p1, z2, p2, z1) + return loss - return z1, z2 \ No newline at end of file From 9bbd98e6906aad74e165cd6f9942f4cb44d4636f Mon Sep 17 00:00:00 2001 From: Dojo2024 Date: Thu, 30 Oct 2025 23:50:09 +0530 Subject: [PATCH 3/7] refactor(simclr): Refactor pre-training script This commit refactors the SimCLR/SimSiam pre-training script based on the feedback. --- src/simclr/loss.py | 66 ++++----- src/simclr/main_pretrain.py | 286 +++++++++++++++++++++++------------- src/simclr/models.py | 73 +++------ 3 files changed, 226 insertions(+), 199 deletions(-) diff --git a/src/simclr/loss.py b/src/simclr/loss.py index 1f466e2..16c6b6c 100644 --- a/src/simclr/loss.py +++ b/src/simclr/loss.py @@ -1,15 +1,20 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# SimSiam: https://github.com/facebookresearch/simsiam +# -------------------------------------------------------- + import torch import torch.nn as nn import torch.nn.functional as F import os class SimCLRProjectionHead(nn.Module): - """ The g(·) projection head for SimCLR. """ - def __init__(self, in_dim, hidden_dim=4096, out_dim=1024): + def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 128): super().__init__() self.head = nn.Sequential( nn.Linear(in_dim, hidden_dim), - nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, out_dim), ) @@ -18,9 +23,9 @@ def forward(self, x): return self.head(x) class SimSiamProjectionHead(nn.Module): - """ The g(·) projection head for SimSiam. """ - def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): + def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048): super().__init__() + self.head = nn.Sequential( nn.Linear(in_dim, hidden_dim, bias=False), nn.BatchNorm1d(hidden_dim), @@ -36,8 +41,7 @@ def forward(self, x): return self.head(x) class SimSiamPredictionHead(nn.Module): - """ The h(·) prediction head for SimSiam. """ - def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): + def __init__(self, in_dim: int = 2048, hidden_dim: int = 512, out_dim: int = 2048): super().__init__() self.head = nn.Sequential( nn.Linear(in_dim, hidden_dim, bias=False), @@ -49,53 +53,34 @@ def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): def forward(self, x): return self.head(x) -def nt_xent_loss(z1, z2, temperature=0.5, distributed=False): - """ The NT-Xent loss for SimCLR. """ + +def nt_xent_loss(z1: torch.Tensor, z2: torch.Tensor, temperature: float = 0.5, distributed: bool = False): + z1 = F.normalize(z1, dim=-1) z2 = F.normalize(z2, dim=-1) - - # Concatenate all representations + all_z = torch.cat([z1, z2], dim=0) if distributed: - # Gather representations from all GPUs in a distributed setting all_z_dist = torch.cat(torch.distributed.nn.all_gather(all_z), dim=0) - world_size = int(os.getenv("WORLD_SIZE", 1)) rank = int(os.getenv("LOCAL_RANK", 0)) else: all_z_dist = all_z - world_size = 1 rank = 0 - # Calculate pairwise similarity logits = torch.matmul(all_z, all_z_dist.T) / temperature - # Create labels batch_size = z1.shape[0] - labels = torch.arange(batch_size, device=z1.device) - labels = labels + (rank * batch_size) - - # The positive pair for z1[i] is z2[i], which is at index (i + batch_size) in all_z - # And the positive pair for z2[i] is z1[i], which is at index i in all_z - labels = torch.cat([labels + batch_size, labels], dim=0) - - # We need to mask out the similarity of an embedding with itself - mask = ~torch.eye(2 * batch_size, device=z1.device, dtype=torch.bool) - logits = logits[mask].view(2 * batch_size, -1) - - # Adjust labels because of the mask - labels_adjusted = labels.clone() - for i in range(2 * batch_size): - if labels[i] > i: - labels_adjusted[i] -= 1 - - loss = F.cross_entropy(logits, labels_adjusted, reduction="sum") - return loss / (2 * batch_size) - - -def simsiam_loss(p1, z2, p2, z1): - """ The loss for SimSiam. """ - # Stop-gradient: we don't want gradients to flow from z to the encoder + labels_v1 = torch.arange(batch_size, device=z1.device) + batch_size + labels_v2 = torch.arange(batch_size, device=z1.device) + labels = torch.cat([labels_v1, labels_v2], dim=0) + + labels = labels + (rank * 2 * batch_size) + + return F.cross_entropy(logits, labels) + +def simsiam_loss(p1: torch.Tensor, z2: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor): + z1 = z1.detach() z2 = z2.detach() @@ -103,4 +88,3 @@ def simsiam_loss(p1, z2, p2, z1): loss2 = -F.cosine_similarity(p2, z1, dim=-1).mean() return (loss1 + loss2) / 2 - diff --git a/src/simclr/main_pretrain.py b/src/simclr/main_pretrain.py index 80b5b23..18ba4f0 100644 --- a/src/simclr/main_pretrain.py +++ b/src/simclr/main_pretrain.py @@ -1,119 +1,197 @@ -import torch -import torch.optim as optim -from torch.utils.data import DataLoader +import argparse +import datetime +import json +import os +import random +import time from pathlib import Path -from tqdm import tqdm - -# --- Import our custom and project-level components --- -# Data-related imports -import src.data.flat_data as flat_data -from simclr.data import SimCLRTransform, simclr_collate # Our new data components - -# Model-related imports -from flat_mae.models_mae import MaskedViT # Reusing the backbone -from simclr.models import ContrastiveModel # Our new main model - -def main(): - """ - Main function to run a simplified contrastive pre-training loop. - Supports both 'simclr' and 'simsiam' modes. - """ - print("--- Starting Contrastive Pre-training Script ---") - - # --- 1. Configuration --- - # File Paths - data_folder_path = Path("nsd-train-task-clips-16t") # Assumes this is in the project root - - # Model Hyperparameters - CONTRASTIVE_MODE = "simclr" # <-- CHANGE THIS TO "simsiam" TO TEST THE OTHER MODE - BACKBONE_EMBED_DIM = 384 # For ViT-Small - - # Training Hyperparameters - BATCH_SIZE = 4 - LEARNING_RATE = 1e-4 - NUM_EPOCHS = 5 # Run for a few epochs for verification - MASK_RATIO = 0.75 # Ratio of patches to mask in the encoder - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - print(f"Running in mode: '{CONTRASTIVE_MODE}'") - - # --- 2. Data Pipeline Setup --- - print("\n--- Setting up Data Pipeline ---") - # Base transform with stochastic augmentations (e.g., random crop) - base_transform = flat_data.make_flat_transform( - img_size=(224, 560), - normalize='global', - random_crop=True, - crop_kwargs={'scale': (0.8, 1.0), 'ratio': (2.4, 2.6)} - ) - # Wrap it to produce two views - simclr_transform = SimCLRTransform(base_transform) - - dataset = flat_data.FlatClipsDataset(root=data_folder_path, transform=simclr_transform) - - data_loader = DataLoader( - dataset, - batch_size=BATCH_SIZE, - collate_fn=simclr_collate, # Use our custom collate function - shuffle=True, - num_workers=0 - ) - print("Data pipeline setup complete.") - - # --- 3. Model and Optimizer Setup --- - print("\n--- Setting up Model and Optimizer ---") - # Initialize the backbone encoder - backbone = MaskedViT( - img_size=(224, 560), - in_chans=1, - embed_dim=BACKBONE_EMBED_DIM, - depth=12, - num_heads=6 +import math + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import wandb +from omegaconf import DictConfig, OmegaConf +from torch.utils.data.distributed import DistributedSampler + +import flat_mae.utils as ut +import flat_mae.models_mae as models_mae + +import data.flat_data as flat_data + +from simclr.data import SimCLRTransform, simclr_collate +from simclr.models import ContrastiveModel +from simclr.loss import nt_xent_loss, simsiam_loss + +import flat_mae.models_mae as models_mae + +PROJECT = "fMRI-foundation-model" + + +BACKBONE_MODELS_DICT = models_mae.__dict__ + +def main(args: DictConfig): + ut.init_distributed_mode(args) + global_rank = ut.get_rank() + is_master = global_rank == 0 + world_size = ut.get_world_size() + device = torch.device(args.device) + ut.random_seed(args.seed, rank=global_rank) + + if args.name and not args.output_dir.endswith(args.name): + args.output_dir = f"{args.output_dir}/{args.name}" + output_dir = Path(args.output_dir) + + if is_master: + output_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(args, output_dir / "config.yaml") + + ut.setup_for_distributed(log_path=output_dir / "log.txt") + print(f"pretraining with {args.model.contrastive_mode}") + print("config:", OmegaConf.to_yaml(args), sep="\n") + + train_loader, eval_loaders = create_data_loaders(args) + + print(f"Creating backbone: {args.model.backbone_name}") + backbone = BACKBONE_MODELS_DICT[args.model.backbone_name]( + img_size=args.data.img_size, + in_chans=args.data.in_chans, + **args.model.get("backbone_kwargs", {}), ) - - # Initialize our main ContrastiveModel model = ContrastiveModel( backbone=backbone, - mode=CONTRASTIVE_MODE, - embed_dim=BACKBONE_EMBED_DIM + mode=args.model.contrastive_mode, + embed_dim=args.model.backbone_kwargs.embed_dim, + model_kwargs=args.model.get("head_kwargs"), ).to(device) - - # Set up the optimizer - optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) - - print("Model and optimizer setup complete.") - print(f"Total model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M") - # --- 4. The Training Loop --- - print(f"\n--- Starting Training for {NUM_EPOCHS} Epochs ---") + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + print("model:", model, sep="\n") + + total_batch_size = args.optim.batch_size * args.optim.accum_iter * world_size + if not args.optim.get("lr"): args.optim.lr = args.optim.base_lr * total_batch_size / 256 + + param_groups = ut.get_param_groups(model) + ut.update_lr(param_groups, args.optim.lr) + ut.update_wd(param_groups, args.optim.weight_decay) + optimizer = torch.optim.AdamW(param_groups, betas=tuple(args.optim.betas)) + + epoch_num_batches = len(train_loader) + steps_per_epoch = epoch_num_batches // args.optim.accum_iter + total_steps = args.optim.epochs * steps_per_epoch + warmup_steps = args.optim.warmup_epochs * steps_per_epoch + lr_schedule = ut.WarmupThenCosine( + base_value=args.optim.lr, final_value=args.optim.min_lr, + total_iters=total_steps, warmup_iters=warmup_steps + ) + + loss_scaler = ut.GradScaler() if args.amp and args.amp_dtype != 'bfloat16' else None + + ut.load_model(args, model_without_ddp, optimizer, loss_scaler) + + print(f"start training for {args.optim.epochs} epochs") + for epoch in range(args.start_epoch, args.optim.epochs): + if args.distributed: train_loader.sampler.set_epoch(epoch) + + train_stats = train_one_epoch(args, model, train_loader, optimizer, loss_scaler, lr_schedule, epoch, device) + + + if args.output_dir: + ut.save_model(args, epoch, model_without_ddp, optimizer, loss_scaler) + + +def create_data_loaders(args: DictConfig): + base_transform = flat_data.make_flat_transform( + img_size=args.data.img_size, + clip_vmax=args.data.get("clip_vmax"), + normalize=args.data.get("normalize"), + random_crop=args.data.get("random_crop", False), + crop_kwargs=args.data.get("crop_kwargs"), + ) + + transform = SimCLRTransform(base_transform) + + data_loaders = {} + dataset_names = [args.train_dataset] + args.eval_datasets + for name in dataset_names: + config = args.datasets[name] + dataset = flat_data.FlatClipsDataset(root=config.root, transform=transform) + sampler = DistributedSampler(dataset, shuffle=config.shuffle) if args.distributed else None + + loader = flat_data.DataLoader( + dataset, batch_size=args.optim.batch_size, + collate_fn=simclr_collate, sampler=sampler, + shuffle=sampler is None and config.shuffle, + num_workers=args.num_workers, pin_memory=True, drop_last=True + ) + data_loaders[name] = loader - model.train() # Set the model to training mode - for epoch in range(NUM_EPOCHS): - loop = tqdm(data_loader, desc=f"Epoch [{epoch+1}/{NUM_EPOCHS}]", leave=False) - total_loss = 0.0 + train_loader = data_loaders.pop(args.train_dataset) + return train_loader, data_loaders + + +def train_one_epoch(args, model, data_loader, optimizer, loss_scaler, lr_schedule, epoch, device): + # --- This is the training engine, adapted for SimCLR/SimSiam --- + model.train() + metric_logger = ut.MetricLogger(delimiter=" ") + header = f'Train: [{epoch}]' + + epoch_num_batches = len(data_loader) + steps_per_epoch = epoch_num_batches // args.optim.accum_iter + + optimizer.zero_grad() + + for batch_idx, (batch_view_1, batch_view_2) in enumerate(metric_logger.log_every(data_loader, 100, header)): + + global_step = epoch * steps_per_epoch + (batch_idx + 1) // args.optim.accum_iter + lr = lr_schedule[global_step] + need_update = (batch_idx + 1) % args.optim.accum_iter == 0 + if need_update: ut.update_lr(optimizer.param_groups, lr) - for batch_view_1, batch_view_2 in loop: - # Move data to the configured device - view_1_images = batch_view_1['image'].to(device) - view_2_images = batch_view_2['image'].to(device) + view_1 = batch_view_1['image'].to(device, non_blocking=True) + view_2 = batch_view_2['image'].to(device, non_blocking=True) - # --- Core Training Steps --- - optimizer.zero_grad() - loss = model(view_1_images, view_2_images, mask_ratio=MASK_RATIO) - loss.backward() - optimizer.step() - # ------------------------- + with torch.autocast(device_type=device.type, dtype=getattr(torch, args.amp_dtype), enabled=args.amp): - total_loss += loss.item() - loop.set_postfix(loss=loss.item()) + outputs = model(view_1, view_2, mask_ratio=args.model.mask_ratio) - avg_loss = total_loss / len(data_loader) - print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] completed. Average Loss: {avg_loss:.4f}") + if args.model.contrastive_mode == "simclr": + z1, z2 = outputs + loss = nt_xent_loss(z1, z2, temperature=args.model.get("temperature", 0.5), distributed=args.distributed) + elif args.model.contrastive_mode == "simsiam": + p1, z2, p2, z1 = outputs + loss = simsiam_loss(p1, z2, p2, z1) + else: + raise ValueError(f"Unknown contrastive mode: {args.model.contrastive_mode}") - print("\n--- Training Verification Complete ---") - print("The script successfully ran a few training epochs without crashing.") + loss_value = loss.item() + if not math.isfinite(loss_value): raise RuntimeError(f"Loss is {loss_value}, stopping training") + + ut.backward_step(loss / args.optim.accum_iter, optimizer, scaler=loss_scaler, + need_update=need_update, max_norm=args.optim.get("clip_grad")) + + metric_logger.update(loss=loss_value) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + +@torch.no_grad() +def evaluate(model, data_loader, name, device, args): + model.eval() + print(f"--- Running evaluation on {name} ---") if __name__ == "__main__": - main() \ No newline at end of file + parser = argparse.ArgumentParser() + parser.add_argument("--cfg-path", type=str, required=True) + parser.add_argument("--overrides", type=str, default=None, nargs="+") + cli_args = parser.parse_args() + + cfg = OmegaConf.load(cli_args.cfg_path) + if cli_args.overrides: + cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(cli_args.overrides)) + + main(cfg) \ No newline at end of file diff --git a/src/simclr/models.py b/src/simclr/models.py index a8dd12f..8a39da9 100644 --- a/src/simclr/models.py +++ b/src/simclr/models.py @@ -1,89 +1,54 @@ +# This source code is licensed under the CC-BY-NC license +# found in the LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# SimSiam: https://github.com/facebookresearch/simsiam +# -------------------------------------------------------- + import torch import torch.nn as nn -from src.flat_mae.models_mae import MaskedViT # We reuse the encoder from the MAE implementation + + +from flat_mae.models_mae import MaskedViT + from simclr.loss import ( SimCLRProjectionHead, SimSiamProjectionHead, SimSiamPredictionHead, - nt_xent_loss, - simsiam_loss, ) class ContrastiveModel(nn.Module): - """ - A unified model for contrastive learning, supporting both SimCLR and SimSiam. - """ def __init__(self, backbone: MaskedViT, mode: str = "simclr", embed_dim: int = 384): - """ - Args: - backbone (MaskedViT): The pre-trained or randomly initialized backbone encoder. - mode (str): The contrastive learning mode. Can be "simclr" or "simsiam". - embed_dim (int): The output dimension of the backbone encoder. - """ super().__init__() if mode not in ["simclr", "simsiam"]: raise ValueError(f"Invalid contrastive mode: {mode}") - + self.mode = mode self.backbone = backbone if self.mode == "simclr": - # For SimCLR, we only need a projection head. self.projection_head = SimCLRProjectionHead(in_dim=embed_dim) - + elif self.mode == "simsiam": - # For SimSiam, we need both a projection head and a prediction head. self.projection_head = SimSiamProjectionHead(in_dim=embed_dim) self.prediction_head = SimSiamPredictionHead() def get_representation(self, x: torch.Tensor, mask_ratio: float): - """ - A helper function to pass an input through the backbone and get the CLS token. - """ - # The MAE backbone returns (cls_token, reg_tokens, patch_tokens, mask, ids_keep) - # We only need the cls_token for contrastive learning. - cls_embeds, _, _, _, _ = self.backbone(x, mask_ratio=mask_ratio) - # The cls_embeds has a shape of [Batch, 1, Dim], so we squeeze it. + cls_embeds, _, _ = self.backbone.forward_embedding(x, mask_ratio=mask_ratio) return cls_embeds.squeeze(1) def forward(self, view_1: torch.Tensor, view_2: torch.Tensor, mask_ratio: float): - """ - The main forward pass. It takes two augmented views and computes the final loss. - - Args: - view_1 (torch.Tensor): The first batch of augmented images. - view_2 (torch.Tensor): The second batch of augmented images. - mask_ratio (float): The ratio of patches to mask in the encoder. - - Returns: - torch.Tensor: The final calculated loss for the batch. - """ - - # Get the representations (h1, h2) from the backbone for each view h1 = self.get_representation(view_1, mask_ratio) h2 = self.get_representation(view_2, mask_ratio) + z1 = self.projection_head(h1) + z2 = self.projection_head(h2) + if self.mode == "simclr": - # --- SimCLR Forward Pass --- - # 1. Get the projections (z1, z2) - z1 = self.projection_head(h1) - z2 = self.projection_head(h2) - - # 2. Calculate the loss - loss = nt_xent_loss(z1, z2) - return loss + return z1, z2 elif self.mode == "simsiam": - # --- SimSiam Forward Pass --- - # 1. Get the projections (z1, z2) - z1 = self.projection_head(h1) - z2 = self.projection_head(h2) - - # 2. Get the predictions (p1, p2) p1 = self.prediction_head(z1) p2 = self.prediction_head(z2) - # 3. Calculate the loss - loss = simsiam_loss(p1, z2, p2, z1) - return loss - + return p1, z2, p2, z1 \ No newline at end of file From b578fc30a35d7f4ca13ba34e7fe5fcb9b92983b5 Mon Sep 17 00:00:00 2001 From: Dojo2024 <119791514+Dojo2024@users.noreply.github.com> Date: Thu, 30 Oct 2025 23:52:04 +0530 Subject: [PATCH 4/7] Refactor main_pretrain.py --- main_pretrain.py | 273 +++++++++++++++++++++++++++++++---------------- 1 file changed, 180 insertions(+), 93 deletions(-) diff --git a/main_pretrain.py b/main_pretrain.py index 87e8d15..bdefe91 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -1,110 +1,197 @@ -# src/simclr/main_pretrain.py +import argparse +import datetime +import json +import os +import random +import time +from pathlib import Path +import math +import numpy as np import torch -import torch.optim as optim -from torch.utils.data import DataLoader -from pathlib import Path -from tqdm import tqdm - -# --- Import components using the new modular structure --- -import src.data.flat_data as flat_data -import src.simclr.data as simclr_data -import src.flat_mae.models_mae as models_mae -import src.simclr.models as simclr_models -import src.simclr.loss as simclr_loss - -def main(): - """ - Main function to run a simplified SimCLR pre-training loop. - """ - print("--- Starting SimCLR Pre-training Verification Script ---") - - # --- 1. Configuration --- - # File Paths - # IMPORTANT: Paths are now relative to the project root, not the script location. - data_folder_path = Path("nsd-train-task-clips-16t") # Assuming this folder is in the project root - - # Model Hyperparameters - backbone_embed_dim = 384 # For ViT-Small - projection_hidden_dim = 512 - projection_output_dim = 128 # As used in the SimCLR paper - - # Training Hyperparameters - batch_size = 4 - learning_rate = 1e-4 - temperature = 0.5 - num_epochs = 3 # Run for a few epochs to see it work +import torch.backends.cudnn as cudnn +import wandb +from omegaconf import DictConfig, OmegaConf +from torch.utils.data.distributed import DistributedSampler - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") +import flat_mae.utils as ut +import flat_mae.models_mae as models_mae - # --- 2. Data Pipeline Setup --- - print("\n--- Setting up Data Pipeline ---") - base_transform = flat_data.make_flat_transform( - img_size=(224, 560), - normalize='global', - random_crop=True, - crop_kwargs={'scale': (0.8, 1.0), 'ratio': (2.4, 2.6)} +import data.flat_data as flat_data + +from simclr.data import SimCLRTransform, simclr_collate +from simclr.models import ContrastiveModel +from simclr.loss import nt_xent_loss, simsiam_loss + +import flat_mae.models_mae as models_mae + +PROJECT = "fMRI-foundation-model" + + +BACKBONE_MODELS_DICT = models_mae.__dict__ + +def main(args: DictConfig): + ut.init_distributed_mode(args) + global_rank = ut.get_rank() + is_master = global_rank == 0 + world_size = ut.get_world_size() + device = torch.device(args.device) + ut.random_seed(args.seed, rank=global_rank) + + if args.name and not args.output_dir.endswith(args.name): + args.output_dir = f"{args.output_dir}/{args.name}" + output_dir = Path(args.output_dir) + + if is_master: + output_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(args, output_dir / "config.yaml") + + ut.setup_for_distributed(log_path=output_dir / "log.txt") + print(f"pretraining with {args.model.contrastive_mode}") + print("config:", OmegaConf.to_yaml(args), sep="\n") + + train_loader, eval_loaders = create_data_loaders(args) + + print(f"Creating backbone: {args.model.backbone_name}") + backbone = BACKBONE_MODELS_DICT[args.model.backbone_name]( + img_size=args.data.img_size, + in_chans=args.data.in_chans, + **args.model.get("backbone_kwargs", {}), ) - transform = simclr_data.SimCLRTransform(base_transform) - dataset = flat_data.FlatClipsDataset(root=data_folder_path, transform=transform) - data_loader = DataLoader( - dataset, - batch_size=batch_size, - collate_fn=simclr_data.simclr_collate, - shuffle=True, - num_workers=0 + model = ContrastiveModel( + backbone=backbone, + mode=args.model.contrastive_mode, + embed_dim=args.model.backbone_kwargs.embed_dim, + model_kwargs=args.model.get("head_kwargs"), + ).to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + print("model:", model, sep="\n") + + total_batch_size = args.optim.batch_size * args.optim.accum_iter * world_size + if not args.optim.get("lr"): args.optim.lr = args.optim.base_lr * total_batch_size / 256 + + param_groups = ut.get_param_groups(model) + ut.update_lr(param_groups, args.optim.lr) + ut.update_wd(param_groups, args.optim.weight_decay) + optimizer = torch.optim.AdamW(param_groups, betas=tuple(args.optim.betas)) + + epoch_num_batches = len(train_loader) + steps_per_epoch = epoch_num_batches // args.optim.accum_iter + total_steps = args.optim.epochs * steps_per_epoch + warmup_steps = args.optim.warmup_epochs * steps_per_epoch + lr_schedule = ut.WarmupThenCosine( + base_value=args.optim.lr, final_value=args.optim.min_lr, + total_iters=total_steps, warmup_iters=warmup_steps ) - print("Data pipeline setup complete.") - # --- 3. Model and Optimizer Setup --- - print("\n--- Setting up Model and Optimizer ---") - # Base Encoder (Backbone) - backbone = models_mae.mae_vit_small(img_size=(224, 560), in_chans=1) - - # Projection Head - projection_head = simclr_models.ProjectionHead( - input_dim=backbone_embed_dim, - hidden_dim=projection_hidden_dim, - output_dim=projection_output_dim + loss_scaler = ut.GradScaler() if args.amp and args.amp_dtype != 'bfloat16' else None + + ut.load_model(args, model_without_ddp, optimizer, loss_scaler) + + print(f"start training for {args.optim.epochs} epochs") + for epoch in range(args.start_epoch, args.optim.epochs): + if args.distributed: train_loader.sampler.set_epoch(epoch) + + train_stats = train_one_epoch(args, model, train_loader, optimizer, loss_scaler, lr_schedule, epoch, device) + + + if args.output_dir: + ut.save_model(args, epoch, model_without_ddp, optimizer, loss_scaler) + + +def create_data_loaders(args: DictConfig): + base_transform = flat_data.make_flat_transform( + img_size=args.data.img_size, + clip_vmax=args.data.get("clip_vmax"), + normalize=args.data.get("normalize"), + random_crop=args.data.get("random_crop", False), + crop_kwargs=args.data.get("crop_kwargs"), ) - - # Full SimCLR Model - model = simclr_models.SimCLRModel(backbone, projection_head).to(device) - - # Optimizer - optimizer = optim.Adam(model.parameters(), lr=learning_rate) - - # Loss Function - criterion = simclr_loss.NTXentLoss(temperature=temperature).to(device) - - print("Model, optimizer, and loss function setup complete.") - # --- 4. The Training Loop --- - print(f"\n--- Starting Training for {num_epochs} Epochs ---") + transform = SimCLRTransform(base_transform) + + data_loaders = {} + dataset_names = [args.train_dataset] + args.eval_datasets + for name in dataset_names: + config = args.datasets[name] + dataset = flat_data.FlatClipsDataset(root=config.root, transform=transform) + sampler = DistributedSampler(dataset, shuffle=config.shuffle) if args.distributed else None + + loader = flat_data.DataLoader( + dataset, batch_size=args.optim.batch_size, + collate_fn=simclr_collate, sampler=sampler, + shuffle=sampler is None and config.shuffle, + num_workers=args.num_workers, pin_memory=True, drop_last=True + ) + data_loaders[name] = loader - for epoch in range(num_epochs): - loop = tqdm(data_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False) - total_loss = 0.0 + train_loader = data_loaders.pop(args.train_dataset) + return train_loader, data_loaders + - for batch_view_1, batch_view_2 in loop: - view_1_images = batch_view_1['image'].to(device) - view_2_images = batch_view_2['image'].to(device) +def train_one_epoch(args, model, data_loader, optimizer, loss_scaler, lr_schedule, epoch, device): + # --- This is the training engine, adapted for SimCLR/SimSiam --- + model.train() + metric_logger = ut.MetricLogger(delimiter=" ") + header = f'Train: [{epoch}]' - optimizer.zero_grad() - z1, z2 = model(view_1_images, view_2_images) - loss = criterion(z1, z2) - loss.backward() - optimizer.step() + epoch_num_batches = len(data_loader) + steps_per_epoch = epoch_num_batches // args.optim.accum_iter - total_loss += loss.item() - loop.set_postfix(loss=loss.item()) + optimizer.zero_grad() - avg_loss = total_loss / len(data_loader) - print(f"Epoch [{epoch+1}/{num_epochs}] completed. Average Loss: {avg_loss:.4f}") + for batch_idx, (batch_view_1, batch_view_2) in enumerate(metric_logger.log_every(data_loader, 100, header)): + + global_step = epoch * steps_per_epoch + (batch_idx + 1) // args.optim.accum_iter + lr = lr_schedule[global_step] + need_update = (batch_idx + 1) % args.optim.accum_iter == 0 + if need_update: ut.update_lr(optimizer.param_groups, lr) + + view_1 = batch_view_1['image'].to(device, non_blocking=True) + view_2 = batch_view_2['image'].to(device, non_blocking=True) + + with torch.autocast(device_type=device.type, dtype=getattr(torch, args.amp_dtype), enabled=args.amp): + + outputs = model(view_1, view_2, mask_ratio=args.model.mask_ratio) + + if args.model.contrastive_mode == "simclr": + z1, z2 = outputs + loss = nt_xent_loss(z1, z2, temperature=args.model.get("temperature", 0.5), distributed=args.distributed) + elif args.model.contrastive_mode == "simsiam": + p1, z2, p2, z1 = outputs + loss = simsiam_loss(p1, z2, p2, z1) + else: + raise ValueError(f"Unknown contrastive mode: {args.model.contrastive_mode}") + + loss_value = loss.item() + if not math.isfinite(loss_value): raise RuntimeError(f"Loss is {loss_value}, stopping training") + + ut.backward_step(loss / args.optim.accum_iter, optimizer, scaler=loss_scaler, + need_update=need_update, max_norm=args.optim.get("clip_grad")) + + metric_logger.update(loss=loss_value) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + +@torch.no_grad() +def evaluate(model, data_loader, name, device, args): + model.eval() + print(f"--- Running evaluation on {name} ---") - print("\n--- Training Verification Complete ---") - print("Refactoring successful. The script ran correctly from its new location.") if __name__ == "__main__": - main() \ No newline at end of file + parser = argparse.ArgumentParser() + parser.add_argument("--cfg-path", type=str, required=True) + parser.add_argument("--overrides", type=str, default=None, nargs="+") + cli_args = parser.parse_args() + + cfg = OmegaConf.load(cli_args.cfg_path) + if cli_args.overrides: + cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(cli_args.overrides)) + + main(cfg) From ee4ff53977c12f170f14b8bb20ab805c3be22b61 Mon Sep 17 00:00:00 2001 From: Dojo2024 <119791514+Dojo2024@users.noreply.github.com> Date: Thu, 30 Oct 2025 23:52:37 +0530 Subject: [PATCH 5/7] Delete main_pretrain.py --- main_pretrain.py | 197 ----------------------------------------------- 1 file changed, 197 deletions(-) delete mode 100644 main_pretrain.py diff --git a/main_pretrain.py b/main_pretrain.py deleted file mode 100644 index bdefe91..0000000 --- a/main_pretrain.py +++ /dev/null @@ -1,197 +0,0 @@ -import argparse -import datetime -import json -import os -import random -import time -from pathlib import Path -import math - -import numpy as np -import torch -import torch.backends.cudnn as cudnn -import wandb -from omegaconf import DictConfig, OmegaConf -from torch.utils.data.distributed import DistributedSampler - -import flat_mae.utils as ut -import flat_mae.models_mae as models_mae - -import data.flat_data as flat_data - -from simclr.data import SimCLRTransform, simclr_collate -from simclr.models import ContrastiveModel -from simclr.loss import nt_xent_loss, simsiam_loss - -import flat_mae.models_mae as models_mae - -PROJECT = "fMRI-foundation-model" - - -BACKBONE_MODELS_DICT = models_mae.__dict__ - -def main(args: DictConfig): - ut.init_distributed_mode(args) - global_rank = ut.get_rank() - is_master = global_rank == 0 - world_size = ut.get_world_size() - device = torch.device(args.device) - ut.random_seed(args.seed, rank=global_rank) - - if args.name and not args.output_dir.endswith(args.name): - args.output_dir = f"{args.output_dir}/{args.name}" - output_dir = Path(args.output_dir) - - if is_master: - output_dir.mkdir(parents=True, exist_ok=True) - OmegaConf.save(args, output_dir / "config.yaml") - - ut.setup_for_distributed(log_path=output_dir / "log.txt") - print(f"pretraining with {args.model.contrastive_mode}") - print("config:", OmegaConf.to_yaml(args), sep="\n") - - train_loader, eval_loaders = create_data_loaders(args) - - print(f"Creating backbone: {args.model.backbone_name}") - backbone = BACKBONE_MODELS_DICT[args.model.backbone_name]( - img_size=args.data.img_size, - in_chans=args.data.in_chans, - **args.model.get("backbone_kwargs", {}), - ) - model = ContrastiveModel( - backbone=backbone, - mode=args.model.contrastive_mode, - embed_dim=args.model.backbone_kwargs.embed_dim, - model_kwargs=args.model.get("head_kwargs"), - ).to(device) - - model_without_ddp = model - if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) - model_without_ddp = model.module - - print("model:", model, sep="\n") - - total_batch_size = args.optim.batch_size * args.optim.accum_iter * world_size - if not args.optim.get("lr"): args.optim.lr = args.optim.base_lr * total_batch_size / 256 - - param_groups = ut.get_param_groups(model) - ut.update_lr(param_groups, args.optim.lr) - ut.update_wd(param_groups, args.optim.weight_decay) - optimizer = torch.optim.AdamW(param_groups, betas=tuple(args.optim.betas)) - - epoch_num_batches = len(train_loader) - steps_per_epoch = epoch_num_batches // args.optim.accum_iter - total_steps = args.optim.epochs * steps_per_epoch - warmup_steps = args.optim.warmup_epochs * steps_per_epoch - lr_schedule = ut.WarmupThenCosine( - base_value=args.optim.lr, final_value=args.optim.min_lr, - total_iters=total_steps, warmup_iters=warmup_steps - ) - - loss_scaler = ut.GradScaler() if args.amp and args.amp_dtype != 'bfloat16' else None - - ut.load_model(args, model_without_ddp, optimizer, loss_scaler) - - print(f"start training for {args.optim.epochs} epochs") - for epoch in range(args.start_epoch, args.optim.epochs): - if args.distributed: train_loader.sampler.set_epoch(epoch) - - train_stats = train_one_epoch(args, model, train_loader, optimizer, loss_scaler, lr_schedule, epoch, device) - - - if args.output_dir: - ut.save_model(args, epoch, model_without_ddp, optimizer, loss_scaler) - - -def create_data_loaders(args: DictConfig): - base_transform = flat_data.make_flat_transform( - img_size=args.data.img_size, - clip_vmax=args.data.get("clip_vmax"), - normalize=args.data.get("normalize"), - random_crop=args.data.get("random_crop", False), - crop_kwargs=args.data.get("crop_kwargs"), - ) - - transform = SimCLRTransform(base_transform) - - data_loaders = {} - dataset_names = [args.train_dataset] + args.eval_datasets - for name in dataset_names: - config = args.datasets[name] - dataset = flat_data.FlatClipsDataset(root=config.root, transform=transform) - sampler = DistributedSampler(dataset, shuffle=config.shuffle) if args.distributed else None - - loader = flat_data.DataLoader( - dataset, batch_size=args.optim.batch_size, - collate_fn=simclr_collate, sampler=sampler, - shuffle=sampler is None and config.shuffle, - num_workers=args.num_workers, pin_memory=True, drop_last=True - ) - data_loaders[name] = loader - - train_loader = data_loaders.pop(args.train_dataset) - return train_loader, data_loaders - - -def train_one_epoch(args, model, data_loader, optimizer, loss_scaler, lr_schedule, epoch, device): - # --- This is the training engine, adapted for SimCLR/SimSiam --- - model.train() - metric_logger = ut.MetricLogger(delimiter=" ") - header = f'Train: [{epoch}]' - - epoch_num_batches = len(data_loader) - steps_per_epoch = epoch_num_batches // args.optim.accum_iter - - optimizer.zero_grad() - - for batch_idx, (batch_view_1, batch_view_2) in enumerate(metric_logger.log_every(data_loader, 100, header)): - - global_step = epoch * steps_per_epoch + (batch_idx + 1) // args.optim.accum_iter - lr = lr_schedule[global_step] - need_update = (batch_idx + 1) % args.optim.accum_iter == 0 - if need_update: ut.update_lr(optimizer.param_groups, lr) - - view_1 = batch_view_1['image'].to(device, non_blocking=True) - view_2 = batch_view_2['image'].to(device, non_blocking=True) - - with torch.autocast(device_type=device.type, dtype=getattr(torch, args.amp_dtype), enabled=args.amp): - - outputs = model(view_1, view_2, mask_ratio=args.model.mask_ratio) - - if args.model.contrastive_mode == "simclr": - z1, z2 = outputs - loss = nt_xent_loss(z1, z2, temperature=args.model.get("temperature", 0.5), distributed=args.distributed) - elif args.model.contrastive_mode == "simsiam": - p1, z2, p2, z1 = outputs - loss = simsiam_loss(p1, z2, p2, z1) - else: - raise ValueError(f"Unknown contrastive mode: {args.model.contrastive_mode}") - - loss_value = loss.item() - if not math.isfinite(loss_value): raise RuntimeError(f"Loss is {loss_value}, stopping training") - - ut.backward_step(loss / args.optim.accum_iter, optimizer, scaler=loss_scaler, - need_update=need_update, max_norm=args.optim.get("clip_grad")) - - metric_logger.update(loss=loss_value) - - return {k: meter.global_avg for k, meter in metric_logger.meters.items()} - -@torch.no_grad() -def evaluate(model, data_loader, name, device, args): - model.eval() - print(f"--- Running evaluation on {name} ---") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--cfg-path", type=str, required=True) - parser.add_argument("--overrides", type=str, default=None, nargs="+") - cli_args = parser.parse_args() - - cfg = OmegaConf.load(cli_args.cfg_path) - if cli_args.overrides: - cfg = OmegaConf.merge(cfg, OmegaConf.from_dotlist(cli_args.overrides)) - - main(cfg) From 80b73827b289ad240c72825ee0090140e13af7e8 Mon Sep 17 00:00:00 2001 From: Dojo2024 Date: Fri, 31 Oct 2025 09:13:56 +0530 Subject: [PATCH 6/7] refactor(simclr): pre-training script and resolve runtime errors This commit. --- src/simclr/config/hcp_pretrain.yaml | 94 +++++++++++++++++++++++++++++ src/simclr/main_pretrain.py | 92 ++++++++++++++++++++++------ src/simclr/models.py | 2 +- 3 files changed, 167 insertions(+), 21 deletions(-) create mode 100644 src/simclr/config/hcp_pretrain.yaml diff --git a/src/simclr/config/hcp_pretrain.yaml b/src/simclr/config/hcp_pretrain.yaml new file mode 100644 index 0000000..19c6208 --- /dev/null +++ b/src/simclr/config/hcp_pretrain.yaml @@ -0,0 +1,94 @@ +# In src/simclr/config/hcp_pretrain.yaml + +# A professional config for pre-training SimCLR on the full HCP dataset. +# Adapted from the flat_mae/config/default_pretrain.yaml. + +name: "simclr_pretrain_hcp_100_shards" +notes: "Pre-training run of SimCLR with a ViT-Small backbone on 100 local shards of the HCP-YA dataset." +output_dir: "checkpoints/${name}" + +# --- Data Configuration --- +# Inherited from the MAE config +data: + img_size: [224, 560] + in_chans: 1 + num_frames: 16 # <-- ADD THIS LINE + clip_vmax: 3.0 + normalize: 'frame' + random_crop: true + crop_kwargs: + scale: [0.8, 1.0] + ratio: [2.4, 2.6] + +# --- Datasets --- +# Using the full dataset definitions from the MAE config +train_dataset: 'hcp-flat' + +eval_datasets: ['hcp-val', 'nsd-val'] # Evaluate on both in-domain and out-of-domain data + +datasets: + hcp-flat: + type: flat-wds + # This now uses brace expansion to read all 100 local files. + url: "datasets/hcp-flat/hcp-flat_{0000..0099}.tar" + clipping: random + clipping_kwargs: {oversample: 4.0} # Oversample to get more data per epoch + shuffle: true + buffer_size: 1000 + # Update samples per epoch: ~100 samples/shard * 100 shards = 10000 + samples_per_epoch: 10000 + + hcp-val: + type: flat-clips + root: "/teamspace/gcs_folders/share/fmri-fm/datasets/flat-clips/hcp-val-clips-16t" # Assumes this path on Lightning + shuffle: false + + nsd-val: + type: flat-clips + root: "/teamspace/gcs_folders/share/fmri-fm/datasets/flat-clips/nsd-subj01-clips-16t" # Assumes this path on Lightning + shuffle: false + +# --- Model Configuration (Specific to our ContrastiveModel) --- +model: + contrastive_mode: "simclr" # Can be switched to "simsiam" + mask_ratio: 0.0 # IMPORTANT: Set to 0.0 for standard SimCLR, which doesn't use masking in the encoder + temperature: 0.1 # A common temperature for SimCLR + + # Define the backbone to use + backbone_name: "mae_vit_small" + backbone_kwargs: {} + # Add other ViT-Base specific args if needed + + # Define the projection head architecture + head_kwargs: + simclr_head_kwargs: + hidden_dim: 4096 # Example hidden dimension + out_dim: 128 # Standard output dimension for SimCLR + +# --- Optimization Configuration --- +optim: + epochs: 100 + batch_size: 128 # A realistic batch size for a single large GPU + accum_iter: 1 + base_lr: 1e-3 + lr: null # Let the script calculate the final LR based on batch size + min_lr: 0.0 + warmup_epochs: 10 + weight_decay: 0.05 + betas: [0.9, 0.95] + +# --- General Settings --- +num_workers: 8 +amp: true +amp_dtype: float16 # <-- ADD THIS LINE +seed: 42 +start_epoch: 0 +checkpoint_period: 10 # Save a checkpoint every 10 epochs +device: 'cuda' +ckpt: null # <-- ADD THIS LINE + + +# --- Logging --- +wandb: true +wandb_entity: "medarc" # Use the team's entity +wandb_project: "fMRI-foundation-model" \ No newline at end of file diff --git a/src/simclr/main_pretrain.py b/src/simclr/main_pretrain.py index 18ba4f0..90bedc6 100644 --- a/src/simclr/main_pretrain.py +++ b/src/simclr/main_pretrain.py @@ -6,6 +6,7 @@ import time from pathlib import Path import math +from webdataset import WebLoader import numpy as np import torch @@ -27,10 +28,10 @@ PROJECT = "fMRI-foundation-model" - BACKBONE_MODELS_DICT = models_mae.__dict__ def main(args: DictConfig): + ut.init_distributed_mode(args) global_rank = ut.get_rank() is_master = global_rank == 0 @@ -50,7 +51,12 @@ def main(args: DictConfig): print(f"pretraining with {args.model.contrastive_mode}") print("config:", OmegaConf.to_yaml(args), sep="\n") - train_loader, eval_loaders = create_data_loaders(args) + train_loader, eval_loaders, samplers = make_data_loaders(args) + + + + + print(f"Creating backbone: {args.model.backbone_name}") backbone = BACKBONE_MODELS_DICT[args.model.backbone_name]( @@ -58,10 +64,16 @@ def main(args: DictConfig): in_chans=args.data.in_chans, **args.model.get("backbone_kwargs", {}), ) + + + backbone_embed_dim = backbone.encoder.patch_embed.out_features + print(f"Backbone created with embedding dimension: {backbone_embed_dim}") + + model = ContrastiveModel( backbone=backbone, mode=args.model.contrastive_mode, - embed_dim=args.model.backbone_kwargs.embed_dim, + embed_dim=backbone_embed_dim, model_kwargs=args.model.get("head_kwargs"), ).to(device) @@ -104,7 +116,8 @@ def main(args: DictConfig): ut.save_model(args, epoch, model_without_ddp, optimizer, loss_scaler) -def create_data_loaders(args: DictConfig): + +def make_data_loaders(args: DictConfig): base_transform = flat_data.make_flat_transform( img_size=args.data.img_size, clip_vmax=args.data.get("clip_vmax"), @@ -112,30 +125,69 @@ def create_data_loaders(args: DictConfig): random_crop=args.data.get("random_crop", False), crop_kwargs=args.data.get("crop_kwargs"), ) - transform = SimCLRTransform(base_transform) data_loaders = {} - dataset_names = [args.train_dataset] + args.eval_datasets - for name in dataset_names: - config = args.datasets[name] - dataset = flat_data.FlatClipsDataset(root=config.root, transform=transform) - sampler = DistributedSampler(dataset, shuffle=config.shuffle) if args.distributed else None - - loader = flat_data.DataLoader( - dataset, batch_size=args.optim.batch_size, - collate_fn=simclr_collate, sampler=sampler, - shuffle=sampler is None and config.shuffle, - num_workers=args.num_workers, pin_memory=True, drop_last=True - ) - data_loaders[name] = loader + samplers = {} + + world_size = ut.get_world_size() + + all_dataset_names = [args.train_dataset] + args.eval_datasets + + for dataset_name in all_dataset_names: + if not dataset_name: continue + + dataset_config = args.datasets[dataset_name].copy() + print(f"loading dataset: {dataset_name}\n\n{OmegaConf.to_yaml(dataset_config)}") + + dataset_type = dataset_config.pop("type") + + if dataset_type == "flat-wds": + samples_per_epoch = dataset_config.pop("samples_per_epoch") + dataset = flat_data.make_flat_wds_dataset( + num_frames=args.data.num_frames, **dataset_config + ) + dataset = dataset.map(transform) + sampler = None + shuffle = False + + elif dataset_type == "flat-clips": + dataset = flat_data.FlatClipsDataset(dataset_config.root, transform=transform) + samples_per_epoch = len(dataset) + sampler = DistributedSampler(dataset, shuffle=dataset_config.shuffle) if args.distributed else None + shuffle = sampler is None and dataset_config.shuffle + + else: + raise ValueError(f"Unknown dataset type {dataset_type}.") + + + collate_fn = simclr_collate + + loader = WebLoader( + dataset, + batch_size=args.optim.batch_size, + collate_fn=collate_fn, + sampler=sampler, + shuffle=shuffle, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True, + ) + + num_batches = samples_per_epoch // (world_size * args.optim.batch_size) + loader = loader.with_epoch(num_batches) + loader = loader.with_length(num_batches, silent=True) + + data_loaders[dataset_name] = loader + samplers[dataset_name] = sampler + train_loader = data_loaders.pop(args.train_dataset) - return train_loader, data_loaders + eval_loaders = data_loaders + return train_loader, eval_loaders, samplers def train_one_epoch(args, model, data_loader, optimizer, loss_scaler, lr_schedule, epoch, device): - # --- This is the training engine, adapted for SimCLR/SimSiam --- model.train() metric_logger = ut.MetricLogger(delimiter=" ") header = f'Train: [{epoch}]' diff --git a/src/simclr/models.py b/src/simclr/models.py index 8a39da9..1e05bc4 100644 --- a/src/simclr/models.py +++ b/src/simclr/models.py @@ -18,7 +18,7 @@ ) class ContrastiveModel(nn.Module): - def __init__(self, backbone: MaskedViT, mode: str = "simclr", embed_dim: int = 384): + def __init__(self, backbone: MaskedViT, mode: str = "simclr", embed_dim: int = 384, model_kwargs: dict = None): super().__init__() if mode not in ["simclr", "simsiam"]: raise ValueError(f"Invalid contrastive mode: {mode}") From 8b7d00fb568107c44b9cc8e7b3e23967b9c908d6 Mon Sep 17 00:00:00 2001 From: Dojo2024 Date: Fri, 31 Oct 2025 19:47:33 +0530 Subject: [PATCH 7/7] feat: Implement and debug SimCLR/SimSiam pre-training script --- src/simclr/config/hcp_pretrain.yaml | 135 ++++++++++++++++------------ 1 file changed, 80 insertions(+), 55 deletions(-) diff --git a/src/simclr/config/hcp_pretrain.yaml b/src/simclr/config/hcp_pretrain.yaml index 19c6208..7cc9668 100644 --- a/src/simclr/config/hcp_pretrain.yaml +++ b/src/simclr/config/hcp_pretrain.yaml @@ -1,94 +1,119 @@ -# In src/simclr/config/hcp_pretrain.yaml +# Name of the run. For output directory base name and wandb. +name: pretrain_simclr_vit_base -# A professional config for pre-training SimCLR on the full HCP dataset. -# Adapted from the flat_mae/config/default_pretrain.yaml. +# Description of the run. Goes in wandb notes. +notes: "SimCLR pre-training with a ViT-Base backbone on the HCP dataset." -name: "simclr_pretrain_hcp_100_shards" -notes: "Pre-training run of SimCLR with a ViT-Small backbone on 100 local shards of the HCP-YA dataset." -output_dir: "checkpoints/${name}" +# Root output directory +output_dir: ./checkpoints -# --- Data Configuration --- -# Inherited from the MAE config +# How often to print logs to the console during training/evaluation. +print_freq: 100 + +# --- Data Config --- +# All parameters related to data shape and transformations. data: - img_size: [224, 560] + # The 2D spatial size of the input images [height, width]. + # The model gets the time dimension from `num_frames`. + img_size: [224, 560] in_chans: 1 - num_frames: 16 # <-- ADD THIS LINE + patch_size: 16 + num_frames: 16 + t_patch_size: 16 clip_vmax: 3.0 - normalize: 'frame' - random_crop: true + normalize: frame + random_crop: false crop_kwargs: - scale: [0.8, 1.0] - ratio: [2.4, 2.6] - -# --- Datasets --- -# Using the full dataset definitions from the MAE config -train_dataset: 'hcp-flat' + scale: [0.9, 1.0] + ratio: [2.5, 2.5] + interpolation: 3 -eval_datasets: ['hcp-val', 'nsd-val'] # Evaluate on both in-domain and out-of-domain data +# --- Model Config --- +model: + contrastive_mode: simclr + backbone_name: mae_vit_base + mask_ratio: 0.9 + temperature: 0.1 + + # Arguments passed to the backbone model constructor. + # Architectural details like embed_dim are set by the `backbone_name` preset. + backbone_kwargs: + pos_embed: sep + class_token: true + drop_path_rate: 0.0 + + # Arguments passed to the projection/prediction head constructors. + head_kwargs: + hidden_dim: 2048 + out_dim: 128 +# --- Datasets --- +# Replace the placeholder paths with the actual locations of your datasets. datasets: - hcp-flat: + hcp-train: type: flat-wds - # This now uses brace expansion to read all 100 local files. - url: "datasets/hcp-flat/hcp-flat_{0000..0099}.tar" + url: "path/to/your/hcp-flat/hcp-flat_{0000..1799}.tar" clipping: random - clipping_kwargs: {oversample: 4.0} # Oversample to get more data per epoch + clipping_kwargs: {oversample: 4.0} shuffle: true buffer_size: 1000 - # Update samples per epoch: ~100 samples/shard * 100 shards = 10000 - samples_per_epoch: 10000 + samples_per_epoch: 200000 + + hcp-train-subset: + type: flat-clips + root: "path/to/your/flat-clips/hcp-train-clips-16t" + shuffle: false hcp-val: type: flat-clips - root: "/teamspace/gcs_folders/share/fmri-fm/datasets/flat-clips/hcp-val-clips-16t" # Assumes this path on Lightning + root: "path/to/your/flat-clips/hcp-val-clips-16t" shuffle: false nsd-val: type: flat-clips - root: "/teamspace/gcs_folders/share/fmri-fm/datasets/flat-clips/nsd-subj01-clips-16t" # Assumes this path on Lightning + root: "path/to/your/flat-clips/nsd-subj01-clips-16t" shuffle: false -# --- Model Configuration (Specific to our ContrastiveModel) --- -model: - contrastive_mode: "simclr" # Can be switched to "simsiam" - mask_ratio: 0.0 # IMPORTANT: Set to 0.0 for standard SimCLR, which doesn't use masking in the encoder - temperature: 0.1 # A common temperature for SimCLR - - # Define the backbone to use - backbone_name: "mae_vit_small" - backbone_kwargs: {} - # Add other ViT-Base specific args if needed +# Which datasets to use for training and evaluation. +train_dataset: hcp-train +eval_datasets: + - hcp-val + - nsd-val - # Define the projection head architecture - head_kwargs: - simclr_head_kwargs: - hidden_dim: 4096 # Example hidden dimension - out_dim: 128 # Standard output dimension for SimCLR +# --- Data Loader --- +num_workers: 8 -# --- Optimization Configuration --- +# --- Optimization --- optim: epochs: 100 - batch_size: 128 # A realistic batch size for a single large GPU + batch_size: 32 accum_iter: 1 base_lr: 1e-3 - lr: null # Let the script calculate the final LR based on batch size min_lr: 0.0 - warmup_epochs: 10 + warmup_epochs: 5 + start_warmup_lr: 1e-6 weight_decay: 0.05 betas: [0.9, 0.95] + clip_grad: 1.0 -# --- General Settings --- -num_workers: 8 +# --- Training Settings --- amp: true -amp_dtype: float16 # <-- ADD THIS LINE -seed: 42 +amp_dtype: float16 + +# --- Checkpointing --- +ckpt: null +resume: true +auto_resume: true start_epoch: 0 -checkpoint_period: 10 # Save a checkpoint every 10 epochs -device: 'cuda' -ckpt: null # <-- ADD THIS LINE +max_checkpoints: 1 +checkpoint_period: 1 +# --- Misc --- +device: cuda +seed: 7338 +debug: false # --- Logging --- wandb: true -wandb_entity: "medarc" # Use the team's entity -wandb_project: "fMRI-foundation-model" \ No newline at end of file +wandb_entity: null +wandb_project: fMRI-foundation-model \ No newline at end of file