Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ PASTIS

.. autoclass:: PASTIS

SubstationDataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be in alphabetical order

^^^^^^

.. autoclass:: SubstationDataset

PatternNet
^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`SSL4EO`_-S12,T,Sentinel-1/2,"CC-BY-4.0",1M,-,264x264,10,"SAR, MSI"
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
`SubstationDataset`_,S,“OpenStreetMap, Sentinel-2”, “CC BY-SA 2.0", “27,000+“, 2, 228x228, 10, MSI
`SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI
`Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB
Expand Down
62 changes: 62 additions & 0 deletions tests/data/substation_seg/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil

import numpy as np

SIZE = 228
NUM_SAMPLES = 5
np.random.seed(0)

FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str]

filenames: FILENAME_HIERARCHY = {
'image_stack': ['image'],
'mask': ['mask'],
}

def create_file(path: str) -> None:
for i in range(NUM_SAMPLES):
new_path = f'{path}_{i}.npz'
fn = os.path.basename(new_path)
if fn.startswith('image'):
data = np.random.rand(4, SIZE, SIZE).astype(np.float32) # 4 channels (RGB + NIR)
elif fn.startswith('mask'):
data = np.random.randint(0, 4, size=(SIZE, SIZE)).astype(np.uint8) # Mask with 4 classes
np.savez_compressed(new_path, arr_0=data)

def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_file(path)

if __name__ == '__main__':
create_directory('.', filenames)

# Create a zip archive of the generated dataset
filename_images = 'image_stack.tar.gz'
filename_masks = 'mask.tar.gz'
shutil.make_archive('image_stack', 'gztar', '.', 'image_stack')
shutil.make_archive('mask', 'gztar', '.', 'mask')

# Compute checksums
with open(filename_images, 'rb') as f:
md5_images = hashlib.md5(f.read()).hexdigest()
print(f'{filename_images}: {md5_images}')

with open(filename_masks, 'rb') as f:
md5_masks = hashlib.md5(f.read()).hexdigest()
print(f'{filename_masks}: {md5_masks}')
Binary file added tests/data/substation_seg/image_stack.tar.gz
Binary file not shown.
Binary file added tests/data/substation_seg/mask.tar.gz
Binary file not shown.
142 changes: 142 additions & 0 deletions tests/datasets/test_substation_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
from pathlib import Path

import numpy as np
import pytest
import torch
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, SubstationDataset


class TestSubstationDataset:
@pytest.fixture(
params=[
{
'image_files': ['image_1.npz', 'image_2.npz'],
'geo_transforms': None,
'color_transforms': None,
'image_resize': None,
'mask_resize': None,
}
]
)
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path, request: pytest.FixtureRequest) -> SubstationDataset:
"""
Fixture to create a mock dataset with specified parameters.
"""
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'

# Creating mock image and mask files
for filename in request.param['image_files']:
os.makedirs(os.path.join(tmp_path, 'image_stack'), exist_ok=True)
os.makedirs(os.path.join(tmp_path, 'mask'), exist_ok=True)
np.savez_compressed(os.path.join(tmp_path, 'image_stack', filename), arr_0=np.random.rand(4, 128, 128))
np.savez_compressed(os.path.join(tmp_path, 'mask', filename), arr_0=np.random.randint(0, 4, (128, 128)))

image_files = request.param['image_files']
geo_transforms = request.param['geo_transforms']
color_transforms = request.param['color_transforms']
image_resize = request.param['image_resize']
mask_resize = request.param['mask_resize']

return SubstationDataset(
args,
image_files=image_files,
geo_transforms=geo_transforms,
color_transforms=color_transforms,
image_resize=image_resize,
mask_resize=mask_resize,
)

def test_getitem(self, dataset: SubstationDataset) -> None:
image, mask = dataset[0]
assert isinstance(image, torch.Tensor)
assert isinstance(mask, torch.Tensor)
assert image.shape[0] == 4 # Checking number of channels
assert mask.shape == (1, 128, 128)

def test_len(self, dataset: SubstationDataset) -> None:
assert len(dataset) == 2

def test_already_downloaded(self, tmp_path: Path) -> None:
# Test to ensure dataset initialization doesn't download if data already exists
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'

os.makedirs(os.path.join(tmp_path, 'image_stack'))
os.makedirs(os.path.join(tmp_path, 'mask'))

# No need to assign `dataset` variable, just assert
SubstationDataset(args, image_files=[])
assert os.path.exists(os.path.join(tmp_path, 'image_stack'))
assert os.path.exists(os.path.join(tmp_path, 'mask'))

def test_not_downloaded(self, tmp_path: Path) -> None:
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'

with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SubstationDataset(args, image_files=[])

def test_plot(self, dataset: SubstationDataset) -> None:
dataset.plot()
# No assertion, just ensuring that the plotting does not throw any exceptions.

def test_corrupted(self, tmp_path: Path) -> None:
class Args:
pass

args = Args()
args.data_dir = tmp_path
args.in_channels = 4
args.use_timepoints = False
args.normalizing_type = 'zscore'
args.normalizing_factor = np.array([1.0])
args.means = np.array([0.5])
args.stds = np.array([0.1])
args.mask_2d = True
args.model_type = 'segmentation'

# Creating corrupted files
os.makedirs(os.path.join(tmp_path, 'image_stack'))
os.makedirs(os.path.join(tmp_path, 'mask'))
with open(os.path.join(tmp_path, 'image_stack', 'image_1.npz'), 'w') as f:
f.write('corrupted')

with pytest.raises(Exception):
SubstationDataset(args, image_files=['image_1.npz'])
3 changes: 3 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
)
from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12
from .ssl4eo_benchmark import SSL4EOLBenchmark
from .substation_seg import SubstationDataset
from .sustainbench_crop_yield import SustainBenchCropYield
from .ucmerced import UCMerced
from .usavars import USAVars
Expand All @@ -152,6 +153,7 @@
'AsterGDEM',
'CanadianBuildingFootprints',
'CDL',
'ChaBuDx',
'Chesapeake',
'Chesapeake7',
'Chesapeake13',
Expand Down Expand Up @@ -263,6 +265,7 @@
'SSL4EOLBenchmark',
'SSL4EOL',
'SSL4EOS12',
'SubstationDataset',
'SustainBenchCropYield',
'TropicalCyclone',
'UCMerced',
Expand Down
Loading
Loading