Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Bug in TextMotionMatchTrainer #39

Open
mpiseno opened this issue Jan 19, 2024 · 0 comments
Open

Potential Bug in TextMotionMatchTrainer #39

mpiseno opened this issue Jan 19, 2024 · 0 comments

Comments

@mpiseno
Copy link

mpiseno commented Jan 19, 2024

I am training my own text and motion embedding models for evaluation. I noticed in the TextMotionMatchTrainer class, there is a potential bug in the shift applied to create negative examples for the contrastive loss.

def backward(self):

        batch_size = self.text_embedding.shape[0]
        '''Positive pairs'''
        pos_labels = torch.zeros(batch_size).to(self.text_embedding.device)
        self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels)

        '''Negative Pairs, shifting index'''
        neg_labels = torch.ones(batch_size).to(self.text_embedding.device)
        shift = np.random.randint(0, batch_size-1) # BUG
        new_idx = np.arange(shift, batch_size + shift) % batch_size
        self.mis_motion_embedding = self.motion_embedding.clone()[new_idx]
        self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels)
        self.loss = self.loss_pos + self.loss_neg

        loss_logs = OrderedDict({})
        loss_logs['loss'] = self.loss.item()
        loss_logs['loss_pos'] = self.loss_pos.item()
        loss_logs['loss_neg'] = self.loss_neg.item()
        return loss_logs

If we shift 0, then the "negative" examples with be compared to itself. This is especially problematic when training with low batch sizes, like in the README (batch size 8). The correction is

shift = np.random.randint(1, batch_size-1)

After doing this, I see improved training curves. Below, the grey curve is with the bug fix and the purple curve is with the original code. Batch size is 8.

Screenshot 2024-01-18 at 6 01 22 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant