Skip to content

Commit

Permalink
Add South America Soybean DataModule (microsoft#1959)
Browse files Browse the repository at this point in the history
* Add South America Soybean DataModule

* Add train_aug

* Regenerate data

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
cookie-kyu and adamjstewart authored Mar 25, 2024
1 parent 2849944 commit 5d253c5
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Sentinel

.. autoclass:: Sentinel2CDLDataModule
.. autoclass:: Sentinel2NCCMDataModule
.. autoclass:: Sentinel2SouthAmericaSoybeanDataModule

Non-geospatial DataModules
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions tests/conf/sentinel2_south_america_soybean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "deeplabv3+"
backbone: "resnet18"
in_channels: 13
num_classes: 2
num_filters: 1
data:
class_path: Sentinel2SouthAmericaSoybeanDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
south_america_soybean_paths: "tests/data/south_america_soybean"
sentinel2_paths: "tests/data/sentinel2"
Binary file modified tests/data/south_america_soybean/SouthAmericaSoybean.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 3 additions & 10 deletions tests/data/south_america_soybean/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 32
SIZE = 128


np.random.seed(0)
Expand All @@ -24,15 +24,8 @@ def create_file(path: str, dtype: str):
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
0.0002499999999999943131,
0.0,
-82.0005000000000024,
0.0,
-0.0002499999999999943131,
0.0005000000000000,
),
"crs": CRS.from_epsg(32616),
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_reduced",
"sentinel2_cdl",
"sentinel2_nccm",
"sentinel2_south_america_soybean",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .sentinel2_nccm import Sentinel2NCCMDataModule
from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
Expand All @@ -53,6 +54,7 @@
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
"Sentinel2NCCMDataModule",
"Sentinel2SouthAmericaSoybeanDataModule",
# NonGeoDataset
"BigEarthNetDataModule",
"ChaBuDDataModule",
Expand Down
123 changes: 123 additions & 0 deletions torchgeo/datamodules/sentinel2_south_america_soybean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.


"""South America Soybean datamodule."""


from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule):
"""LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets.
.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2SouthAmericaSoybeanDataModule instance.
Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SouthAmericaSoybean`
(prefix keys with ``south_america_soybean_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
self.south_america_soybean_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
if key.startswith("south_america_soybean_"):
self.south_america_soybean_kwargs[key[22:]] = val
elif key.startswith("sentinel2_"):
self.sentinel2_kwargs[key[10:]] = val

super().__init__(
SouthAmericaSoybean,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.south_america_soybean = SouthAmericaSoybean(
**self.south_america_soybean_kwargs
)
self.dataset = self.sentinel2 & self.south_america_soybean

generator = torch.Generator().manual_seed(1)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)

if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run SouthAmericaSoybean plot method.
Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.south_america_soybean.plot(*args, **kwargs)

0 comments on commit 5d253c5

Please sign in to comment.