Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional gamma correction parameter #350

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"cachetools>=3.1.0",
"click",
"click-spinner",
"color-operations",
"flask",
"flask_cors",
"marshmallow>=3.0.0",
Expand Down
23 changes: 21 additions & 2 deletions terracotta/handlers/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def rgb(
tile_xyz: Optional[Tuple[int, int, int]] = None,
*,
stretch_ranges: Optional[ListOfRanges] = None,
gamma_factor: Optional[float] = None,
color_transform: Optional[str] = None,
tile_size: Optional[Tuple[int, int]] = None
) -> BinaryIO:
"""Return RGB image as PNG
Expand Down Expand Up @@ -80,6 +82,7 @@ def get_band_future(band_key: str) -> Future:
futures = [get_band_future(key) for key in rgb_values]
band_items = zip(rgb_values, stretch_ranges_, futures)

out_ranges = []
out_arrays = []

for i, (band_key, band_stretch_override, band_data_future) in enumerate(
Expand All @@ -88,7 +91,8 @@ def get_band_future(band_key: str) -> Future:
keys = (*some_keys, band_key)
metadata = driver.get_metadata(keys)

band_stretch_range = list(metadata["range"])
band_range = list(metadata["range"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why introduce this variable? Looks unused.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh now I see, you're using this further below. But beware, since you're setting this in a loop only the value for the last band will be used in the color transform!

band_stretch_range = band_range.copy()
scale_min, scale_max = band_stretch_override

percentiles = metadata.get("percentiles", [])
Expand All @@ -104,7 +108,22 @@ def get_band_future(band_key: str) -> Future:
)

band_data = band_data_future.result()
out_arrays.append(image.to_uint8(band_data, *band_stretch_range))

out_ranges.append(band_stretch_range)
out_arrays.append(band_data)

out = np.ma.stack(out_arrays, axis=0)

if color_transform:
band_stretch_range_arr = [np.array(band_rng, dtype=band_data.dtype) for band_rng in out_ranges]
band_stretch_range_arr = np.ma.stack(band_stretch_range_arr, axis=0)

band_stretch_range_arr = image.apply_color_transform(band_stretch_range_arr, color_transform)
band_data = image.apply_color_transform(out, color_transform)

out_arrays = []
for k in range(band_data.shape[0]):
out_arrays.append(image.to_uint8(band_data[k], *band_stretch_range_arr[k]))

out = np.ma.stack(out_arrays, axis=-1)
return image.array_to_png(out)
14 changes: 13 additions & 1 deletion terracotta/handlers/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import collections

import numpy as np

from terracotta import get_settings, get_driver, image, xyz
from terracotta.profile import trace

Expand All @@ -26,6 +28,7 @@ def singleband(
*,
colormap: Union[str, Mapping[Number, RGBA], None] = None,
stretch_range: Optional[Tuple[NumberOrString, NumberOrString]] = None,
gamma_factor: Optional[float] = None,
tile_size: Optional[Tuple[int, int]] = None
) -> BinaryIO:
"""Return singleband image as PNG"""
Expand Down Expand Up @@ -61,7 +64,8 @@ def singleband(
out = image.label(tile_data, labels)
else:
# determine stretch range from metadata and arguments
stretch_range_ = list(metadata["range"])
band_range = list(metadata["range"])
stretch_range_ = band_range.copy()

percentiles = metadata.get("percentiles", [])
if stretch_min is not None:
Expand All @@ -71,6 +75,14 @@ def singleband(
stretch_range_[1] = image.get_stretch_scale(stretch_max, percentiles)

cmap_or_palette = cast(Optional[str], colormap)

if gamma_factor:
# gamma correction is monotonic and preserves percentiles
band_stretch_range_arr = np.array(stretch_range_, dtype=tile_data.dtype)
stretch_range_ = list(image.gamma_correction(band_stretch_range_arr, gamma_factor, band_range))
# gamma correct band data
tile_data = image.gamma_correction(tile_data, gamma_factor, band_range)

out = image.to_uint8(tile_data, *stretch_range_)

return image.array_to_png(out, colormap=cmap_or_palette)
43 changes: 42 additions & 1 deletion terracotta/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
Utilities to create and manipulate images.
"""

