Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TreeSatAI: Add new dataset #2402

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ SustainBench Crop Yield

.. autoclass:: SustainBenchCropYieldDataModule

TreeSatAI
^^^^^^^^^

.. autoclass:: TreeSatAIDataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ SustainBench Crop Yield

.. autoclass:: SustainBenchCropYield

TreeSatAI
^^^^^^^^^

.. autoclass:: TreeSatAI

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI
`SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI
`SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI
`TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR"
`Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB
`USAVars`_,R,NAIP Aerial,"CC-BY-4.0",100K,-,-,4,"RGB, NIR"
Expand Down
13 changes: 13 additions & 0 deletions tests/conf/treesatai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
model:
class_path: MultiLabelClassificationTask
init_args:
model: 'resnet18'
in_channels: 19
num_classes: 15
loss: 'bce'
data:
class_path: TreeSatAIDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/treesatai'
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/treesatai/aerial_60m_alnus_spec.zip
Binary file not shown.
Binary file not shown.
Binary file added tests/data/treesatai/aerial_60m_picea_abies.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/treesatai/aerial_60m_quercus_rubra.zip
Binary file not shown.
129 changes: 129 additions & 0 deletions tests/data/treesatai/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import json
import os
import random
import shutil
import zipfile

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 32

random.seed(0)
np.random.seed(0)

classes = (
'Abies',
'Acer',
'Alnus',
'Betula',
'Cleared',
'Fagus',
'Fraxinus',
'Larix',
'Picea',
'Pinus',
'Populus',
'Prunus',
'Pseudotsuga',
'Quercus',
'Tilia',
)

species = (
'Acer_pseudoplatanus',
'Alnus_spec',
'Fagus_sylvatica',
'Picea_abies',
'Pseudotsuga_menziesii',
'Quercus_petraea',
'Quercus_rubra',
)

profile = {
'aerial': {
'driver': 'GTiff',
'dtype': 'uint8',
'nodata': None,
'width': SIZE,
'height': SIZE,
'count': 4,
'crs': CRS.from_epsg(25832),
'transform': Affine(
0.19999999999977022, 0.0, 552245.4, 0.0, -0.19999999999938728, 5728215.0
),
},
's1': {
'driver': 'GTiff',
'dtype': 'float32',
'nodata': -9999.0,
'width': SIZE // 16,
'height': SIZE // 16,
'count': 3,
'crs': CRS.from_epsg(32632),
'transform': Affine(10.0, 0.0, 552245.0, 0.0, -10.0, 5728215.0),
},
's2': {
'driver': 'GTiff',
'dtype': 'uint16',
'nodata': None,
'width': SIZE // 16,
'height': SIZE // 16,
'count': 12,
'crs': CRS.from_epsg(32632),
'transform': Affine(10.0, 0.0, 552241.6565, 0.0, -10.0, 5728211.6251),
},
}

multi_labels = {}
for split in ['train', 'test']:
with open(f'{split}_filenames.lst') as f:
for filename in f:
filename = filename.strip()
for sensor in ['aerial', 's1', 's2']:
kwargs = profile[sensor]
directory = os.path.join(sensor, '60m')
os.makedirs(directory, exist_ok=True)
if 'int' in kwargs['dtype']:
Z = np.random.randint(
np.iinfo(kwargs['dtype']).min,
np.iinfo(kwargs['dtype']).max,
size=(kwargs['height'], kwargs['width']),
dtype=kwargs['dtype'],
)
else:
Z = np.random.rand(kwargs['height'], kwargs['width'])

path = os.path.join(directory, filename)
with rasterio.open(path, 'w', **kwargs) as src:
for i in range(1, kwargs['count'] + 1):
src.write(Z, i)

k = random.randrange(1, 4)
labels = random.choices(classes, k=k)
pcts = np.random.rand(k)
pcts /= np.sum(pcts)
multi_labels[filename] = list(map(list, zip(labels, map(float, pcts))))

os.makedirs('labels', exist_ok=True)
path = os.path.join('labels', 'TreeSatBA_v9_60m_multi_labels.json')
with open(path, 'w') as f:
json.dump(multi_labels, f)

for sensor in ['s1', 's2', 'labels']:
shutil.make_archive(sensor, 'zip', '.', sensor)

