-
Notifications
You must be signed in to change notification settings - Fork 465
Open
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
The gpu memory usage keeps growing when using AveragePrecision. This doesn't happen with other metrics (Accuracy and Precision).
Is this the expected behaviour for this metric?
To Reproduce
Code sample
from torchmetrics import Accuracy, Precision, AveragePrecision
import torch
def get_memory():
return torch.cuda.memory_allocated(0) / (1024**2)
def loop_metric(metric, repetitions=20, loop_length=100):
print(type(metric).__name__)
print(f"{get_memory():.3f}MB", end="")
for i in range(repetitions):
for j in range(loop_length):
metric(pred, target)
print(f" -> {get_memory():.3f}MB", end="")
print("\n")
N_LABELS = 16
avg_p = AveragePrecision(task="multilabel", num_labels=N_LABELS).to("cuda")
prec = Precision(task="multilabel", num_labels=N_LABELS).to("cuda")
acc = Accuracy(task="multilabel", num_labels=N_LABELS).to("cuda")
BATCH_SIZE = 2056
target = torch.randint(size=(BATCH_SIZE, N_LABELS), low=0, high=2, device="cuda")
pred = torch.rand(size=(BATCH_SIZE, N_LABELS), device="cuda")
print()
loop_metric(acc)
loop_metric(prec)
loop_metric(avg_p)Environment
- TorchMetrics: 1.8.2
- Python: 3.13.7
- PyTorch: 2.9.0+cu126
- Happens in Linux and Windows 10
Additional context
I ran into a OOM error when using AveragePrecision during model training.
Using gc.collect() and torch.cuda.empty_cache() wasn't able to prevent it, and I tracked the issue down to this metric.
Is there some function that needs to be called to free this memory?
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed