-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
log_dict
to support ClasswiseWrapper
#2683
Comments
log_dict
to support ClasswiseWrapper
Hi @robmarkcole, thanks for raising this issue, I finally had some time to tackle it. I have some bad and good news. The bad news is that I do not think it will be possible to directly integrate Why integrating
|
class TestModel(BoringModel): | |
def __init__(self) -> None: | |
super().__init__() | |
self.train_metrics = MetricCollection( | |
{ | |
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"), | |
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)), | |
}, | |
prefix="train_", | |
) | |
self.val_metrics = MetricCollection( | |
{ | |
"macro_accuracy": MulticlassAccuracy(num_classes=5, average="macro"), | |
"classwise_accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=5, average=None)), | |
}, | |
prefix="val_", | |
) | |
def training_step(self, batch, batch_idx): | |
loss = self(batch).sum() | |
preds = torch.randint(0, 5, (100,), device=batch.device) | |
target = torch.randint(0, 5, (100,), device=batch.device) | |
self.train_metrics.update(preds, target) | |
batch_values = self.train_metrics.compute() | |
self.log_dict(batch_values, on_step=True, on_epoch=False) | |
return {"loss": loss} | |
def validation_step(self, batch, batch_idx): | |
preds = torch.randint(0, 5, (100,), device=batch.device) | |
target = torch.randint(0, 5, (100,), device=batch.device) | |
self.val_metrics.update(preds, target) | |
def on_validation_epoch_end(self): | |
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) |
for an example of how this manual logging would look like in Pytorch Lightning using
MetricCollection
and ClasswiseWrapper
.
I hope this is enough feedback on this issue. Else feel free to reopen the issue.
@SkafteNicki thanks for the explanation! def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.val_metrics.update(preds, target)
self.log_dict(self.val_metrics.compute(), on_step=False, on_epoch=True) # Only called on_epoch end, and reset is called after |
@robmarkcole i am pretty sure you will need the The correct way is therefore (for manually logging): def validation_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)
self.val_metrics.update(preds, target)
def on_validation_step_epoch_end(self):
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset() (I actually missed the |
Edit: The docs says this is wrong. |
🚀 Feature
I want to be able to call:
where
my_metrics
is aMetricCollection
that includesClasswiseWrapper
metrics. Currently this will fail with an error like:An example of how to generate metrics that will currently fail from this PR:
The text was updated successfully, but these errors were encountered: