Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ PASTIS

.. autoclass:: PASTIS

SubstationDataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be in alphabetical order

^^^^^^^^^^^^^^^^^

.. autoclass:: SubstationDataset

PatternNet
^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI"
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC BY-SA 2.0", 27K, 2, 228x228, 10, MSI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC BY-SA 2.0", 27K, 2, 228x228, 10, MSI
`SubstationDataset`_,S,OpenStreetMap & Sentinel-2, "CC-BY-SA 2.0", 27K, 2, 228x228, 10, MSI

Should be a valid SPDX identifier

`SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI
`Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB
Expand Down
83 changes: 83 additions & 0 deletions tests/data/substation_seg/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np

# Parameters
SIZE = 228 # Image dimensions
NUM_SAMPLES = 5 # Number of samples
np.random.seed(0)

# Define directory hierarchy
FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str]

filenames: FILENAME_HIERARCHY = {'image_stack': ['image'], 'mask': ['mask']}


def create_file(path: str, value: str) -> None:
"""
Generates .npz files for images or masks based on the path.

Args:
- path (str): Base path for saving files (either 'image' or 'mask').
"""
for i in range(NUM_SAMPLES):
new_path = f'{path}_{i}.npz'

if value == 'image':
# Generate image data with shape (4, 13, SIZE, SIZE) for timepoints and channels
data = np.random.rand(4, 13, SIZE, SIZE).astype(
np.float32
) # 4 timepoints, 13 channels
elif value == 'mask':
# Generate mask data with shape (SIZE, SIZE) with 4 classes
data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8)

np.savez_compressed(new_path, arr_0=data)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
"""
Recursively creates directory structure based on hierarchy and populates with data files.

Args:
- directory (str): Base directory for dataset.
- hierarchy (FILENAME_HIERARCHY): Directory and file structure.
"""
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, 'image')
create_file(path, value)


if __name__ == '__main__':
# Generate directory structure and data
create_directory('.', filenames)

# Create zip archives of dataset folders
filename_images = 'image_stack.tar.gz'
filename_masks = 'mask.tar.gz'
shutil.make_archive('image_stack', 'gztar', '.', 'image_stack')
shutil.make_archive('mask', 'gztar', '.', 'mask')

# Compute and print MD5 checksums for data validation
with open(filename_images, 'rb') as f:
md5_images = hashlib.md5(f.read()).hexdigest()
print(f'{filename_images}: {md5_images}')

with open(filename_masks, 'rb') as f:
md5_masks = hashlib.md5(f.read()).hexdigest()
print(f'{filename_masks}: {md5_masks}')
Binary file added tests/data/substation_seg/image_stack.tar.gz
Binary file not shown.
Binary file added tests/data/substation_seg/image_stack/image_0.npz
Binary file not shown.
Binary file added tests/data/substation_seg/image_stack/image_1.npz
Binary file not shown.
Binary file added tests/data/substation_seg/image_stack/image_2.npz
Binary file not shown.
Binary file added tests/data/substation_seg/image_stack/image_3.npz
Binary file not shown.
Binary file added tests/data/substation_seg/image_stack/image_4.npz
Binary file not shown.
Binary file added tests/data/substation_seg/mask.tar.gz
Binary file not shown.
Binary file added tests/data/substation_seg/mask/image_0.npz
Binary file not shown.
Binary file added tests/data/substation_seg/mask/image_1.npz
Binary file not shown.
Binary file added tests/data/substation_seg/mask/image_2.npz
Binary file not shown.
Binary file added tests/data/substation_seg/mask/image_3.npz
Binary file not shown.
Binary file added tests/data/substation_seg/mask/image_4.npz
Binary file not shown.
267 changes: 267 additions & 0 deletions tests/datasets/test_substation_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
import os
import shutil
from collections.abc import Generator
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
import torchvision.transforms as transforms

from torchgeo.datasets import SubstationDataset


class Args:
"""Mocked arguments for testing SubstationDataset."""

def __init__(self) -> None:
self.data_dir: str = os.path.join(os.getcwd(), 'tests', 'data')
self.in_channels: int = 13
self.use_timepoints: bool = True
self.normalizing_type: str = 'percentile'
self.mask_2d: bool = True
self.model_type: str = 'vanilla_unet'
self.timepoint_aggregation: str = 'median'
self.color_transforms: bool = False
self.geo_transforms: bool = False
self.normalizing_factor: Any = np.array([[0, 0.5, 1.0]], dtype=np.float32)
self.means: Any = np.array(
[
[[1431]],
[[1233]],
[[1209]],
[[1192]],
[[1448]],
[[2238]],
[[2609]],
[[2537]],
[[2828]],
[[884]],
[[20]],
[[2226]],
[[1537]],
],
dtype=np.float32,
)
self.stds: Any = np.array(
[
[[157]],
[[254]],
[[290]],
[[420]],
[[363]],
[[457]],
[[575]],
[[606]],
[[630]],
[[156]],
[[3]],
[[554]],
[[523]],
],
dtype=np.float32,
)


