Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: support binary, multiclass, and multilabel tasks #2219

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/conf/bigearthnet_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ model:
loss: "bce"
model: "resnet18"
in_channels: 14
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
3 changes: 2 additions & 1 deletion tests/conf/bigearthnet_s1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ model:
loss: "bce"
model: "resnet18"
in_channels: 2
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
3 changes: 2 additions & 1 deletion tests/conf/bigearthnet_s2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ model:
loss: "bce"
model: "resnet18"
in_channels: 12
num_classes: 19
task: "multilabel"
num_labels: 19
data:
class_path: BigEarthNetDataModule
init_args:
Expand Down
105 changes: 6 additions & 99 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import (
BigEarthNetDataModule,
EuroSATDataModule,
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError
from torchgeo.datamodules import EuroSATDataModule, MisconfigurationException
from torchgeo.datasets import EuroSAT, RGBBandsMissingError
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
from torchgeo.trainers import ClassificationTask


class ClassificationTestModel(Module):
Expand All @@ -47,11 +43,6 @@ def setup(self, stage: str) -> None:
self.predict_dataset = EuroSAT(split='test', **self.kwargs)


class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule):
def setup(self, stage: str) -> None:
self.predict_dataset = BigEarthNet(split='test', **self.kwargs)


def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)

Expand All @@ -73,6 +64,9 @@ class TestClassificationTask:
@pytest.mark.parametrize(
'name',
[
'bigearthnet_all',
'bigearthnet_s1',
'bigearthnet_s2',
'eurosat',
'eurosat100',
'eurosatspatial',
Expand Down Expand Up @@ -236,90 +230,3 @@ def test_freeze_backbone(self, model_name: str) -> None:
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
)


class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
config = os.path.join('tests', 'conf', name + '.yaml')

monkeypatch.setattr(timm, 'create_model', create_model)

args = [
'--config',
config,
'--trainer.accelerator',
'cpu',
'--trainer.fast_dev_run',
str(fast_dev_run),
'--trainer.max_epochs',
'1',
'--trainer.log_every_n_steps',
'1',
]

main(['fit'] + args)
try:
main(['test'] + args)
except MisconfigurationException:
pass
try:
main(['predict'] + args)
except MisconfigurationException:
pass

def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)
3 changes: 1 addition & 2 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import BaseTask
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .classification import ClassificationTask
from .detection import ObjectDetectionTask
from .iobench import IOBenchTask
from .moco import MoCoTask
Expand All @@ -16,7 +16,6 @@
__all__ = (
# Supervised
'ClassificationTask',
'MultiLabelClassificationTask',
'ObjectDetectionTask',
'PixelwiseRegressionTask',
'RegressionTask',
Expand Down
Loading
Loading