From 92d272938041a0445565ec4a223ff4e271069a75 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 14:09:08 +0000 Subject: [PATCH 01/12] initial version of WMSDataset --- pyproject.toml | 1 + requirements/datasets.txt | 1 + tests/datasets/test_wms.py | 65 +++++++++++++++++++++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/wms.py | 104 ++++++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+) create mode 100644 tests/datasets/test_wms.py create mode 100644 torchgeo/datasets/wms.py diff --git a/pyproject.toml b/pyproject.toml index 350d9273bdc..9962c7e744a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -276,6 +276,7 @@ filterwarnings = [ ] markers = [ "slow: marks tests as slow", + "online: marks a test as needing to be online" ] norecursedirs = [ ".ipynb_checkpoints", diff --git a/requirements/datasets.txt b/requirements/datasets.txt index f183132801f..6200bd6280c 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -2,6 +2,7 @@ h5py==3.10.0 laspy==2.5.3 opencv-python==4.9.0.80 +OWSLib=0.30.0 pycocotools==2.0.7 pyvista==0.43.4 radiant-mlhub==0.4.1 diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py new file mode 100644 index 00000000000..7d38b6b5f6d --- /dev/null +++ b/tests/datasets/test_wms.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pytest + +from torchgeo.datasets import ( + WMSDataset, +) + +import requests +SERVICE_URL = 'https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?' + + +class TestWMSDataset: + + def service_ok(url, timeout=5): + try: + resp = requests.head(url, allow_redirects=True, timeout=timeout) + ok = resp.ok + except requests.exceptions.ReadTimeout: + print('No 2') + ok = False + except requests.exceptions.ConnectTimeout: + print('No 3') + ok = False + except Exception: + print('No 4') + ok = False + return ok + + @pytest.mark.online + @pytest.mark.skipif(not service_ok(SERVICE_URL), + reason="WMS service is unreachable") + def test_wms_no_layer(self): + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0,) + print(wms.layers()) + assert('nexrad_base_reflect' in wms.layers()) + assert(4326 == wms.crs.to_epsg()) + wms.layer('nexrad_base_reflect', crs=4326) + assert(-126 == wms.index.bounds[0]) + assert(-66 == wms.index.bounds[1]) + assert(24 == wms.index.bounds[2]) + assert(50 == wms.index.bounds[3]) + assert('image/png' == wms._format) + + def test_wms_layer(self): + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0, layer='nexrad_base_reflect', crs=4326) + assert(4326 == wms.crs.to_epsg()) + assert(-126 == wms.index.bounds[0]) + assert(-66 == wms.index.bounds[1]) + assert(24 == wms.index.bounds[2]) + assert(50 == wms.index.bounds[3]) + assert('image/png' == wms._format) + + def test_wms_layer_nocrs(self): + """MESONET GetMap 1.1.1""" + wms = WMSDataset(SERVICE_URL, 10.0, layer='nexrad_base_reflect') + assert(4326 == wms.crs.to_epsg()) + assert(-126 == wms.index.bounds[0]) + assert(-66 == wms.index.bounds[1]) + assert(24 == wms.index.bounds[2]) + assert(50 == wms.index.bounds[3]) + assert('image/png' == wms._format) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5f3f974e2b2..fdb89bdd09f 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -135,6 +135,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture +from .wms import WMSDataset from .xview import XView2 from .zuericrop import ZueriCrop @@ -263,6 +264,7 @@ "RasterDataset", "UnionDataset", "VectorDataset", + "WMSDataset", # Utilities "BoundingBox", "concat_samples", diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py new file mode 100644 index 00000000000..b621f148123 --- /dev/null +++ b/torchgeo/datasets/wms.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# Author: Ian Turton, Glasgow University ian.turton@gla.ac.uk + +from typing import Any + +from owslib.wms import WebMapService + +from rasterio.coords import BoundingBox +from rasterio.crs import CRS +from rasterio.errors import CRSError + +from torchgeo.datasets import GeoDataset +from io import BytesIO +from PIL import Image +import torchvision.transforms as transforms +from rtree.index import Index, Property + + +class WMSDataset(GeoDataset): + """ + Allow models to fetch images from a WMS (at a good resolution) + """ + _url = None + _wms = None + + _layers = [] + _layer = None + _layer_name = "" + is_image = True + + def __init__(self, url, res, layer=None, transforms=None, crs=None): + super().__init__(transforms) + self._url = url + self._res = res + if crs is not None: + self._crs = CRS.from_epsg(crs) + self._wms = WebMapService(url) + self._format = self._wms.getOperationByName('GetMap').formatOptions[0] + self._layers = list(self._wms.contents) + + if layer in self._layers: + self.layer(layer, crs) + + def layer(self, layer, crs=None): + self._layer = self._wms[layer] + self._layer_name = layer + coords = self._wms[layer].boundingBox + self.index = Index(interleaved=False, properties=Property(dimension=3),) + self.index.insert(0, (float(coords[0]), float(coords[2]), float(coords[1]), + float(coords[3]), 0, 9.223372036854776e+18)) + if crs is None: + i = 0 + while self._crs is None: + crs_str = sorted(self._layer.crsOptions)[i].upper() + if 'EPSG:' in crs_str: + crs_str = crs_str[5:] + elif 'CRS:84': + crs_str = '4326' + try: + self._crs = CRS.from_epsg(crs_str) + except CRSError: + pass + + def getlayer(self): + return self._layer + + def layers(self): + return self._layers + + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + img = self._wms.getmap(layers=[self._layer_name], + srs="epsg:"+str(self.crs.to_epsg()), + bbox=(query.minx, query.miny, query.maxx, query.maxy), + # TODO fix size + size=(500, 500), + format=self._format, + transparent=True + ) + sample = {"crs": self.crs, "bbox": query} + + transform = transforms.Compose([transforms.ToTensor()]) + # Convert the PIL image to Torch tensor + img_tensor = transform(Image.open(BytesIO(img.read()))) + if self.is_image: + sample["image"] = img_tensor + else: + sample["mask"] = img_tensor + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample From 4e1ee5380ac3469ce5813f630ff2d11da3271d33 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 14:14:45 +0000 Subject: [PATCH 02/12] fix requirements --- requirements/datasets.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 6200bd6280c..211be7599ec 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -2,7 +2,7 @@ h5py==3.10.0 laspy==2.5.3 opencv-python==4.9.0.80 -OWSLib=0.30.0 +OWSLib==0.30.0 pycocotools==2.0.7 pyvista==0.43.4 radiant-mlhub==0.4.1 From d11231be14e44abc65ff8336cad568611e5301bd Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 14:25:53 +0000 Subject: [PATCH 03/12] fix formatting --- tests/datasets/test_wms.py | 63 ++++++++++++++++++-------------------- torchgeo/datasets/wms.py | 38 ++++++++++++++--------- 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py index 7d38b6b5f6d..1caa8ccb8b6 100644 --- a/tests/datasets/test_wms.py +++ b/tests/datasets/test_wms.py @@ -2,13 +2,11 @@ # Licensed under the MIT License. import pytest +from torchgeo.datasets import WMSDataset +import requests -from torchgeo.datasets import ( - WMSDataset, -) -import requests -SERVICE_URL = 'https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?' +SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?" class TestWMSDataset: @@ -18,48 +16,45 @@ def service_ok(url, timeout=5): resp = requests.head(url, allow_redirects=True, timeout=timeout) ok = resp.ok except requests.exceptions.ReadTimeout: - print('No 2') ok = False except requests.exceptions.ConnectTimeout: - print('No 3') ok = False except Exception: - print('No 4') ok = False return ok @pytest.mark.online - @pytest.mark.skipif(not service_ok(SERVICE_URL), - reason="WMS service is unreachable") + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) def test_wms_no_layer(self): """MESONET GetMap 1.1.1""" - wms = WMSDataset(SERVICE_URL, 10.0,) - print(wms.layers()) - assert('nexrad_base_reflect' in wms.layers()) - assert(4326 == wms.crs.to_epsg()) - wms.layer('nexrad_base_reflect', crs=4326) - assert(-126 == wms.index.bounds[0]) - assert(-66 == wms.index.bounds[1]) - assert(24 == wms.index.bounds[2]) - assert(50 == wms.index.bounds[3]) - assert('image/png' == wms._format) + wms = WMSDataset(SERVICE_URL, 10.0) + assert "nexrad_base_reflect" in wms.layers() + assert 4326 == wms.crs.to_epsg() + wms.layer("nexrad_base_reflect", crs=4326) + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format def test_wms_layer(self): """MESONET GetMap 1.1.1""" - wms = WMSDataset(SERVICE_URL, 10.0, layer='nexrad_base_reflect', crs=4326) - assert(4326 == wms.crs.to_epsg()) - assert(-126 == wms.index.bounds[0]) - assert(-66 == wms.index.bounds[1]) - assert(24 == wms.index.bounds[2]) - assert(50 == wms.index.bounds[3]) - assert('image/png' == wms._format) + wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326) + assert 4326 == wms.crs.to_epsg() + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format def test_wms_layer_nocrs(self): """MESONET GetMap 1.1.1""" - wms = WMSDataset(SERVICE_URL, 10.0, layer='nexrad_base_reflect') - assert(4326 == wms.crs.to_epsg()) - assert(-126 == wms.index.bounds[0]) - assert(-66 == wms.index.bounds[1]) - assert(24 == wms.index.bounds[2]) - assert(50 == wms.index.bounds[3]) - assert('image/png' == wms._format) + wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect") + assert 4326 == wms.crs.to_epsg() + assert -126 == wms.index.bounds[0] + assert -66 == wms.index.bounds[1] + assert 24 == wms.index.bounds[2] + assert 50 == wms.index.bounds[3] + assert "image/png" == wms._format diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index b621f148123..bf93c2f286a 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -36,7 +36,7 @@ def __init__(self, url, res, layer=None, transforms=None, crs=None): if crs is not None: self._crs = CRS.from_epsg(crs) self._wms = WebMapService(url) - self._format = self._wms.getOperationByName('GetMap').formatOptions[0] + self._format = self._wms.getOperationByName("GetMap").formatOptions[0] self._layers = list(self._wms.contents) if layer in self._layers: @@ -47,16 +47,25 @@ def layer(self, layer, crs=None): self._layer_name = layer coords = self._wms[layer].boundingBox self.index = Index(interleaved=False, properties=Property(dimension=3),) - self.index.insert(0, (float(coords[0]), float(coords[2]), float(coords[1]), - float(coords[3]), 0, 9.223372036854776e+18)) + self.index.insert( + 0, + ( + float(coords[0]), + float(coords[2]), + float(coords[1]), + float(coords[3]), + 0, + 9.223372036854776e+18 + ) + ) if crs is None: i = 0 while self._crs is None: crs_str = sorted(self._layer.crsOptions)[i].upper() - if 'EPSG:' in crs_str: + if "EPSG:" in crs_str: crs_str = crs_str[5:] - elif 'CRS:84': - crs_str = '4326' + elif "CRS:84": + crs_str = "4326" try: self._crs = CRS.from_epsg(crs_str) except CRSError: @@ -80,14 +89,15 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Raises: IndexError: if query is not found in the index """ - img = self._wms.getmap(layers=[self._layer_name], - srs="epsg:"+str(self.crs.to_epsg()), - bbox=(query.minx, query.miny, query.maxx, query.maxy), - # TODO fix size - size=(500, 500), - format=self._format, - transparent=True - ) + img = self._wms.getmap( + layers=[self._layer_name], + srs="epsg:"+str(self.crs.to_epsg()), + bbox=(query.minx, query.miny, query.maxx, query.maxy), + # TODO fix size + size=(500, 500), + format=self._format, + transparent=True + ) sample = {"crs": self.crs, "bbox": query} transform = transforms.Compose([transforms.ToTensor()]) From b8698accd09d0c50cbec6aefe26868f07cfa08ac Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 14:37:14 +0000 Subject: [PATCH 04/12] fix black issues --- torchgeo/datasets/wms.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index bf93c2f286a..483d9dd311e 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -21,6 +21,7 @@ class WMSDataset(GeoDataset): """ Allow models to fetch images from a WMS (at a good resolution) """ + _url = None _wms = None @@ -46,7 +47,7 @@ def layer(self, layer, crs=None): self._layer = self._wms[layer] self._layer_name = layer coords = self._wms[layer].boundingBox - self.index = Index(interleaved=False, properties=Property(dimension=3),) + self.index = Index(interleaved=False, properties=Property(dimension=3)) self.index.insert( 0, ( @@ -55,8 +56,8 @@ def layer(self, layer, crs=None): float(coords[1]), float(coords[3]), 0, - 9.223372036854776e+18 - ) + 9.223372036854776e18, + ), ) if crs is None: i = 0 @@ -91,12 +92,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """ img = self._wms.getmap( layers=[self._layer_name], - srs="epsg:"+str(self.crs.to_epsg()), + srs="epsg:" + str(self.crs.to_epsg()), bbox=(query.minx, query.miny, query.maxx, query.maxy), # TODO fix size size=(500, 500), format=self._format, - transparent=True + transparent=True, ) sample = {"crs": self.crs, "bbox": query} From dc99c2b97b4b7e5739dd56019aa25bccd82936cb Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 15:22:43 +0000 Subject: [PATCH 05/12] isort --- tests/datasets/test_wms.py | 2 +- torchgeo/datasets/wms.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py index 1caa8ccb8b6..c581b65770d 100644 --- a/tests/datasets/test_wms.py +++ b/tests/datasets/test_wms.py @@ -2,9 +2,9 @@ # Licensed under the MIT License. import pytest -from torchgeo.datasets import WMSDataset import requests +from torchgeo.datasets import WMSDataset SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?" diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index 483d9dd311e..5465dd4ddd4 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -2,19 +2,18 @@ # Licensed under the MIT License. # Author: Ian Turton, Glasgow University ian.turton@gla.ac.uk +from io import BytesIO from typing import Any +import torchvision.transforms as transforms from owslib.wms import WebMapService - +from PIL import Image from rasterio.coords import BoundingBox from rasterio.crs import CRS from rasterio.errors import CRSError +from rtree.index import Index, Property from torchgeo.datasets import GeoDataset -from io import BytesIO -from PIL import Image -import torchvision.transforms as transforms -from rtree.index import Index, Property class WMSDataset(GeoDataset): From 31fc997a44114b4371901ab0b622a98e11e6092e Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 16:53:04 +0000 Subject: [PATCH 06/12] QA fixes --- requirements/required.txt | 1 + tests/datasets/test_wms.py | 31 ++++++++++---------- torchgeo/datasets/wms.py | 59 +++++++++++++++++++++++++++++--------- 3 files changed, 62 insertions(+), 29 deletions(-) diff --git a/requirements/required.txt b/requirements/required.txt index afebd521697..030e02b97a9 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -9,6 +9,7 @@ lightly==1.5.2 lightning[pytorch-extra]==2.2.1 matplotlib==3.8.3 numpy==1.26.4 +owsLib==0.30.0 pandas==2.2.1 pillow==10.2.0 pyproj==3.6.1 diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py index c581b65770d..a2cf174535d 100644 --- a/tests/datasets/test_wms.py +++ b/tests/datasets/test_wms.py @@ -9,25 +9,26 @@ SERVICE_URL = "https://mesonet.agron.iastate.edu/cgi-bin/wms/nexrad/n0r-t.cgi?" -class TestWMSDataset: +def service_ok(url: str, timeout: int = 5) -> bool: + try: + resp = requests.head(url, allow_redirects=True, timeout=timeout) + ok = bool(resp.ok) + except requests.exceptions.ReadTimeout: + ok = False + except requests.exceptions.ConnectTimeout: + ok = False + except Exception: + ok = False + return ok + - def service_ok(url, timeout=5): - try: - resp = requests.head(url, allow_redirects=True, timeout=timeout) - ok = resp.ok - except requests.exceptions.ReadTimeout: - ok = False - except requests.exceptions.ConnectTimeout: - ok = False - except Exception: - ok = False - return ok +class TestWMSDataset: @pytest.mark.online @pytest.mark.skipif( not service_ok(SERVICE_URL), reason="WMS service is unreachable" ) - def test_wms_no_layer(self): + def test_wms_no_layer(self) -> None: """MESONET GetMap 1.1.1""" wms = WMSDataset(SERVICE_URL, 10.0) assert "nexrad_base_reflect" in wms.layers() @@ -39,7 +40,7 @@ def test_wms_no_layer(self): assert 50 == wms.index.bounds[3] assert "image/png" == wms._format - def test_wms_layer(self): + def test_wms_layer(self) -> None: """MESONET GetMap 1.1.1""" wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326) assert 4326 == wms.crs.to_epsg() @@ -49,7 +50,7 @@ def test_wms_layer(self): assert 50 == wms.index.bounds[3] assert "image/png" == wms._format - def test_wms_layer_nocrs(self): + def test_wms_layer_nocrs(self) -> None: """MESONET GetMap 1.1.1""" wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect") assert 4326 == wms.crs.to_epsg() diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index 5465dd4ddd4..501e89e028c 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -2,10 +2,12 @@ # Licensed under the MIT License. # Author: Ian Turton, Glasgow University ian.turton@gla.ac.uk +"""A simple class to fetch WMS images.""" from io import BytesIO -from typing import Any +from typing import Any, Callable, Optional, Union import torchvision.transforms as transforms +from owslib.map.wms111 import ContentMetadata from owslib.wms import WebMapService from PIL import Image from rasterio.coords import BoundingBox @@ -17,19 +19,39 @@ class WMSDataset(GeoDataset): - """ - Allow models to fetch images from a WMS (at a good resolution) - """ + """Allow models to fetch images from a WMS (at a good resolution).""" - _url = None - _wms = None + _url: str = "" + _wms: WebMapService = None - _layers = [] - _layer = None - _layer_name = "" + _layers: list[ContentMetadata] = [] + _layer: ContentMetadata = None + _layer_name: str = "" is_image = True - def __init__(self, url, res, layer=None, transforms=None, crs=None): + def __init__( + self, + url: str, + res: float, + layer: Optional[str] = None, + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, + crs: Optional[Union[str, int]] = None, + ) -> None: + """Initialize the data set instance. + + Args: + url: String pointing to the WMS Server (no need for + request=getcapabilities). + res: The best resolution to make requests at + (should probably match your other datasets). + layer: An optional string with the name of the + layer you want to fetch. + transforms: a function/transform that takes an input sample + and returns a transformed version + crs: An optional string or integer code that is a + valid EPSG code (without the 'EPSG:') + + """ super().__init__(transforms) self._url = url self._res = res @@ -39,10 +61,17 @@ def __init__(self, url, res, layer=None, transforms=None, crs=None): self._format = self._wms.getOperationByName("GetMap").formatOptions[0] self._layers = list(self._wms.contents) - if layer in self._layers: + if layer is not None and layer in self._layers: self.layer(layer, crs) - def layer(self, layer, crs=None): + def layer(self, layer: str, crs: Optional[Union[str, int]] = None) -> None: + """Set the layer to be fetched. + + Args: + layer: A string with the name of the layer you want to fetch. + crs: An optional string or integer code that is a valid EPSG + code (without the 'EPSG:') + """ self._layer = self._wms[layer] self._layer_name = layer coords = self._wms[layer].boundingBox @@ -71,10 +100,12 @@ def layer(self, layer, crs=None): except CRSError: pass - def getlayer(self): + def getlayer(self) -> ContentMetadata: + """Return the selected layer object.""" return self._layer - def layers(self): + def layers(self) -> list[str]: + """Return a list of availiable layers.""" return self._layers def __getitem__(self, query: BoundingBox) -> dict[str, Any]: From 7182c968e08f7ebc9d902365629bb877e13346f7 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Wed, 27 Mar 2024 17:06:17 +0000 Subject: [PATCH 07/12] add requests, might fix mypy issue? --- requirements/datasets.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 211be7599ec..a8401e2b9e2 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -7,6 +7,8 @@ pycocotools==2.0.7 pyvista==0.43.4 radiant-mlhub==0.4.1 rarfile==4.1 +requests==2.31.0 +types-requests==2.31.0.20240311 scikit-image==0.22.0 scipy==1.12.0 zipfile-deflate64==0.2.0 From 2d9000cb8574031842f1b4b4ef4370ed8c4b7311 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Fri, 5 Apr 2024 17:10:55 +0100 Subject: [PATCH 08/12] remove requests --- tests/datasets/test_wms.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py index a2cf174535d..3380e59f0a5 100644 --- a/tests/datasets/test_wms.py +++ b/tests/datasets/test_wms.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import pytest -import requests +import urllib3 from torchgeo.datasets import WMSDataset @@ -11,11 +11,10 @@ def service_ok(url: str, timeout: int = 5) -> bool: try: - resp = requests.head(url, allow_redirects=True, timeout=timeout) - ok = bool(resp.ok) - except requests.exceptions.ReadTimeout: - ok = False - except requests.exceptions.ConnectTimeout: + http = urllib3.PoolManager() + resp = http.request("HEAD", url, allow_redirects=True, timeout=timeout) + ok = 200 == resp.status + except urllib3.exceptions.NewConnectionError: ok = False except Exception: ok = False From 7752e95b7ad12f68d631908a29c7b7bde1fa694d Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Fri, 5 Apr 2024 17:11:16 +0100 Subject: [PATCH 09/12] remove requests --- requirements/datasets.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/requirements/datasets.txt b/requirements/datasets.txt index a8401e2b9e2..211be7599ec 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -7,8 +7,6 @@ pycocotools==2.0.7 pyvista==0.43.4 radiant-mlhub==0.4.1 rarfile==4.1 -requests==2.31.0 -types-requests==2.31.0.20240311 scikit-image==0.22.0 scipy==1.12.0 zipfile-deflate64==0.2.0 From 95b49f163415b73109ec4d1f4a4f668020e35973 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Sat, 6 Apr 2024 12:24:26 +0100 Subject: [PATCH 10/12] remove online tag and check each test is online --- pyproject.toml | 1 - tests/datasets/test_wms.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9962c7e744a..350d9273bdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -276,7 +276,6 @@ filterwarnings = [ ] markers = [ "slow: marks tests as slow", - "online: marks a test as needing to be online" ] norecursedirs = [ ".ipynb_checkpoints", diff --git a/tests/datasets/test_wms.py b/tests/datasets/test_wms.py index 3380e59f0a5..31c942fd78c 100644 --- a/tests/datasets/test_wms.py +++ b/tests/datasets/test_wms.py @@ -12,7 +12,7 @@ def service_ok(url: str, timeout: int = 5) -> bool: try: http = urllib3.PoolManager() - resp = http.request("HEAD", url, allow_redirects=True, timeout=timeout) + resp = http.request("HEAD", url, timeout=timeout) ok = 200 == resp.status except urllib3.exceptions.NewConnectionError: ok = False @@ -23,7 +23,7 @@ def service_ok(url: str, timeout: int = 5) -> bool: class TestWMSDataset: - @pytest.mark.online + @pytest.mark.slow @pytest.mark.skipif( not service_ok(SERVICE_URL), reason="WMS service is unreachable" ) @@ -39,6 +39,10 @@ def test_wms_no_layer(self) -> None: assert 50 == wms.index.bounds[3] assert "image/png" == wms._format + @pytest.mark.slow + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) def test_wms_layer(self) -> None: """MESONET GetMap 1.1.1""" wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect", crs=4326) @@ -49,6 +53,10 @@ def test_wms_layer(self) -> None: assert 50 == wms.index.bounds[3] assert "image/png" == wms._format + @pytest.mark.slow + @pytest.mark.skipif( + not service_ok(SERVICE_URL), reason="WMS service is unreachable" + ) def test_wms_layer_nocrs(self) -> None: """MESONET GetMap 1.1.1""" wms = WMSDataset(SERVICE_URL, 10.0, layer="nexrad_base_reflect") From a8acd0abccb95aecde4a7c1ba7d5bf059d242c06 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Mon, 8 Apr 2024 11:53:52 +0100 Subject: [PATCH 11/12] update to 3.10 syntax --- .pre-commit-config.yaml | 2 +- torchgeo/datasets/wms.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e825fcce91..8aa299712bc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,6 +33,6 @@ repos: rev: v1.8.0 hooks: - id: mypy - args: [--strict, --ignore-missing-imports, --show-error-codes] + args: [--python-version=3.10, --strict, --ignore-missing-imports, --show-error-codes] additional_dependencies: [einops>=0.6.0, kornia>=0.6.9, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, scikit-image>=0.18.0, torch>=2.2, torchmetrics>=0.10] exclude: (build|data|dist|logo|logs|output)/ diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index 501e89e028c..391ce7fb4bf 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -3,8 +3,9 @@ # Author: Ian Turton, Glasgow University ian.turton@gla.ac.uk """A simple class to fetch WMS images.""" +from collections.abc import Callable from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any import torchvision.transforms as transforms from owslib.map.wms111 import ContentMetadata @@ -33,9 +34,9 @@ def __init__( self, url: str, res: float, - layer: Optional[str] = None, - transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, - crs: Optional[Union[str, int]] = None, + layer: str | None = None, + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + crs: str | int | None = None, ) -> None: """Initialize the data set instance. @@ -64,7 +65,7 @@ def __init__( if layer is not None and layer in self._layers: self.layer(layer, crs) - def layer(self, layer: str, crs: Optional[Union[str, int]] = None) -> None: + def layer(self, layer: str, crs: str | int | None = None) -> None: """Set the layer to be fetched. Args: From 2bf3c2a416cb72c2c0b3071dcf3dbaef277fc643 Mon Sep 17 00:00:00 2001 From: Ian Turton Date: Tue, 9 Apr 2024 11:18:08 +0100 Subject: [PATCH 12/12] make OWSlib more optional --- pyproject.toml | 2 ++ requirements/min-reqs.old | 1 + requirements/required.txt | 1 - torchgeo/datasets/wms.py | 15 ++++++++++++--- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 350d9273bdc..547603eaad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,8 @@ datasets = [ "laspy>=2", # opencv-python 4.4.0.46+ required for Python 3.9 wheels "opencv-python>=4.4.0.46", + # OWSLib is required for WMS Dataset + "OWSLib>=0.30.0", # pycocotools 2.0.5+ required for cython 3+ support "pycocotools>=2.0.5", # pyvista 0.34.2+ required to avoid ImportError in CI diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 34b22649ceb..646525c1c20 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -25,6 +25,7 @@ torchvision==0.14.0 h5py==3.0.0 laspy==2.0.0 opencv-python==4.4.0.46 +OWSLib==0.21.0 pycocotools==2.0.5 pyvista==0.34.2 radiant-mlhub==0.3.0 diff --git a/requirements/required.txt b/requirements/required.txt index 030e02b97a9..afebd521697 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -9,7 +9,6 @@ lightly==1.5.2 lightning[pytorch-extra]==2.2.1 matplotlib==3.8.3 numpy==1.26.4 -owsLib==0.30.0 pandas==2.2.1 pillow==10.2.0 pyproj==3.6.1 diff --git a/torchgeo/datasets/wms.py b/torchgeo/datasets/wms.py index 391ce7fb4bf..1ce46752ed3 100644 --- a/torchgeo/datasets/wms.py +++ b/torchgeo/datasets/wms.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# Author: Ian Turton, Glasgow University ian.turton@gla.ac.uk """A simple class to fetch WMS images.""" from collections.abc import Callable @@ -8,8 +7,6 @@ from typing import Any import torchvision.transforms as transforms -from owslib.map.wms111 import ContentMetadata -from owslib.wms import WebMapService from PIL import Image from rasterio.coords import BoundingBox from rasterio.crs import CRS @@ -22,6 +19,11 @@ class WMSDataset(GeoDataset): """Allow models to fetch images from a WMS (at a good resolution).""" + try: + from owslib.map.wms111 import ContentMetadata + from owslib.wms import WebMapService + except ImportError: + raise ImportError("OWSLib is not installed and is required to use this dataset") _url: str = "" _wms: WebMapService = None @@ -53,9 +55,16 @@ def __init__( valid EPSG code (without the 'EPSG:') """ + try: + from owslib.wms import WebMapService + except ImportError: + raise ImportError( + "OWSLib is not installed and is required to use this dataset" + ) super().__init__(transforms) self._url = url self._res = res + if crs is not None: self._crs = CRS.from_epsg(crs) self._wms = WebMapService(url)