Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about class_loss and matcher #14

Open
Pujin823 opened this issue Mar 17, 2023 · 1 comment
Open

about class_loss and matcher #14

Pujin823 opened this issue Mar 17, 2023 · 1 comment

Comments

@Pujin823
Copy link

大佬好,我在阅读源码过程中,发现在您的losses.py文件中的loss_labels函数中,独热编码好像好像有点问题,因为看sigmoid_focal_loss函数中的要求是target与input必须是相同维度,且值为1表示文本,值为0表示背景。您源码中的input shape应该是
[bs, num_queries, num_pts, 1], 但是如果按照您生成对应gt的独热编码代码,生成的shape是[bs, num_queries, num_pts, 1],但值全部都是0.
我看到源码在初始化target_classess矩阵的时候用的是num_class,也就是1. 同样的疑问在matcher中的计算分类的权重损失矩阵时也存在,在BoxHungarianMatcher()类中,cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids], pos_cost_class.shape = [bs,*num_queries, 1], tgt_ids由于是文本,应该都是1,这里还出现了数组越界的问题。不知道是否是我哪里理解得不正确,还请大佬能解答一下我的疑惑,非常感谢。
1
2
3

@Pujin823
Copy link
Author

因为原始的deformable detr中的num_classes是91,但这里的类别是1。在生成one-hot编码时候,以计算encoder的matcher和分类loss为例,tagets_class = torch.full(src_logits.shape[:-1], self.num_classes, dtype=torch.int64, device=src_logits.device)
targets_classes.shape = [bs, num_queries],且值全部为1,而且后面taget_classes[idx] = target_classes_o,结果还是1
最后的target_class_onehot.scatter_(-1, target_classes.unsqueeze(-1), 1)的shape是[bs, num_query, 2],但只有[bs, num_query, 1]的值为1,最后target_class_onehot = target_classes_onehot[..., -1]的值全是0,没有起到编码作用

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant