Skip to content

Commit 80dbb88

Browse files
yingufanmeta-codesync[bot]
authored andcommitted
Fix allow_missing_label_with_zero_weight for NE metric (#3465)
Summary: Pull Request resolved: #3465 This fixes two issues when enabling allow_missing_label_with_zero_weight for NE metric: 1. The returned tensor should have the same shape even if some weights are zero, to prevent errors in downstream processing. 1. Correct NE should be returned for the tasks with non-zero weights. Reviewed By: iamzainhuda Differential Revision: D84975421 fbshipit-source-id: aaf217d497a2b0857f4ea438f2e3f37abbe2b4c9
1 parent af25076 commit 80dbb88

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

torchrec/metrics/ne.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,14 @@ def compute_ne(
5555
eta: float,
5656
allow_missing_label_with_zero_weight: bool = False,
5757
) -> torch.Tensor:
58-
if allow_missing_label_with_zero_weight and not weighted_num_samples.all():
59-
# If nan were to occur, return a dummy value instead of nan if
60-
# allow_missing_label_with_zero_weight is True
61-
return torch.tensor([eta])
62-
63-
# Goes into this block if all elements in weighted_num_samples > 0
64-
weighted_num_samples = weighted_num_samples.double().clamp(min=eta)
65-
mean_label = pos_labels / weighted_num_samples
58+
clamped_weighted_num_samples = weighted_num_samples.double().clamp(min=eta)
59+
mean_label = pos_labels / clamped_weighted_num_samples
6660
ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta)
67-
return ce_sum / ce_norm
61+
ne = ce_sum / ce_norm
62+
if allow_missing_label_with_zero_weight and not weighted_num_samples.all():
63+
# If inf were to occur, return a dummy value instead.
64+
return torch.where(weighted_num_samples > 0, ne, eta)
65+
return ne
6866

6967

7068
def compute_logloss(

torchrec/metrics/tests/test_ne.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,20 @@ def test_ne_zero_weights(self) -> None:
194194
zero_weights=True,
195195
)
196196

197+
def test_ne_allow_missing_label_with_zero_weight(self) -> None:
198+
eta = 1e-12
199+
ne = compute_ne(
200+
ce_sum=torch.rand(3),
201+
weighted_num_samples=torch.tensor([3, 0, 2]),
202+
pos_labels=torch.tensor([1, 0, 2]),
203+
neg_labels=torch.tensor([2, 0, 0]),
204+
eta=eta,
205+
allow_missing_label_with_zero_weight=True,
206+
)
207+
self.assertTrue(torch.all(~ne.isinf()))
208+
self.assertTrue(torch.all(~ne.isnan()))
209+
self.assertTrue(torch.equal(ne.eq(eta), torch.tensor([False, True, False])))
210+
197211
_logloss_metric_test_helper: Callable[..., None] = partial(
198212
metric_test_helper, include_logloss=True
199213
)

0 commit comments

Comments
 (0)