Skip to content

Commit

Permalink
BaseTask: explicitly pass batch size to logger
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Mar 4, 2024
1 parent b8f2a2b commit 7963a34
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 28 deletions.
3 changes: 2 additions & 1 deletion torchgeo/trainers/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def training_step(
AssertionError: If channel dimensions are incorrect.
"""
x = batch["image"]
batch_size = x.shape[0]

in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
Expand All @@ -387,7 +388,7 @@ def training_step(

loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

self.log("train_loss", loss)
self.log("train_loss", loss, batch_size=batch_size)
self.model.update_target()

return loss
Expand Down
28 changes: 17 additions & 11 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ def training_step(
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss)
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics)
self.log_dict(self.train_metrics, batch_size=batch_size)

return loss

Expand All @@ -198,11 +199,12 @@ def validation_step(
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics)
self.log_dict(self.val_metrics, batch_size=batch_size)

if (
batch_idx < 10
Expand Down Expand Up @@ -241,11 +243,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics)
self.log_dict(self.test_metrics, batch_size=batch_size)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down Expand Up @@ -317,10 +320,11 @@ def training_step(
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss: Tensor = self.criterion(y_hat, y.to(torch.float))
self.log("train_loss", loss)
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat_hard, y)
self.log_dict(self.train_metrics)

Expand All @@ -338,12 +342,13 @@ def validation_step(
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss = self.criterion(y_hat, y.to(torch.float))
self.log("val_loss", loss)
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat_hard, y)
self.log_dict(self.val_metrics)
self.log_dict(self.val_metrics, batch_size=batch_size)

if (
batch_idx < 10
Expand Down Expand Up @@ -381,12 +386,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
"""
x = batch["image"]
y = batch["label"]
batch_size = x.shape[0]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
loss = self.criterion(y_hat, y.to(torch.float))
self.log("test_loss", loss)
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat_hard, y)
self.log_dict(self.test_metrics)
self.log_dict(self.test_metrics, batch_size=batch_size)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def training_step(
The loss tensor.
"""
x = batch["image"]
batch_size = x.shape[0]

in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
Expand Down Expand Up @@ -420,8 +421,8 @@ def training_step(
output_std = torch.mean(output_std, dim=0)
self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item()

self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)
self.log("train_ssl_std", self.avg_output_std, batch_size=batch_size)
self.log("train_loss", loss, batch_size=batch_size)

return loss

Expand Down
15 changes: 9 additions & 6 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,16 @@ def training_step(
The loss tensor.
"""
x = batch["image"]
batch_size = x.shape[0]
# TODO: remove .to(...) once we have a real pixelwise regression dataset
y = batch[self.target_key].to(torch.float)
y_hat = self(x)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss)
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics)
self.log_dict(self.train_metrics, batch_size=batch_size)

return loss

Expand All @@ -182,15 +183,16 @@ def validation_step(
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
batch_size = x.shape[0]
# TODO: remove .to(...) once we have a real pixelwise regression dataset
y = batch[self.target_key].to(torch.float)
y_hat = self(x)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics)
self.log_dict(self.val_metrics, batch_size=batch_size)

if (
batch_idx < 10
Expand Down Expand Up @@ -231,15 +233,16 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
dataloader_idx: Index of the current dataloader.
"""
x = batch["image"]
batch_size = x.shape[0]
# TODO: remove .to(...) once we have a real pixelwise regression dataset
y = batch[self.target_key].to(torch.float)
y_hat = self(x)
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics)
self.log_dict(self.test_metrics, batch_size=batch_size)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down
15 changes: 9 additions & 6 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,12 @@ def training_step(
"""
x = batch["image"]
y = batch["mask"]
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log("train_loss", loss)
self.log("train_loss", loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics)
self.log_dict(self.train_metrics, batch_size=batch_size)
return loss

def validation_step(
Expand All @@ -246,11 +247,12 @@ def validation_step(
"""
x = batch["image"]
y = batch["mask"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
self.log("val_loss", loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics)
self.log_dict(self.val_metrics, batch_size=batch_size)

if (
batch_idx < 10
Expand Down Expand Up @@ -289,11 +291,12 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
"""
x = batch["image"]
y = batch["mask"]
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
self.log("test_loss", loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics)
self.log_dict(self.test_metrics, batch_size=batch_size)

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down
5 changes: 3 additions & 2 deletions torchgeo/trainers/simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def training_step(
AssertionError: If channel dimensions are incorrect.
"""
x = batch["image"]
batch_size = x.shape[0]

in_channels: int = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
Expand Down Expand Up @@ -250,8 +251,8 @@ def training_step(
output_std = torch.mean(output_std, dim=0)
self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item()

self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)
self.log("train_ssl_std", self.avg_output_std, batch_size=batch_size)
self.log("train_loss", loss, batch_size=batch_size)

return loss

Expand Down

0 comments on commit 7963a34

Please sign in to comment.