for spec in species:
path = f'aerial_60m_{spec}.zip'.lower()
with zipfile.ZipFile(path, 'w') as f:
for path in glob.iglob(os.path.join('aerial', '60m', f'{spec}_*.tif')):
filename = os.path.split(path)[-1]
f.write(path, arcname=filename)
Binary file added tests/data/treesatai/labels.zip
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"Picea_abies_3_46636_WEFL_NLF.tif": [["Prunus", 0.20692122963708826], ["Fraxinus", 0.7930787703629117]], "Pseudotsuga_menziesii_1_339575_BI_NLF.tif": [["Tilia", 0.4243067837573989], ["Larix", 0.5756932162426011]], "Quercus_rubra_1_92184_WEFL_NLF.tif": [["Tilia", 0.5816157697641007], ["Fagus", 0.4183842302358993]], "Fagus_sylvatica_9_29995_WEFL_NLF.tif": [["Larix", 1.0]], "Quercus_petraea_5_80549_WEFL_NLF.tif": [["Alnus", 0.5749721529276662], ["Acer", 0.4250278470723338]], "Acer_pseudoplatanus_3_5758_WEFL_NLF.tif": [["Tilia", 0.8430361090251272], ["Larix", 0.1569638909748729]], "Alnus_spec._5_13114_WEFL_NLF.tif": [["Pseudotsuga", 0.17881149698366108], ["Quercus", 0.38732907538618866], ["Cleared", 0.4338594276301503]], "Quercus_petraea_2_84375_WEFL_NLF.tif": [["Acer", 0.3909090505343164], ["Pseudotsuga", 0.2628926194326892], ["Cleared", 0.34619833003299444]], "Picea_abies_2_46896_WEFL_NLF.tif": [["Acer", 0.4953810312272686], ["Fraxinus", 0.0006659055704136941], ["Pinus", 0.5039530632023177]], "Acer_pseudoplatanus_4_6058_WEFL_NLF.tif": [["Tilia", 1.0]]}
Binary file added tests/data/treesatai/s1.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/treesatai/s2.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/treesatai/test_filenames.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Acer_pseudoplatanus_4_6058_WEFL_NLF.tif
9 changes: 9 additions & 0 deletions tests/data/treesatai/train_filenames.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Picea_abies_3_46636_WEFL_NLF.tif
Pseudotsuga_menziesii_1_339575_BI_NLF.tif
Quercus_rubra_1_92184_WEFL_NLF.tif
Fagus_sylvatica_9_29995_WEFL_NLF.tif
Quercus_petraea_5_80549_WEFL_NLF.tif
Acer_pseudoplatanus_3_5758_WEFL_NLF.tif
Alnus_spec._5_13114_WEFL_NLF.tif
Quercus_petraea_2_84375_WEFL_NLF.tif
Picea_abies_2_46896_WEFL_NLF.tif
62 changes: 62 additions & 0 deletions tests/datasets/test_treesatai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch.nn as nn
from pytest import MonkeyPatch
from torch import Tensor

from torchgeo.datasets import DatasetNotFoundError, TreeSatAI

root = os.path.join('tests', 'data', 'treesatai')
md5s = {
'aerial_60m_acer_pseudoplatanus.zip': '',
'labels.zip': '',
's1.zip': '',
's2.zip': '',
'test_filenames.lst': '',
'train_filenames.lst': '',
}


class TestTreeSatAI:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch) -> TreeSatAI:
monkeypatch.setattr(TreeSatAI, 'url', root + os.sep)
monkeypatch.setattr(TreeSatAI, 'md5s', md5s)
transforms = nn.Identity()
return TreeSatAI(root, transforms=transforms)

def test_getitem(self, dataset: TreeSatAI) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['label'], Tensor)
for sensor in dataset.sensors:
assert isinstance(x[f'image_{sensor}'], Tensor)

def test_len(self, dataset: TreeSatAI) -> None:
assert len(dataset) == 9

def test_download(self, dataset: TreeSatAI, tmp_path: Path) -> None:
TreeSatAI(tmp_path, download=True)

def test_extract(self, dataset: TreeSatAI, tmp_path: Path) -> None:
for file in glob.iglob(os.path.join(root, '*.*')):
shutil.copy(file, tmp_path)
TreeSatAI(tmp_path)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
TreeSatAI(tmp_path)

def test_plot(self, dataset: TreeSatAI) -> None:
x = dataset[0]
x['prediction'] = x['label']
dataset.plot(x)
plt.close()
2 changes: 1 addition & 1 deletion tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_freeze_backbone(self, model_name: str) -> None:

class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2']
'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai']
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule
from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule
from .sustainbench_crop_yield import SustainBenchCropYieldDataModule
from .treesatai import TreeSatAIDataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
Expand Down Expand Up @@ -106,6 +107,7 @@
'SSL4EOLDataModule',
'SSL4EOS12DataModule',
'SustainBenchCropYieldDataModule',
'TreeSatAIDataModule',
'TropicalCycloneDataModule',
'UCMercedDataModule',
'USAVarsDataModule',
Expand Down
Loading