Skip to content
This repository was archived by the owner on May 1, 2025. It is now read-only.
This repository was archived by the owner on May 1, 2025. It is now read-only.

Question about KL Divergence Implementation in Momentum Distillation #148

@ThomaswellY

Description

@ThomaswellY

Thanks for your amazing work on ALBEF! While reviewing the paper and analyzing the provided implementation, I noticed a potential inconsistency between the paper's description of KL divergence computation and the implementation. I would appreciate it if you could clarify the following:
In the paper, the KL divergence loss is defined as:

Image

Where:
q : Pseudo-targets generated by the momentum model.
p: Predictions from the student model.
KL divergence measures the difference between the pseudo-target q (from the momentum model) and the prediction p(from the student model).
In your code , the student's prediction p is calculated as follows:
sim_i2t = image_feat @ text_feat_all / self.temp
sim_t2i = text_feat @ image_feat_all / self.temp
Where:
image_feat and text_feat are the student model's embeddings.
text_feat_all and image_feat_all include a concatenation of momentum model outputs with the current batch.
This means that p is partially influenced by the momentum model outputs (via text_feat_all and image_feat_all), which seems inconsistent with the theoretical definition in the paper. In the paper p should be computed purely from the student model without any involvement of the momentum model.
so my question is:
Does this affect the optimization of KL divergence KL(q||p), given that p partially depends on q.
How does this influence the model's training dynamics, especially in the later stages where
α increases and the KL divergence term becomes dominant?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions