Skip to content

Commit

Permalink
Pre-commit fixes, /singleband stretch, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atanas-balevsky committed Dec 22, 2023
1 parent 1775496 commit 7c4c20e
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 69 deletions.
26 changes: 5 additions & 21 deletions terracotta/handlers/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from terracotta import get_settings, get_driver, image, xyz, exceptions
from terracotta.profile import trace

Number = TypeVar("Number", int, float)
NumberOrString = TypeVar("NumberOrString", int, float, str)
ListOfRanges = Sequence[Optional[Tuple[Optional[NumberOrString], Optional[NumberOrString]]]]
ListOfRanges = Sequence[
Optional[Tuple[Optional[image.NumberOrString], Optional[image.NumberOrString]]]
]


@trace("rgb_handler")
Expand Down Expand Up @@ -91,10 +92,10 @@ def get_band_future(band_key: str) -> Future:
scale_min, scale_max = band_stretch_override

if scale_min is not None:
band_stretch_range[0] = get_scale(scale_min, metadata)
band_stretch_range[0] = image.get_stretch_scale(scale_min, metadata)

if scale_max is not None:
band_stretch_range[1] = get_scale(scale_max, metadata)
band_stretch_range[1] = image.get_stretch_scale(scale_max, metadata)

if band_stretch_range[1] < band_stretch_range[0]:
raise exceptions.InvalidArgumentsError(
Expand All @@ -106,20 +107,3 @@ def get_band_future(band_key: str) -> Future:

out = np.ma.stack(out_arrays, axis=-1)
return image.array_to_png(out)


def get_scale(scale: NumberOrString, metadata) -> Number:
if isinstance(scale, (int, float)):
return scale
if isinstance(scale, str):
# can be a percentile
if scale.startswith("p"):
# TODO check if percentile is in range
percentile = int(scale[1:]) - 1
return metadata["percentiles"][percentile]

# can be a number
return float(scale)
raise exceptions.InvalidArgumentsError(
"Invalid scale value: %s" % scale
)
10 changes: 7 additions & 3 deletions terracotta/handlers/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from terracotta.profile import trace

Number = TypeVar("Number", int, float)
NumberOrString = TypeVar("NumberOrString", int, float, str)
ListOfRanges = Sequence[
Optional[Tuple[Optional[image.NumberOrString], Optional[image.NumberOrString]]]
]
RGBA = Tuple[Number, Number, Number, Number]


Expand All @@ -21,7 +25,7 @@ def singleband(
tile_xyz: Optional[Tuple[int, int, int]] = None,
*,
colormap: Union[str, Mapping[Number, RGBA], None] = None,
stretch_range: Optional[Tuple[Number, Number]] = None,
stretch_range: Optional[Tuple[NumberOrString, NumberOrString]] = None,
tile_size: Optional[Tuple[int, int]] = None
) -> BinaryIO:
"""Return singleband image as PNG"""
Expand Down Expand Up @@ -60,10 +64,10 @@ def singleband(
stretch_range_ = list(metadata["range"])

if stretch_min is not None:
stretch_range_[0] = stretch_min
stretch_range_[0] = image.get_stretch_scale(stretch_min, metadata)

if stretch_max is not None:
stretch_range_[1] = stretch_max
stretch_range_[1] = image.get_stretch_scale(stretch_max, metadata)

cmap_or_palette = cast(Optional[str], colormap)
out = image.to_uint8(tile_data, *stretch_range_)
Expand Down
26 changes: 25 additions & 1 deletion terracotta/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Utilities to create and manipulate images.
"""

from typing import Sequence, Tuple, TypeVar, Union
from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
from typing.io import BinaryIO

from io import BytesIO
Expand All @@ -15,6 +15,7 @@
from terracotta import exceptions, get_settings

Number = TypeVar("Number", int, float)
NumberOrString = TypeVar("NumberOrString", int, float, str)
RGBA = Tuple[Number, Number, Number, Number]
Palette = Sequence[RGBA]
Array = Union[np.ndarray, np.ma.MaskedArray]
Expand Down Expand Up @@ -179,3 +180,26 @@ def label(data: Array, labels: Sequence[Number]) -> Array:
out_data[data == label] = i

return out_data


def get_stretch_scale(scale: NumberOrString, metadata: Dict[str, Any]) -> int | float:
if isinstance(scale, (int, float)):
return scale
if isinstance(scale, str):
# can be a percentile
if scale.startswith("p"):
try:
percentile = int(scale[1:]) - 1
except ValueError:
raise exceptions.InvalidArgumentsError(
f"Invalid percentile value: {scale}"
)

if 0 <= percentile < len(metadata["percentiles"]):
return metadata["percentiles"][percentile]

raise exceptions.InvalidArgumentsError(
f"Invalid percentile, out of range: {scale}"
)

raise exceptions.InvalidArgumentsError(f"Invalid scale value: {scale}")
37 changes: 37 additions & 0 deletions terracotta/server/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import re
from marshmallow import ValidationError, fields

from typing import Any


class StringOrNumber(fields.Field):
def _serialize(
self, value: Any, attr: Any, obj: Any, **kwargs: Any
) -> str | float | None:
if isinstance(value, (str, bytes)):
return fields.String()._serialize(value, attr, obj, **kwargs)
elif isinstance(value, (int, float)):
return fields.Float()._serialize(value, attr, obj, **kwargs)
else:
raise ValidationError("Must be a string or a number")

def _deserialize(
self, value: Any, attr: Any, data: Any, **kwargs: Any
) -> str | float | None:
if isinstance(value, (str, bytes)):
return fields.String()._deserialize(value, attr, data, **kwargs)
elif isinstance(value, (int, float)):
return fields.Float()._deserialize(value, attr, data, **kwargs)
else:
raise ValidationError("Must be a string or a number")


def validate_stretch_range(data: Any) -> None:
if isinstance(data, str) and data.startswith("p"):
if not re.match("^p\\d+$", data):
raise ValidationError("Percentile format is `p<digits>`")
else:
try:
float(data)
except ValueError:
raise ValidationError("Must be a number")
38 changes: 4 additions & 34 deletions terracotta/server/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Optional, Any, Mapping, Dict, Tuple
import json
import re

from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE
from flask import request, send_file, Response

from terracotta.server.fields import StringOrNumber, validate_stretch_range
from terracotta.server.flask_api import TILE_API


Expand All @@ -22,36 +22,6 @@ class RGBQuerySchema(Schema):
tile_x = fields.Int(required=True, description="x coordinate")


def validate_range(data):
if isinstance(data, str) and data.startswith("p"):
if not re.match("^p\d+$", data):
raise ValidationError("Percentile format is `p<digits>`")
else:
try:
float(data)
except ValueError:
raise ValidationError("Must be a number")



class StringOrNumber(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if isinstance(value, (str, bytes)):
return fields.String()._serialize(value, attr, obj, **kwargs)
elif isinstance(value, (int, float)):
return fields.Float()._serialize(value, attr, obj, **kwargs)
else:
raise ValidationError("Must be a string or a number")

def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, (str, bytes)):
return fields.String()._deserialize(value, attr, data, **kwargs)
elif isinstance(value, (int, float)):
return fields.Float()._deserialize(value, attr, data, **kwargs)
else:
raise ValidationError("Must be a string or a number")


class RGBOptionSchema(Schema):
class Meta:
unknown = EXCLUDE
Expand All @@ -60,7 +30,7 @@ class Meta:
g = fields.String(required=True, description="Key value for green band")
b = fields.String(required=True, description="Key value for blue band")
r_range = fields.List(
StringOrNumber(allow_none=True, validate=validate_range),
StringOrNumber(allow_none=True, validate=validate_stretch_range),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
Expand All @@ -70,7 +40,7 @@ class Meta:
),
)
g_range = fields.List(
StringOrNumber(allow_none=True, validate=validate_range),
StringOrNumber(allow_none=True, validate=validate_stretch_range),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
Expand All @@ -80,7 +50,7 @@ class Meta:
),
)
b_range = fields.List(
StringOrNumber(allow_none=True, validate=validate_range),
StringOrNumber(allow_none=True, validate=validate_stretch_range),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
Expand Down
3 changes: 2 additions & 1 deletion terracotta/server/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from flask import request, send_file, Response

from terracotta.server.fields import StringOrNumber, validate_stretch_range
from terracotta.server.flask_api import TILE_API
from terracotta.cmaps import AVAILABLE_CMAPS

Expand All @@ -35,7 +36,7 @@ class Meta:
unknown = EXCLUDE

stretch_range = fields.List(
fields.Number(allow_none=True),
StringOrNumber(allow_none=True, validate=validate_stretch_range),
validate=validate.Length(equal=2),
example="[0,1]",
description="Stretch range to use as JSON array, uses full range by default. "
Expand Down
17 changes: 11 additions & 6 deletions tests/handlers/test_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,13 @@ def test_rgb_lowzoom(use_testdb, raster_file, raster_file_xyz_lowzoom):


@pytest.mark.parametrize(
"stretch_range", [
[0, 20000], [10000, 20000], [-50000, 50000], [100, 100],
["0", "20000"], ["10000", "20000"], ["-50000", "50000"], ["100", "100"],
]
"stretch_range",
[
[0, 20000],
[10000, 20000],
[-50000, 50000],
[100, 100],
],
)
def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz):
import terracotta
Expand Down Expand Up @@ -109,7 +112,6 @@ def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz):
valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

stretch_range = [float(stretch_range[0]), float(stretch_range[1])]
assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
Expand Down Expand Up @@ -161,7 +163,10 @@ def test_rgb_percentile_stretch(use_testdb, testdb, raster_file_xyz):
)
band_metadata = driver.get_metadata(ds_keys)

stretch_range = [band_metadata["percentiles"][1], band_metadata["percentiles"][97]]
stretch_range = [
band_metadata["percentiles"][1],
band_metadata["percentiles"][97],
]

# filter transparent values
valid_mask = ~tile_data.mask
Expand Down
84 changes: 84 additions & 0 deletions tests/handlers/test_singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,87 @@ def test_singleband_noxyz(use_testdb):
img_data = np.asarray(Image.open(raw_img))

assert img_data.shape == settings.DEFAULT_TILE_SIZE


def test_singleband_stretch(use_testdb, testdb, raster_file_xyz):
import terracotta
from terracotta.xyz import get_tile_data
from terracotta.handlers import singleband

ds_keys = ["val21", "x", "val22"]
stretch_range = [0, 10000]

raw_img = singleband.singleband(
ds_keys,
tile_xyz=raster_file_xyz,
stretch_range=stretch_range,
)
img_data = np.asarray(Image.open(raw_img))

# get unstretched data to compare to
driver = terracotta.get_driver(testdb)

with driver.connect():
tile_data = get_tile_data(
driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape
)

# filter transparent values
valid_mask = ~tile_data.mask
assert np.all(img_data[~valid_mask] == 0)

valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
)
assert np.all(valid_img[stretch_range_mask] >= 1)
assert np.all(valid_img[stretch_range_mask] <= 255)
assert np.all(valid_img[valid_data > stretch_range[1]] == 255)


def test_singleband_stretch_percentile(use_testdb, testdb, raster_file_xyz):
import terracotta
from terracotta.xyz import get_tile_data
from terracotta.handlers import singleband

ds_keys = ["val21", "x", "val22"]
pct_stretch_range = ["p2", "p98"]

raw_img = singleband.singleband(
ds_keys,
tile_xyz=raster_file_xyz,
stretch_range=pct_stretch_range,
)
img_data = np.asarray(Image.open(raw_img))

# get unstretched data to compare to
driver = terracotta.get_driver(testdb)

with driver.connect():
tile_data = get_tile_data(
driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape
)
band_metadata = driver.get_metadata(ds_keys)

stretch_range = [
band_metadata["percentiles"][1],
band_metadata["percentiles"][97],
]

# filter transparent values
valid_mask = ~tile_data.mask
assert np.all(img_data[~valid_mask] == 0)

valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
)
assert np.all(valid_img[stretch_range_mask] >= 1)
assert np.all(valid_img[stretch_range_mask] <= 255)
assert np.all(valid_img[valid_data > stretch_range[1]] == 255)
Loading

0 comments on commit 7c4c20e

Please sign in to comment.