Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: support binary, multiclass, and multilabel tasks #2219

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
model:
class_path: MultiLabelClassificationTask
class_path: ClassificationTask
init_args:
loss: "bce"
model: "resnet18"
in_channels: 14
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
5 changes: 3 additions & 2 deletions tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
model:
class_path: MultiLabelClassificationTask
class_path: ClassificationTask
init_args:
loss: "bce"
model: "resnet18"
in_channels: 2
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
5 changes: 3 additions & 2 deletions tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
model:
class_path: MultiLabelClassificationTask
class_path: ClassificationTask
init_args:
loss: "bce"
model: "resnet18"
in_channels: 12
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
109 changes: 9 additions & 100 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import (
BigEarthNetDataModule,
EuroSATDataModule,
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError
from torchgeo.datamodules import EuroSATDataModule, MisconfigurationException
from torchgeo.datasets import EuroSAT, RGBBandsMissingError
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
from torchgeo.trainers import ClassificationTask


class ClassificationTestModel(Module):
Expand All @@ -47,11 +43,6 @@ def setup(self, stage: str) -> None:
self.predict_dataset = EuroSAT(split='test', **self.kwargs)


class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule):
def setup(self, stage: str) -> None:
self.predict_dataset = BigEarthNet(split='test', **self.kwargs)


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)

Expand All @@ -73,6 +64,9 @@ class TestClassificationTask:
@pytest.mark.parametrize(
'name',
[
'bigearthnet_all',
'bigearthnet_s1',
'bigearthnet_s2',
'eurosat',
'eurosat100',
'eurosatspatial',
Expand Down Expand Up @@ -231,95 +225,10 @@ def test_predict(self, fast_dev_run: bool) -> None:
'model_name', ['resnet18', 'efficientnetv2_s', 'vit_base_patch16_384']
)
def test_freeze_backbone(self, model_name: str) -> None:
model = ClassificationTask(model=model_name, freeze_backbone=True)
model = ClassificationTask(
model=model_name, num_classes=10, freeze_backbone=True
)
assert not all([param.requires_grad for param in model.model.parameters()])
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
)


class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
config = os.path.join('tests', 'conf', name + '.yaml')

monkeypatch.setattr(timm, 'create_model', create_model)

args = [
'--config',
config,
'--trainer.accelerator',
'cpu',
'--trainer.fast_dev_run',
str(fast_dev_run),
'--trainer.max_epochs',
'1',
'--trainer.log_every_n_steps',
'1',
]

main(['fit'] + args)
try:
main(['test'] + args)
except MisconfigurationException:
pass
try:
main(['predict'] + args)
except MisconfigurationException:
pass

def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
10 changes: 7 additions & 3 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def mocked_weights(
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
weights.meta['model'], in_chans=weights.meta['in_chans']
weights.meta['model'], in_chans=weights.meta['in_chans'], num_classes=10
)
torch.save(model.state_dict(), path)
try:
Expand All @@ -154,13 +154,15 @@ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
backbone=mocked_weights.meta['model'],
weights=mocked_weights,
in_channels=mocked_weights.meta['in_chans'],
num_classes=10,
)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
SemanticSegmentationTask(
backbone=mocked_weights.meta['model'],
weights=str(mocked_weights),
in_channels=mocked_weights.meta['in_chans'],
num_classes=10,
)

@pytest.mark.slow
Expand Down Expand Up @@ -227,7 +229,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
)
def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
model = SemanticSegmentationTask(
model=model_name, backbone=backbone, freeze_backbone=True
model=model_name, backbone=backbone, num_classes=10, freeze_backbone=True
)
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
Expand All @@ -242,7 +244,9 @@ def test_freeze_backbone(self, model_name: str, backbone: str) -> None:

@pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+'])
def test_freeze_decoder(self, model_name: str) -> None:
model = SemanticSegmentationTask(model=model_name, freeze_decoder=True)
model = SemanticSegmentationTask(
model=model_name, num_classes=10, freeze_decoder=True
)
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
Expand Down
3 changes: 1 addition & 2 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import BaseTask
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .classification import ClassificationTask
from .detection import ObjectDetectionTask
from .iobench import IOBenchTask
from .moco import MoCoTask
Expand All @@ -16,7 +16,6 @@
__all__ = (
# Supervised
'ClassificationTask',
'MultiLabelClassificationTask',
'ObjectDetectionTask',
'PixelwiseRegressionTask',
'RegressionTask',
Expand Down
Loading
Loading