You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on May 1, 2025. It is now read-only.
Thanks for your amazing work on ALBEF! While reviewing the paper and analyzing the provided implementation, I noticed a potential inconsistency between the paper's description of KL divergence computation and the implementation. I would appreciate it if you could clarify the following:
In the paper, the KL divergence loss is defined as:
Where:
q : Pseudo-targets generated by the momentum model.
p: Predictions from the student model.
KL divergence measures the difference between the pseudo-target q (from the momentum model) and the prediction p(from the student model).
In your code , the student's prediction p is calculated as follows:
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
Where:
image_feat and text_feat are the student model's embeddings.
text_feat_all and image_feat_all include a concatenation of momentum model outputs with the current batch.
This means that p is partially influenced by the momentum model outputs (via text_feat_all and image_feat_all), which seems inconsistent with the theoretical definition in the paper. In the paper p should be computed purely from the student model without any involvement of the momentum model.
so my question is:
Does this affect the optimization of KL divergence KL(q||p), given that p partially depends on q.
How does this influence the model's training dynamics, especially in the later stages where
α increases and the KL divergence term becomes dominant?