Skip to content
126 changes: 93 additions & 33 deletions src/color/oklab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,51 @@ impl Operations for Reference {
}
}

/// A fast fused multiply-add operation that uses hardware FMA if available.
/// If hardware FMA is not available, it falls back to a regular multiply-add.
#[inline(always)]
fn fma(a: Vec3A, b: Vec3A, c: Vec3A) -> Vec3A {
#[cfg(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
target_arch = "aarch64"
))]
{
a.mul_add(b, c)
}
#[cfg(not(any(
all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "fma"
),
target_arch = "aarch64"
)))]
{
a * b + c
}
}

struct Fast;
#[allow(clippy::excessive_precision)]
impl Operations for Fast {
fn srgb_to_linear(c: Vec3A) -> Vec3A {
Vec3A::select(
c.cmpge(Vec3A::splat(0.04045)),
{
// This uses a Padé approximant for ((c + 0.055) / 1.055) ^ 2.4:
// (0.000857709 +0.0359438 x+0.524293 x^2+1.31193 x^3)/(1+0.992498 x-0.119725 x^2)
// Polynomial approximation for ((c + 0.055) / 1.055) ^ 2.4
// This has a max error of 0.0001228 and is exact at c=0.04045 and c=1
const A0: f32 = 0.00117465;
const A1: f32 = 0.02381997;
const A2: f32 = 0.58750746;
const A3: f32 = 0.47736490;
const A4: f32 = -0.08986699;
let c2 = c * c;
let c3 = c2 * c;
Vec3A::min(
Vec3A::ONE,
(0.000857709 + 0.0359438 * c + 0.524293 * c2 + 1.31193 * c3)
/ (Vec3A::ONE + 0.992498 * c - 0.119725 * c2),
)
let p01 = fma(c, Vec3A::splat(A1), Vec3A::splat(A0));
let p23 = fma(c, Vec3A::splat(A3), Vec3A::splat(A2));
let t = fma(c2, Vec3A::splat(A4), p23);
fma(c2, t, p01)
},
c * (1.0 / 12.92),
)
Expand All @@ -68,33 +98,47 @@ impl Operations for Fast {
{
// This uses a Padé approximant for 1.055 c^(1/2.4) - 0.055:
// (-0.0117264+21.0897 x+949.46 x^2+2225.62 x^3)/(1+176.398 x+1983.15 x^2+1035.65 x^3)
const P0: f32 = -0.0117264;
const P1: f32 = 21.0897;
const P2: f32 = 949.46;
const P3: f32 = 2225.62;
const Q1: f32 = 176.398;
const Q2: f32 = 1983.15;
const Q3: f32 = 1035.65;
let c2 = c * c;
let c3 = c2 * c;
(-0.0117264 + 21.0897 * c + 949.46 * c2 + 2225.62 * c3)
/ (1.0 + 176.398 * c + 1983.15 * c2 + 1035.65 * c3)
let p01 = fma(c, Vec3A::splat(P1), Vec3A::splat(P0));
let p23 = fma(c, Vec3A::splat(P3), Vec3A::splat(P2));
let p = fma(c2, p23, p01);
let q01 = fma(c, Vec3A::splat(Q1), Vec3A::ONE);
let q23 = fma(c, Vec3A::splat(Q3), Vec3A::splat(Q2));
let q = fma(c2, q23, q01);
p / q
},
c * 12.92,
)
}
#[allow(clippy::excessive_precision)]
fn cbrt(x: Vec3A) -> Vec3A {
// This is the fast cbrt approximation from the oklab crate.
// Source: https://gitlab.com/kornelski/oklab/-/blob/d3c074f154187dd5c0642119a6402a6c0753d70c/oklab/src/lib.rs#L61
// Author: Kornel (https://gitlab.com/kornelski/)
// This is the fast cbrt approximation inspired by the non-std cbrt
// implementation (https://gitlab.com/kornelski/oklab/-/blob/d3c074f154187dd5c0642119a6402a6c0753d70c/oklab/src/lib.rs#L61)
// in the oklab crate by Kornel (https://gitlab.com/kornelski/), which
// in turn seems to be based on the libm implementation.
// In this version, I replaced the part after the initial guess with
// one Halley iteration. This reduces accuracy, but saves 2 divisions
// which helps performance a lot.
const B: u32 = 709957561;
const C: f32 = 5.4285717010e-1;
const D: f32 = -7.0530611277e-1;
const E: f32 = 1.4142856598e+0;
const F: f32 = 1.6071428061e+0;
const G: f32 = 3.5714286566e-1;

let mut t = Vec3A::from_array(
x.to_array()
.map(|x| f32::from_bits((x.to_bits() / 3).wrapping_add(B))),
);
let s = C + (t * t) * (t / x);
t *= G + F / (s + E + D / s);
t
fn initial_guess(x: f32) -> f32 {
let bits = x.to_bits();
// divide by 3 using multiplication and bitshift
// this is only correct if bits <= 2^31, which is true for all
// positive f32 values
let div = ((bits as u64 * 1431655766) >> 32) as u32;
f32::from_bits(div + B)
}
let t = Vec3A::from_array(x.to_array().map(initial_guess));

// one halley iteration
let s = t * t * t;
t * fma(Vec3A::splat(2.0), x, s) / fma(Vec3A::splat(2.0), s, x)
}
}