@pytest.fixture
def dataset(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> Generator[SubstationDataset, None, None]:
"""Fixture for the SubstationDataset."""
args = Args()
image_files = ['image_0.npz', 'image_1.npz']

yield SubstationDataset(args, image_files)


@pytest.mark.parametrize(
'config',
[
{
'normalizing_type': 'percentile',
'in_channels': 3,
'use_timepoints': False,
'mask_2d': True,
},
{
'normalizing_type': 'zscore',
'in_channels': 9,
'model_type': 'swin',
'use_timepoints': True,
'timepoint_aggregation': 'concat',
'mask_2d': False,
},
{
'normalizing_type': None,
'in_channels': 12,
'use_timepoints': True,
'timepoint_aggregation': 'median',
'mask_2d': True,
'normalizing_factor': 1.0,
},
{
'normalizing_type': None,
'in_channels': 5,
'use_timepoints': True,
'timepoint_aggregation': 'first',
'mask_2d': False,
'normalizing_factor': 1.0,
},
{
'normalizing_type': None,
'in_channels': 4,
'use_timepoints': True,
'timepoint_aggregation': 'random',
'mask_2d': True,
'normalizing_factor': 1.0,
},
{
'normalizing_type': 'zscore',
'in_channels': 2,
'use_timepoints': False,
'mask_2d': False,
'color_transforms': True,
'geo_transforms': True,
},
{
'normalizing_type': None,
'in_channels': 5,
'use_timepoints': False,
'timepoint_aggregation': 'first',
'mask_2d': False,
'normalizing_factor': 1.0,
},
{
'normalizing_type': None,
'in_channels': 4,
'use_timepoints': False,
'timepoint_aggregation': 'random',
'mask_2d': True,
'normalizing_factor': 1.0,
},
],
)
def test_getitem_semantic(config: dict[str, Any]) -> None:
args = Args()
for key, value in config.items():
setattr(args, key, value) # Dynamically set arguments for each config

# Setting mock paths and creating dataset instance
image_files = ['image_0.npz', 'image_1.npz']
image_resize = transforms.Compose(
[transforms.Resize(228, transforms.InterpolationMode.BICUBIC)]
)
mask_resize = transforms.Compose(
[transforms.Resize(228, transforms.InterpolationMode.NEAREST)]
)
dataset = SubstationDataset(
args, image_files, image_resize=image_resize, mask_resize=mask_resize
)

x = dataset[0]
assert isinstance(x, dict), f'Expected dict, got {type(x)}'
assert isinstance(x['image'], torch.Tensor), 'Expected image to be a torch.Tensor'
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor'


def test_len(dataset: SubstationDataset) -> None:
"""Test the length of the dataset."""
assert len(dataset) == 2


def test_output_shape(dataset: SubstationDataset) -> None:
"""Test the output shape of the dataset."""
x = dataset[0]
assert x['image'].shape == torch.Size([13, 228, 228])
assert x['mask'].shape == torch.Size([2, 228, 228])


def test_plot(dataset: SubstationDataset, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test the plot method of the dataset."""
# Mock plt.show to avoid showing the plot during the test
mock_show = MagicMock()
monkeypatch.setattr(plt, 'show', mock_show)

# Mock np.random.randint to return a fixed index (e.g., 0)
monkeypatch.setattr(
np.random, 'randint', lambda low, high: 0
) # Correct the lambda to accept 2 arguments

# Mock __getitem__ to return a sample with an image (3 channels) and a mask
mock_image = torch.rand(3, 228, 228) # Create a dummy 3-channel image (RGB)
mock_mask = torch.randint(0, 4, (228, 228)) # Create a dummy mask
monkeypatch.setattr(
dataset, '__getitem__', lambda idx: {'image': mock_image, 'mask': mock_mask}
)

# Call the plot method
dataset.plot()


def test_already_downloaded(
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test that the dataset doesn't re-download if already present."""
# Simulating that files are already present by copying them to the target directory
url_for_images = os.path.join(
'tests', 'data', 'substation_seg', 'image_stack.tar.gz'
)
url_for_masks = os.path.join('tests', 'data', 'substation_seg', 'mask.tar.gz')

# Copy files to the temporary directory to simulate already downloaded files
shutil.copy(url_for_images, tmp_path)
shutil.copy(url_for_masks, tmp_path)

# No download should be attempted, since the files are already present
# Mock the _download method to simulate the behavior
monkeypatch.setattr(dataset, '_download', MagicMock())
dataset._download() # This will now call the mocked method


def test_verify(dataset: SubstationDataset, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test the _verify method of the dataset."""
# Mock os.path.exists to return False for the image and mask directories
monkeypatch.setattr(os.path, 'exists', lambda path: False)

# Mock the _download method to avoid actually downloading the dataset
mock_download = MagicMock()
monkeypatch.setattr(dataset, '_download', mock_download)

# Call the _verify method
dataset._verify()

# Check that the _download method was called
mock_download.assert_called_once()


def test_download(
dataset: SubstationDataset, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Test the _download method of the dataset."""
# Mock the download_url and extract_archive functions
mock_download_url = MagicMock()
mock_extract_archive = MagicMock()
monkeypatch.setattr(
'torchgeo.datasets.substation_seg.download_url', mock_download_url
)
monkeypatch.setattr(
'torchgeo.datasets.substation_seg.extract_archive', mock_extract_archive
)

# Call the _download method
dataset._download()

# Check that download_url was called twice
mock_download_url.assert_called()
assert mock_download_url.call_count == 2

# Check that extract_archive was called twice
mock_extract_archive.assert_called()
assert mock_extract_archive.call_count == 2


if __name__ == '__main__':
pytest.main([__file__])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed, the file is run by pytest, not the other way around.

3 changes: 3 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
)
from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12
from .ssl4eo_benchmark import SSL4EOLBenchmark
from .substation_seg import SubstationDataset
from .sustainbench_crop_yield import SustainBenchCropYield
from .ucmerced import UCMerced
from .usavars import USAVars
Expand All @@ -152,6 +153,7 @@
'AsterGDEM',
'CanadianBuildingFootprints',
'CDL',
'ChaBuDx',
'Chesapeake',
'Chesapeake7',
'Chesapeake13',
Expand Down Expand Up @@ -263,6 +265,7 @@
'SSL4EOLBenchmark',
'SSL4EOL',
'SSL4EOS12',
'SubstationDataset',
'SustainBenchCropYield',
'TropicalCyclone',
'UCMerced',
Expand Down
Loading