from typing import List, Sequence, Tuple, TypeVar, Union
from typing import List, Sequence, Tuple, TypeVar, Union, Optional
from typing.io import BinaryIO

from io import BytesIO
import numbers

import numpy as np
from PIL import Image
from color_operations import parse_operations
from color_operations.operations import gamma
from color_operations.utils import to_math_type, scale_dtype

from terracotta.profile import trace
from terracotta import exceptions, get_settings
Expand Down Expand Up @@ -162,6 +166,43 @@ def to_uint8(data: Array, lower_bound: Number, upper_bound: Number) -> Array:
return rescaled.astype(np.uint8)


def gamma_correction(
masked_data: Array,
gamma_factor: float,
band_range: list,
out_dtype: type = np.uint16,
) -> Array:
"""Apply gamma correction to the input array and scale it to the output dtype."""
if not isinstance(gamma_factor, numbers.Number) or gamma_factor <= 0:
raise exceptions.InvalidArgumentsError("Invalid gamma factor")

if band_range:
arr = contrast_stretch(masked_data, band_range, (0, 1))
elif np.issubdtype(masked_data.dtype, np.integer):
arr = to_math_type(masked_data)
else:
raise exceptions.InvalidArgumentsError("No band range given and array is not of integer type")

arr = gamma(arr, gamma_factor)
arr = scale_dtype(arr, out_dtype)
return arr


def apply_color_transform(
masked_data: Array,
color_transform: str,
out_dtype: type = np.uint16,
) -> Array:
"""Apply gamma correction to the input array and scale it to the output dtype."""
arr = to_math_type(masked_data)

for func in parse_operations(color_transform):
arr = func(arr)
Comment on lines +182 to +183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume a particular scaling already?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It may be easier to apply stretching before color transform then. I.e., normalize to [0,1] using stretch_range and clamp spillover to 0 / 1, then apply color transform, then convert to uint8. That way we don't rely on a transformation of the stretch range?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may not be a good idea since it would change the result.

The color stretch fills the [0, 1] range whereas the normalization to_math_type does is relative to the dtype range. So when you apply the color transform I think that would shift where values fall along the curve.

This is what chatgpt said in the case of gamma correction:

Gamma correction is typically applied before color stretching, particularly when working with image data in scientific imaging, photography, or graphic design.

Here's why:

    Gamma Correction First: Gamma correction adjusts the image data to a linear color space, which compensates for the nonlinear response of human vision and many display systems. This linearization step ensures that subsequent adjustments, like color stretching, are applied to data that more accurately represents real-world light intensities.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but who says that the dtype range is appropriate? The user-provided stretch range is telling us how to map the values in the raster to a linear scale.

Summoning @vincentsarago in case you want to weigh in on the appropriate order of linear scaling (and clamping out of range values) vs. color correction :)


arr = scale_dtype(arr, out_dtype)
biserhong marked this conversation as resolved.
Show resolved Hide resolved
return arr


