diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index b6218986070..c21a9ed9a8b 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -1,10 +1,11 @@ model: - class_path: MultiLabelClassificationTask + class_path: ClassificationTask init_args: loss: "bce" model: "resnet18" in_channels: 14 - num_classes: 19 + task: "multilabel" + num_labels: 19 data: class_path: BigEarthNetDataModule init_args: diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 060d45dfd13..ff0d5ed2b76 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -1,10 +1,11 @@ model: - class_path: MultiLabelClassificationTask + class_path: ClassificationTask init_args: loss: "bce" model: "resnet18" in_channels: 2 - num_classes: 19 + task: "multilabel" + num_labels: 19 data: class_path: BigEarthNetDataModule init_args: diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index a06fcc52be8..6e59f88901f 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -1,10 +1,11 @@ model: - class_path: MultiLabelClassificationTask + class_path: ClassificationTask init_args: loss: "bce" model: "resnet18" in_channels: 12 - num_classes: 19 + task: "multilabel" + num_labels: 19 data: class_path: BigEarthNetDataModule init_args: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index cd437f9faed..71b6a44591e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -15,15 +15,11 @@ from torch.nn.modules import Module from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ( - BigEarthNetDataModule, - EuroSATDataModule, - MisconfigurationException, -) -from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError +from torchgeo.datamodules import EuroSATDataModule, MisconfigurationException +from torchgeo.datasets import EuroSAT, RGBBandsMissingError from torchgeo.main import main from torchgeo.models import ResNet18_Weights -from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask +from torchgeo.trainers import ClassificationTask class ClassificationTestModel(Module): @@ -47,11 +43,6 @@ def setup(self, stage: str) -> None: self.predict_dataset = EuroSAT(split='test', **self.kwargs) -class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule): - def setup(self, stage: str) -> None: - self.predict_dataset = BigEarthNet(split='test', **self.kwargs) - - def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) @@ -73,6 +64,9 @@ class TestClassificationTask: @pytest.mark.parametrize( 'name', [ + 'bigearthnet_all', + 'bigearthnet_s1', + 'bigearthnet_s2', 'eurosat', 'eurosat100', 'eurosatspatial', @@ -231,95 +225,10 @@ def test_predict(self, fast_dev_run: bool) -> None: 'model_name', ['resnet18', 'efficientnetv2_s', 'vit_base_patch16_384'] ) def test_freeze_backbone(self, model_name: str) -> None: - model = ClassificationTask(model=model_name, freeze_backbone=True) + model = ClassificationTask( + model=model_name, num_classes=10, freeze_backbone=True + ) assert not all([param.requires_grad for param in model.model.parameters()]) assert all( [param.requires_grad for param in model.model.get_classifier().parameters()] ) - - -class TestMultiLabelClassificationTask: - @pytest.mark.parametrize( - 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2'] - ) - def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool - ) -> None: - config = os.path.join('tests', 'conf', name + '.yaml') - - monkeypatch.setattr(timm, 'create_model', create_model) - - args = [ - '--config', - config, - '--trainer.accelerator', - 'cpu', - '--trainer.fast_dev_run', - str(fast_dev_run), - '--trainer.max_epochs', - '1', - '--trainer.log_every_n_steps', - '1', - ] - - main(['fit'] + args) - try: - main(['test'] + args) - except MisconfigurationException: - pass - try: - main(['predict'] + args) - except MisconfigurationException: - pass - - def test_invalid_loss(self) -> None: - match = "Loss type 'invalid_loss' is not valid." - with pytest.raises(ValueError, match=match): - MultiLabelClassificationTask(model='resnet18', loss='invalid_loss') - - def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot) - datamodule = BigEarthNetDataModule( - root='tests/data/bigearthnet', batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask( - model='resnet18', in_channels=14, num_classes=19, loss='bce' - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: - monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands) - datamodule = BigEarthNetDataModule( - root='tests/data/bigearthnet', batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask( - model='resnet18', in_channels=14, num_classes=19, loss='bce' - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.validate(model=model, datamodule=datamodule) - - def test_predict(self, fast_dev_run: bool) -> None: - datamodule = PredictMultiLabelClassificationDataModule( - root='tests/data/bigearthnet', batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask( - model='resnet18', in_channels=14, num_classes=19, loss='bce' - ) - trainer = Trainer( - accelerator='cpu', - fast_dev_run=fast_dev_run, - log_every_n_steps=1, - max_epochs=1, - ) - trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d8b207d5d2d..ad70c80d306 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -136,7 +136,7 @@ def mocked_weights( ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( - weights.meta['model'], in_chans=weights.meta['in_chans'] + weights.meta['model'], in_chans=weights.meta['in_chans'], num_classes=10 ) torch.save(model.state_dict(), path) try: @@ -154,6 +154,7 @@ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: backbone=mocked_weights.meta['model'], weights=mocked_weights, in_channels=mocked_weights.meta['in_chans'], + num_classes=10, ) def test_weight_str(self, mocked_weights: WeightsEnum) -> None: @@ -161,6 +162,7 @@ def test_weight_str(self, mocked_weights: WeightsEnum) -> None: backbone=mocked_weights.meta['model'], weights=str(mocked_weights), in_channels=mocked_weights.meta['in_chans'], + num_classes=10, ) @pytest.mark.slow @@ -227,7 +229,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: ) def test_freeze_backbone(self, model_name: str, backbone: str) -> None: model = SemanticSegmentationTask( - model=model_name, backbone=backbone, freeze_backbone=True + model=model_name, backbone=backbone, num_classes=10, freeze_backbone=True ) assert all( [param.requires_grad is False for param in model.model.encoder.parameters()] @@ -242,7 +244,9 @@ def test_freeze_backbone(self, model_name: str, backbone: str) -> None: @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) def test_freeze_decoder(self, model_name: str) -> None: - model = SemanticSegmentationTask(model=model_name, freeze_decoder=True) + model = SemanticSegmentationTask( + model=model_name, num_classes=10, freeze_decoder=True + ) assert all( [param.requires_grad is False for param in model.model.decoder.parameters()] ) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index be4fb4a03db..1b7842fd912 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -5,7 +5,7 @@ from .base import BaseTask from .byol import BYOLTask -from .classification import ClassificationTask, MultiLabelClassificationTask +from .classification import ClassificationTask from .detection import ObjectDetectionTask from .iobench import IOBenchTask from .moco import MoCoTask @@ -16,7 +16,6 @@ __all__ = ( # Supervised 'ClassificationTask', - 'MultiLabelClassificationTask', 'ObjectDetectionTask', 'PixelwiseRegressionTask', 'RegressionTask', diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index cc293099519..d9290290b0a 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -8,19 +8,12 @@ import matplotlib.pyplot as plt import timm -import torch import torch.nn as nn from matplotlib.figure import Figure from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss from torch import Tensor from torchmetrics import MetricCollection -from torchmetrics.classification import ( - MulticlassAccuracy, - MulticlassFBetaScore, - MulticlassJaccardIndex, - MultilabelAccuracy, - MultilabelFBetaScore, -) +from torchmetrics.classification import Accuracy, FBetaScore, JaccardIndex from torchvision.models._api import WeightsEnum from ..datasets import RGBBandsMissingError, unbind_samples @@ -37,7 +30,9 @@ def __init__( model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, - num_classes: int = 1000, + task: str = 'multiclass', + num_classes: int | None = None, + num_labels: int | None = None, loss: str = 'ce', class_weights: Tensor | None = None, lr: float = 1e-3, @@ -53,7 +48,9 @@ def __init__( representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + task: One of 'binary', 'multiclass', or 'multilabel'. + num_classes: Number of prediction classes (only for ``task='multiclass'``). + num_labels: Number of prediction labels (only for ``task='multilabel'``). loss: One of 'ce', 'bce', 'jaccard', or 'focal'. class_weights: Optional rescaling weight given to each class and used with 'ce' loss. @@ -62,8 +59,8 @@ class and used with 'ce' loss. freeze_backbone: Freeze the backbone network to linear probe the classifier head. - .. versionchanged:: 0.4 - *classification_model* was renamed to *model*. + .. versionadded:: 0.6 + The *task* and *num_labels* parameters. .. versionadded:: 0.5 The *class_weights* and *freeze_backbone* parameters. @@ -71,6 +68,9 @@ class and used with 'ce' loss. .. versionchanged:: 0.5 *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. + + .. versionchanged:: 0.4 + *classification_model* was renamed to *model*. """ self.weights = weights super().__init__(ignore='weights') @@ -82,7 +82,7 @@ def configure_models(self) -> None: # Create model self.model = timm.create_model( self.hparams['model'], - num_classes=self.hparams['num_classes'], + num_classes=self.hparams['num_classes'] or self.hparams['num_labels'] or 2, in_chans=self.hparams['in_channels'], pretrained=weights is True, ) @@ -127,13 +127,13 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * :class:`~torchmetrics.classification.MulticlassAccuracy`: The number of + * :class:`~torchmetrics.Accuracy`: The number of true positives divided by the dataset size. Both overall accuracy (OA) using 'micro' averaging and average accuracy (AA) using 'macro' averaging are reported. Higher values are better. - * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + * :class:`~torchmetrics.JaccardIndex`: Intersection over union (IoU). Uses 'macro' averaging. Higher valuers are better. - * :class:`~torchmetrics.classification.MulticlassFBetaScore`: F1 score. + * :class:`~torchmetrics.FBetaScore`: F1 score. The harmonic mean of precision and recall. Uses 'micro' averaging. Higher values are better. @@ -143,20 +143,17 @@ def configure_metrics(self) -> None: * 'Macro' averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes. """ + kwargs = { + 'task': self.hparams['task'], + 'num_classes': self.hparams['num_classes'], + 'num_labels': self.hparams['num_labels'], + } metrics = MetricCollection( { - 'OverallAccuracy': MulticlassAccuracy( - num_classes=self.hparams['num_classes'], average='micro' - ), - 'AverageAccuracy': MulticlassAccuracy( - num_classes=self.hparams['num_classes'], average='macro' - ), - 'JaccardIndex': MulticlassJaccardIndex( - num_classes=self.hparams['num_classes'] - ), - 'F1Score': MulticlassFBetaScore( - num_classes=self.hparams['num_classes'], beta=1.0, average='micro' - ), + 'OverallAccuracy': Accuracy(average='micro', **kwargs), + 'AverageAccuracy': Accuracy(average='macro', **kwargs), + 'JaccardIndex': JaccardIndex(**kwargs), + 'F1Score': FBetaScore(beta=1.0, average='micro', **kwargs), } ) self.train_metrics = metrics.clone(prefix='train_') @@ -266,147 +263,3 @@ def predict_step( x = batch['image'] y_hat: Tensor = self(x).softmax(dim=-1) return y_hat - - -class MultiLabelClassificationTask(ClassificationTask): - """Multi-label image classification.""" - - def configure_metrics(self) -> None: - """Initialize the performance metrics. - - * :class:`~torchmetrics.classification.MultilabelAccuracy`: The number of - true positives divided by the dataset size. Both overall accuracy (OA) - using 'micro' averaging and average accuracy (AA) using 'macro' averaging - are reported. Higher values are better. - * :class:`~torchmetrics.classification.MultilabelFBetaScore`: F1 score. - The harmonic mean of precision and recall. Uses 'micro' averaging. - Higher values are better. - - .. note:: - * 'Micro' averaging suits overall performance evaluation but may not - reflect minority class accuracy. - * 'Macro' averaging gives equal weight to each class, and is useful for - balanced performance assessment across imbalanced classes. - """ - metrics = MetricCollection( - { - 'OverallAccuracy': MultilabelAccuracy( - num_labels=self.hparams['num_classes'], average='micro' - ), - 'AverageAccuracy': MultilabelAccuracy( - num_labels=self.hparams['num_classes'], average='macro' - ), - 'F1Score': MultilabelFBetaScore( - num_labels=self.hparams['num_classes'], beta=1.0, average='micro' - ), - } - ) - self.train_metrics = metrics.clone(prefix='train_') - self.val_metrics = metrics.clone(prefix='val_') - self.test_metrics = metrics.clone(prefix='test_') - - def training_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> Tensor: - """Compute the training loss and additional metrics. - - Args: - batch: The output of your DataLoader. - batch_idx: Integer displaying index of this batch. - dataloader_idx: Index of the current dataloader. - - Returns: - The loss tensor. - """ - x = batch['image'] - y = batch['label'] - batch_size = x.shape[0] - y_hat = self(x) - y_hat_hard = torch.sigmoid(y_hat) - loss: Tensor = self.criterion(y_hat, y.to(torch.float)) - self.log('train_loss', loss, batch_size=batch_size) - self.train_metrics(y_hat_hard, y) - self.log_dict(self.train_metrics) - - return loss - - def validation_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> None: - """Compute the validation loss and additional metrics. - - Args: - batch: The output of your DataLoader. - batch_idx: Integer displaying index of this batch. - dataloader_idx: Index of the current dataloader. - """ - x = batch['image'] - y = batch['label'] - batch_size = x.shape[0] - y_hat = self(x) - y_hat_hard = torch.sigmoid(y_hat) - loss = self.criterion(y_hat, y.to(torch.float)) - self.log('val_loss', loss, batch_size=batch_size) - self.val_metrics(y_hat_hard, y) - self.log_dict(self.val_metrics, batch_size=batch_size) - - if ( - batch_idx < 10 - and hasattr(self.trainer, 'datamodule') - and hasattr(self.trainer.datamodule, 'plot') - and self.logger - and hasattr(self.logger, 'experiment') - and hasattr(self.logger.experiment, 'add_figure') - ): - datamodule = self.trainer.datamodule - batch['prediction'] = y_hat_hard - for key in ['image', 'label', 'prediction']: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - - fig: Figure | None = None - try: - fig = datamodule.plot(sample) - except RGBBandsMissingError: - pass - - if fig: - summary_writer = self.logger.experiment - summary_writer.add_figure( - f'image/{batch_idx}', fig, global_step=self.global_step - ) - - def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - """Compute the test loss and additional metrics. - - Args: - batch: The output of your DataLoader. - batch_idx: Integer displaying index of this batch. - dataloader_idx: Index of the current dataloader. - """ - x = batch['image'] - y = batch['label'] - batch_size = x.shape[0] - y_hat = self(x) - y_hat_hard = torch.sigmoid(y_hat) - loss = self.criterion(y_hat, y.to(torch.float)) - self.log('test_loss', loss, batch_size=batch_size) - self.test_metrics(y_hat_hard, y) - self.log_dict(self.test_metrics, batch_size=batch_size) - - def predict_step( - self, batch: Any, batch_idx: int, dataloader_idx: int = 0 - ) -> Tensor: - """Compute the predicted class probabilities. - - Args: - batch: The output of your DataLoader. - batch_idx: Integer displaying index of this batch. - dataloader_idx: Index of the current dataloader. - - Returns: - Output predicted probabilities. - """ - x = batch['image'] - y_hat = torch.sigmoid(self(x)) - return y_hat diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index afd71521002..eddd2e57ce4 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -11,8 +11,7 @@ import torch.nn as nn from matplotlib.figure import Figure from torch import Tensor -from torchmetrics import MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex +from torchmetrics import Accuracy, JaccardIndex, MetricCollection from torchvision.models._api import WeightsEnum from ..datasets import RGBBandsMissingError, unbind_samples @@ -30,7 +29,9 @@ def __init__( backbone: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, - num_classes: int = 1000, + task: str = 'multiclass', + num_classes: int | None = None, + num_labels: int | None = None, num_filters: int = 3, loss: str = 'ce', class_weights: Tensor | None = None, @@ -54,7 +55,9 @@ def __init__( model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. - num_classes: Number of prediction classes (including the background). + task: One of 'binary', 'multiclass', or 'multilabel'. + num_classes: Number of prediction classes (only for ``task='multiclass'``). + num_labels: Number of prediction labels (only for ``task='multilabel'``). num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss. @@ -69,23 +72,26 @@ class and used with 'ce' loss. freeze_decoder: Freeze the decoder network to linear probe the segmentation head. - .. versionchanged:: 0.3 - *ignore_zeros* was renamed to *ignore_index*. + .. versionadded:: 0.6 + The *task* and *num_labels* parameters. - .. versionchanged:: 0.4 - *segmentation_model*, *encoder_name*, and *encoder_weights* - were renamed to *model*, *backbone*, and *weights*. + .. versionchanged:: 0.6 + The *ignore_index* parameter now works for jaccard loss. .. versionadded:: 0.5 - The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. + The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. .. versionchanged:: 0.5 The *weights* parameter now supports WeightEnums and checkpoint paths. *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. - .. versionchanged:: 0.6 - The *ignore_index* parameter now works for jaccard loss. + .. versionchanged:: 0.4 + *segmentation_model*, *encoder_name*, and *encoder_weights* + were renamed to *model*, *backbone*, and *weights*. + + .. versionchanged:: 0.3 + *ignore_zeros* was renamed to *ignore_index*. """ self.weights = weights super().__init__(ignore='weights') @@ -100,7 +106,9 @@ def configure_models(self) -> None: backbone: str = self.hparams['backbone'] weights = self.weights in_channels: int = self.hparams['in_channels'] - num_classes: int = self.hparams['num_classes'] + num_classes: int = ( + self.hparams['num_classes'] or self.hparams['num_labels'] or 2 + ) num_filters: int = self.hparams['num_filters'] if model == 'unet': @@ -181,10 +189,10 @@ def configure_losses(self) -> None: def configure_metrics(self) -> None: """Initialize the performance metrics. - * :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy + * :class:`~torchmetrics.Accuracy`: Overall accuracy (OA) using 'micro' averaging. The number of true positives divided by the dataset size. Higher values are better. - * :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection + * :class:`~torchmetrics.JaccardIndex`: Intersection over union (IoU). Uses 'micro' averaging. Higher valuers are better. .. note:: @@ -193,19 +201,16 @@ def configure_metrics(self) -> None: * 'Macro' averaging, not used here, gives equal weight to each class, useful for balanced performance assessment across imbalanced classes. """ - num_classes: int = self.hparams['num_classes'] - ignore_index: int | None = self.hparams['ignore_index'] + kwargs = { + 'task': self.hparams['task'], + 'num_classes': self.hparams['num_classes'], + 'num_labels': self.hparams['num_labels'], + 'ignore_index': self.hparams['ignore_index'], + } metrics = MetricCollection( [ - MulticlassAccuracy( - num_classes=num_classes, - ignore_index=ignore_index, - multidim_average='global', - average='micro', - ), - MulticlassJaccardIndex( - num_classes=num_classes, ignore_index=ignore_index, average='micro' - ), + Accuracy(multidim_average='global', average='micro', **kwargs), + JaccardIndex(average='micro', **kwargs), ] ) self.train_metrics = metrics.clone(prefix='train_')