-
Notifications
You must be signed in to change notification settings - Fork 28
Add SIMD optimization for int_to_float conversion #580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add SIMD fast paths for converting custom bit-depth floats to f32: - 32-bit float passthrough: Simple bitcast using SIMD - 16-bit float (f16/half-precision): SIMD conversion with scalar fallback for subnormal values The 16-bit float SIMD path handles normal, zero, and inf/nan cases directly, falling back to scalar for the rare subnormal case which requires variable-iteration normalization. Also adds BitDepth::f16() test helper and comprehensive unit tests for the conversion functions.
Benchmark @ 85ee297Comparing: 352a1543 (Base) vs a1817c3d (PR)
|
Address veluca93 review: add load_f16_bits() and store_f16() methods to F32SimdVec trait instead of implementing conversion in convert.rs. - AVX2+F16C: Hardware _mm256_cvtph_ps/_mm256_cvtps_ph - AVX-512: Hardware _mm512_cvtph_ps/_mm512_cvtps_ph - SSE4.2/NEON/Scalar: Scalar fallback Simplifies convert.rs by ~100 lines.
| fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { | ||
| assert!(mem.len() >= Self::LEN); | ||
| // Check for F16C at runtime and use hardware conversion if available | ||
| if is_x86_feature_detected!("f16c") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not a good idea. Given that f16c is as common as avx2 (if not more), let's just always require f16c for the AVX2 path.
|
|
||
| fn store_f16(this: F32VecNeon, dest: &mut [u16]) { | ||
| assert!(dest.len() >= F32VecNeon::LEN); | ||
| // TODO: Use vcvt_f16_f32 once Rust stdarch fix lands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at this point I would just use inline ASM here, but we can do that as a follow-up.
| unsafe fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { | ||
| // SAFETY: mem.len() >= 16 is checked by caller, and avx512f is available | ||
| unsafe { | ||
| let bits = _mm256_loadu_si256(mem.as_ptr() as *const __m256i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only the loadu needs to be in an unsafe block.
| unsafe { | ||
| // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 | ||
| let bits = _mm512_cvtps_ph::<0>(v); | ||
| _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, only the store needs to be in an unsafe block.
| // AVX512 implies F16C, so we can always use hardware conversion | ||
| #[target_feature(enable = "avx512f")] | ||
| #[inline] | ||
| unsafe fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function does not need to be unsafe if we move the assert inside.
| // SAFETY: dest.len() >= 16 is checked by caller, and avx512f is available | ||
| unsafe { | ||
| // _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC = 0 | ||
| let bits = _mm512_cvtps_ph::<0>(v); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's please use ::<{_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC}>.
|
|
||
| use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask}; | ||
|
|
||
| /// Convert f16 bits (as u16) to f32. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already https://github.com/libjxl/jxl-rs/blob/main/jxl/src/util/float16.rs that has conversion code.
I think we should use that type and code (perhaps by moving the code to the jxl_simd crate), instead of using u16.
SIMD fast paths for the
int_to_floatfunction which converts custom bit-depth floats stored as i32 back to f32.32-bit float: straightforward bitcast via SIMD.
16-bit float (f16): SIMD handles normal values, zeros, and inf/nan. Subnormals fall back to scalar since they need a variable-iteration normalization loop.
Waiting for perf CI to see the impact.