def label(data: Array, labels: Sequence[Number]) -> Array:
"""Create a labelled uint8 version of data, with output values starting at 1.

Expand Down
13 changes: 12 additions & 1 deletion terracotta/server/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Meta:
example="[0,1]",
missing=None,
description=(
"Stretch range [min, max] to use for the gren band as JSON array. "
"Stretch range [min, max] to use for the green band as JSON array. "
"Min and max may be numbers to use as absolute range, or strings "
"of the format `p<digits>` with an integer between 0 and 100 "
"to use percentiles of the image instead. "
Expand All @@ -68,6 +68,15 @@ class Meta:
"Null values indicate global minimum / maximum."
),
)
gamma_factor = fields.Float(
validate=validate.Range(min=0, min_inclusive=False),
missing=None,
description="Gamma factor to perform gamma correction."
)
color_transform = fields.String(
missing=None,
description="Gamma factor to perform gamma correction."
)
tile_size = fields.List(
fields.Integer(),
validate=validate.Length(equal=2),
Expand Down Expand Up @@ -165,11 +174,13 @@ def _get_rgb_image(

rgb_values = (options.pop("r"), options.pop("g"), options.pop("b"))
stretch_ranges = tuple(options.pop(k) for k in ("r_range", "g_range", "b_range"))
gamma_factor = options.pop("gamma_factor")

image = rgb(
some_keys,
rgb_values,
stretch_ranges=stretch_ranges,
gamma_factor=gamma_factor,
tile_xyz=tile_xyz,
**options,
)
Expand Down
6 changes: 6 additions & 0 deletions terracotta/server/singleband.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class Meta:
missing=None,
)

gamma_factor = fields.Float(
validate=validate.Range(min=0, min_inclusive=False),
missing=None,
description="Gamma factor to perform gamma correction."
)

colormap = fields.String(
description="Colormap to apply to image (see /colormap)",
validate=validate.OneOf(("explicit", *AVAILABLE_CMAPS)),
Expand Down
117 changes: 117 additions & 0 deletions tests/handlers/test_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,120 @@ def test_rgb_preview(use_testdb):
raw_img = rgb.rgb(["val21", "x"], ["val22", "val23", "val24"])
img_data = np.asarray(Image.open(raw_img))
assert img_data.shape == (*terracotta.get_settings().DEFAULT_TILE_SIZE, 3)


def test_rgb_gamma_correction(use_testdb, testdb, raster_file_xyz):
import terracotta
from terracotta.xyz import get_tile_data
from terracotta.handlers import rgb
from terracotta import image

ds_keys = ["val21", "x", "val22"]
bands = ["val22", "val23", "val24"]
gamma_factor = 2

raw_img = rgb.rgb(
ds_keys[:2],
bands,
raster_file_xyz,
gamma_factor=gamma_factor,
)
img_data = np.asarray(Image.open(raw_img))[..., 0]

# get non-gamma corrected data to compare to
driver = terracotta.get_driver(testdb)

with driver.connect():
tile_data = get_tile_data(
driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape
)

tile_metadata = driver.get_metadata(ds_keys)

# non-gamma corrected uint8 data
tile_uint8 = image.to_uint8(tile_data, *tile_metadata["range"])

# filter transparent values
valid_mask = ~tile_data.mask
assert np.all(img_data[~valid_mask] == 0)

valid_img = img_data[valid_mask]
valid_data = tile_uint8.compressed()

# gamma factor of 2 is sqrt(x) in [0, 1]
assert np.all(valid_img > valid_data)


@pytest.mark.parametrize(
"gamma_factor_params",
[
['-1', "Invalid gamma factor"],
['2,2', "Invalid gamma factor"],
['[1]', "Invalid gamma factor"],
['0', "Invalid gamma factor"],
],
)
def test_rgb_invalid_gamma_factor(use_testdb, raster_file_xyz, gamma_factor_params):
from terracotta.handlers import rgb

ds_keys = ["val21", "x", "val22"]
bands = ["val22", "val23", "val24"]

gamma_factor = gamma_factor_params[:2]
with pytest.raises(exceptions.InvalidArgumentsError) as err:
rgb.rgb(
ds_keys[:2],
bands,
raster_file_xyz,
gamma_factor=gamma_factor,
)
assert gamma_factor[1] in str(err.value)


def test_rgb_stretch_gamma_correction(use_testdb, testdb, raster_file_xyz):
import terracotta
from terracotta.xyz import get_tile_data
from terracotta.handlers import rgb

ds_keys = ["val21", "x", "val22"]
bands = ["val22", "val23", "val24"]
gamma_factor = 2
pct_stretch_range = ["p2", "p98"]

raw_img = rgb.rgb(
ds_keys[:2],
bands,
raster_file_xyz,
gamma_factor=gamma_factor,
stretch_ranges=[pct_stretch_range] * 3,
)
img_data = np.asarray(Image.open(raw_img))[..., 0]

# get unstretched data to compare to
driver = terracotta.get_driver(testdb)

with driver.connect():
tile_data = get_tile_data(
driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape
)
band_metadata = driver.get_metadata(ds_keys)

stretch_range = [
band_metadata["percentiles"][1],
band_metadata["percentiles"][97],
]

# filter transparent values
valid_mask = ~tile_data.mask
assert np.all(img_data[~valid_mask] == 0)

valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
)
assert np.all(valid_img[stretch_range_mask] >= 1)
assert np.all(valid_img[stretch_range_mask] <= 255)
assert np.all(valid_img[valid_data > stretch_range[1]] == 255)
Loading