Skip to content

Commit d41a0cd

Browse files
Faster approximation for Oklab conversions (#76)
1 parent 0497e6f commit d41a0cd

File tree

3 files changed

+184
-124
lines changed

3 files changed

+184
-124
lines changed

src/color/oklab.rs

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,51 @@ impl Operations for Reference {
4343
}
4444
}
4545

46+
/// A fast fused multiply-add operation that uses hardware FMA if available.
47+
/// If hardware FMA is not available, it falls back to a regular multiply-add.
48+
#[inline(always)]
49+
fn fma(a: Vec3A, b: Vec3A, c: Vec3A) -> Vec3A {
50+
#[cfg(any(
51+
all(
52+
any(target_arch = "x86", target_arch = "x86_64"),
53+
target_feature = "fma"
54+
),
55+
target_arch = "aarch64"
56+
))]
57+
{
58+
a.mul_add(b, c)
59+
}
60+
#[cfg(not(any(
61+
all(
62+
any(target_arch = "x86", target_arch = "x86_64"),
63+
target_feature = "fma"
64+
),
65+
target_arch = "aarch64"
66+
)))]
67+
{
68+
a * b + c
69+
}
70+
}
71+
4672
struct Fast;
73+
#[allow(clippy::excessive_precision)]
4774
impl Operations for Fast {
4875
fn srgb_to_linear(c: Vec3A) -> Vec3A {
4976
Vec3A::select(
5077
c.cmpge(Vec3A::splat(0.04045)),
5178
{
52-
// This uses a Padé approximant for ((c + 0.055) / 1.055) ^ 2.4:
53-
// (0.000857709 +0.0359438 x+0.524293 x^2+1.31193 x^3)/(1+0.992498 x-0.119725 x^2)
79+
// Polynomial approximation for ((c + 0.055) / 1.055) ^ 2.4
80+
// This has a max error of 0.0001228 and is exact at c=0.04045 and c=1
81+
const A0: f32 = 0.00117465;
82+
const A1: f32 = 0.02381997;
83+
const A2: f32 = 0.58750746;
84+
const A3: f32 = 0.47736490;
85+
const A4: f32 = -0.08986699;
5486
let c2 = c * c;
55-
let c3 = c2 * c;
56-
Vec3A::min(
57-
Vec3A::ONE,
58-
(0.000857709 + 0.0359438 * c + 0.524293 * c2 + 1.31193 * c3)
59-
/ (Vec3A::ONE + 0.992498 * c - 0.119725 * c2),
60-
)
87+
let p01 = fma(c, Vec3A::splat(A1), Vec3A::splat(A0));
88+
let p23 = fma(c, Vec3A::splat(A3), Vec3A::splat(A2));
89+
let t = fma(c2, Vec3A::splat(A4), p23);
90+
fma(c2, t, p01)
6191
},
6292
c * (1.0 / 12.92),
6393
)
@@ -68,33 +98,47 @@ impl Operations for Fast {
6898
{
6999
// This uses a Padé approximant for 1.055 c^(1/2.4) - 0.055:
70100
// (-0.0117264+21.0897 x+949.46 x^2+2225.62 x^3)/(1+176.398 x+1983.15 x^2+1035.65 x^3)
101+
const P0: f32 = -0.0117264;
102+
const P1: f32 = 21.0897;
103+
const P2: f32 = 949.46;
104+
const P3: f32 = 2225.62;
105+
const Q1: f32 = 176.398;
106+
const Q2: f32 = 1983.15;
107+
const Q3: f32 = 1035.65;
71108
let c2 = c * c;
72-
let c3 = c2 * c;
73-
(-0.0117264 + 21.0897 * c + 949.46 * c2 + 2225.62 * c3)
74-
/ (1.0 + 176.398 * c + 1983.15 * c2 + 1035.65 * c3)
109+
let p01 = fma(c, Vec3A::splat(P1), Vec3A::splat(P0));
110+
let p23 = fma(c, Vec3A::splat(P3), Vec3A::splat(P2));
111+
let p = fma(c2, p23, p01);
112+
let q01 = fma(c, Vec3A::splat(Q1), Vec3A::ONE);
113+
let q23 = fma(c, Vec3A::splat(Q3), Vec3A::splat(Q2));
114+
let q = fma(c2, q23, q01);
115+
p / q
75116
},
76117
c * 12.92,
77118
)
78119
}
79-
#[allow(clippy::excessive_precision)]
80120
fn cbrt(x: Vec3A) -> Vec3A {
81-
// This is the fast cbrt approximation from the oklab crate.
82-
// Source: https://gitlab.com/kornelski/oklab/-/blob/d3c074f154187dd5c0642119a6402a6c0753d70c/oklab/src/lib.rs#L61
83-
// Author: Kornel (https://gitlab.com/kornelski/)
121+
// This is the fast cbrt approximation inspired by the non-std cbrt
122+
// implementation (https://gitlab.com/kornelski/oklab/-/blob/d3c074f154187dd5c0642119a6402a6c0753d70c/oklab/src/lib.rs#L61)
123+
// in the oklab crate by Kornel (https://gitlab.com/kornelski/), which
124+
// in turn seems to be based on the libm implementation.
125+
// In this version, I replaced the part after the initial guess with
126+
// one Halley iteration. This reduces accuracy, but saves 2 divisions
127+
// which helps performance a lot.
84128
const B: u32 = 709957561;
85-
const C: f32 = 5.4285717010e-1;
86-
const D: f32 = -7.0530611277e-1;
87-
const E: f32 = 1.4142856598e+0;
88-
const F: f32 = 1.6071428061e+0;
89-
const G: f32 = 3.5714286566e-1;
90-
91-
let mut t = Vec3A::from_array(
92-
x.to_array()
93-
.map(|x| f32::from_bits((x.to_bits() / 3).wrapping_add(B))),
94-
);
95-
let s = C + (t * t) * (t / x);
96-
t *= G + F / (s + E + D / s);
97-
t
129+
fn initial_guess(x: f32) -> f32 {
130+
let bits = x.to_bits();
131+
// divide by 3 using multiplication and bitshift
132+
// this is only correct if bits <= 2^31, which is true for all
133+
// positive f32 values
134+
let div = ((bits as u64 * 1431655766) >> 32) as u32;
135+
f32::from_bits(div + B)
136+
}
137+
let t = Vec3A::from_array(x.to_array().map(initial_guess));
138+
139+
// one halley iteration
140+
let s = t * t * t;
141+
t * fma(Vec3A::splat(2.0), x, s) / fma(Vec3A::splat(2.0), s, x)
98142
}
99143
}
100144

