diff --git a/terracotta/server/rgb.py b/terracotta/server/rgb.py index a7cf1bc5..5e0cbaff 100644 --- a/terracotta/server/rgb.py +++ b/terracotta/server/rgb.py @@ -5,6 +5,7 @@ from typing import Optional, Any, Mapping, Dict, Tuple import json +import re from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE from flask import request, send_file, Response @@ -21,6 +22,36 @@ class RGBQuerySchema(Schema): tile_x = fields.Int(required=True, description="x coordinate") +def validate_range(data): + if isinstance(data, str) and data.startswith("p"): + if not re.match("^p\d+$", data): + raise ValidationError("Percentile format is `p`") + else: + try: + float(data) + except ValueError: + raise ValidationError("Must be a number") + + + +class StringOrNumber(fields.Field): + def _serialize(self, value, attr, obj, **kwargs): + if isinstance(value, (str, bytes)): + return fields.String()._serialize(value, attr, obj, **kwargs) + elif isinstance(value, (int, float)): + return fields.Float()._serialize(value, attr, obj, **kwargs) + else: + raise ValidationError("Must be a string or a number") + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, (str, bytes)): + return fields.String()._deserialize(value, attr, data, **kwargs) + elif isinstance(value, (int, float)): + return fields.Float()._deserialize(value, attr, data, **kwargs) + else: + raise ValidationError("Must be a string or a number") + + class RGBOptionSchema(Schema): class Meta: unknown = EXCLUDE @@ -29,7 +60,7 @@ class Meta: g = fields.String(required=True, description="Key value for green band") b = fields.String(required=True, description="Key value for blue band") r_range = fields.List( - fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")), + StringOrNumber(allow_none=True, validate=validate_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, @@ -39,7 +70,7 @@ class Meta: ), ) g_range = fields.List( - fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")), + StringOrNumber(allow_none=True, validate=validate_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, @@ -49,7 +80,7 @@ class Meta: ), ) b_range = fields.List( - fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")), + StringOrNumber(allow_none=True, validate=validate_range), validate=validate.Length(equal=2), example="[0,1]", missing=None, diff --git a/tests/server/test_flask_api.py b/tests/server/test_flask_api.py index f7e5aff8..ac884ea4 100644 --- a/tests/server/test_flask_api.py +++ b/tests/server/test_flask_api.py @@ -410,6 +410,8 @@ def test_get_rgb_stretch(client, use_testdb, raster_file_xyz): for stretch_range in ( "[0,10000]", + "[\"1.0e%2B01\",\"1.0e%2B04\"]", + "[\"p2\",\"p98\"]", "[0,null]", "[null, 10000]", "[null,null]",