Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an initial version of WMSDataset #1965

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ repos:
rev: v1.8.0
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --show-error-codes]
args: [--python-version=3.10, --strict, --ignore-missing-imports, --show-error-codes]
additional_dependencies: [einops>=0.6.0, kornia>=0.6.9, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, scikit-image>=0.18.0, torch>=2.2, torchmetrics>=0.10]
exclude: (build|data|dist|logo|logs|output)/
1 change: 1 addition & 0 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
h5py==3.10.0
laspy==2.5.3
opencv-python==4.9.0.80
OWSLib==0.30.0
pycocotools==2.0.7
pyvista==0.43.4
radiant-mlhub==0.4.1
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lightly==1.5.2
lightning[pytorch-extra]==2.2.1
matplotlib==3.8.3
numpy==1.26.4
owsLib==0.30.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this optional instead of required. So it should listed in pyproject.toml, requirements/datasets.txt, and requirements/min-reqs.old. The lower bound will have to be tested. Easiest way is to pip-install older and older versions until the tests fail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to add to requirements/min-reqs.old in order to get the minimum tests to pass.

pandas==2.2.1
pillow==10.2.0
pyproj==3.6.1
Expand Down
68 changes: 68 additions & 0 deletions tests/datasets/test_wms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import pytest
import urllib3

from torchgeo.datasets import WMSDataset

SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?"


def service_ok(url: str, timeout: int = 5) -> bool:
try:
http = urllib3.PoolManager()
resp = http.request("HEAD", url, timeout=timeout)
ok = 200 == resp.status
except urllib3.exceptions.NewConnectionError:
ok = False
except Exception:
ok = False
return ok


class TestWMSDataset:

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_no_layer(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0)
assert "nexrad_base_reflect" in wms.layers()
assert 4326 == wms.crs.to_epsg()
wms.layer("nexrad_base_reflect", crs=4326)
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_layer(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326)
assert 4326 == wms.crs.to_epsg()
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format

@pytest.mark.slow
@pytest.mark.skipif(
not service_ok(SERVICE_URL), reason="WMS service is unreachable"
)
def test_wms_layer_nocrs(self) -> None:
"""MESONET GetMap 1.1.1"""
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect")
assert 4326 == wms.crs.to_epsg()
assert -126 == wms.index.bounds[0]
assert -66 == wms.index.bounds[1]
assert 24 == wms.index.bounds[2]
assert 50 == wms.index.bounds[3]
assert "image/png" == wms._format
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
from .vaihingen import Vaihingen2D
from .vhr10 import VHR10
from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture
from .wms import WMSDataset
from .xview import XView2
from .zuericrop import ZueriCrop

Expand Down Expand Up @@ -263,6 +264,7 @@
"RasterDataset",
"UnionDataset",
"VectorDataset",
"WMSDataset",
# Utilities
"BoundingBox",
"concat_samples",
Expand Down
146 changes: 146 additions & 0 deletions torchgeo/datasets/wms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# Author: Ian Turton, Glasgow University [email protected]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't normally edit the license header. Your contribution will still be included in the git history though.


"""A simple class to fetch WMS images."""
from collections.abc import Callable
from io import BytesIO
from typing import Any

import torchvision.transforms as transforms
from owslib.map.wms111 import ContentMetadata
from owslib.wms import WebMapService
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports should be done lazily inside of the class so that the rest of TorchGeo still works even if owslib isn't installed. If you grep for any of our other optional dataset dependencies, you should see example code for this you can copy-n-paste.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ianturton you can check torchgeo.datasets.chabud as an example. See how we check if h5py is available in the constructor and then import h5py in the method that uses it.

from PIL import Image
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
from rasterio.errors import CRSError
from rtree.index import Index, Property

from torchgeo.datasets import GeoDataset


class WMSDataset(GeoDataset):
"""Allow models to fetch images from a WMS (at a good resolution)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what "at a good resolution" means? Also would be good to explain what WMS is and why it is useful. For example, see the documentation for some of our other existing dataset base classes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you add:

.. versionadded:: 0.6

to the docstring? This will document what version of TorchGeo is required to use this feature.


_url: str = ""
_wms: WebMapService = None

_layers: list[ContentMetadata] = []
_layer: ContentMetadata = None
_layer_name: str = ""
is_image = True

def __init__(
self,
url: str,
res: float,
layer: str | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
crs: str | int | None = None,
) -> None:
"""Initialize the data set instance.

Args:
url: String pointing to the WMS Server (no need for
request=getcapabilities).
res: The best resolution to make requests at
(should probably match your other datasets).
layer: An optional string with the name of the
layer you want to fetch.
transforms: a function/transform that takes an input sample
and returns a transformed version
crs: An optional string or integer code that is a
valid EPSG code (without the 'EPSG:')

"""
super().__init__(transforms)
self._url = url
self._res = res
if crs is not None:
self._crs = CRS.from_epsg(crs)
self._wms = WebMapService(url)
self._format = self._wms.getOperationByName("GetMap").formatOptions[0]
self._layers = list(self._wms.contents)

if layer is not None and layer in self._layers:
self.layer(layer, crs)

def layer(self, layer: str, crs: str | int | None = None) -> None:
"""Set the layer to be fetched.

Args:
layer: A string with the name of the layer you want to fetch.
crs: An optional string or integer code that is a valid EPSG
code (without the 'EPSG:')
"""
self._layer = self._wms[layer]
self._layer_name = layer
coords = self._wms[layer].boundingBox
self.index = Index(interleaved=False, properties=Property(dimension=3))
self.index.insert(
0,
(
float(coords[0]),
float(coords[2]),
float(coords[1]),
float(coords[3]),
0,
9.223372036854776e18,
),
)
if crs is None:
i = 0
while self._crs is None:
crs_str = sorted(self._layer.crsOptions)[i].upper()
if "EPSG:" in crs_str:
crs_str = crs_str[5:]
elif "CRS:84":
crs_str = "4326"
try:
self._crs = CRS.from_epsg(crs_str)
except CRSError:
pass

def getlayer(self) -> ContentMetadata:
"""Return the selected layer object."""
return self._layer

def layers(self) -> list[str]:
"""Return a list of availiable layers."""
return self._layers
Comment on lines +113 to +119
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getters/setters are generally considered to be bad practice in Python, especially if you don't actually need to do anything special. In this case, if we want the ability to access the attribute, let's just make it a public attribute.


def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index

Returns:
sample of image/mask and metadata at that index

Raises:
IndexError: if query is not found in the index
"""
img = self._wms.getmap(
layers=[self._layer_name],
srs="epsg:" + str(self.crs.to_epsg()),
bbox=(query.minx, query.miny, query.maxx, query.maxy),
# TODO fix size
size=(500, 500),
format=self._format,
transparent=True,
)
sample = {"crs": self.crs, "bbox": query}

transform = transforms.Compose([transforms.ToTensor()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't normally use torchvision transforms. What is the output from wms, a numpy array? Let's just convert that directly to PyTorch without relying on PIL

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a JPEG or PNG image can I convert that directly to a tensor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T.ToTensor() also automatically divides by 255 which we should leave up to the user to decide. You can convert the image to a tensor using torch.from_numpy(np.array(Image.open(path).astype(np.float32))). You can break this up over multiple lines so it's not too verbose

# Convert the PIL image to Torch tensor
img_tensor = transform(Image.open(BytesIO(img.read())))
if self.is_image:
sample["image"] = img_tensor
else:
sample["mask"] = img_tensor

if self.transforms is not None:
sample = self.transforms(sample)

return sample