@@ -133,7 +177,8 @@ fn oklab_to_srgb_impl<O: Operations>(lab: Vec3A) -> Vec3A {
133177
lms.dot(Vec3A::new(-0.0041960863, -0.7034186147, 1.7076147010)),
134178
);
135179

136-
O::linear_to_srgb(rgb)
180+
// the clamping is necessary for out-of-gamut colors
181+
O::linear_to_srgb(rgb).clamp(Vec3A::ZERO, Vec3A::ONE)
137182
}
138183

139184
#[allow(unused)]
@@ -193,14 +238,14 @@ mod tests {
193238
let ref_oklab = srgb_to_oklab(color);
194239

195240
assert!(
196-
(fast_oklab - ref_oklab).abs().max_element() < 1e-3,
241+
(fast_oklab - ref_oklab).abs().max_element() < 0.001,
197242
"{color:?} -> fast: {fast_oklab:?} vs ref: {ref_oklab:?}"
198243
);
199244

200245
let srgb = fast_oklab_to_srgb(fast_oklab);
201246

202247
assert!(
203-
(color - srgb).abs().max_element() < 2.5e-3,
248+
(color - srgb).abs().max_element() < 0.0025,
204249
"{color:?} -> {srgb:?}"
205250
);
206251

@@ -212,6 +257,14 @@ mod tests {
212257
fast_oklab.min_element() >= 0.0,
213258
"{color:?} -> {fast_oklab:?}"
214259
);
260+
assert!(
261+
srgb.max_element() <= 1.0,
262+
"{color:?} -> {fast_oklab:?} -> {srgb:?}"
263+
);
264+
assert!(
265+
srgb.min_element() >= 0.0,
266+
"{color:?} -> {fast_oklab:?} -> {srgb:?}"
267+
);
215268
}
216269
}
217270
}
@@ -262,14 +315,21 @@ mod tests {
262315
fn test_error_fast_srgb_to_linear() {
263316
assert_eq!(
264317
get_error_stats(RefScalar::srgb_to_linear, FastScalar::srgb_to_linear),
265-
"Error: avg=0.00002514 max=0.00013047 for 0.999"
318+
"Error: avg=0.00007546 max=0.00012287 for 0.641"
266319
);
267320
}
268321
#[test]
269322
fn test_error_fast_linear_to_srgb() {
270323
assert_eq!(
271324
get_error_stats(RefScalar::linear_to_srgb, FastScalar::linear_to_srgb),
272-
"Error: avg=0.00105457 max=0.00236702 for 0.732"
325+
"Error: avg=0.00105456 max=0.00236708 for 0.730"
326+
);
327+
}
328+
#[test]
329+
fn test_error_fast_cbrt() {
330+
assert_eq!(
331+
get_error_stats(RefScalar::cbrt, FastScalar::cbrt),
332+
"Error: avg=0.00000283 max=0.00001299 for 0.250"
273333
);
274334
}
275335

0 commit comments

Comments
 (0)