diff --git a/main_task_retrieval.py b/main_task_retrieval.py index 412a1ab..871c087 100644 --- a/main_task_retrieval.py +++ b/main_task_retrieval.py @@ -265,8 +265,6 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, input_ids, input_mask, segment_ids, video, video_mask = batch loss = model(input_ids, segment_ids, input_mask, video, video_mask) - if n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps diff --git a/modules/modeling.py b/modules/modeling.py index 698e888..ce4e25a 100644 --- a/modules/modeling.py +++ b/modules/modeling.py @@ -8,6 +8,7 @@ from torch import nn from modules.until_module import PreTrainedModel, AllGather, CrossEn +from torch.distributed import all_gather from modules.module_cross import CrossModel, CrossConfig, Transformer as TransformerClip from modules.module_clip import CLIP, convert_weights @@ -260,13 +261,19 @@ def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=N sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids, attention_mask, video, video_mask, shaped=True, video_frame=video_frame) - + positive_pos = 0 if self.training: loss = 0. sim_matrix, *_tmp = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True, loose_type=self.loose_type) - sim_loss1 = self.loss_fct(sim_matrix) - sim_loss2 = self.loss_fct(sim_matrix.T) + + # if train on multi-GPU, aligning the positive samples in local batch except 0th GPU + # Ensuring the tensor.diag() in loss_fn will get the right positive samples + if self.task_config.n_gpu != 1: + positive_pos = self.task_config.local_rank * sim_matrix[0].shape[0] + + sim_loss1 = self.loss_fct(sim_matrix[0], positive_pos) + sim_loss2 = self.loss_fct(sim_matrix[1], positive_pos) sim_loss = (sim_loss1 + sim_loss2) / 2 loss += sim_loss @@ -383,12 +390,6 @@ def _loose_similarity(self, sequence_output, visual_output, attention_mask, vide visual_output = visual_output.permute(1, 0, 2) # LND -> NLD visual_output = visual_output + visual_output_original - if self.training: - visual_output = allgather(visual_output, self.task_config) - video_mask = allgather(video_mask, self.task_config) - sequence_output = allgather(sequence_output, self.task_config) - torch.distributed.barrier() - visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) visual_output = self._mean_pooling_for_similarity_visual(visual_output, video_mask) visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True) @@ -397,6 +398,21 @@ def _loose_similarity(self, sequence_output, visual_output, attention_mask, vide sequence_output = sequence_output / sequence_output.norm(dim=-1, keepdim=True) logit_scale = self.clip.logit_scale.exp() + + # https://github.com/openai/CLIP/issues/132 + if self.training: + all_visual_output = [torch.empty_like(visual_output) for _ in range(self.task_config.world_size)] + all_gather(all_visual_output, visual_output) + all_visual_output = torch.cat(all_visual_output, dim=0) + all_sequence_output = [torch.empty_like(sequence_output) for _ in range(self.task_config.world_size)] + all_gather(all_sequence_output, sequence_output) + all_sequence_output = torch.cat(all_sequence_output, dim=0) + torch.distributed.barrier() + + retrieve_logits1 = logit_scale * torch.matmul(sequence_output, all_visual_output.t()) + retrieve_logits2 = logit_scale * torch.matmul(visual_output, all_sequence_output.t()) + return [retrieve_logits1, retrieve_logits2] + retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t()) return retrieve_logits diff --git a/modules/until_module.py b/modules/until_module.py index 5ae873a..2abd5c4 100644 --- a/modules/until_module.py +++ b/modules/until_module.py @@ -183,9 +183,9 @@ class CrossEn(nn.Module): def __init__(self,): super(CrossEn, self).__init__() - def forward(self, sim_matrix): + def forward(self, sim_matrix, positive_pos=0): logpt = F.log_softmax(sim_matrix, dim=-1) - logpt = torch.diag(logpt) + logpt = torch.diag(logpt, positive_pos) nce_loss = -logpt sim_loss = nce_loss.mean() return sim_loss