Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Aug 24, 2024
1 parent 0df2fc9 commit 33e7070
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 44 deletions.
5 changes: 3 additions & 2 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
NonGeoClassificationDataset,
NonGeoDataset,
RasterDataset,
Sample,
Sentinel2,
UnionDataset,
VectorDataset,
Expand All @@ -46,7 +47,7 @@ def __init__(
self.res = res
self.paths = paths or []

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
bounds = BoundingBox(*hit.bounds)
Expand Down Expand Up @@ -77,7 +78,7 @@ class CustomSentinelDataset(Sentinel2):


class CustomNonGeoDataset(NonGeoDataset):
def __getitem__(self, index: int) -> dict[str, int]:
def __getitem__(self, index: int) -> Sample:
return {'index': index}

def __len__(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from collections.abc import Sequence
from math import floor, isclose
from typing import Any

import pytest
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
Sample,
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
def __getitem__(self, query: BoundingBox) -> Sample:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
return {'content': hit.object}
Expand Down
18 changes: 9 additions & 9 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from rasterio.crs import CRS

from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets import BoundingBox, DependencyNotFoundError, Sample
from torchgeo.datasets.utils import (
Executable,
array_to_tensor,
Expand Down Expand Up @@ -381,13 +381,13 @@ def test_disambiguate_timestamp(

class TestCollateFunctionsMatchingKeys:
@pytest.fixture(scope='class')
def samples(self) -> list[dict[str, Any]]:
def samples(self) -> list[Sample]:
return [
{'image': torch.tensor([1, 2, 0]), 'crs': CRS.from_epsg(2000)},
{'image': torch.tensor([0, 0, 3]), 'crs': CRS.from_epsg(2001)},
]

def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
sample = stack_samples(samples)
assert sample['image'].size() == torch.Size([2, 3])
assert torch.allclose(sample['image'], torch.tensor([[1, 2, 0], [0, 0, 3]]))
Expand All @@ -398,13 +398,13 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
assert torch.allclose(samples[i]['image'], new_samples[i]['image'])
assert samples[i]['crs'] == new_samples[i]['crs']

def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
def test_concat_samples(self, samples: list[Sample]) -> None:
sample = concat_samples(samples)
assert sample['image'].size() == torch.Size([6])
assert torch.allclose(sample['image'], torch.tensor([1, 2, 0, 0, 0, 3]))
assert sample['crs'] == CRS.from_epsg(2000)

def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
def test_merge_samples(self, samples: list[Sample]) -> None:
sample = merge_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert torch.allclose(sample['image'], torch.tensor([1, 2, 3]))
Expand All @@ -413,13 +413,13 @@ def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:

class TestCollateFunctionsDifferingKeys:
@pytest.fixture(scope='class')
def samples(self) -> list[dict[str, Any]]:
def samples(self) -> list[Sample]:
return [
{'image': torch.tensor([1, 2, 0]), 'crs1': CRS.from_epsg(2000)},
{'mask': torch.tensor([0, 0, 3]), 'crs2': CRS.from_epsg(2001)},
]

def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
def test_stack_unbind_samples(self, samples: list[Sample]) -> None:
sample = stack_samples(samples)
assert sample['image'].size() == torch.Size([1, 3])
assert sample['mask'].size() == torch.Size([1, 3])
Expand All @@ -434,7 +434,7 @@ def test_stack_unbind_samples(self, samples: list[dict[str, Any]]) -> None:
assert torch.allclose(samples[1]['mask'], new_samples[0]['mask'])
assert samples[1]['crs2'] == new_samples[0]['crs2']

def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
def test_concat_samples(self, samples: list[Sample]) -> None:
sample = concat_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert sample['mask'].size() == torch.Size([3])
Expand All @@ -443,7 +443,7 @@ def test_concat_samples(self, samples: list[dict[str, Any]]) -> None:
assert sample['crs1'] == CRS.from_epsg(2000)
assert sample['crs2'] == CRS.from_epsg(2001)

def test_merge_samples(self, samples: list[dict[str, Any]]) -> None:
def test_merge_samples(self, samples: list[Sample]) -> None:
sample = merge_samples(samples)
assert sample['image'].size() == torch.Size([3])
assert sample['mask'].size() == torch.Size([3])
Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rasterio.crs import CRS
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units


Expand All @@ -32,7 +32,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
return {'index': query}


Expand Down
4 changes: 2 additions & 2 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rasterio.crs import CRS
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.datasets import BoundingBox, GeoDataset, Sample, stack_samples
from torchgeo.samplers import (
GeoSampler,
GridGeoSampler,
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
def __getitem__(self, query: BoundingBox) -> Sample:
return {'index': query}


Expand Down
26 changes: 15 additions & 11 deletions torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import sys
from collections.abc import Callable, Sequence
from typing import ClassVar, cast
from typing import Any, ClassVar, cast

import fiona
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -347,8 +347,8 @@ def __getitem__(self, query: BoundingBox) -> Sample:
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[dict[str, str]], [hit.object for hit in hits])

sample: Sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query}
images: list[np.typing.NDArray[Any]] = []
masks: list[np.typing.NDArray[Any]] = []

if len(filepaths) == 0:
raise IndexError(
Expand Down Expand Up @@ -389,23 +389,27 @@ def __getitem__(self, query: BoundingBox) -> Sample:
'waterbodies',
'water',
]:
sample['image'].append(data)
images.append(data)
elif layer in ['prior', 'prior_no_osm_no_buildings']:
if self.prior_as_input:
sample['image'].append(data)
images.append(data)
else:
sample['mask'].append(data)
masks.append(data)
elif layer in ['lc']:
data = self.raw_enviroatlas_to_idx_map[data]
sample['mask'].append(data)
masks.append(data)
else:
raise IndexError(f'query: {query} spans multiple tiles which is not valid')

sample['image'] = np.concatenate(sample['image'], axis=0)
sample['mask'] = np.concatenate(sample['mask'], axis=0)
image = torch.from_numpy(np.concatenate(images, axis=0))
mask = torch.from_numpy(np.concatenate(masks, axis=0))

sample['image'] = torch.from_numpy(sample['image'])
sample['mask'] = torch.from_numpy(sample['mask'])
sample: Sample = {
'image': image,
'mask': mask,
'crs': self.crs,
'bounds': query,
}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down
6 changes: 2 additions & 4 deletions torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any

import fiona
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import VectorDataset
Expand Down Expand Up @@ -247,9 +247,7 @@ def plot(

fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4))

def apply_cmap(
arr: 'np.typing.NDArray[Any]',
) -> 'np.typing.NDArray[np.float64]':
def apply_cmap(arr: Tensor) -> 'np.typing.NDArray[np.float64]':
# Color 0 as black, while applying default color map for the class indices.
cmap = plt.get_cmap('viridis')
im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map))
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __getitem__(self, index: int) -> Sample:
label_path = label_path.replace('.tif', '.xml')
voc = parse_pascal_voc(label_path)
boxes, labels = self._load_target(voc['points'], voc['labels'])
sample: Sample = {'image': image, 'boxes': boxes, 'label': labels}
sample = {'image': image, 'boxes': boxes, 'label': labels}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/gid15.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, index: int) -> Sample:
mask = self._load_target(files['mask'])
sample: Sample = {'image': image, 'mask': mask}
else:
sample: Sample = {'image': image}
sample = {'image': image}

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down
8 changes: 4 additions & 4 deletions torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __len__(self) -> int:

