diff --git a/terracotta/handlers/rgb.py b/terracotta/handlers/rgb.py index b7d41471..e001024f 100644 --- a/terracotta/handlers/rgb.py +++ b/terracotta/handlers/rgb.py @@ -10,9 +10,10 @@ from terracotta import get_settings, get_driver, image, xyz, exceptions from terracotta.profile import trace -Number = TypeVar("Number", int, float) NumberOrString = TypeVar("NumberOrString", int, float, str) -ListOfRanges = Sequence[Optional[Tuple[Optional[NumberOrString], Optional[NumberOrString]]]] +ListOfRanges = Sequence[ + Optional[Tuple[Optional[image.NumberOrString], Optional[image.NumberOrString]]] +] @trace("rgb_handler") @@ -91,10 +92,10 @@ def get_band_future(band_key: str) -> Future: scale_min, scale_max = band_stretch_override if scale_min is not None: - band_stretch_range[0] = get_scale(scale_min, metadata) + band_stretch_range[0] = image.get_stretch_scale(scale_min, metadata) if scale_max is not None: - band_stretch_range[1] = get_scale(scale_max, metadata) + band_stretch_range[1] = image.get_stretch_scale(scale_max, metadata) if band_stretch_range[1] < band_stretch_range[0]: raise exceptions.InvalidArgumentsError( @@ -106,20 +107,3 @@ def get_band_future(band_key: str) -> Future: out = np.ma.stack(out_arrays, axis=-1) return image.array_to_png(out) - - -def get_scale(scale: NumberOrString, metadata) -> Number: - if isinstance(scale, (int, float)): - return scale - if isinstance(scale, str): - # can be a percentile - if scale.startswith("p"): - # TODO check if percentile is in range - percentile = int(scale[1:]) - 1 - return metadata["percentiles"][percentile] - - # can be a number - return float(scale) - raise exceptions.InvalidArgumentsError( - "Invalid scale value: %s" % scale - ) diff --git a/terracotta/handlers/singleband.py b/terracotta/handlers/singleband.py index c153dea2..780b1eb4 100644 --- a/terracotta/handlers/singleband.py +++ b/terracotta/handlers/singleband.py @@ -12,6 +12,10 @@ from terracotta.profile import trace Number = TypeVar("Number", int, float) +NumberOrString = TypeVar("NumberOrString", int, float, str) +ListOfRanges = Sequence[ + Optional[Tuple[Optional[image.NumberOrString], Optional[image.NumberOrString]]] +] RGBA = Tuple[Number, Number, Number, Number] @@ -21,7 +25,7 @@ def singleband( tile_xyz: Optional[Tuple[int, int, int]] = None, *, colormap: Union[str, Mapping[Number, RGBA], None] = None, - stretch_range: Optional[Tuple[Number, Number]] = None, + stretch_range: Optional[Tuple[NumberOrString, NumberOrString]] = None, tile_size: Optional[Tuple[int, int]] = None ) -> BinaryIO: """Return singleband image as PNG""" @@ -60,10 +64,10 @@ def singleband( stretch_range_ = list(metadata["range"]) if stretch_min is not None: - stretch_range_[0] = stretch_min + stretch_range_[0] = image.get_stretch_scale(stretch_min, metadata) if stretch_max is not None: - stretch_range_[1] = stretch_max + stretch_range_[1] = image.get_stretch_scale(stretch_max, metadata) cmap_or_palette = cast(Optional[str], colormap) out = image.to_uint8(tile_data, *stretch_range_) diff --git a/terracotta/image.py b/terracotta/image.py index 229f2c2a..a59ad194 100755 --- a/terracotta/image.py +++ b/terracotta/image.py @@ -3,7 +3,7 @@ Utilities to create and manipulate images. """ -from typing import Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Sequence, Tuple, TypeVar, Union from typing.io import BinaryIO from io import BytesIO @@ -15,6 +15,7 @@ from terracotta import exceptions, get_settings Number = TypeVar("Number", int, float) +NumberOrString = TypeVar("NumberOrString", int, float, str) RGBA = Tuple[Number, Number, Number, Number] Palette = Sequence[RGBA] Array = Union[np.ndarray, np.ma.MaskedArray] @@ -179,3 +180,26 @@ def label(data: Array, labels: Sequence[Number]) -> Array: out_data[data == label] = i return out_data + + +def get_stretch_scale(scale: NumberOrString, metadata: Dict[str, Any]) -> int | float: + if isinstance(scale, (int, float)): + return scale + if isinstance(scale, str): + # can be a percentile + if scale.startswith("p"): + try: + percentile = int(scale[1:]) - 1 + except ValueError: + raise exceptions.InvalidArgumentsError( + "Invalid percentile value: %s" % scale + ) + + if 0 <= percentile < len(metadata["percentiles"]): + return metadata["percentiles"][percentile] + + raise exceptions.InvalidArgumentsError( + "Invalid percentile, out of range: %s" % scale + ) + + raise exceptions.InvalidArgumentsError("Invalid scale value: %s" % scale) diff --git a/terracotta/server/fields.py b/terracotta/server/fields.py new file mode 100644 index 00000000..defbd08b --- /dev/null +++ b/terracotta/server/fields.py @@ -0,0 +1,37 @@ +import re +from marshmallow import ValidationError, fields + +from typing import Any + + +class StringOrNumber(fields.Field): + def _serialize( + self, value: Any, attr: Any, obj: Any, **kwargs: Any + ) -> str | float | None: + if isinstance(value, (str, bytes)): + return fields.String()._serialize(value, attr, obj, **kwargs) + elif isinstance(value, (int, float)): + return fields.Float()._serialize(value, attr, obj, **kwargs) + else: + raise ValidationError("Must be a string or a number") + + def _deserialize( + self, value: Any, attr: Any, data: Any, **kwargs: Any + ) -> str | float | None: + if isinstance(value, (str, bytes)): + return fields.String()._deserialize(value, attr, data, **kwargs) + elif isinstance(value, (int, float)): + return fields.Float()._deserialize(value, attr, data, **kwargs) + else: + raise ValidationError("Must be a string or a number") + + +def validate_stretch_range(data: Any) -> None: + if isinstance(data, str) and data.startswith("p"): + if not re.match("^p\\d+$", data): + raise ValidationError("Percentile format is `p`") + else: + try: + float(data) + except ValueError: + raise ValidationError("Must be a number") diff --git a/terracotta/server/rgb.py b/terracotta/server/rgb.py index 5e0cbaff..72fcaf1c 100644 --- a/terracotta/server/rgb.py +++ b/terracotta/server/rgb.py @@ -5,11 +5,11 @@ from typing import Optional, Any, Mapping, Dict, Tuple import json -import re from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE from flask import request, send_file, Response +from terracotta.server.fields import StringOrNumber, validate_stretch_range from terracotta.server.flask_api import TILE_API @@ -22,36 +22,6 @@ class RGBQuerySchema(Schema): tile_x = fields.Int(required=True, description="x coordinate") -def validate_range(data): - if isinstance(data, str) and data.startswith("p"): - if not re.match("^p\d+$", data): - raise ValidationError("Percentile format is `p`") - else: - try: - float(data) - except ValueError: - raise ValidationError("Must be a number") - - - -class StringOrNumber(fields.Field): - def _serialize(self, value, attr, obj, **kwargs): - if isinstance(value, (str, bytes)): - return fields.String()._serialize(value, attr, obj, **kwargs) - elif isinstance(value, (int, float)): - return fields.Float()._serialize(value, attr, obj, **kwargs) - else: - raise ValidationError("Must be a string or a number") - - def _deserialize(self, value, attr, data, **kwargs): - if isinstance(value, (str, bytes)): - return fields.String()._deserialize(value, attr, data, **kwargs) - elif isinstance(value, (int, float)): - return fields.Float()._deserialize(value, attr, data, **kwargs) - else: - raise ValidationError("Must be a string or a number") - - class RGBOptionSchema(Schema): class Meta: unknown = EXCLUDE @@ -60,7 +30,7 @@ class Meta: g = fields.String(required=True, description="Key value for green band") b = fields.String(required=True, description="Key value for blue band") r_range = fields.List( - StringOrNumber(allow_none=True, validate=validate_range), + StringOrNumber(allow_none=True, validate=validate_stretch_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, @@ -70,7 +40,7 @@ class Meta: ), ) g_range = fields.List( - StringOrNumber(allow_none=True, validate=validate_range), + StringOrNumber(allow_none=True, validate=validate_stretch_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, @@ -80,7 +50,7 @@ class Meta: ), ) b_range = fields.List( - StringOrNumber(allow_none=True, validate=validate_range), + StringOrNumber(allow_none=True, validate=validate_stretch_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, diff --git a/terracotta/server/singleband.py b/terracotta/server/singleband.py index 1aeb5670..b0295bba 100644 --- a/terracotta/server/singleband.py +++ b/terracotta/server/singleband.py @@ -17,6 +17,7 @@ ) from flask import request, send_file, Response +from terracotta.server.fields import StringOrNumber, validate_stretch_range from terracotta.server.flask_api import TILE_API from terracotta.cmaps import AVAILABLE_CMAPS @@ -35,7 +36,7 @@ class Meta: unknown = EXCLUDE stretch_range = fields.List( - fields.Number(allow_none=True), + StringOrNumber(allow_none=True, validate=validate_stretch_range), validate=validate.Length(equal=2), example="[0,1]", description="Stretch range to use as JSON array, uses full range by default. " diff --git a/tests/handlers/test_rgb.py b/tests/handlers/test_rgb.py index 4a484af0..07a59eee 100644 --- a/tests/handlers/test_rgb.py +++ b/tests/handlers/test_rgb.py @@ -74,10 +74,13 @@ def test_rgb_lowzoom(use_testdb, raster_file, raster_file_xyz_lowzoom): @pytest.mark.parametrize( - "stretch_range", [ - [0, 20000], [10000, 20000], [-50000, 50000], [100, 100], - ["0", "20000"], ["10000", "20000"], ["-50000", "50000"], ["100", "100"], - ] + "stretch_range", + [ + [0, 20000], + [10000, 20000], + [-50000, 50000], + [100, 100], + ], ) def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz): import terracotta @@ -109,7 +112,6 @@ def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz): valid_img = img_data[valid_mask] valid_data = tile_data.compressed() - stretch_range = [float(stretch_range[0]), float(stretch_range[1])] assert np.all(valid_img[valid_data < stretch_range[0]] == 1) stretch_range_mask = (valid_data > stretch_range[0]) & ( valid_data < stretch_range[1] @@ -161,7 +163,10 @@ def test_rgb_percentile_stretch(use_testdb, testdb, raster_file_xyz): ) band_metadata = driver.get_metadata(ds_keys) - stretch_range = [band_metadata["percentiles"][1], band_metadata["percentiles"][97]] + stretch_range = [ + band_metadata["percentiles"][1], + band_metadata["percentiles"][97], + ] # filter transparent values valid_mask = ~tile_data.mask diff --git a/tests/handlers/test_singleband.py b/tests/handlers/test_singleband.py index a6448000..09dbf51f 100644 --- a/tests/handlers/test_singleband.py +++ b/tests/handlers/test_singleband.py @@ -112,3 +112,87 @@ def test_singleband_noxyz(use_testdb): img_data = np.asarray(Image.open(raw_img)) assert img_data.shape == settings.DEFAULT_TILE_SIZE + + +def test_singleband_stretch(use_testdb, testdb, raster_file_xyz): + import terracotta + from terracotta.xyz import get_tile_data + from terracotta.handlers import singleband + + ds_keys = ["val21", "x", "val22"] + stretch_range = [0, 10000] + + raw_img = singleband.singleband( + ds_keys, + tile_xyz=raster_file_xyz, + stretch_range=stretch_range, + ) + img_data = np.asarray(Image.open(raw_img)) + + # get unstretched data to compare to + driver = terracotta.get_driver(testdb) + + with driver.connect(): + tile_data = get_tile_data( + driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape + ) + + # filter transparent values + valid_mask = ~tile_data.mask + assert np.all(img_data[~valid_mask] == 0) + + valid_img = img_data[valid_mask] + valid_data = tile_data.compressed() + + assert np.all(valid_img[valid_data < stretch_range[0]] == 1) + stretch_range_mask = (valid_data > stretch_range[0]) & ( + valid_data < stretch_range[1] + ) + assert np.all(valid_img[stretch_range_mask] >= 1) + assert np.all(valid_img[stretch_range_mask] <= 255) + assert np.all(valid_img[valid_data > stretch_range[1]] == 255) + + +def test_singleband_stretch_percentile(use_testdb, testdb, raster_file_xyz): + import terracotta + from terracotta.xyz import get_tile_data + from terracotta.handlers import singleband + + ds_keys = ["val21", "x", "val22"] + pct_stretch_range = ["p2", "p98"] + + raw_img = singleband.singleband( + ds_keys, + tile_xyz=raster_file_xyz, + stretch_range=pct_stretch_range, + ) + img_data = np.asarray(Image.open(raw_img)) + + # get unstretched data to compare to + driver = terracotta.get_driver(testdb) + + with driver.connect(): + tile_data = get_tile_data( + driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape + ) + band_metadata = driver.get_metadata(ds_keys) + + stretch_range = [ + band_metadata["percentiles"][1], + band_metadata["percentiles"][97], + ] + + # filter transparent values + valid_mask = ~tile_data.mask + assert np.all(img_data[~valid_mask] == 0) + + valid_img = img_data[valid_mask] + valid_data = tile_data.compressed() + + assert np.all(valid_img[valid_data < stretch_range[0]] == 1) + stretch_range_mask = (valid_data > stretch_range[0]) & ( + valid_data < stretch_range[1] + ) + assert np.all(valid_img[stretch_range_mask] >= 1) + assert np.all(valid_img[stretch_range_mask] <= 255) + assert np.all(valid_img[valid_data > stretch_range[1]] == 255) diff --git a/tests/server/test_flask_api.py b/tests/server/test_flask_api.py index ac884ea4..e153c2f0 100644 --- a/tests/server/test_flask_api.py +++ b/tests/server/test_flask_api.py @@ -331,7 +331,14 @@ def test_get_singleband_stretch(client, use_testdb, raster_file_xyz): x, y, z = raster_file_xyz - for stretch_range in ("[0,1]", "[0,null]", "[null, 1]", "[null,null]", "null"): + for stretch_range in ( + "[0,1]", + "[0,null]", + "[null, 1]", + "[null,null]", + "null", + '["p2","p98"]', + ): rv = client.get( f"/singleband/val11/x/val12/{z}/{x}/{y}.png?stretch_range={stretch_range}" ) @@ -410,8 +417,8 @@ def test_get_rgb_stretch(client, use_testdb, raster_file_xyz): for stretch_range in ( "[0,10000]", - "[\"1.0e%2B01\",\"1.0e%2B04\"]", - "[\"p2\",\"p98\"]", + "[1.0e%2B01,1.0e%2B04]", + '["p2","p98"]', "[0,null]", "[null, 10000]", "[null,null]",