diff --git a/.gitignore b/.gitignore index 723ef36..79c2d95 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -.idea \ No newline at end of file +.idea +__pycache__ +deploy.sh \ No newline at end of file diff --git a/Dockerfile b/Dockerfile.python similarity index 100% rename from Dockerfile rename to Dockerfile.python index 99bced7..a46839d 100644 --- a/Dockerfile +++ b/Dockerfile.python @@ -23,10 +23,6 @@ RUN mkdir /data RUN chmod 775 /data RUN chown -R :1337 /data -COPY src /src -RUN chmod 775 /src -RUN chown -R :1337 /src - RUN pip3 install transformers RUN pip3 install torchcodec RUN pip3 install h5py nilearn nibabel @@ -40,6 +36,10 @@ RUN pip3 install wandb matplotlib scikit-learn RUN pip3 freeze > pip-freeze.txt +COPY src /src +RUN chmod 775 /src +RUN chown -R :1337 /src + WORKDIR /src ENTRYPOINT ["python3"] \ No newline at end of file diff --git a/Dockerfile.sweep b/Dockerfile.sweep new file mode 100644 index 0000000..2b7be2d --- /dev/null +++ b/Dockerfile.sweep @@ -0,0 +1,2 @@ +FROM eidos-service.di.unito.it/barbano/algonauts:latest +ENTRYPOINT ["/opt/conda/bin/wandb", "agent"] \ No newline at end of file diff --git a/build.sh b/build.sh index 024145c..4a7f41a 100755 --- a/build.sh +++ b/build.sh @@ -1,5 +1,8 @@ #!/bin/sh # Script to build & deploy your docker image -docker build -t eidos-service.di.unito.it/barbano/algonauts:latest . -f Dockerfile -docker push eidos-service.di.unito.it/barbano/algonauts:latest \ No newline at end of file +docker build -t eidos-service.di.unito.it/barbano/algonauts:latest . -f Dockerfile.python +docker push eidos-service.di.unito.it/barbano/algonauts:latest + +docker build -t eidos-service.di.unito.it/barbano/algonauts:sweep . -f Dockerfile.sweep +docker push eidos-service.di.unito.it/barbano/algonauts:sweep \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 72ce6fb..c6c5b8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,41 +1,90 @@ +absl-py==2.2.2 +annotated-types==0.7.0 +awscli==1.38.34 +botocore==1.37.34 certifi==2025.1.31 charset-normalizer==3.4.1 -filelock==3.17.0 -fsspec==2025.3.0 +click==8.1.8 +colorama==0.4.6 +contourpy==1.3.2 +cycler==0.12.1 +docker-pycreds==0.4.0 +docutils==0.19 +filelock==3.18.0 +fonttools==4.57.0 +fsspec==2025.3.2 +gitdb==4.0.12 +GitPython==3.1.44 +grpcio==1.71.0 h5py==3.13.0 -huggingface-hub==0.29.3 +huggingface-hub==0.30.2 idna==3.10 +importlib_resources==6.5.2 Jinja2==3.1.6 +jmespath==1.0.1 joblib==1.4.2 -lxml==5.3.1 +kiwisolver==1.4.8 +lxml==5.3.2 +Markdown==3.8 MarkupSafe==3.0.2 +matplotlib==3.10.1 mpmath==1.3.0 networkx==3.4.2 nibabel==5.3.2 nilearn==0.11.1 -numpy==2.2.3 +numpy==2.2.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 packaging==24.2 pandas==2.2.3 -pillow==11.1.0 +pillow==11.2.1 +platformdirs==4.3.7 +protobuf==5.29.4 +psutil==7.0.0 +pyasn1==0.6.1 +pydantic==2.11.3 +pydantic_core==2.33.1 +pyparsing==3.2.3 python-dateutil==2.9.0.post0 -pytz==2025.1 +pytz==2025.2 PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 +rsa==4.7.2 +s3transfer==0.11.4 safetensors==0.5.3 scikit-learn==1.6.1 scipy==1.15.2 -setuptools==76.0.0 +sentry-sdk==2.26.1 +setproctitle==1.3.5 six==1.17.0 +smmap==5.0.2 sympy==1.13.1 -threadpoolctl==3.5.0 -tokenizers==0.21.0 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +threadpoolctl==3.6.0 +tokenizers==0.21.1 torch==2.6.0 -torchaudio==2.6.0 TorchCodec==0.2.1 torchvision==0.21.0 tqdm==4.67.1 -transformers==4.49.0 -typing_extensions==4.12.2 -tzdata==2025.1 -urllib3==2.3.0 +transformers==4.51.3 +triton==3.2.0 +typing-inspection==0.4.0 +typing_extensions==4.13.2 +tzdata==2025.2 +urllib3==2.4.0 +wandb==0.19.9 +Werkzeug==3.1.3 +natsort \ No newline at end of file diff --git a/src/data/friends.py b/src/data/friends.py index e3c65cc..de6e19f 100644 --- a/src/data/friends.py +++ b/src/data/friends.py @@ -6,6 +6,8 @@ import os import torch import torchvision +import logging +import numpy as np from torch.utils.data.dataset import Dataset from torchcodec.decoders import VideoDecoder @@ -14,14 +16,66 @@ from brainannlib.algonauts_funcs import load_fmri +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + + +def load_video_chunk(path, sample_index, tr=1.49, target_video_len=32, + transform=None, stimulus_window=1, hrf_delay=0, device="cpu"): + + # compute the starting index of the current stimulus, + # automatically adjusts for out of bounds windows + # as long as target_video_len < tr*fps (~44) + start_index = max(0, sample_index - stimulus_window - hrf_delay + 1) + start_t = start_index * tr + end_t = (start_index + 1) * tr + + decoder = VideoDecoder(path, device=device) + if end_t > decoder.metadata.duration_seconds: + end_t = decoder.metadata.duration_seconds + + chunk = decoder.get_frames_played_in_range(start_t, end_t).data + + # if the chunk is shorter than self.target_video_len, pad it + if len(chunk) < target_video_len: + # Pad with last frame + chunk = torch.cat([chunk, chunk[-1].unsqueeze(0).expand(target_video_len - len(chunk), -1, -1, -1)]) + + # if the chunk is longer than self.target_video_len, take N frames uniformly + if len(chunk) > target_video_len: + idx = np.linspace(0, len(chunk) - 1, target_video_len).astype(int) + chunk = chunk[idx] + + if transform is not None: + lst = torch.split(chunk, 1, 0) + lst = [l[0] for l in lst] + chunk = transform(lst, return_tensors="pt") + + return chunk + + class FriendsDataset(Dataset): - def __init__(self, root, modalities=["fmri", "video"], image_transform=None, subjects=[1,2,3,5], timesample=1, - target_video_len=32): + def __init__(self, root, modalities=["fmri", "video"], image_transform=None, tr=1.49, seasons=[1,2,3,4,5,6], + subjects=[1,2,3,5], timesample=1, target_video_len=32, stimulus_window=1, hrf_delay=0, + downsampled=False, fmri_window=1): self.root = root self.modalities = modalities + self.seasons = seasons self.image_transform = image_transform self.timesample = timesample self.target_video_len = target_video_len + self.downsampled = downsampled + self.tr = tr + self.stimulus_window = stimulus_window + self.hrf_delay = hrf_delay + self.subjects = subjects + + if fmri_window % 2 == 0: + raise ValueError("fmri_window must be odd, got {}".format(fmri_window)) + self.fmri_window = fmri_window // 2 + # List all .h5 files in the root directory self.fmris = [] @@ -37,7 +91,11 @@ def __init__(self, root, modalities=["fmri", "video"], image_transform=None, sub curr_fmri = fmri[key] fmri_samples = curr_fmri.shape[0] - self.fmris.append({"movie": key, "fmri": curr_fmri, "n_samples": fmri_samples}) + season = int(key[1:3]) + if season not in self.seasons: + continue + + self.fmris.append({"movie": key, "fmri": curr_fmri, "n_samples": fmri_samples, "subject": subject}) # Map all indexes between (tot_samples, tot_samples+fmri_samples) to current fmri index self.scan_idx_map.extend([last_idx] * fmri_samples) @@ -47,13 +105,14 @@ def __init__(self, root, modalities=["fmri", "video"], image_transform=None, sub print("Loaded", len(self.fmris), "fmri files, total samples:", self.tot_samples) - def load_movie(self, movie_name) -> VideoDecoder: - movie_folder = os.path.join(self.root, "algonauts_2025.competitors", "stimuli", "movies", "friends") + def get_movie_path(self, movie_name) -> VideoDecoder: + movie_folder = os.path.join(self.root, "algonauts_2025.competitors/stimuli/movies/friends") + if self.downsampled: + movie_folder = os.path.join(self.root, "algonauts_2025.competitors/stimuli/movies_224/friends") + season = int(movie_name[1:3]) episode_path = os.path.join(movie_folder, f"s{season}", f"friends_{movie_name}.mkv") - - decoder = VideoDecoder(episode_path, device="cpu") - return decoder + return episode_path def __len__(self): return self.tot_samples @@ -63,44 +122,129 @@ def __getitem__(self, idx): sample_index = self.time_idx_map[idx] fmri = self.fmris[fmri_index] - fmri_data = fmri["fmri"][sample_index] + fmri_data = None + + if self.fmri_window == 0: + fmri_data = torch.tensor(fmri["fmri"][sample_index]) + + elif sample_index - self.fmri_window - 1 >= 0 and sample_index + self.fmri_window < fmri["n_samples"]: + fmri_data = torch.tensor(fmri["fmri"][sample_index - self.fmri_window - 1:sample_index + self.fmri_window]) + + elif sample_index - self.fmri_window - 1 < 0: + fmri_data = torch.tensor(fmri["fmri"][:sample_index + self.fmri_window]) + # print("Padding with first frame:", fmri_data.shape, fmri_data[0].shape, sample_index, self.fmri_window) + # pad with the first frame at the beginning to reach self.fmri_window + pad = fmri_data[0].unsqueeze(0).repeat(self.fmri_window + 1 - sample_index, 1) + fmri_data = torch.cat([pad, fmri_data], dim=0) + + elif sample_index + self.fmri_window >= fmri["n_samples"]: + fmri_data = torch.tensor(fmri["fmri"][sample_index - self.fmri_window - 1:]) + # print("Padding with last frame:", fmri_data.shape, fmri_data[0].shape, sample_index, self.fmri_window) + # pad with the last frame at the end to reach self.fmri_window + pad = fmri_data[-1].unsqueeze(0).repeat(sample_index + self.fmri_window - fmri["n_samples"], 1) + fmri_data = torch.cat([fmri_data, pad], dim=0) + + assert fmri_data is not None, f"fmri_data is None for sample {sample_index} in movie {fmri['movie']}" + + movie_path = self.get_movie_path(fmri["movie"]) + movie_chunk = load_video_chunk(movie_path, sample_index, tr=self.tr, + target_video_len=self.target_video_len, + transform=self.image_transform, + stimulus_window=self.stimulus_window, + hrf_delay=self.hrf_delay) + return movie_chunk, fmri_data, float(fmri["subject"]) + + +class FriendsFeatureDataset(Dataset): + def __init__(self, root, features_root, subjects=[1,2,3,5], seasons=[1,2,3,4,5,6], + stimulus_window=3): + self.root = root + self.features_root = features_root + self.seasons = seasons + self.subjects = subjects + self.stimulus_window = stimulus_window + + # List all .h5 files in the root directory + self.fmris = [] + self.tot_samples = 0 + self.scan_idx_map = [] + self.time_idx_map = [] + + last_idx = 0 + for subject in subjects: + fmri = load_fmri(root, subject, targets=["friends"]) - video = self.load_movie(fmri["movie"]) - n_frames = len(video) - n_samples = fmri["n_samples"] + for key in fmri.keys(): + curr_fmri = fmri[key] + fmri_samples = curr_fmri.shape[0] - window_length = n_frames // n_samples - start_frame = sample_index * window_length - end_frame = start_frame + window_length + season = int(key[1:3]) + if season not in self.seasons: + continue - # Ensure that the end frame does not exceed the number of frames - if end_frame > n_frames: - end_frame = n_frames - # print("Adjusted end frame: ", end_frame) + self.fmris.append({"movie": key, "fmri": curr_fmri, "n_samples": fmri_samples, "subject": subject}) - video_data = video[start_frame:end_frame:self.timesample] # TxCxHxW + # Map all indexes between (tot_samples, tot_samples+fmri_samples) to current fmri index + self.scan_idx_map.extend([last_idx] * fmri_samples) + self.time_idx_map.extend(list(range(fmri_samples))) + self.tot_samples += fmri_samples + last_idx += 1 + + print("Loaded", len(self.fmris), "fmri files, total samples:", self.tot_samples) - # If the video length is less than the target length, pad it - if len(video_data) < self.target_video_len: - video_data = torch.cat([video_data, video_data[-1].unsqueeze(0).expand(self.target_video_len - len(video_data), -1, -1, -1)]) + def load_movie_features(self, movie_name) -> torch.Tensor: + movie_folder = os.path.join(self.features_root, "friends") + + season = int(movie_name[1:3]) + episode_path = os.path.join(movie_folder, f"s{season}", f"friends_{movie_name}.pth") + + features = torch.load(episode_path, map_location="cpu") + return features + + def __len__(self): + return self.tot_samples + + def __getitem__(self, idx): + fmri_index = self.scan_idx_map[idx] + sample_index = self.time_idx_map[idx] + + fmri = self.fmris[fmri_index] - # If the video length is greater than the target length, truncate it (from the end) - elif len(video_data) > self.target_video_len: - video_data = video_data[-self.target_video_len:] + if sample_index >= fmri["n_samples"]: + fmri_data = fmri["fmri"][-1] + else: + fmri_data = fmri["fmri"][sample_index] - if self.image_transform is not None: - lst = torch.split(video_data, 1, 0) - lst = [l[0] for l in lst] - video_data = self.image_transform(lst, return_tensors="pt") + features = self.load_movie_features(fmri["movie"]) + # logging.info(f"Loaded features for {fmri['movie']} with shape {features.shape} for subject {fmri['subject']}") - return video_data, fmri_data + # assert len(features) == fmri["n_samples"], f"Features length {len(features)} does not match fmri samples {fmri['n_samples']}" + # get the features for the current sample in the range sample_index-self.stimulus_window:sample_index + if sample_index - self.stimulus_window < 0: + # pad with the first sample to reach self.stimulus_window + pad = torch.zeros(self.stimulus_window - sample_index, features.shape[1], device=features.device) + features = torch.cat([pad, features[:sample_index]], dim=0) + elif sample_index >= features.shape[0]: + features = features[-1 - self.stimulus_window:-1] + else: + features = features[sample_index - self.stimulus_window:sample_index] + return features.flatten(), fmri_data +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default="/home/barbano/data") + parser.add_argument('--features_dir', type=str) + args = parser.parse_args() + dataset = FriendsFeatureDataset(root=args.data_dir, features_root=args.features_dir) + sample = dataset[0] + print(sample[0].shape, sample[1].shape) diff --git a/src/models/__init__.py b/src/models/__init__.py index 74c9774..047565f 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -2,4 +2,6 @@ Author: Carlo Alberto Barbano Date: 12/04/25 """ -from . import vivit \ No newline at end of file +from . import vivit +from . import videomae +from . import predictors \ No newline at end of file diff --git a/src/models/predictors.py b/src/models/predictors.py new file mode 100644 index 0000000..67e64f2 --- /dev/null +++ b/src/models/predictors.py @@ -0,0 +1,42 @@ +""" +Author: Carlo Alberto Barbano +Date: 21/04/25 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from scipy.stats import pearsonr + + +class MLPSmall(nn.Module): + def __init__(self, input_dim, output_dim, dropout=0.0): + super().__init__() + + self.mlp = nn.Sequential( + nn.Dropout(dropout), + nn.BatchNorm1d(input_dim), + nn.Linear(input_dim, input_dim // 2), + nn.ReLU(), + nn.BatchNorm1d(input_dim // 2), + nn.Linear(input_dim // 2, output_dim), + ) + + def forward(self, x, y=None): + x = self.mlp(x) + + if y is None: + return x + + loss = F.l1_loss(x, y) + + # compute average correlation of minibatch + x_ = x.detach().cpu().numpy() + y_ = y.detach().cpu().numpy() + + # r = [] + # for i in range(x_.shape[0]): + # r.append(pearsonr(x_[i], y_[i])[0]) + # r = torch.tensor(r).mean() + + return loss, x \ No newline at end of file diff --git a/src/models/videomae.py b/src/models/videomae.py new file mode 100644 index 0000000..6de705a --- /dev/null +++ b/src/models/videomae.py @@ -0,0 +1,61 @@ +""" +Author: Carlo Alberto Barbano +Date: 02/05/25 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import VideoMAEModel, AutoImageProcessor + + +_SUBJECT_EMBEDDING_SIZE = 10 + +class VideoMAERegression(nn.Module): + def __init__(self, pretrained="MCG-NJU/videomae-base", n_parcels=1000, + torch_dtype=torch.float32, criterion="mae", freeze_encoder=False, use_subject_id=False): + super().__init__() + + self.processor = AutoImageProcessor.from_pretrained(pretrained) + self.video_encoder = VideoMAEModel.from_pretrained( + pretrained, + torch_dtype=torch_dtype + ) + self.video_projection = nn.Linear(768 + _SUBJECT_EMBEDDING_SIZE if use_subject_id else 0, n_parcels, bias=True) + self.criterion = criterion + + if freeze_encoder: + for param in self.video_encoder.parameters(): + param.requires_grad = False + + self.subject_embedding = nn.Linear(1, _SUBJECT_EMBEDDING_SIZE) # nn.Embedding(1, _SUBJECT_EMBEDDING_SIZE) + + def image_processor(self): + return self.processor + + def encode_video(self, video): + return self.video_encoder(**video)[0][:, 0, :] + + def predict_activation(self, video_features, subject_id=None): + if subject_id is not None: + subject_embedding = self.subject_embedding(subject_id[:, None]) + video_features = torch.cat((video_features, subject_embedding), dim=1) + + logits = self.video_projection(video_features) + return logits + + + def forward(self, video, fmri, subject_id=None): + video_features = self.encode_video(video) + logits = self.predict_activation(video_features, subject_id) + + if len(fmri.shape) > 2: # fmris are stacked as [bsz, 1, 1000] by dataloaders + fmri = fmri.view(fmri.shape[0], -1) + + if self.criterion == "mae": + loss = F.l1_loss(logits, fmri) + elif self.criterion == "mse": + loss = F.mse_loss(logits, fmri) + + return loss, logits + diff --git a/src/models/vivit.py b/src/models/vivit.py index d935eb0..d2fd805 100644 --- a/src/models/vivit.py +++ b/src/models/vivit.py @@ -10,9 +10,48 @@ from transformers import VivitModel, VivitImageProcessor +class VivitRegression(nn.Module): + def __init__(self, vivit_pretrained="google/vivit-b-16x2-kinetics400", n_parcels=1000, + torch_dtype=torch.float32, criterion="mae", freeze_encoder=False): + super().__init__() + + self.vivit_processor = VivitImageProcessor.from_pretrained(vivit_pretrained) + self.video_encoder = VivitModel.from_pretrained( + vivit_pretrained, + attn_implementation="sdpa", + torch_dtype=torch_dtype + ) + self.video_projection = nn.Linear(768, n_parcels, bias=True) + self.criterion = criterion + + if freeze_encoder: + for param in self.video_encoder.parameters(): + param.requires_grad = False + + def image_processor(self): + return self.vivit_processor + + def encode_video(self, video): + return self.video_encoder(**video)[0][:, 0, :] + + def forward(self, video, fmri): + video_features = self.encode_video(video) + logits = self.video_projection(video_features) + + if len(fmri.shape) > 2: # fmris are stacked as [bsz, 1, 1000] by dataloaders + fmri = fmri.view(fmri.shape[0], -1) + + if self.criterion == "mae": + loss = F.l1_loss(logits, fmri) + elif self.criterion == "mse": + loss = F.mse_loss(logits, fmri) + + return loss, logits + + class VivitMLPContrastive(nn.Module): def __init__(self, vivit_pretrained="google/vivit-b-16x2-kinetics400", embed_dim=128, - temperature=1.0, torch_dtype=torch.float32): + temperature=1.0, fmri_window=1, torch_dtype=torch.float32, freeze_encoder=False): super().__init__() self.vivit_processor = VivitImageProcessor.from_pretrained(vivit_pretrained) @@ -24,8 +63,8 @@ def __init__(self, vivit_pretrained="google/vivit-b-16x2-kinetics400", embed_dim self.video_projection = nn.Linear(768, embed_dim, bias=False) self.fmri_encoder = torchvision.ops.MLP( - in_channels=1000, - hidden_channels=[512, embed_dim*2, embed_dim], + in_channels=1000*fmri_window, + hidden_channels=[512*fmri_window, embed_dim*2, embed_dim], activation_layer=torch.nn.ReLU, bias=True, dropout=0 @@ -33,20 +72,28 @@ def __init__(self, vivit_pretrained="google/vivit-b-16x2-kinetics400", embed_dim self.temperature = temperature + if freeze_encoder: + for param in self.video_encoder.parameters(): + param.requires_grad = False + def image_processor(self): return self.vivit_processor - def encode_video(self, video): - return self.video_encoder(**video).pooler_output + return self.video_encoder(**video)[0][:, 0, :] + def encode_fmri(self, fmri): + if len(fmri.shape) > 2: + fmri = fmri.reshape(fmri.shape[0], fmri.shape[1]*fmri.shape[2]) + fmri_features = self.fmri_encoder(fmri) + return fmri_features def forward(self, video, fmri): video_features = self.video_encoder(**video) - video_features = self.video_projection(video_features.pooler_output) + video_features = self.video_projection(video_features[0][:, 0, :]) - fmri_features = self.fmri_encoder(fmri) + fmri_features = self.encode_fmri(fmri) video_features = F.normalize(video_features, dim=-1) fmri_features = F.normalize(fmri_features, dim=-1) @@ -60,4 +107,29 @@ def forward(self, video, fmri): F.cross_entropy(logits_video, labels) + F.cross_entropy(logits_fmri, labels) ) / 2 - return loss + + return loss, video_features, fmri_features + + +class VivitConvContrastive(VivitMLPContrastive): + def __init__(self, vivit_pretrained="google/vivit-b-16x2-kinetics400", embed_dim=128, + temperature=1.0, fmri_window=1, torch_dtype=torch.float32, freeze_encoder=False): + super().__init__(vivit_pretrained, embed_dim, temperature, fmri_window, torch_dtype, freeze_encoder) + + self.fmri_encoder = nn.Sequential( + nn.Conv1d(1000, 1000, kernel_size=3, stride=1, padding=0, groups=1000), + nn.ReLU(), + nn.Conv1d(1000, 1000, kernel_size=3, stride=1, padding=0, groups=1000), + nn.ReLU(), + nn.AdaptiveAvgPool1d(1), + nn.Flatten(), + nn.Linear(1000, 512), + nn.ReLU(), + nn.Linear(512, embed_dim), + nn.ReLU(), + nn.Linear(embed_dim, embed_dim), + ) + + def encode_fmri(self, fmri): + fmri_features = self.fmri_encoder(fmri.permute(0, 2, 1)) + return fmri_features \ No newline at end of file diff --git a/src/scripts/collect_features/collect_vivit_activations.py b/src/scripts/collect_features/collect_vivit_activations.py new file mode 100644 index 0000000..a21d735 --- /dev/null +++ b/src/scripts/collect_features/collect_vivit_activations.py @@ -0,0 +1,161 @@ +""" +Author: Carlo Alberto Barbano +Date: 14/04/25 +""" +import models +import torch +import torch.multiprocessing as mp +import argparse +import os +import numpy as np +import logging + +from torchcodec.decoders import VideoDecoder +from glob import glob +from natsort import natsorted +from tqdm import tqdm + +from data.friends import load_video_chunk + +# mp.set_start_method("spawn", force=True) + +TORCHCODEC_DEVICE = "cpu" + +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + + +class FriendsStimuliVideoDataset(torch.utils.data.Dataset): + def __init__(self, root, transform, tr=1.49, timesample=1, target_video_len=32, downsampled=True, + stimulus_window=1, hrf_delay=0): + self.root = root + self.transform = transform + self.timesample = timesample + self.target_video_len = target_video_len + self.tr = tr + self.stimulus_window = stimulus_window + self.hrf_delay = hrf_delay + + data_dir = os.path.join(self.root, "algonauts_2025.competitors/stimuli/movies/friends/**/*.mkv") + if downsampled: + data_dir = os.path.join(self.root, "algonauts_2025.competitors/stimuli/movies_224/friends/**/*.mkv") + self.movies = natsorted(glob(data_dir)) + + self.chunks = [] + self.chunk_idx_to_movie_idx = {} + for movie_idx, movie in enumerate(self.movies): + decoder = VideoDecoder(movie, device=TORCHCODEC_DEVICE) + duration = decoder.metadata.duration_seconds + num_chunks = int(round(duration / self.tr)) + print(f"Movie {movie} has {duration:.2f} seconds and {num_chunks:.2f} chunks of {self.tr:.2f} seconds") + + # map the chunk index to the movie index + offset = len(self.chunk_idx_to_movie_idx) + + for i in range(num_chunks): + self.chunk_idx_to_movie_idx[i + offset] = (movie_idx, i) + + self.chunks.append(num_chunks) + + assert sum(self.chunks) == len(self.chunk_idx_to_movie_idx), f"The number of chunks ({sum(self.chunks)}) does not match the number of movies ({len(self.chunk_idx_to_movie_idx)})" + + def __len__(self): + return sum(self.chunks) + + def __getitem__(self, idx): + movie_idx, chunk_idx = self.chunk_idx_to_movie_idx[idx] + movie_path = self.movies[movie_idx] + + video_data = load_video_chunk(movie_path, chunk_idx, self.tr, self.target_video_len, self.transform, + stimulus_window=self.stimulus_window, hrf_delay=self.hrf_delay) + return video_data, movie_idx, chunk_idx + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str) + parser.add_argument('--output_dir', type=str, help="output directory (only if weights is None)") + parser.add_argument('--data_dir', type=str) + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--batch_size', type=int, default=10) + parser.add_argument('--num_workers', type=int, default=8) + parser.add_argument('--downsampled', action='store_true', help="Use downsampled videos (224x224)") + parser.add_argument('--amp', action='store_true', help="Use automatic mixed precision") + parser.add_argument('--stimulus_window', type=int, default=1) + parser.add_argument('--hrf_delay', type=int, default=0) + args = parser.parse_args() + + # Load the model + if args.weights: + checkpoint = torch.load(args.weights, map_location=args.device, weights_only=False) + model = models.vivit.VivitMLPContrastive(embed_dim=checkpoint['opts'].embed_dim, + temperature=checkpoint['opts'].temperature, + fmri_window=checkpoint['opts'].fmri_window,) + model.load_state_dict(checkpoint['model']) + model = model.to(args.device) + image_processor = model.image_processor() + print("Model loaded from", args.weights) + else: + model = models.vivit.VivitMLPContrastive(embed_dim=128, temperature=1.) + image_processor = model.image_processor() + model = model.to(args.device) + print("Model initialized") + + dataset = FriendsStimuliVideoDataset(args.data_dir, transform=image_processor, + downsampled=args.downsampled, + stimulus_window=args.stimulus_window, + hrf_delay=args.hrf_delay) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, + pin_memory=True) + print("Dataset loaded. Tot chunks:", len(dataset)) + + output_dir = f"features_w{args.stimulus_window}_hrf{args.hrf_delay}/friends/" + if args.weights: + output_dir = os.path.join(os.path.dirname(args.weights), output_dir) + else: + output_dir = os.path.join(args.output_dir, output_dir) + os.makedirs(output_dir, exist_ok=True) + + curr_movie_idx = 0 + curr_movie_name = os.path.basename(dataset.movies[0].replace(".mkv", ".pth")) + prev_chunk = -1 + episode_features = [] + + model.eval() + for idx, (video, movie_idx, chunk_idx) in enumerate(tqdm(dataloader)): + video = video.to(args.device, non_blocking=True) + video['pixel_values'] = video['pixel_values'].squeeze(1) + + with torch.amp.autocast("cuda", enabled=args.amp): + features = model.encode_video(video) + + for features_, movie_idx_, chunk_idx_ in zip(features, movie_idx, chunk_idx): + if movie_idx_ != curr_movie_idx: + print(curr_movie_name, curr_movie_name[9:11]) + season = int(curr_movie_name[9:11]) + + episode_features = torch.stack(episode_features, dim=0) + season_path = os.path.join(output_dir, f"s{season}") + os.makedirs(season_path, exist_ok=True) + + episode_path = os.path.join(season_path, curr_movie_name) + logging.info(f"Saving features for episode {curr_movie_name} to: {episode_path} (shape: {episode_features.shape})") + + torch.save(episode_features.cpu(), episode_path) + episode_features = [] + prev_chunk = -1 + curr_movie_name = os.path.basename(dataset.movies[movie_idx_.item()]).replace(".mkv", ".pth") + + + episode_features.append(features_) + + assert chunk_idx_ > prev_chunk + curr_movie_idx = movie_idx_ + prev_chunk = chunk_idx_ + + +if __name__ == '__main__': + main() diff --git a/src/scripts/downsample_videos.py b/src/scripts/downsample_videos.py new file mode 100644 index 0000000..0ae150e --- /dev/null +++ b/src/scripts/downsample_videos.py @@ -0,0 +1,46 @@ +""" +Author: Carlo Alberto Barbano +Date: 17/04/25 +""" +import argparse +import os +import multiprocessing +import subprocess + +from glob import glob +from natsort import natsorted +from tqdm import tqdm +from functools import partial + + +def downsample_movie(movie, target, size=224): + target_path = movie.replace(f"movies/{target}", f"movies_{size}/{target}") + os.makedirs(os.path.dirname(target_path), exist_ok=True) + + command = f"ffmpeg -y -i {movie} -vf scale={size}:{size} {target_path}" + + with open(os.devnull, "wb") as devnull: + res = subprocess.call(command.split(" "), stdout=devnull, stderr=devnull) + + if res != 0: + print(f"Error processing {movie}") + return + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str) + parser.add_argument('--target', choices=["friends", "movie10"], default="friends") + parser.add_argument('--size', type=int, default=224) + args = parser.parse_args() + + movies = natsorted(glob(os.path.join(args.data_dir, f"algonauts_2025.competitors/stimuli/movies/{args.target}/**/*.mkv"))) + + downsample_fn = partial(downsample_movie, target=args.target, size=args.size) + with multiprocessing.Pool() as pool: + for _ in tqdm(pool.imap_unordered(downsample_fn, movies), total=len(movies)): + pass + + +if __name__ == '__main__': + main() diff --git a/src/train_contrastive.py b/src/scripts/predict/train_predictor.py similarity index 52% rename from src/train_contrastive.py rename to src/scripts/predict/train_predictor.py index 5676606..21eae1c 100644 --- a/src/train_contrastive.py +++ b/src/scripts/predict/train_predictor.py @@ -1,47 +1,48 @@ """ Author: Carlo Alberto Barbano -Date: 11/04/25 +Date: 20/04/25 """ import argparse import os import math import time -import shutil import datetime import torch +import torch.nn.functional as F import torch.utils.data import torch.utils.tensorboard import wandb +import numpy as np import util import models -from data.friends import FriendsDataset -from util import warmup_learning_rate, adjust_learning_rate, save_model +from data.friends import FriendsFeatureDataset +from util import warmup_learning_rate, adjust_learning_rate, save_model, torch_pearsonr def parse_args(): - parser = argparse.ArgumentParser(description="Train a contrastive video-fmri model on friends dataset", + parser = argparse.ArgumentParser(description="Train movie features->fmri predictors on friends dataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the dataset") - parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the trained model") + parser.add_argument("--features_dir", type=str, required=True, help="Directory containing the features") + parser.add_argument("--log_dir", type=str, required=True, help="Directory to save the logs") # misc parser.add_argument('--device', help="device to use", type=str, default='cuda') parser.add_argument('--trial', help="random seed / trial id", type=int, default=0) parser.add_argument('--amp', action='store_true', help="use automatic mixed precision") - parser.add_argument('--print_freq', type=int, help='print frequency', default=1) + parser.add_argument('--print_freq', type=int, help='print frequency', default=50) parser.add_argument('--minibatch_log_freq', type=int, help='minibatch log frequency', default=50) - parser.add_argument('--restore', action="store_true", help='restore training') # model - parser.add_argument('--model', help="model to use", type=str, default='vivit-mlp') - parser.add_argument('--embed_dim', help="embedding dimension", type=int, default=128) - parser.add_argument('--temperature', help="temperature for clip loss", type=float, default=1.0) + parser.add_argument('--model', help="model to use", type=str, default='mlp-small') + parser.add_argument('--dropout', type=float, help="dropout", default=0.0) + parser.add_argument('--stimulus_window', type=int, help="width of stimulus window (num. of chunks)", default=1) # optimization - parser.add_argument('--optimizer', help="optimizer to use", type=str, default='adamw') + parser.add_argument('--optimizer', help="optimizer to use", type=str, default='adam') parser.add_argument('--lr', help="learning rate", type=float, default=1e-3) parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='step') parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='decay rate for learning rate (for step)') @@ -52,7 +53,6 @@ def parse_args(): parser.add_argument('--weight_decay', help="weight decay", type=float, default=1e-5) parser.add_argument('--batch_size', help="batch size", type=int, default=256) parser.add_argument('--epochs', help="number of epochs", type=int, default=10) - parser.add_argument('--timesample', help="time downsample factor (reduce memory)", type=int, default=1) opts = parser.parse_args() @@ -84,9 +84,12 @@ def parse_args(): def load_model(opts): - if opts.model == "vivit-mlp": - model = models.vivit.VivitMLPContrastive(embed_dim=opts.embed_dim, temperature=opts.temperature).to(opts.device) - return model.image_processor(), model + if opts.model == "mlp-small": + return models.predictors.MLPSmall( + 768 * opts.stimulus_window, + 1000, + opts.dropout + ).to(opts.device) raise ValueError(f"Model not recognized {opts.model}") @@ -103,24 +106,26 @@ def load_optimizer(model, opts): def train(model, dataloader, optimizer, opts, epoch, writer): loss = util.AverageMeter() + # corr = util.AverageMeter() batch_time = util.AverageMeter() data_time = util.AverageMeter() scaler = torch.amp.GradScaler("cuda", enabled=opts.amp) + all_outputs = [] + all_labels = [] + model.train() t1 = time.time() - for idx, (video, fmri) in enumerate(dataloader): - video, fmri = video.to(opts.device), fmri.to(opts.device) + for idx, (features, fmri) in enumerate(dataloader): + features, fmri = features.to(opts.device), fmri.to(opts.device) data_time.update(time.time() - t1) - video['pixel_values'] = video['pixel_values'].squeeze(1) - # print("Video shape:", video['pixel_values'].shape) - bsz = video['pixel_values'].shape[0] + bsz = features.shape[0] warmup_learning_rate(opts, epoch, idx, len(dataloader), optimizer) with torch.amp.autocast("cuda", enabled=opts.amp): - running_loss = model(video, fmri) + running_loss, outputs = model(features, fmri) optimizer.zero_grad() if opts.amp: @@ -132,68 +137,88 @@ def train(model, dataloader, optimizer, opts, epoch, writer): optimizer.step() loss.update(running_loss.item(), bsz) + # corr.update(running_r.item(), bsz) batch_time.update(time.time() - t1) t1 = time.time() eta = batch_time.avg * (len(dataloader) - idx) + all_outputs.append(outputs.detach()) + all_labels.append(fmri) if (idx + 1) % opts.print_freq == 0: print(f"Train: [{epoch}][{idx + 1}/{len(dataloader)}]:\t" + f"DT {data_time.avg:.3f}\t" f"BT {batch_time.avg:.3f}\t" f"ETA {datetime.timedelta(seconds=eta)}\t" - f"loss {loss.avg:.3f}\t") + f"loss {loss.avg:.3f}") - if (idx + 1) % opts.minibatch_log_freq == 0 or idx == 0: - writer.add_scalar("train/MB_loss", loss.avg, idx + epoch * len(dataloader)) - writer.add_scalar("MB_lr", optimizer.param_groups[0]['lr'], idx + epoch * len(dataloader)) - writer.add_scalar("MB_BT", batch_time.avg, idx + epoch * len(dataloader)) - writer.add_scalar("MB_DT", data_time.avg, idx + epoch * len(dataloader)) - writer.add_scalar("MB_step", idx + epoch * len(dataloader), idx + epoch * len(dataloader)) + # if (idx + 1) % opts.minibatch_log_freq == 0 or idx == 0: + # writer.add_scalar("train/MB_loss", loss.avg, idx + epoch * len(dataloader)) + # writer.add_scalar("MB_lr", optimizer.param_groups[0]['lr'], idx + epoch * len(dataloader)) + # writer.add_scalar("MB_BT", batch_time.avg, idx + epoch * len(dataloader)) + # writer.add_scalar("MB_DT", data_time.avg, idx + epoch * len(dataloader)) + # writer.add_scalar("MB_step", idx + epoch * len(dataloader), idx + epoch * len(dataloader)) - return loss.avg, batch_time.avg, data_time.avg + all_outputs = torch.cat(all_outputs, dim=0) + all_labels = torch.cat(all_labels, dim=0) + r = torch_pearsonr(all_outputs, all_labels).item() + print("r:", r) -def main(): - opts = parse_args() - util.set_seed(opts.trial) + return loss.avg, r, batch_time.avg, data_time.avg - run_name = (f"{opts.model}_{opts.optimizer}_lr{opts.lr}_decay{opts.lr_decay}_" - f"wd{opts.weight_decay}_bsz{opts.batch_size}_ts{opts.timesample}_" - f"epochs{opts.epochs}_s{opts.trial}") - tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name) - save_dir = os.path.join(opts.save_dir, "models", run_name) - opts.save_dir = save_dir - os.makedirs(tb_dir, exist_ok=True) - os.makedirs(save_dir, exist_ok=True) - wandb.init(project="algonauts-challenge-2025", name=run_name, config=opts, sync_tensorboard=True) - writer = torch.utils.tensorboard.SummaryWriter(tb_dir) +@torch.inference_mode() +def test(model, dataloader, optimizer, opts, epoch, writer): + model.eval() - preprocess, model = load_model(opts) - optimizer = load_optimizer(model, opts) + all_outputs = [] + all_labels = [] - # Load dataset - dataset = FriendsDataset(root=opts.data_dir, timesample=opts.timesample, image_transform=preprocess) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=8, - pin_memory=True, prefetch_factor=2) + for idx, (features, fmri) in enumerate(dataloader): + features, fmri = features.to(opts.device), fmri.to(opts.device) - save_file = os.path.join(save_dir, "weights.pth") - start_epoch = 1 - if opts.restore: - print("Restoring training....") - print("Attempting to load", save_file) + with torch.amp.autocast("cuda", enabled=opts.amp): + _, outputs = model(features, fmri) + + all_outputs.append(outputs.cpu()) + all_labels.append(fmri.cpu()) + + all_outputs = torch.cat(all_outputs, dim=0) + all_labels = torch.cat(all_labels, dim=0) - checkpoint = torch.load(save_file, map_location=opts.device) - if checkpoint['epoch'] >= opts.epochs: - print(f"Model already trained for {checkpoint['epoch']} epochs") - exit(0) + mae = F.l1_loss(all_outputs, all_labels) - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - start_epoch = checkpoint['epoch'] + 1 - print(f"Restored model from epoch {start_epoch}") + # compute average correlation + r = torch_pearsonr(all_outputs, all_labels).item() + print("test r:", r) - # Copy old file to weights.pth.{epoch} - shutil.copyfile(save_file, f"{save_file}.{checkpoint['epoch']}") + return mae, r + + +def run_training(opts, subject, writer): + model = load_model(opts) + optimizer = load_optimizer(model, opts) + + trainable_parameters = filter(lambda p: p.requires_grad, model.parameters()) + tot_trainable = sum([np.prod(p.size()) for p in trainable_parameters]) + tot_parameters = sum([np.prod(p.size()) for p in model.parameters()]) + print("Total parameters:", tot_parameters, "Trainable parameters:", tot_trainable) + + # Load dataset + train_dataset = FriendsFeatureDataset(root=opts.data_dir, features_root=opts.features_dir, + subjects=[subject], seasons=[1,2,3,4,5], + stimulus_window=opts.stimulus_window) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, + shuffle=True, num_workers=8, pin_memory=True) + + test_dataset = FriendsFeatureDataset(root=opts.data_dir, features_root=opts.features_dir, + subjects=[subject], seasons=[6], + stimulus_window=opts.stimulus_window) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=opts.batch_size, + shuffle=False, num_workers=8, pin_memory=True) + + save_file = os.path.join(opts.save_dir, f"predictor_sub-{subject}.pth") + start_epoch = 1 print('Config:', opts) print('Model:', opts.model, model.__class__.__name__) @@ -206,17 +231,54 @@ def main(): adjust_learning_rate(opts, optimizer, epoch) t1 = time.time() - loss, batch_time, data_time = train(model, dataloader, optimizer, opts, epoch, writer) + loss, corr, batch_time, data_time = train(model, train_loader, optimizer, opts, epoch, writer) t2 = time.time() - writer.add_scalar("train/loss", loss, epoch) + test_mae, test_r = test(model, test_loader, optimizer, opts, epoch, writer) + + writer.add_scalar(f"sub-{subject}/train/loss", loss, epoch) + writer.add_scalar(f"sub-{subject}/train/r", corr, epoch) + writer.add_scalar(f"sub-{subject}/test/MAE", test_mae, epoch) + writer.add_scalar(f"sub-{subject}/test/r", test_r, epoch) writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch) writer.add_scalar("BT", batch_time, epoch) writer.add_scalar("DT", data_time, epoch) writer.add_scalar("epoch", epoch, epoch) - print(f"epoch {epoch}, total time {t2 - start_time:.2f}, epoch time {t2 - t1:.3f} loss {loss:.4f}") + print(f"epoch {epoch}, total time {t2 - start_time:.2f}, epoch time {t2 - t1:.3f} " + f"loss {loss:.4f} train r {corr:.4f} - test MAE {test_mae:.4f} test r {test_r:.4f}") + + save_model(model, None, None, opts, epoch, save_file) + +def main(): + opts = parse_args() + util.set_seed(opts.trial) + + run_name = (f"predictor_{opts.model}_w{opts.stimulus_window}_{opts.optimizer}_lr{opts.lr}_" + f"decay{opts.lr_decay}_wd{opts.weight_decay}_bsz{opts.batch_size}_dropout{opts.dropout}_" + f"epochs{opts.epochs}_s{opts.trial}") + + tb_dir = os.path.join(opts.log_dir, "tensorboard", run_name) + save_dir = os.path.join(opts.features_dir, "predictors", run_name) + opts.save_dir = save_dir + os.makedirs(tb_dir, exist_ok=True) + os.makedirs(save_dir, exist_ok=True) + + print("Saving weights to", save_dir) + print("Saving logs to", tb_dir) + + wandb.init(project="algonauts-challenge-2025", name=run_name, config=opts, sync_tensorboard=True) + writer = torch.utils.tensorboard.SummaryWriter(tb_dir) + + packages = util.get_packages_versions() + wandb.config.update({"env": packages}) + print("Packages:") + for k, v in packages.items(): + print(f"{k}=={v}") + + for subject in [1,2,3,5]: + print(f"Training subject {subject}") + run_training(opts, subject, writer) - save_model(model, optimizer, opts, epoch, save_file) if __name__ == '__main__': main() diff --git a/src/scripts/preprocess_friends_vivit.py b/src/scripts/preprocess_friends_vivit.py new file mode 100644 index 0000000..56a3b93 --- /dev/null +++ b/src/scripts/preprocess_friends_vivit.py @@ -0,0 +1,112 @@ +""" +Author: Carlo Alberto Barbano +Date: 17/04/25 +""" +import models +import torch +import argparse +import os + +from torchcodec.decoders import VideoDecoder +from glob import glob +from natsort import natsorted +from tqdm import tqdm +from transformers import VivitImageProcessor + + +class FriendsStimuliVideoDataset(torch.utils.data.Dataset): + def __init__(self, root, transform, tr=1.49, timesample=1, target_video_len=32): + self.root = root + self.transform = transform + self.timesample = timesample + self.target_video_len = target_video_len + self.tr = tr + + self.movies = natsorted(glob(os.path.join(self.root, "algonauts_2025.competitors/stimuli/movies/friends/**/*.mkv"))) + + self.chunks = [] + self.chunk_idx_to_movie_idx = {} + for movie_idx, movie in enumerate(self.movies): + decoder = VideoDecoder(movie, device="cpu") + duration = decoder.metadata.duration_seconds + num_chunks = int(round(duration / self.tr)) + print(f"Movie {movie} has {duration:.2f} seconds and {num_chunks:.2f} chunks of {self.tr:.2f} seconds") + + # map the chunk index to the movie index + offset = len(self.chunk_idx_to_movie_idx) + + for i in range(num_chunks): + self.chunk_idx_to_movie_idx[i + offset] = (movie_idx, i) + + self.chunks.append(num_chunks) + + assert sum(self.chunks) == len(self.chunk_idx_to_movie_idx), f"The number of chunks ({sum(self.chunks)}) does not match the number of movies ({len(self.chunk_idx_to_movie_idx)})" + + def __len__(self): + return sum(self.chunks) + + def __getitem__(self, idx): + movie_idx, chunk_idx = self.chunk_idx_to_movie_idx[idx] + movie_path = self.movies[movie_idx] + + start_t = chunk_idx * self.tr + end_t = (chunk_idx + 1) * self.tr + + decoder = VideoDecoder(movie_path, device="cpu") + if end_t > decoder.metadata.duration_seconds: + end_t = decoder.metadata.duration_seconds + + chunk = decoder.get_frames_played_in_range(start_t, end_t).data + + # if the chunk is shorter than self.target_video_len, pad it + if len(chunk) < self.target_video_len: + # Pad with last frame + chunk = torch.cat([chunk, chunk[-1].unsqueeze(0).expand(self.target_video_len - len(chunk), -1, -1, -1)]) + + # if the chunk is longer than self.target_video_len, take last N frames + if len(chunk) > self.target_video_len: + chunk = chunk[-self.target_video_len:] + + lst = torch.split(chunk, 1, 0) + lst = [l[0] for l in lst] + video_data = self.transform(lst, return_tensors="pt") + + return video_data, movie_idx + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str) + parser.add_argument('--save_dir', type=str) + parser.add_argument('--vivit_pretrained', type=str, default="google/vivit-b-16x2-kinetics400") + args = parser.parse_args() + + os.makedirs(args.save_dir, exist_ok=True) + + vivit_processor = VivitImageProcessor.from_pretrained(args.vivit_pretrained) + + dataset = FriendsStimuliVideoDataset(args.data_dir, transform=vivit_processor) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True) + print("Dataset loaded. Tot chunks:", len(dataset)) + + chunk_idx = 0 + last_movie_idx = 0 + + for idx, (video, movie_idx) in enumerate(tqdm(dataloader)): + video = video['pixel_values'].squeeze(1) + + for chunk, chunk_movie_idx in zip(video, movie_idx): + movie_name = os.path.basename(dataset.movies[chunk_movie_idx.item()]) + output_file = os.path.join(args.save_dir, f"{movie_name}_chunk_{chunk_idx:05d}.pt") + torch.save(chunk, output_file) + + chunk_idx += 1 + if chunk_movie_idx != last_movie_idx: + print(f"Saved {movie_name}") + last_movie_idx = movie_idx + + + +if __name__ == '__main__': + main() diff --git a/src/scripts/pretrain/pretrain_video.py b/src/scripts/pretrain/pretrain_video.py new file mode 100644 index 0000000..de55300 --- /dev/null +++ b/src/scripts/pretrain/pretrain_video.py @@ -0,0 +1,375 @@ +""" +Author: Carlo Alberto Barbano +Date: 11/04/25 +""" +import argparse +import os +import math +import time +import shutil +import datetime + +import torch +import torch.utils.data +import torch.utils.tensorboard +import wandb +import numpy as np + +import util +import models + +from data.friends import FriendsDataset +from util import warmup_learning_rate, adjust_learning_rate, save_model + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a contrastive video-fmri model on friends dataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the dataset") + parser.add_argument("--save_dir", type=str, required=True, help="Directory to save the trained model") + parser.add_argument('--downsampled', action='store_true', help='Use downsampled videos') + + # misc + parser.add_argument('--device', help="device to use", type=str, default='cuda') + parser.add_argument('--trial', help="random seed / trial id", type=int, default=0) + parser.add_argument('--amp', action='store_true', help="use automatic mixed precision") + parser.add_argument('--print_freq', type=int, help='print frequency', default=1) + parser.add_argument('--minibatch_log_freq', type=int, help='minibatch log frequency', default=50) + parser.add_argument('--restore', action="store_true", help='restore training') + + # model + parser.add_argument('--model', help="model to use", type=str, default='vivit-mlp') + parser.add_argument('--freeze_encoder', action='store_true', help='freeze encoder') + parser.add_argument('--n_frames', type=int, default=32, help="number of frames to use") + parser.add_argument('--embed_dim', help="embedding dimension", type=int, default=128) + parser.add_argument('--temperature', help="temperature for clip loss", type=float, default=1.0) + parser.add_argument('--stimulus_window', help="stimulus window", type=int, default=1) + parser.add_argument('--hrf_delay', help="hrf delay", type=int, default=0) + parser.add_argument('--fmri_window', help="fmri window", type=int, default=1) + parser.add_argument('--subjects', help="subjects to use", type=int, nargs='+', default=[1,2,3,5]) + parser.add_argument('--encode_subject_id', action='store_true', help='encode subject id in the model') + parser.add_argument('--train_seasons', help="seasons to use for training", type=int, nargs='+', + default=[1, 2, 3, 4, 5]) + parser.add_argument('--test_seasons', help="seasons to use for testing", type=int, nargs='+', default=[6]) + + # optimization + parser.add_argument('--method', type=str, choices=["mae", "mse", "clip"], default="clip", help="method to use") + parser.add_argument('--optimizer', help="optimizer to use", type=str, default='adamw') + parser.add_argument('--lr', help="learning rate", type=float, default=1e-3) + parser.add_argument('--lr_decay', type=str, help='type of decay', choices=['cosine', 'step'], default='step') + parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='decay rate for learning rate (for step)') + parser.add_argument('--lr_decay_epochs', type=str, help='steps of lr decay (list)', default="7,8,9") + parser.add_argument('--lr_decay_step', type=int, help='decay rate step (overwrites lr_decay_epochs)', default=10) + parser.add_argument('--warm', action='store_true', help='warmup learning rate') + parser.add_argument('--momentum', type=float, help='momentum', default=0.9) + parser.add_argument('--weight_decay', help="weight decay", type=float, default=1e-5) + parser.add_argument('--batch_size', help="batch size", type=int, default=256) + parser.add_argument('--epochs', help="number of epochs", type=int, default=10) + parser.add_argument('--timesample', help="time downsample factor (reduce memory)", type=int, default=1) + + opts = parser.parse_args() + + if opts.batch_size > 256: + print("Forcing warm") + opts.warm = True + + if opts.lr_decay_step is not None: + opts.lr_decay_epochs = list(range(opts.lr_decay_step, opts.epochs, opts.lr_decay_step)) + print(f"Computed decay epochs based on step ({opts.lr_decay_step}):", opts.lr_decay_epochs) + else: + iterations = opts.lr_decay_epochs.split(',') + opts.lr_decay_epochs = list([]) + for it in iterations: + opts.lr_decay_epochs.append(int(it)) + + if opts.warm: + opts.warmup_from = 0.01 + opts.warm_epochs = 10 + if opts.lr_decay == 'cosine': + eta_min = opts.lr * (opts.lr_decay_rate ** 3) + opts.warmup_to = eta_min + (opts.lr - eta_min) * ( + 1 + math.cos(math.pi * opts.warm_epochs / opts.epochs)) / 2 + else: + opts.milestones = [int(s) for s in opts.lr_decay_epochs] + opts.warmup_to = opts.lr + + if opts.method == "clip" and opts.model not in ["vivit-mlp", "vivit-conv1d"]: + raise ValueError("Model not compatible with CLIP loss") + + if opts.model in ["vivit-mlp", "vivit-conv1d"] and opts.method != "clip": + raise ValueError("Model not compatible with MAE loss") + + return opts + + +def load_model(opts): + if opts.model == "vivit-mlp": + model = models.vivit.VivitMLPContrastive(embed_dim=opts.embed_dim, temperature=opts.temperature, + fmri_window=opts.fmri_window, freeze_encoder=opts.freeze_encoder).to(opts.device) + return model.image_processor(), model + + elif opts.model == "vivit-conv1d": + model = models.vivit.VivitConvContrastive(embed_dim=opts.embed_dim, temperature=opts.temperature, + fmri_window=opts.fmri_window, freeze_encoder=opts.freeze_encoder).to(opts.device) + return model.image_processor(), model + + elif opts.model == "vivit": + model = models.vivit.VivitRegression(criterion=opts.method, freeze_encoder=opts.freeze_encoder).to(opts.device) + return model.image_processor(), model + + elif opts.model == "videomae": + model = models.videomae.VideoMAERegression(criterion=opts.method, freeze_encoder=opts.freeze_encoder, + use_subject_id=opts.encode_subject_id).to(opts.device) + return model.image_processor(), model + + raise ValueError(f"Model not recognized {opts.model}") + + +def load_optimizer(model, opts): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + + if opts.optimizer == "adam": + return torch.optim.Adam(parameters, lr=opts.lr, weight_decay=opts.weight_decay) + elif opts.optimizer == "adamw": + return torch.optim.AdamW(parameters, lr=opts.lr, betas=(0.9, 0.999), eps=1e-8) #, weight_decay=opts.weight_decay) + elif opts.optimizer == "sgd": + return torch.optim.SGD(parameters, lr=opts.lr, weight_decay=opts.weight_decay, momentum=opts.momentum) + + raise ValueError("Optimizer not recognized") + + +def train(model, dataloader, optimizer, opts, epoch, writer, scaler): + loss = util.AverageMeter() + batch_time = util.AverageMeter() + data_time = util.AverageMeter() + acc = util.AverageMeter() + + model.train() + + t1 = time.time() + for idx, (video, fmri, subjects) in enumerate(dataloader): + video, fmri = video.to(opts.device), fmri.to(opts.device) + subjects = subjects.to(opts.device) + data_time.update(time.time() - t1) + + video['pixel_values'] = video['pixel_values'].squeeze(1) + # print("Video shape:", video['pixel_values'].shape) + bsz = video['pixel_values'].shape[0] + warmup_learning_rate(opts, epoch, idx, len(dataloader), optimizer) + + with torch.amp.autocast("cuda", enabled=opts.amp): + outputs = model(video, fmri, subjects.half()) + running_loss = outputs[0] + + optimizer.zero_grad() + if opts.amp: + scaler.scale(running_loss).backward() + scaler.step(optimizer) + scaler.update() + else: + running_loss.backward() + optimizer.step() + + loss.update(running_loss.item(), bsz) + batch_time.update(time.time() - t1) + t1 = time.time() + eta = batch_time.avg * (len(dataloader) - idx) + + if opts.method != "clip": + r = util.torch_pearsonr(outputs[1], fmri) + acc.update(r.item(), bsz) + + if (idx + 1) % opts.print_freq == 0: + print(f"Train: [{epoch}][{idx + 1}/{len(dataloader)}]:\t" + f"DT {data_time.avg:.3f}\t" + f"BT {batch_time.avg:.3f}\t" + f"ETA {datetime.timedelta(seconds=eta)}\t" + f"loss {loss.avg:.3f}\t" + f"r {acc.avg:.3f}\t") + + if (idx + 1) % opts.minibatch_log_freq == 0 or idx == 0: + writer.add_scalar("train/MB_loss", loss.avg, idx + epoch * len(dataloader)) + writer.add_scalar("train/MB_r", acc.avg, idx + epoch * len(dataloader)) + writer.add_scalar("MB_lr", optimizer.param_groups[0]['lr'], idx + epoch * len(dataloader)) + writer.add_scalar("MB_BT", batch_time.avg, idx + epoch * len(dataloader)) + writer.add_scalar("MB_DT", data_time.avg, idx + epoch * len(dataloader)) + writer.add_scalar("MB_step", idx + epoch * len(dataloader), idx + epoch * len(dataloader)) + + return loss.avg, batch_time.avg, data_time.avg + + +@torch.inference_mode() +def test(model, dataloader, opts, epoch, writer, scaler): + loss = util.AverageMeter() + batch_time = util.AverageMeter() + data_time = util.AverageMeter() + + all_outputs = [] + all_labels = [] + all_subjects = [] + + model.eval() + + t1 = time.time() + for idx, (video, fmri, subjects) in enumerate(dataloader): + video, fmri = video.to(opts.device), fmri.to(opts.device) + subjects = subjects.to(opts.device) + data_time.update(time.time() - t1) + + video['pixel_values'] = video['pixel_values'].squeeze(1) + bsz = video['pixel_values'].shape[0] + + with torch.amp.autocast("cuda", enabled=opts.amp): + running_loss, outputs = model(video, fmri, subjects.half()) + + all_outputs.append(outputs.detach()) + all_labels.append(fmri) + all_subjects.append(subjects.int()) + + loss.update(running_loss.item(), bsz) + batch_time.update(time.time() - t1) + t1 = time.time() + eta = batch_time.avg * (len(dataloader) - idx) + + if (idx + 1) % opts.print_freq == 0: + print(f"Test: [{epoch}][{idx + 1}/{len(dataloader)}]:\t" + f"DT {data_time.avg:.3f}\t" + f"BT {batch_time.avg:.3f}\t" + f"ETA {datetime.timedelta(seconds=eta)}\t" + f"loss {loss.avg:.3f}\t") + + if (idx + 1) % opts.minibatch_log_freq == 0 or idx == 0: + writer.add_scalar("test/MB_loss", loss.avg, idx + epoch * len(dataloader)) + + all_outputs = torch.cat(all_outputs, dim=0) + all_labels = torch.cat(all_labels, dim=0) + all_subjects = torch.cat(all_subjects, dim=0) + + r = {} + for subject in torch.unique(all_subjects.int()).tolist(): + r[f"sub{int(subject):02d}"] = util.torch_pearsonr(all_outputs[all_subjects == subject], all_labels[all_subjects == subject]) + print("r:", r) + + return r, loss.avg, batch_time.avg, data_time.avg + + +def main(): + opts = parse_args() + util.set_seed(opts.trial) + + run_name = (f"{opts.model}_{opts.method}_{'downsampled_' if opts.downsampled else ''}_nframes{opts.n_frames}_" + f"sub{''.join(str(s) for s in opts.subjects)}_id{opts.encode_subject_id}_" + f"w{opts.stimulus_window}_hrf{opts.hrf_delay}_fmriW{opts.fmri_window}_" + f"{opts.optimizer}_lr{opts.lr}_decay{opts.lr_decay}_" + f"wd{opts.weight_decay}_bsz{opts.batch_size}_ts{opts.timesample}_" + f"epochs{opts.epochs}_s{opts.trial}") + + tb_dir = os.path.join(opts.save_dir, "tensorboard", run_name) + save_dir = os.path.join(opts.save_dir, "models", run_name) + opts.save_dir = save_dir + os.makedirs(tb_dir, exist_ok=True) + os.makedirs(save_dir, exist_ok=True) + + print("Saving weights to", save_dir) + print("Saving logs to", tb_dir) + + if opts.restore: # change run name for wandb only (local files will be saved in the same folder) + run_name = f"{run_name}_restore" + + wandb.init(project="algonauts-challenge-2025", name=run_name, config=opts, sync_tensorboard=True) + writer = torch.utils.tensorboard.SummaryWriter(tb_dir) + + packages = util.get_packages_versions() + wandb.config.update({"env": packages}) + print("Packages:") + for k, v in packages.items(): + print(f"{k}=={v}") + + preprocess, model = load_model(opts) + if torch.cuda.device_count() > 1: + print(f"Using {torch.cuda.device_count()} GPUs") + model = torch.nn.DataParallel(model) + + optimizer = load_optimizer(model, opts) + scaler = torch.amp.GradScaler("cuda", enabled=opts.amp) + + trainable_parameters = filter(lambda p: p.requires_grad, model.parameters()) + tot_trainable = sum([np.prod(p.size()) for p in trainable_parameters]) + tot_parameters = sum([np.prod(p.size()) for p in model.parameters()]) + print("Total parameters:", tot_parameters, "Trainable parameters:", tot_trainable) + + # Load dataset + dataset = FriendsDataset(root=opts.data_dir, timesample=opts.timesample, image_transform=preprocess, + downsampled=opts.downsampled, stimulus_window=opts.stimulus_window, + hrf_delay=opts.hrf_delay, subjects=opts.subjects, seasons=opts.train_seasons, + fmri_window=opts.fmri_window, target_video_len=opts.n_frames) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, + num_workers=8 * torch.cuda.device_count(), + pin_memory=True, prefetch_factor=2) + + + test_dataset = FriendsDataset(root=opts.data_dir, timesample=opts.timesample, image_transform=preprocess, + downsampled=opts.downsampled, stimulus_window=opts.stimulus_window, + hrf_delay=opts.hrf_delay, subjects=opts.subjects, seasons=opts.test_seasons, + fmri_window=opts.fmri_window, target_video_len=opts.n_frames) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, + num_workers=8 * torch.cuda.device_count(), pin_memory=True, + prefetch_factor=2) + + save_file = os.path.join(save_dir, "weights.pth") + start_epoch = 1 + if opts.restore: + print("Restoring training....") + print("Attempting to load", save_file) + + checkpoint = torch.load(save_file, map_location=opts.device, weights_only=False) + if checkpoint['epoch'] >= opts.epochs: + print(f"Model already trained for {checkpoint['epoch']} epochs") + exit(0) + + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + scaler.load_state_dict(checkpoint['scaler']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Restored model from epoch {start_epoch}") + + # Copy old file to weights.pth.{epoch} + shutil.copyfile(save_file, f"{save_file}.{checkpoint['epoch']}") + del checkpoint + torch.cuda.empty_cache() + + print('Config:', opts) + print('Model:', opts.model, model.__class__.__name__) + print('Optimizer:', optimizer) + print('Scheduler:', opts.lr_decay) + print("CUDA available:", torch.cuda.is_available(), f"({torch.cuda.device_count()} devices)") + + start_time = time.time() + for epoch in range(start_epoch, opts.epochs + 1): + adjust_learning_rate(opts, optimizer, epoch) + + t1 = time.time() + loss, batch_time, data_time = train(model, dataloader, optimizer, opts, epoch, writer, scaler) + t2 = time.time() + + writer.add_scalar("train/loss", loss, epoch) + writer.add_scalar("lr", optimizer.param_groups[0]['lr'], epoch) + writer.add_scalar("BT", batch_time, epoch) + writer.add_scalar("DT", data_time, epoch) + writer.add_scalar("epoch", epoch, epoch) + print(f"epoch {epoch}, total time {t2 - start_time:.2f}, epoch time {t2 - t1:.3f} loss {loss:.4f}") + + if opts.method != "clip": + r, test_loss, _, _ = test(model, test_dataloader, opts, epoch, writer, scaler) + writer.add_scalar("test/loss", test_loss, epoch) + + for sub, r_value in r.items(): + writer.add_scalar(f"test/{sub}_r", r_value, epoch) + + save_model(model, optimizer, scaler, opts, epoch, save_file) + torch.cuda.empty_cache() + +if __name__ == '__main__': + main() + + diff --git a/src/util/__init__.py b/src/util/__init__.py new file mode 100644 index 0000000..e746e6a --- /dev/null +++ b/src/util/__init__.py @@ -0,0 +1,5 @@ +""" +Author: Carlo Alberto Barbano +Date: 21/04/25 +""" +from .util import * \ No newline at end of file diff --git a/src/util.py b/src/util/util.py similarity index 92% rename from src/util.py rename to src/util/util.py index aaf434a..f084635 100644 --- a/src/util.py +++ b/src/util/util.py @@ -10,6 +10,18 @@ from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score + +def get_packages_versions(): + try: + from pip._internal.operations import freeze + except ImportError: + from pip.operations import freeze + + + pkgs = {p[0]: p[1] for p in [pkg.split("==") for pkg in freeze.freeze() if "file://" not in pkg]} + return pkgs + + class NViewTransform: """Create N augmented views of the same image""" @@ -182,7 +194,7 @@ def set_seed(seed): torch.manual_seed(seed) -def save_model(model, optimizer, opts, epoch, save_file): +def save_model(model, optimizer, scaler, opts, epoch, save_file): print('==> Saving...') state_dict = model.state_dict() if torch.cuda.device_count() > 1: @@ -191,7 +203,8 @@ def save_model(model, optimizer, opts, epoch, save_file): state = { 'opts': opts, 'model': state_dict, - 'optimizer': optimizer.state_dict(), + 'optimizer': optimizer.state_dict() if optimizer else None, + 'scaler': scaler.state_dict() if scaler else None, 'epoch': epoch, 'run_id': wandb.run.id } @@ -324,3 +337,13 @@ def compute_site_ba(model, train_loader, test_int, test_ext, opts): ba_ext = site_estimator.score(ext_X, ext_y) return site_estimator, ba_train, ba_int, ba_ext + + +def torch_pearsonr(output, target): + x = output + y = target + + vx = x - torch.mean(x, dim=-1, keepdim=True) + vy = y - torch.mean(y, dim=-1, keepdim=True) + + return torch.mean(torch.sum(vx * vy, dim=-1) / (torch.sqrt(torch.sum(vx ** 2, dim=-1)) * torch.sqrt(torch.sum(vy ** 2, dim=-1)))) diff --git a/sweeps/bigmri_train_cv.yaml b/sweeps/bigmri_train_cv.yaml new file mode 100644 index 0000000..8791f4f --- /dev/null +++ b/sweeps/bigmri_train_cv.yaml @@ -0,0 +1,51 @@ +# submit --name algonauts-mae-ft-ds64-w5-hrf2-fw1-allsubseas --gpus 1 --host hssh3 +# eidos-service.di.unito.it/barbano/algonauts:latest -m scripts.pretrain.pretrain_video - +# -data_dir /scratch/data --save_dir /scratch/workspace/algonauts-challenge-2025/output +# --model videomae --method mae --amp --batch_size 64 --optimizer adamw --lr 1e-4 +# --downsampled --stimulus_window 5 --hrf_delay 2 --fmri_window 1 --subjects 1 2 3 5 +# --encode_subject_id --train_seasons 1 2 3 4 5 --n_frames 16 +program: scripts.pretrain.pretrain_video +project: algonauts-challenge-2025 +command: + - ${env} + - python3 + - -m + - ${program} + - ${args} + - --amp + - --downsampled + - --encode_subject_id + - --freeze_encoder +method: grid +name: videomae - hrf stimulus window +parameters: + data_dir: + value: /scratch/data + save_dir: + value: /scratch/algonauts-challenge-2025/output + model: + value: videomae + method: + value: mae + batch_size: + value: 256 + optimizer: + value: adamw + lr: + values: [1e-4, 1e-5] + stimulus_window: + values: [1, 5, 10, 15] + hrf_delay: + values: [0, 2, 4, 8] + fmri_window: + values: [1, 2, 4, 8] + subjects: + value: 1 + train_seasons: + value: 5 + n_frames: + values: [16, 32] + epochs: + value: 1 + + diff --git a/wandb/debug-cli.carloalbertobarbano.log b/wandb/debug-cli.carloalbertobarbano.log new file mode 100644 index 0000000..e69de29