return num_datapoints

def __getitem__(self, index: int) -> dict[str, str | Tensor]:
def __getitem__(self, index: int) -> Sample:
"""Return an index within the dataset.
Args:
Expand All @@ -143,7 +143,7 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
Returns:
data and label at that index
"""
sample: dict[str, str | Tensor] = {'image': self._load_image(index)}
sample: Sample = {'image': self._load_image(index)}
sample.update(self._load_features(index))

if self.transforms is not None:
Expand Down Expand Up @@ -176,7 +176,7 @@ def _load_image(self, index: int) -> Tensor:
tensor = torch.from_numpy(arr).to(torch.float32)
return tensor

def _load_features(self, index: int) -> dict[str, str | Tensor]:
def _load_features(self, index: int) -> Sample:
"""Load label.
Args:
Expand All @@ -194,7 +194,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]:
path = os.path.join(self.root, f'times_{self.split}_{self.task}.npy')
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)

features: dict[str, str | Tensor] = {
features: Sample = {
'label': torch.tensor(label, dtype=torch.float32),
'date': datestring,
}
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/sustainbench_crop_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
self._verify()

self.images = []
self.features = []
self.features: list[Sample] = []

for country in self.countries:
image_file_path = os.path.join(
Expand All @@ -122,7 +122,7 @@ def __init__(
year = year_npz_file[idx]
ndvi = ndvi_npz_file[idx]

features = {
features: Sample = {
'label': torch.tensor(target).to(torch.float32),
'year': torch.tensor(int(year)),
'ndvi': torch.from_numpy(ndvi).to(dtype=torch.float32),
Expand Down
54 changes: 50 additions & 4 deletions torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,54 @@ class Sample(TypedDict, total=False):
bounds: BoundingBox
crs: CRS

# TODO: remove
# TODO: Additional dataset-specific keys that should be subclasses
images: Tensor
input: Tensor
boxes: Tensor
bboxes: Tensor
masks: Tensor
labels: Tensor
prediction_masks: Tensor
prediction_boxes: Tensor
prediction_labels: Tensor
prediction_label: Tensor
prediction_scores: Tensor
audio: Tensor
points: Tensor
x: Tensor
y: Tensor
relative_time: Tensor
ocean: Tensor
array: Tensor
chm: Tensor
hsi: Tensor
las: Tensor
image1: Tensor
image2: Tensor
crs1: Tensor
crs2: Tensor
magnitude: Tensor
agb: Tensor
key: Tensor
patch: Tensor
geometry: Tensor
properties: Tensor
id: int
centroid_lat: Tensor
centroid_lon: Tensor
content: Tensor
year: Tensor
ndvi: Tensor
filename: str
category: str
field_ids: Tensor
tile_index: Tensor
transform: Tensor
src: Tensor
dst: Tensor
input_size: Tensor
output_size: Tensor
index: BoundingBox


class Batch(Sample):
Expand Down Expand Up @@ -456,7 +502,7 @@ def stack_samples(samples: Iterable[Sample]) -> Batch:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
collated: Batch = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.stack(value)
Expand All @@ -476,7 +522,7 @@ def concat_samples(samples: Iterable[Sample]) -> Batch:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
collated: Batch = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.cat(value)
Expand All @@ -498,7 +544,7 @@ def merge_samples(samples: Iterable[Sample]) -> Batch:
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = {}
collated: Batch = {}
for sample in samples:
for key, value in sample.items():
if key in collated and isinstance(value, Tensor):
Expand Down

0 comments on commit 33e7070

Please sign in to comment.