Skip to content

Commit f4d8463

Browse files
committed
fix: typecast bf16 to float
1 parent 671c793 commit f4d8463

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/evals/metrics/mia/min_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def compute_batch_values(self, batch):
1717

1818
def compute_score(self, sample_stats):
1919
"""Score single sample using min-k negative log probs scores attack."""
20-
lp = sample_stats.cpu().numpy()
20+
lp = sample_stats.float().cpu().numpy()
2121
if lp.size == 0:
2222
return 0
2323

src/evals/metrics/mia/min_k_plus_plus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def compute_score(self, sample_stats):
3030

3131
# Handle numerical stability
3232
sigma = torch.clamp(sigma, min=1e-6)
33-
scores = (target_prob.cpu().numpy() - mu.cpu().numpy()) / torch.sqrt(
33+
scores = (target_prob.float().cpu().numpy() - mu.float().cpu().numpy()) / torch.sqrt(
3434
sigma
3535
).cpu().numpy()
3636

0 commit comments

Comments
 (0)