Expand Down Expand Up @@ -133,7 +177,8 @@ fn oklab_to_srgb_impl<O: Operations>(lab: Vec3A) -> Vec3A {
lms.dot(Vec3A::new(-0.0041960863, -0.7034186147, 1.7076147010)),
);

O::linear_to_srgb(rgb)
// the clamping is necessary for out-of-gamut colors
O::linear_to_srgb(rgb).clamp(Vec3A::ZERO, Vec3A::ONE)
}

#[allow(unused)]
Expand Down Expand Up @@ -193,14 +238,14 @@ mod tests {
let ref_oklab = srgb_to_oklab(color);

assert!(
(fast_oklab - ref_oklab).abs().max_element() < 1e-3,
(fast_oklab - ref_oklab).abs().max_element() < 0.001,
"{color:?} -> fast: {fast_oklab:?} vs ref: {ref_oklab:?}"
);

let srgb = fast_oklab_to_srgb(fast_oklab);

assert!(
(color - srgb).abs().max_element() < 2.5e-3,
(color - srgb).abs().max_element() < 0.0025,
"{color:?} -> {srgb:?}"
);

Expand All @@ -212,6 +257,14 @@ mod tests {
fast_oklab.min_element() >= 0.0,
"{color:?} -> {fast_oklab:?}"
);
assert!(
srgb.max_element() <= 1.0,
"{color:?} -> {fast_oklab:?} -> {srgb:?}"
);
assert!(
srgb.min_element() >= 0.0,
"{color:?} -> {fast_oklab:?} -> {srgb:?}"
);
}
}
}
Expand Down Expand Up @@ -262,14 +315,21 @@ mod tests {
fn test_error_fast_srgb_to_linear() {
assert_eq!(
get_error_stats(RefScalar::srgb_to_linear, FastScalar::srgb_to_linear),
"Error: avg=0.00002514 max=0.00013047 for 0.999"
"Error: avg=0.00007546 max=0.00012287 for 0.641"
);
}
#[test]
fn test_error_fast_linear_to_srgb() {
assert_eq!(
get_error_stats(RefScalar::linear_to_srgb, FastScalar::linear_to_srgb),
"Error: avg=0.00105457 max=0.00236702 for 0.732"
"Error: avg=0.00105456 max=0.00236708 for 0.730"
);
}
#[test]
fn test_error_fast_cbrt() {
assert_eq!(
get_error_stats(RefScalar::cbrt, FastScalar::cbrt),
"Error: avg=0.00000283 max=0.00001299 for 0.250"
);
}

Expand Down
Loading
Loading