diff --git a/setup.py b/setup.py index 1f743d41..421257e5 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ "cachetools>=3.1.0", "click", "click-spinner", + "color-operations", "flask", "flask_cors", "marshmallow>=3.0.0", diff --git a/terracotta/handlers/rgb.py b/terracotta/handlers/rgb.py index 6b75661b..44d70362 100644 --- a/terracotta/handlers/rgb.py +++ b/terracotta/handlers/rgb.py @@ -23,6 +23,7 @@ def rgb( tile_xyz: Optional[Tuple[int, int, int]] = None, *, stretch_ranges: Optional[ListOfRanges] = None, + color_transform: Optional[str] = None, tile_size: Optional[Tuple[int, int]] = None ) -> BinaryIO: """Return RGB image as PNG @@ -80,6 +81,7 @@ def get_band_future(band_key: str) -> Future: futures = [get_band_future(key) for key in rgb_values] band_items = zip(rgb_values, stretch_ranges_, futures) + out_ranges = [] out_arrays = [] for i, (band_key, band_stretch_override, band_data_future) in enumerate( @@ -88,7 +90,8 @@ def get_band_future(band_key: str) -> Future: keys = (*some_keys, band_key) metadata = driver.get_metadata(keys) - band_stretch_range = list(metadata["range"]) + band_range = list(metadata["range"]) + band_stretch_range = band_range.copy() scale_min, scale_max = band_stretch_override percentiles = metadata.get("percentiles", []) @@ -104,7 +107,22 @@ def get_band_future(band_key: str) -> Future: ) band_data = band_data_future.result() - out_arrays.append(image.to_uint8(band_data, *band_stretch_range)) + + out_ranges.append(band_stretch_range) + out_arrays.append(band_data) + + band_data = np.ma.stack(out_arrays, axis=0) + + if color_transform: + out_ranges = [np.array(band_rng, dtype=band_data.dtype) for band_rng in out_ranges] + out_ranges = np.ma.stack(out_ranges, axis=0) + + out_ranges = image.apply_color_transform(out_ranges, color_transform, band_range) + band_data = image.apply_color_transform(band_data, color_transform, band_range) + + out_arrays = [] + for k in range(band_data.shape[0]): + out_arrays.append(image.to_uint8(band_data[k], *out_ranges[k])) out = np.ma.stack(out_arrays, axis=-1) return image.array_to_png(out) diff --git a/terracotta/handlers/singleband.py b/terracotta/handlers/singleband.py index 15311fc8..58f58d3a 100644 --- a/terracotta/handlers/singleband.py +++ b/terracotta/handlers/singleband.py @@ -8,6 +8,8 @@ import collections +import numpy as np + from terracotta import get_settings, get_driver, image, xyz from terracotta.profile import trace @@ -26,6 +28,7 @@ def singleband( *, colormap: Union[str, Mapping[Number, RGBA], None] = None, stretch_range: Optional[Tuple[NumberOrString, NumberOrString]] = None, + color_transform: Optional[str] = None, tile_size: Optional[Tuple[int, int]] = None ) -> BinaryIO: """Return singleband image as PNG""" @@ -61,7 +64,8 @@ def singleband( out = image.label(tile_data, labels) else: # determine stretch range from metadata and arguments - stretch_range_ = list(metadata["range"]) + band_range = list(metadata["range"]) + stretch_range_ = band_range.copy() percentiles = metadata.get("percentiles", []) if stretch_min is not None: @@ -71,6 +75,15 @@ def singleband( stretch_range_[1] = image.get_stretch_scale(stretch_max, percentiles) cmap_or_palette = cast(Optional[str], colormap) + + if color_transform: + stretch_range_ = np.array(stretch_range_, dtype=tile_data.dtype) + stretch_range_ = np.ma.stack(stretch_range_, axis=0) + + stretch_range_ = image.apply_color_transform(stretch_range_, color_transform, band_range) + tile_data = np.expand_dims(tile_data, axis=0) + tile_data = image.apply_color_transform(tile_data, color_transform, band_range)[0] + out = image.to_uint8(tile_data, *stretch_range_) return image.array_to_png(out, colormap=cmap_or_palette) diff --git a/terracotta/image.py b/terracotta/image.py index 44c5b664..5d511ad9 100755 --- a/terracotta/image.py +++ b/terracotta/image.py @@ -10,6 +10,8 @@ import numpy as np from PIL import Image +from color_operations import parse_operations +from color_operations.utils import to_math_type from terracotta.profile import trace from terracotta import exceptions, get_settings @@ -162,6 +164,27 @@ def to_uint8(data: Array, lower_bound: Number, upper_bound: Number) -> Array: return rescaled.astype(np.uint8) +def apply_color_transform( + masked_data: Array, + color_transform: str, + band_range: list, +) -> Array: + """Apply gamma correction to the input array and scale it to the output dtype.""" + + if band_range: + arr = contrast_stretch(masked_data, band_range, (0, 1)) + elif np.issubdtype(masked_data.dtype, np.integer): + arr = to_math_type(masked_data) + else: + raise exceptions.InvalidArgumentsError("No band range given and array is not of integer type") + + + for func in parse_operations(color_transform): + arr = func(arr) + + return arr + + def label(data: Array, labels: Sequence[Number]) -> Array: """Create a labelled uint8 version of data, with output values starting at 1. diff --git a/terracotta/server/fields.py b/terracotta/server/fields.py index 401610af..8f93eb6c 100644 --- a/terracotta/server/fields.py +++ b/terracotta/server/fields.py @@ -8,6 +8,8 @@ from typing import Any, Union +from color_operations import parse_operations + class StringOrNumber(fields.Field): """ @@ -45,3 +47,19 @@ def validate_stretch_range(data: Any) -> None: if isinstance(data, str): if not re.match("^p\\d+$", data): raise ValidationError("Percentile format is `p`") + + +def validate_color_transform(data: Any) -> None: + """ + Validate that the color transform is a string and can be parsed by `color_operations`. + """ + if not isinstance(data, str): + raise ValidationError("Color transform needs to be a string") + + if "saturation" in data: + raise ValidationError("Saturation is currently not supported") + + try: + parse_operations(data) + except (ValueError, KeyError): + raise ValidationError("Invalid color transform") diff --git a/terracotta/server/rgb.py b/terracotta/server/rgb.py index fe9c7a1c..25103497 100644 --- a/terracotta/server/rgb.py +++ b/terracotta/server/rgb.py @@ -9,7 +9,7 @@ from marshmallow import Schema, fields, validate, pre_load, ValidationError, EXCLUDE from flask import request, send_file, Response -from terracotta.server.fields import StringOrNumber, validate_stretch_range +from terracotta.server.fields import StringOrNumber, validate_stretch_range, validate_color_transform from terracotta.server.flask_api import TILE_API @@ -48,7 +48,7 @@ class Meta: example="[0,1]", missing=None, description=( - "Stretch range [min, max] to use for the gren band as JSON array. " + "Stretch range [min, max] to use for the green band as JSON array. " "Min and max may be numbers to use as absolute range, or strings " "of the format `p` with an integer between 0 and 100 " "to use percentiles of the image instead. " @@ -68,6 +68,11 @@ class Meta: "Null values indicate global minimum / maximum." ), ) + color_transform = fields.String( + validate=validate_color_transform, + missing=None, + description="Color transform DSL string from color-operations.", + ) tile_size = fields.List( fields.Integer(), validate=validate.Length(equal=2), diff --git a/terracotta/server/singleband.py b/terracotta/server/singleband.py index 642527e7..39db6227 100644 --- a/terracotta/server/singleband.py +++ b/terracotta/server/singleband.py @@ -17,7 +17,7 @@ ) from flask import request, send_file, Response -from terracotta.server.fields import StringOrNumber, validate_stretch_range +from terracotta.server.fields import StringOrNumber, validate_stretch_range, validate_color_transform from terracotta.server.flask_api import TILE_API from terracotta.cmaps import AVAILABLE_CMAPS @@ -65,6 +65,14 @@ class Meta: "hex strings.", ) + color_transform = fields.String( + validate=validate_color_transform, + missing=None, + example="gamma 1 1.5, sigmoidal 1 15 0.5", + description="Color transform DSL string from color-operations." + "All color operations for singleband should specify band 1.", + ) + tile_size = fields.List( fields.Integer(), validate=validate.Length(equal=2),