-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an initial version of WMSDataset #1965
base: main
Are you sure you want to change the base?
Changes from 13 commits
92d2729
4e1ee53
3dd0a04
d11231b
1e819a6
b8698ac
dc99c2b
31fc997
7182c96
2d9000c
7752e95
95b49f1
a8acd0a
2bf3c2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import pytest | ||
import urllib3 | ||
|
||
from torchgeo.datasets import WMSDataset | ||
|
||
SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?" | ||
|
||
|
||
def service_ok(url: str, timeout: int = 5) -> bool: | ||
try: | ||
http = urllib3.PoolManager() | ||
resp = http.request("HEAD", url, timeout=timeout) | ||
ok = 200 == resp.status | ||
except urllib3.exceptions.NewConnectionError: | ||
ok = False | ||
except Exception: | ||
ok = False | ||
return ok | ||
|
||
|
||
class TestWMSDataset: | ||
|
||
@pytest.mark.slow | ||
@pytest.mark.skipif( | ||
not service_ok(SERVICE_URL), reason="WMS service is unreachable" | ||
) | ||
def test_wms_no_layer(self) -> None: | ||
"""MESONET GetMap 1.1.1""" | ||
wms = WMSDataset(SERVICE_URL, 10.0) | ||
assert "nexrad_base_reflect" in wms.layers() | ||
assert 4326 == wms.crs.to_epsg() | ||
wms.layer("nexrad_base_reflect", crs=4326) | ||
assert -126 == wms.index.bounds[0] | ||
assert -66 == wms.index.bounds[1] | ||
assert 24 == wms.index.bounds[2] | ||
assert 50 == wms.index.bounds[3] | ||
assert "image/png" == wms._format | ||
|
||
@pytest.mark.slow | ||
@pytest.mark.skipif( | ||
not service_ok(SERVICE_URL), reason="WMS service is unreachable" | ||
) | ||
def test_wms_layer(self) -> None: | ||
"""MESONET GetMap 1.1.1""" | ||
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326) | ||
assert 4326 == wms.crs.to_epsg() | ||
assert -126 == wms.index.bounds[0] | ||
assert -66 == wms.index.bounds[1] | ||
assert 24 == wms.index.bounds[2] | ||
assert 50 == wms.index.bounds[3] | ||
assert "image/png" == wms._format | ||
|
||
@pytest.mark.slow | ||
@pytest.mark.skipif( | ||
not service_ok(SERVICE_URL), reason="WMS service is unreachable" | ||
) | ||
def test_wms_layer_nocrs(self) -> None: | ||
"""MESONET GetMap 1.1.1""" | ||
wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect") | ||
assert 4326 == wms.crs.to_epsg() | ||
assert -126 == wms.index.bounds[0] | ||
assert -66 == wms.index.bounds[1] | ||
assert 24 == wms.index.bounds[2] | ||
assert 50 == wms.index.bounds[3] | ||
assert "image/png" == wms._format |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# Author: Ian Turton, Glasgow University [email protected] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't normally edit the license header. Your contribution will still be included in the git history though. |
||
|
||
"""A simple class to fetch WMS images.""" | ||
from collections.abc import Callable | ||
from io import BytesIO | ||
from typing import Any | ||
|
||
import torchvision.transforms as transforms | ||
from owslib.map.wms111 import ContentMetadata | ||
from owslib.wms import WebMapService | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Imports should be done lazily inside of the class so that the rest of TorchGeo still works even if owslib isn't installed. If you grep for any of our other optional dataset dependencies, you should see example code for this you can copy-n-paste. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
from PIL import Image | ||
from rasterio.coords import BoundingBox | ||
from rasterio.crs import CRS | ||
from rasterio.errors import CRSError | ||
from rtree.index import Index, Property | ||
|
||
from torchgeo.datasets import GeoDataset | ||
|
||
|
||
class WMSDataset(GeoDataset): | ||
"""Allow models to fetch images from a WMS (at a good resolution).""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify what "at a good resolution" means? Also would be good to explain what WMS is and why it is useful. For example, see the documentation for some of our other existing dataset base classes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can you add: .. versionadded:: 0.6 to the docstring? This will document what version of TorchGeo is required to use this feature. |
||
|
||
_url: str = "" | ||
_wms: WebMapService = None | ||
|
||
_layers: list[ContentMetadata] = [] | ||
_layer: ContentMetadata = None | ||
_layer_name: str = "" | ||
is_image = True | ||
|
||
def __init__( | ||
self, | ||
url: str, | ||
res: float, | ||
layer: str | None = None, | ||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, | ||
crs: str | int | None = None, | ||
) -> None: | ||
"""Initialize the data set instance. | ||
|
||
Args: | ||
url: String pointing to the WMS Server (no need for | ||
request=getcapabilities). | ||
res: The best resolution to make requests at | ||
(should probably match your other datasets). | ||
layer: An optional string with the name of the | ||
layer you want to fetch. | ||
transforms: a function/transform that takes an input sample | ||
and returns a transformed version | ||
crs: An optional string or integer code that is a | ||
valid EPSG code (without the 'EPSG:') | ||
|
||
""" | ||
super().__init__(transforms) | ||
self._url = url | ||
self._res = res | ||
if crs is not None: | ||
self._crs = CRS.from_epsg(crs) | ||
self._wms = WebMapService(url) | ||
self._format = self._wms.getOperationByName("GetMap").formatOptions[0] | ||
self._layers = list(self._wms.contents) | ||
|
||
if layer is not None and layer in self._layers: | ||
self.layer(layer, crs) | ||
|
||
def layer(self, layer: str, crs: str | int | None = None) -> None: | ||
"""Set the layer to be fetched. | ||
|
||
Args: | ||
layer: A string with the name of the layer you want to fetch. | ||
crs: An optional string or integer code that is a valid EPSG | ||
code (without the 'EPSG:') | ||
""" | ||
self._layer = self._wms[layer] | ||
self._layer_name = layer | ||
coords = self._wms[layer].boundingBox | ||
self.index = Index(interleaved=False, properties=Property(dimension=3)) | ||
self.index.insert( | ||
0, | ||
( | ||
float(coords[0]), | ||
float(coords[2]), | ||
float(coords[1]), | ||
float(coords[3]), | ||
0, | ||
9.223372036854776e18, | ||
), | ||
) | ||
if crs is None: | ||
i = 0 | ||
while self._crs is None: | ||
crs_str = sorted(self._layer.crsOptions)[i].upper() | ||
if "EPSG:" in crs_str: | ||
crs_str = crs_str[5:] | ||
elif "CRS:84": | ||
crs_str = "4326" | ||
try: | ||
self._crs = CRS.from_epsg(crs_str) | ||
except CRSError: | ||
pass | ||
|
||
def getlayer(self) -> ContentMetadata: | ||
"""Return the selected layer object.""" | ||
return self._layer | ||
|
||
def layers(self) -> list[str]: | ||
"""Return a list of availiable layers.""" | ||
return self._layers | ||
Comment on lines
+113
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Getters/setters are generally considered to be bad practice in Python, especially if you don't actually need to do anything special. In this case, if we want the ability to access the attribute, let's just make it a public attribute. |
||
|
||
def __getitem__(self, query: BoundingBox) -> dict[str, Any]: | ||
"""Retrieve image/mask and metadata indexed by query. | ||
|
||
Args: | ||
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index | ||
|
||
Returns: | ||
sample of image/mask and metadata at that index | ||
|
||
Raises: | ||
IndexError: if query is not found in the index | ||
""" | ||
img = self._wms.getmap( | ||
layers=[self._layer_name], | ||
srs="epsg:" + str(self.crs.to_epsg()), | ||
bbox=(query.minx, query.miny, query.maxx, query.maxy), | ||
# TODO fix size | ||
size=(500, 500), | ||
format=self._format, | ||
transparent=True, | ||
) | ||
sample = {"crs": self.crs, "bbox": query} | ||
|
||
transform = transforms.Compose([transforms.ToTensor()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't normally use torchvision transforms. What is the output from wms, a numpy array? Let's just convert that directly to PyTorch without relying on PIL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a JPEG or PNG image can I convert that directly to a tensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. T.ToTensor() also automatically divides by 255 which we should leave up to the user to decide. You can convert the image to a tensor using |
||
# Convert the PIL image to Torch tensor | ||
img_tensor = transform(Image.open(BytesIO(img.read()))) | ||
if self.is_image: | ||
sample["image"] = img_tensor | ||
else: | ||
sample["mask"] = img_tensor | ||
|
||
if self.transforms is not None: | ||
sample = self.transforms(sample) | ||
|
||
return sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this optional instead of required. So it should listed in
pyproject.toml
,requirements/datasets.txt
, andrequirements/min-reqs.old
. The lower bound will have to be tested. Easiest way is to pip-install older and older versions until the tests fail.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also need to add to
requirements/min-reqs.old
in order to get the minimum tests to pass.