Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about the training code #6

Open
AZZMM opened this issue May 5, 2024 · 13 comments
Open

about the training code #6

AZZMM opened this issue May 5, 2024 · 13 comments

Comments

@AZZMM
Copy link

AZZMM commented May 5, 2024

Hello, thanks for the excellent work, may I ask when the training code will be released, if it will be released soon, thank you!

@limuloo
Copy link
Owner

limuloo commented May 6, 2024

@AZZMM
Thank you for your interest in our work.
We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.

@AZZMM
Copy link
Author

AZZMM commented May 6, 2024

Thank you for quick response, look forward to the updates

@AZZMM
Copy link
Author

AZZMM commented May 15, 2024

Excuse me, I've implemented the training code and I am confused about the details of the inhibition loss. I learn from the paper that the loss is based on the attention map in the frozen cross attention layer with shape(batch_size*head_dim, HW, text_len). May I ask how to deal with the text_len dimension and the head_dim?

I sum along the text_len dimension and get 1. I think it is because the softmax().

My implementation is like this:
attn_map = attn_map.reshape(batch_size//head_size, head_size, HW, text_len)[2:, ...] # batch_size: instance_num+2
attn_map = attn_map.permute(0, 2, 1, 3).reshape(-1, HW, seq_len * head_size)
avg = torch.where(background_masks, attn_map, 0).sum(dim=1) / background_masks.sum(dim=1)
# avg shape: (instance_num, text_len*head_dim)
ihbt_loss = (torch.abs(attn_map - avg[..., None, :]) * background_masks).sum() / background_masks.sum()
Is this right?

Looking forward to your reply, thank you!

@limuloo
Copy link
Owner

limuloo commented May 15, 2024

pre_attn = fuser_info['pre_attn'] # (BPN, heads, HW, 77)
BPN, heads, HW, _ = pre_attn.shape
pre_attn = torch.sum(pre_attn[:, :, :, 1:], dim=-1) # (B
PN, heads, HW)
H = W = int(math.sqrt(HW))
pre_attn = pre_attn.view(bsz, -1, HW) # (B, PN*heads, HW)

supplement_mask_inter = F.interpolate(supplement_mask, (H, W), mode=args.inter_mode) # supplement_mask is the mask of BG
supplement_mask_inter = supplement_mask_inter[:, 0, ...].view(bsz, 1, HW) # (B, 1, HW)

pre_attn_mean = (pre_attn * supplement_mask_inter).sum(dim=-1) /
((supplement_mask_inter).sum(dim=-1) + 1e-6) # (B, PN*heads)
aug_scale = 1
now_pre_attn_loss = (abs(pre_attn - pre_attn_mean[..., None].detach()) * supplement_mask_inter).sum(dim=-1) /
(supplement_mask_inter.sum(dim=-1) + 1e-6)
now_pre_attn_loss = (now_pre_attn_loss * aug_scale).mean()
pre_attn_loss = pre_attn_loss + now_pre_attn_loss
pre_attn_loss_cnt = pre_attn_loss_cnt + 1

@AZZMM You can refer to this code to implement inhibition loss. If you still have questions, you can ask me here.

@AZZMM
Copy link
Author

AZZMM commented May 15, 2024

Thank you very much for your quick reply and the code is really helpful to me.
May I ask what the BPN represents. Does it contain the negetive and global prompt attntion maps?
And there are three attention maps in the 16*16 frozen cross attention layer. Is the final loss adding up the three results?

@limuloo
Copy link
Owner

limuloo commented May 15, 2024

@AZZMM In the training, we don't need negative prompts. BPN means Batch * Phase_num, the Phase_num contain {global prompt, instance1_desc, instance2_desc, ..., instanceN_desc}. We use the first two 16*16 attn-maps for calculating inhibition loss.

@AZZMM
Copy link
Author

AZZMM commented May 16, 2024

Thank you very much for your help! I'll try it.

@AZZMM
Copy link
Author

AZZMM commented May 19, 2024

Hello! @limuloo There is one detail that I am not very sure about the cross attention layer without migc. Is it use vanilla cross attention or naive fuser during training?

@limuloo
Copy link
Owner

limuloo commented May 19, 2024

@AZZMM vanilla cross attention

@AZZMM
Copy link
Author

AZZMM commented May 20, 2024

I see, Thank you!

@WUyinwei-hah
Copy link

@AZZMM Thank you for your interest in our work. We are recently expanding MIGC to MIGC++ with more comprehensive functions. We expect to submit a report to arxiv in May 2024. We plan to release the training code together at that time.

Hi, may I ask when will the MIGC++ be released?

@limuloo
Copy link
Owner

limuloo commented Jun 28, 2024

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

@WUyinwei-hah
Copy link

@WUyinwei-hah We have already completed the writing of the MIGC++ paper, and we will submit it in the next few days. Then, we will proceed to consider the open-source work for MIGC++.

Looking forward to it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants