diff --git a/jxl/src/headers/bit_depth.rs b/jxl/src/headers/bit_depth.rs index 898e721e..e83217da 100644 --- a/jxl/src/headers/bit_depth.rs +++ b/jxl/src/headers/bit_depth.rs @@ -65,6 +65,14 @@ impl BitDepth { exponent_bits_per_sample: 8, } } + #[cfg(test)] + pub fn f16() -> BitDepth { + BitDepth { + floating_point_sample: true, + bits_per_sample: 16, + exponent_bits_per_sample: 5, + } + } pub fn bits_per_sample(&self) -> u32 { self.bits_per_sample } diff --git a/jxl/src/image/internal.rs b/jxl/src/image/internal.rs index 822842c9..0c77b4b3 100644 --- a/jxl/src/image/internal.rs +++ b/jxl/src/image/internal.rs @@ -165,8 +165,7 @@ impl RawImageBuffer { // invariant. let start = unsafe { self.buf.add(start) }; // SAFETY: due to the struct safety invariant, we know the entire slice is in a range of - // memory valid for writes. Moreover, the caller promises not to write uninitialized data - // in the returned slice. Finally, the caller guarantees aliasing rules will not be violated. + // memory valid for reads. The caller guarantees aliasing rules will not be violated. unsafe { std::slice::from_raw_parts(start, self.bytes_per_row) } } diff --git a/jxl/src/render/stages/convert.rs b/jxl/src/render/stages/convert.rs index 3ded74f2..1ddd46f5 100644 --- a/jxl/src/render/stages/convert.rs +++ b/jxl/src/render/stages/convert.rs @@ -8,7 +8,7 @@ use crate::{ headers::bit_depth::BitDepth, render::{Channels, ChannelsMut, RenderPipelineInOutStage}, }; -use jxl_simd::{F32SimdVec, simd_function}; +use jxl_simd::{F32SimdVec, I32SimdVec, simd_function}; pub struct ConvertU8F32Stage { channel: usize, @@ -135,20 +135,82 @@ impl std::fmt::Display for ConvertModularToF32Stage { } } +// SIMD 32-bit float passthrough (bitcast i32 to f32) +simd_function!( + int_to_float_32bit_simd_dispatch, + d: D, + fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::I32Vec::LEN; + + // Process complete SIMD vectors + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(xsize.div_ceil(simd_width)) + { + let val = D::I32Vec::load(d, in_chunk); + val.bitcast_to_f32().store(out_chunk); + } + } +); + +// SIMD 16-bit float (half-precision) to 32-bit float conversion +// Uses hardware F16C/NEON instructions when available via F32Vec::load_f16_bits() +simd_function!( + int_to_float_16bit_simd_dispatch, + d: D, + fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) { + let simd_width = D::F32Vec::LEN; + + // Temporary buffer for i32->u16 conversion via SIMD + // Note: Using constant 16 (max AVX-512 width) because D::F32Vec::LEN + // cannot be used as array size in Rust (const generics limitation) + const { assert!(D::F32Vec::LEN <= 16) } + let mut u16_buf = [0u16; 16]; + + // Process complete SIMD vectors + for (in_chunk, out_chunk) in input + .chunks_exact(simd_width) + .zip(output.chunks_exact_mut(simd_width)) + .take(xsize.div_ceil(simd_width)) + { + // Use SIMD to extract lower 16 bits from each i32 lane + let i32_vec = D::I32Vec::load(d, in_chunk); + i32_vec.store_u16(&mut u16_buf[..simd_width]); + // Use hardware f16->f32 conversion + let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]); + result.store(out_chunk); + } + } +); + // Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as // int back to binary32 float. -// TODO(sboukortt): SIMD -fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) { +fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth, xsize: usize) { assert_eq!(input.len(), output.len()); let bits = bit_depth.bits_per_sample(); let exp_bits = bit_depth.exponent_bits_per_sample(); - if bits == 32 { - assert_eq!(exp_bits, 8); - for (&in_val, out_val) in input.iter().zip(output) { - *out_val = f32::from_bits(in_val as u32); - } + + // Use SIMD fast paths for common formats + if bits == 32 && exp_bits == 8 { + // 32-bit float passthrough + int_to_float_32bit_simd_dispatch(input, output, xsize); return; } + + if bits == 16 && exp_bits == 5 { + // IEEE 754 half-precision (f16) - common HDR format + int_to_float_16bit_simd_dispatch(input, output, xsize); + return; + } + + // Generic scalar path for other custom float formats + int_to_float_generic(input, output, bits, exp_bits); +} + +// Generic scalar conversion for arbitrary bit-depth floats +// TODO: SIMD optimization for custom float formats +fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) { let exp_bias = (1 << (exp_bits - 1)) - 1; let sign_shift = bits - 1; let mant_bits = bits - exp_bits - 1; @@ -215,12 +277,9 @@ impl RenderPipelineInOutStage for ConvertModularToF32Stage { ) { let input = &input_rows[0]; if self.bit_depth.floating_point_sample() { - int_to_float( - &input[0][..xsize], - &mut output_rows[0][0][..xsize], - &self.bit_depth, - ); + int_to_float(input[0], output_rows[0][0], &self.bit_depth, xsize); } else { + // TODO(veluca): SIMDfy this code. let scale = 1.0 / ((1u64 << self.bit_depth.bits_per_sample()) - 1) as f32; for i in 0..xsize { output_rows[0][0][i] = input[0][i] as f32 * scale; @@ -419,6 +478,7 @@ impl RenderPipelineInOutStage for ConvertF32ToF16Stage { mod test { use super::*; use crate::error::Result; + use crate::headers::bit_depth::BitDepth; use test_log::test; #[test] @@ -467,4 +527,86 @@ mod test { 1, ) } + + #[test] + fn test_int_to_float_32bit() { + // Test 32-bit float passthrough + let bit_depth = BitDepth::f32(); + let test_values: Vec = vec![ + 0.0, + 1.0, + -1.0, + 0.5, + -0.5, + f32::INFINITY, + f32::NEG_INFINITY, + 1e-30, + 1e30, + ]; + let input: Vec = test_values + .iter() + .map(|&f| f.to_bits() as i32) + .chain(std::iter::repeat(0)) + .take(16) + .collect(); + let mut output = vec![0.0f32; 16]; + + int_to_float(&input, &mut output, &bit_depth, test_values.len()); + + for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() { + if expected.is_nan() { + assert!(actual.is_nan(), "index {}: expected NaN, got {}", i, actual); + } else { + assert_eq!(expected, actual, "index {}: mismatch", i); + } + } + } + + #[test] + fn test_int_to_float_16bit() { + // Test 16-bit float (f16) conversion for normal values + let bit_depth = BitDepth::f16(); + + // f16 format: 1 sign, 5 exp, 10 mantissa + // Test cases: (f16_bits, expected_f32) + let test_cases: Vec<(u16, f32)> = vec![ + (0x0000, 0.0), // +0 + (0x8000, -0.0), // -0 + (0x3C00, 1.0), // 1.0 + (0xBC00, -1.0), // -1.0 + (0x3800, 0.5), // 0.5 + (0x4000, 2.0), // 2.0 + (0x4400, 4.0), // 4.0 + (0x7BFF, 65504.0), // max normal f16 + (0x7C00, f32::INFINITY), // +inf + (0xFC00, f32::NEG_INFINITY), // -inf + (0x0001, 5.960_464_5e-8), // smallest positive subnormal + (0x03FF, 6.097_555e-5), // largest positive subnormal + (0x8001, -5.960_464_5e-8), // smallest negative subnormal + ]; + + let input: Vec = test_cases + .iter() + .map(|(bits, _)| *bits as i32) + .chain(std::iter::repeat(0)) + .take(16) + .collect(); + let mut output = vec![0.0f32; 16]; + + int_to_float(&input, &mut output, &bit_depth, test_cases.len()); + + for (i, (&(_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() { + assert!( + (expected - actual).abs() < 1e-6 + || expected == actual + || (expected.is_sign_negative() == actual.is_sign_negative() + && expected == 0.0 + && actual == 0.0), + "index {}: expected {}, got {}", + i, + expected, + actual + ); + } + } } diff --git a/jxl_simd/src/aarch64/neon.rs b/jxl_simd/src/aarch64/neon.rs index 66f20d15..c0d64993 100644 --- a/jxl_simd/src/aarch64/neon.rs +++ b/jxl_simd/src/aarch64/neon.rs @@ -441,6 +441,42 @@ unsafe impl F32SimdVec for F32VecNeon { vst1_u16(dest.as_mut_ptr(), u16s); } } + + fn store_f16_bits(this: F32VecNeon, dest: &mut [u16]) { + assert!(dest.len() >= F32VecNeon::LEN); + // Use inline asm because Rust stdarch incorrectly requires fp16 target feature + // for vcvt_f16_f32 (fixed in https://github.com/rust-lang/stdarch/pull/1978) + let f16_bits: uint16x4_t; + // SAFETY: NEON is available (guaranteed by descriptor), dest has enough space + unsafe { + std::arch::asm!( + "fcvtn {out:v}.4h, {inp:v}.4s", + inp = in(vreg) this.0, + out = out(vreg) f16_bits, + options(pure, nomem, nostack), + ); + vst1_u16(dest.as_mut_ptr(), f16_bits); + } + } + } + + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // Use inline asm because Rust stdarch incorrectly requires fp16 target feature + // for vcvt_f32_f16 (fixed in https://github.com/rust-lang/stdarch/pull/1978) + let result: float32x4_t; + // SAFETY: NEON is available (guaranteed by descriptor), mem has enough space + unsafe { + let f16_bits = vld1_u16(mem.as_ptr()); + std::arch::asm!( + "fcvtl {out:v}.4s, {inp:v}.4h", + inp = in(vreg) f16_bits, + out = out(vreg) result, + options(pure, nomem, nostack), + ); + } + F32VecNeon(result, d) } #[inline(always)] @@ -450,7 +486,8 @@ unsafe impl F32SimdVec for F32VecNeon { fn prepare_impl(table: &[f32; 8]) -> uint8x16_t { // Convert f32 table to BF16 packed in 128 bits (16 bytes for 8 entries) // BF16 is the high 16 bits of f32 - // SAFETY: neon is available from target_feature + // SAFETY: neon is available from target_feature, and `table` is large + // enough for the loads. let (table_lo, table_hi) = unsafe { (vld1q_f32(table.as_ptr()), vld1q_f32(table.as_ptr().add(4))) }; @@ -653,6 +690,18 @@ impl I32SimdVec for I32VecNeon { // SAFETY: We know neon is available from the safety invariant on `self.1`. unsafe { Self(vshrq_n_s32::(self.0), self.1) } } + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // SAFETY: We know neon is available from the safety invariant on `self.1`, + // and we just checked that `dest` has enough space. + unsafe { + // vmovn narrows i32 to i16 by taking the lower 16 bits + let narrowed = vmovn_s32(self.0); + vst1_u16(dest.as_mut_ptr(), vreinterpret_u16_s16(narrowed)); + } + } } impl Add for I32VecNeon { diff --git a/jxl_simd/src/float16.rs b/jxl_simd/src/float16.rs new file mode 100644 index 00000000..8fb07e0f --- /dev/null +++ b/jxl_simd/src/float16.rs @@ -0,0 +1,312 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//! IEEE 754 half-precision (binary16) floating-point type. +//! +//! This is a minimal implementation providing only the operations needed for JPEG XL decoding, +//! avoiding external dependencies like `half` which pulls in `zerocopy`. + +/// IEEE 754 binary16 half-precision floating-point type. +/// +/// Format: 1 sign bit, 5 exponent bits (bias 15), 10 mantissa bits. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct f16(u16); + +impl f16 { + /// Positive zero. + pub const ZERO: Self = Self(0); + + /// Creates an f16 from its raw bit representation. + #[inline] + pub const fn from_bits(bits: u16) -> Self { + Self(bits) + } + + /// Returns the raw bit representation. + #[inline] + pub const fn to_bits(self) -> u16 { + self.0 + } + + /// Converts to f32. + #[inline] + pub fn to_f32(self) -> f32 { + let bits = self.0; + let sign = ((bits >> 15) & 1) as u32; + let exp = ((bits >> 10) & 0x1F) as u32; + let mant = (bits & 0x3FF) as u32; + + let f32_bits = if exp == 0 { + if mant == 0 { + // Zero (signed) + sign << 31 + } else { + // Denormal f16 -> normalized f32 + // Find the leading 1 bit in mantissa + let mut m = mant; + let mut e = 0u32; + while (m & 0x400) == 0 { + m <<= 1; + e += 1; + } + m &= 0x3FF; // Remove the implicit leading 1 + // f16 denormal exponent is -14 (not -15), adjust by shift count + let new_exp = 127 - 14 - e; + (sign << 31) | (new_exp << 23) | (m << 13) + } + } else if exp == 31 { + // Infinity or NaN + if mant == 0 { + // Infinity + (sign << 31) | (0xFF << 23) + } else { + // NaN - preserve some payload bits, ensure quiet NaN + (sign << 31) | (0xFF << 23) | (mant << 13) | 0x0040_0000 + } + } else { + // Normal number + // Rebias: f16 uses bias 15, f32 uses bias 127 + // new_exp = exp - 15 + 127 = exp + 112 + let new_exp = exp + 112; + (sign << 31) | (new_exp << 23) | (mant << 13) + }; + + f32::from_bits(f32_bits) + } + + /// Creates an f16 from an f32. + #[inline] + pub fn from_f32(f: f32) -> Self { + let bits = f.to_bits(); + let sign = ((bits >> 31) & 1) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x007F_FFFF; + + let h_bits = if exp == 0 { + // Zero or f32 denormal -> f16 zero (too small) + sign << 15 + } else if exp == 255 { + // Infinity or NaN + if mant == 0 { + (sign << 15) | (0x1F << 10) // Infinity + } else { + (sign << 15) | (0x1F << 10) | 0x0200 // Quiet NaN + } + } else { + let unbiased = exp - 127; + + if unbiased < -24 { + // Too small, underflow to zero + sign << 15 + } else if unbiased < -14 { + // Denormal f16 + let shift = (-14 - unbiased) as u32; + let m = ((mant | 0x0080_0000) >> (shift + 14)) as u16; + (sign << 15) | m + } else if unbiased > 15 { + // Overflow to infinity + (sign << 15) | (0x1F << 10) + } else { + // Normal f16 + let h_exp = (unbiased + 15) as u16; + let h_mant = (mant >> 13) as u16; + + // Round to nearest, ties to even + let round_bit = (mant >> 12) & 1; + let sticky = mant & 0x0FFF; + let h_mant = if round_bit == 1 && (sticky != 0 || (h_mant & 1) == 1) { + h_mant + 1 + } else { + h_mant + }; + + // Handle mantissa overflow from rounding + if h_mant > 0x3FF { + if h_exp >= 30 { + // Overflow to infinity + (sign << 15) | (0x1F << 10) + } else { + (sign << 15) | ((h_exp + 1) << 10) + } + } else { + (sign << 15) | (h_exp << 10) | h_mant + } + } + }; + + Self(h_bits) + } + + /// Creates an f16 from an f64. + #[inline] + pub fn from_f64(f: f64) -> Self { + // Convert via f32 - sufficient precision for f16 + Self::from_f32(f as f32) + } + + /// Converts to f64. + #[inline] + pub fn to_f64(self) -> f64 { + self.to_f32() as f64 + } + + /// Returns true if this is neither infinite nor NaN. + #[inline] + pub fn is_finite(self) -> bool { + // Exponent of 31 means infinity or NaN + ((self.0 >> 10) & 0x1F) != 31 + } + + /// Returns the bytes in little-endian order. + #[inline] + pub const fn to_le_bytes(self) -> [u8; 2] { + self.0.to_le_bytes() + } + + /// Returns the bytes in big-endian order. + #[inline] + pub const fn to_be_bytes(self) -> [u8; 2] { + self.0.to_be_bytes() + } +} + +impl From for f32 { + #[inline] + fn from(f: f16) -> f32 { + f.to_f32() + } +} + +impl From for f64 { + #[inline] + fn from(f: f16) -> f64 { + f.to_f64() + } +} + +impl core::fmt::Debug for f16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl core::fmt::Display for f16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zero() { + let z = f16::ZERO; + assert_eq!(z.to_bits(), 0); + assert_eq!(z.to_f32(), 0.0); + assert!(z.is_finite()); + } + + #[test] + fn test_one() { + // 1.0 in f16: sign=0, exp=15 (biased), mant=0 -> 0x3C00 + let one = f16::from_bits(0x3C00); + assert!((one.to_f32() - 1.0).abs() < 1e-6); + assert!(one.is_finite()); + } + + #[test] + fn test_negative_one() { + // -1.0 in f16: sign=1, exp=15, mant=0 -> 0xBC00 + let neg_one = f16::from_bits(0xBC00); + assert!((neg_one.to_f32() - (-1.0)).abs() < 1e-6); + } + + #[test] + fn test_infinity() { + // +Inf: sign=0, exp=31, mant=0 -> 0x7C00 + let inf = f16::from_bits(0x7C00); + assert!(inf.to_f32().is_infinite()); + assert!(!inf.is_finite()); + + // -Inf: 0xFC00 + let neg_inf = f16::from_bits(0xFC00); + assert!(neg_inf.to_f32().is_infinite()); + assert!(!neg_inf.is_finite()); + } + + #[test] + fn test_nan() { + // NaN: exp=31, mant!=0 -> 0x7C01 (or any mant != 0) + let nan = f16::from_bits(0x7C01); + assert!(nan.to_f32().is_nan()); + assert!(!nan.is_finite()); + } + + #[test] + fn test_denormal() { + // Smallest positive denormal: 0x0001 + let tiny = f16::from_bits(0x0001); + let val = tiny.to_f32(); + assert!(val > 0.0); + assert!(val < 1e-6); + assert!(tiny.is_finite()); + } + + #[test] + fn test_roundtrip_normal() { + let test_values: [f32; 8] = [0.5, 1.0, 2.0, 100.0, 0.001, -0.5, -1.0, -100.0]; + for &v in &test_values { + let h = f16::from_f32(v); + let back = h.to_f32(); + // f16 has limited precision, allow ~0.1% error for normal values + let rel_err = ((v - back) / v).abs(); + assert!( + rel_err < 0.002, + "Roundtrip failed for {}: got {}, rel_err {}", + v, + back, + rel_err + ); + } + } + + #[test] + fn test_roundtrip_special() { + // Zero + assert_eq!(f16::from_f32(0.0).to_f32(), 0.0); + + // Infinity + assert!(f16::from_f32(f32::INFINITY).to_f32().is_infinite()); + assert!(f16::from_f32(f32::NEG_INFINITY).to_f32().is_infinite()); + + // NaN + assert!(f16::from_f32(f32::NAN).to_f32().is_nan()); + } + + #[test] + fn test_overflow_to_infinity() { + // f16 max is ~65504, values above should overflow to infinity + let big = f16::from_f32(100000.0); + assert!(big.to_f32().is_infinite()); + } + + #[test] + fn test_underflow_to_zero() { + // Very small values should underflow to zero + let tiny = f16::from_f32(1e-10); + assert_eq!(tiny.to_f32(), 0.0); + } + + #[test] + fn test_bytes() { + let h = f16::from_bits(0x1234); + assert_eq!(h.to_le_bytes(), [0x34, 0x12]); + assert_eq!(h.to_be_bytes(), [0x12, 0x34]); + } +} diff --git a/jxl_simd/src/lib.rs b/jxl_simd/src/lib.rs index c9aa8fb9..4f06dbdd 100644 --- a/jxl_simd/src/lib.rs +++ b/jxl_simd/src/lib.rs @@ -20,7 +20,10 @@ mod x86_64; #[cfg(target_arch = "aarch64")] mod aarch64; -mod scalar; +pub mod float16; +pub mod scalar; + +pub use float16::f16; #[cfg(all(target_arch = "x86_64", feature = "avx"))] pub use x86_64::avx::AvxDescriptor; @@ -270,6 +273,16 @@ pub unsafe trait F32SimdVec: /// Transposes the Self::LEN x Self::LEN matrix formed by array elements /// `data[stride * i]` for i = 0..Self::LEN. fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize); + + /// Loads f16 values (stored as u16 bit patterns) and converts them to f32. + /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM). + /// Requires `mem.len() >= Self::LEN` or it will panic. + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self; + + /// Converts f32 values to f16 and stores as u16 bit patterns. + /// Uses hardware conversion instructions when available (F16C on x86, NEON fp16 on ARM). + /// Requires `dest.len() >= Self::LEN` or it will panic. + fn store_f16_bits(self, dest: &mut [u16]); } pub trait I32SimdVec: @@ -327,6 +340,10 @@ pub trait I32SimdVec: fn shr(self) -> Self; fn mul_wide_take_high(self, rhs: Self) -> Self; + + /// Stores the lower 16 bits of each i32 lane as u16 values. + /// Requires `dest.len() >= Self::LEN` or it will panic. + fn store_u16(self, dest: &mut [u16]); } pub trait U32SimdVec: Sized + Copy + Debug + Send + Sync { @@ -1162,4 +1179,40 @@ mod test { } } test_all_instruction_sets!(test_i32_mul_all_elements); + + fn test_store_u16(d: D) { + let data = [ + 0xbabau32 as i32, + 0x1234u32 as i32, + 0xdeadbabau32 as i32, + 0xdead1234u32 as i32, + 0x1111babau32 as i32, + 0x11111234u32 as i32, + 0x76543210u32 as i32, + 0x01234567u32 as i32, + 0x00000000u32 as i32, + 0xffffffffu32 as i32, + 0x23949289u32 as i32, + 0xf9371913u32 as i32, + 0xdeadbeefu32 as i32, + 0xbeefdeadu32 as i32, + 0xaaaaaaaau32 as i32, + 0xbbbbbbbbu32 as i32, + ]; + let mut output = [0u16; 16]; + for i in (0..16).step_by(D::I32Vec::LEN) { + let vec = D::I32Vec::load(d, &data[i..]); + vec.store_u16(&mut output[i..]); + } + + for i in 0..16 { + let expected = data[i] as u16; + assert_eq!( + output[i], expected, + "store_u16 failed at index {}: expected {}, got {}", + i, expected, output[i] + ); + } + } + test_all_instruction_sets!(test_store_u16); } diff --git a/jxl_simd/src/scalar.rs b/jxl_simd/src/scalar.rs index 9667e7e3..f0444c34 100644 --- a/jxl_simd/src/scalar.rs +++ b/jxl_simd/src/scalar.rs @@ -6,7 +6,7 @@ use std::mem::MaybeUninit; use std::num::Wrapping; -use crate::{U32SimdVec, impl_f32_array_interface}; +use crate::{U32SimdVec, f16, impl_f32_array_interface}; use super::{F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask}; @@ -213,6 +213,16 @@ unsafe impl F32SimdVec for f32 { dest[0] = self.round() as u16; } + #[inline(always)] + fn load_f16_bits(_d: Self::Descriptor, mem: &[u16]) -> Self { + f16::from_bits(mem[0]).to_f32() + } + + #[inline(always)] + fn store_f16_bits(self, dest: &mut [u16]) { + dest[0] = f16::from_f32(self).to_bits(); + } + impl_f32_array_interface!(); #[inline(always)] @@ -295,6 +305,11 @@ impl I32SimdVec for Wrapping { fn mul_wide_take_high(self, rhs: Self) -> Self { Wrapping(((self.0 as i64 * rhs.0 as i64) >> 32) as i32) } + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + dest[0] = self.0 as u16; + } } impl U32SimdVec for Wrapping { diff --git a/jxl_simd/src/x86_64/avx.rs b/jxl_simd/src/x86_64/avx.rs index 6de9f53a..0da8ec9f 100644 --- a/jxl_simd/src/x86_64/avx.rs +++ b/jxl_simd/src/x86_64/avx.rs @@ -96,13 +96,13 @@ fn transpose_8x8_core( (c0, c1, c2, c3, c4, c5, c6, c7) } -// Safety invariant: this type is only ever constructed if avx2 and fma are available. +// Safety invariant: this type is only ever constructed if avx2, fma, and f16c are available. #[derive(Clone, Copy, Debug)] pub struct AvxDescriptor(()); impl AvxDescriptor { /// # Safety - /// The caller must guarantee that the "avx2" and "fma" target features are available. + /// The caller must guarantee that the "avx2", "fma", and "f16c" target features are available. pub unsafe fn new_unchecked() -> Self { Self(()) } @@ -139,8 +139,11 @@ impl SimdDescriptor for AvxDescriptor { } fn new() -> Option { - if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") { - // SAFETY: we just checked avx2 and fma. + if is_x86_feature_detected!("avx2") + && is_x86_feature_detected!("fma") + && is_x86_feature_detected!("f16c") + { + // SAFETY: we just checked avx2, fma, and f16c. Some(unsafe { Self::new_unchecked() }) } else { None @@ -148,12 +151,12 @@ impl SimdDescriptor for AvxDescriptor { } fn call(self, f: impl FnOnce(Self) -> R) -> R { - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] #[inline(never)] unsafe fn inner(d: AvxDescriptor, f: impl FnOnce(AvxDescriptor) -> R) -> R { f(d) } - // SAFETY: the safety invariant on `self` guarantees avx2 and fma. + // SAFETY: the safety invariant on `self` guarantees avx2, fma, and f16c. unsafe { inner(self, f) } } } @@ -165,12 +168,12 @@ macro_rules! fn_avx { fn $name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )? $body: block) => { #[inline(always)] fn $name(self: $self_ty, $($arg: $ty),*) $(-> $ret)? { - #[target_feature(enable = "fma,avx2")] + #[target_feature(enable = "fma,avx2,f16c")] #[inline] fn inner($this: $self_ty, $($arg: $ty),*) $(-> $ret)? { $body } - // SAFETY: `self.1` is constructed iff avx2 and fma are available. + // SAFETY: `self.1` is constructed iff avx2, fma, and f16c are available. unsafe { inner(self, $($arg),*) } } }; @@ -604,7 +607,8 @@ unsafe impl F32SimdVec for F32VecAvx { #[inline(always)] fn prepare_table_bf16_8(_d: AvxDescriptor, table: &[f32; 8]) -> Bf16Table8Avx { // For AVX2, vpermps is exact and fast, so we just load the table as-is - // SAFETY: avx2 is available from the safety invariant on the descriptor + // SAFETY: avx2 is available from the safety invariant on the descriptor, + // and `table` has 8 elements, exactly as many as we load. Bf16Table8Avx(unsafe { _mm256_loadu_ps(table.as_ptr()) }) } @@ -668,6 +672,34 @@ unsafe impl F32SimdVec for F32VecAvx { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + #[target_feature(enable = "avx2,f16c")] + #[inline] + fn load_f16_impl(d: AvxDescriptor, mem: &[u16]) -> F32VecAvx { + assert!(mem.len() >= F32VecAvx::LEN); + // SAFETY: mem.len() >= 8 is checked above + let bits = unsafe { _mm_loadu_si128(mem.as_ptr() as *const __m128i) }; + F32VecAvx(_mm256_cvtph_ps(bits), d) + } + // SAFETY: avx2 and f16c are available from the safety invariant on the descriptor + unsafe { load_f16_impl(d, mem) } + } + + #[inline(always)] + fn store_f16_bits(self, dest: &mut [u16]) { + #[target_feature(enable = "avx2,f16c")] + #[inline] + fn store_f16_bits_impl(v: __m256, dest: &mut [u16]) { + assert!(dest.len() >= F32VecAvx::LEN); + let bits = _mm256_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); + // SAFETY: dest.len() >= 8 is checked above + unsafe { _mm_storeu_si128(dest.as_mut_ptr() as *mut __m128i, bits) }; + } + // SAFETY: avx2 and f16c are available from the safety invariant on the descriptor + unsafe { store_f16_bits_impl(self.0, dest) } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "avx2")] @@ -846,6 +878,29 @@ impl I32SimdVec for I32VecAvx { let p1 = _mm256_unpackhi_epi32(l, h); I32VecAvx(_mm256_unpackhi_epi64(p0, p1), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + #[target_feature(enable = "avx2")] + #[inline] + fn store_u16_impl(v: __m256i, dest: &mut [u16]) { + assert!(dest.len() >= I32VecAvx::LEN); + let tmp = _mm256_shuffle_epi8( + v, + _mm256_setr_epi8( + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, + ), + ); + let tmp = _mm256_permute4x64_epi64(tmp, 0xD8); + // SAFETY: we just checked that `dest` has enough space. + unsafe { + _mm_storeu_si128(dest.as_mut_ptr().cast(), _mm256_extracti128_si256::<0>(tmp)) + }; + } + // SAFETY: avx2 is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecAvx { diff --git a/jxl_simd/src/x86_64/avx512.rs b/jxl_simd/src/x86_64/avx512.rs index e64c5f08..89086c50 100644 --- a/jxl_simd/src/x86_64/avx512.rs +++ b/jxl_simd/src/x86_64/avx512.rs @@ -665,7 +665,8 @@ unsafe impl F32SimdVec for F32VecAvx512 { #[target_feature(enable = "avx512f")] #[inline] fn prepare_impl(table: &[f32; 8]) -> __m512 { - // SAFETY: avx512f is available from target_feature + // SAFETY: avx512f is available from target_feature, and we load 8 elements, + // exactly as many as are present in `table`. let table_256 = unsafe { _mm256_loadu_ps(table.as_ptr()) }; // Zero-extend to 512-bit; vpermutexvar with indices 0-7 only reads first 256 bits _mm512_castps256_ps512(table_256) @@ -730,6 +731,36 @@ unsafe impl F32SimdVec for F32VecAvx512 { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + // AVX512 implies F16C, so we can always use hardware conversion + #[target_feature(enable = "avx512f")] + #[inline] + fn load_f16_impl(d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 { + assert!(mem.len() >= F32VecAvx512::LEN); + // SAFETY: mem.len() >= 16 is checked above + let bits = unsafe { _mm256_loadu_si256(mem.as_ptr() as *const __m256i) }; + F32VecAvx512(_mm512_cvtph_ps(bits), d) + } + // SAFETY: avx512f is available from the safety invariant on the descriptor + unsafe { load_f16_impl(d, mem) } + } + + #[inline(always)] + fn store_f16_bits(self, dest: &mut [u16]) { + // AVX512 implies F16C, so we can always use hardware conversion + #[target_feature(enable = "avx512f")] + #[inline] + fn store_f16_bits_impl(v: __m512, dest: &mut [u16]) { + assert!(dest.len() >= F32VecAvx512::LEN); + let bits = _mm512_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v); + // SAFETY: dest.len() >= 16 is checked above + unsafe { _mm256_storeu_si256(dest.as_mut_ptr() as *mut __m256i, bits) }; + } + // SAFETY: avx512f is available from the safety invariant on the descriptor + unsafe { store_f16_bits_impl(self.0, dest) } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "avx512f")] @@ -1025,6 +1056,20 @@ impl I32SimdVec for I32VecAvx512 { let idx = _mm512_setr_epi32(1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31); I32VecAvx512(_mm512_permutex2var_epi32(l, idx, h), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + #[target_feature(enable = "avx512f")] + #[inline] + fn store_u16_impl(v: __m512i, dest: &mut [u16]) { + assert!(dest.len() >= I32VecAvx512::LEN); + let tmp = _mm512_cvtepi32_epi16(v); + // SAFETY: We just checked `dst` has enough space. + unsafe { _mm256_storeu_epi32(dest.as_mut_ptr().cast(), tmp) }; + } + // SAFETY: avx512f is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecAvx512 { diff --git a/jxl_simd/src/x86_64/mod.rs b/jxl_simd/src/x86_64/mod.rs index 939dd24a..1a5463b1 100644 --- a/jxl_simd/src/x86_64/mod.rs +++ b/jxl_simd/src/x86_64/mod.rs @@ -62,16 +62,16 @@ macro_rules! simd_function_body_sse42 { #[macro_export] macro_rules! simd_function_body_avx { ($name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )?; ($($val:expr),* $(,)?)) => { - if cfg!(all(target_feature = "avx2", target_feature = "fma")) { - // SAFETY: we just checked for avx2 and fma. + if cfg!(all(target_feature = "avx2", target_feature = "fma", target_feature = "f16c")) { + // SAFETY: we just checked for avx2, fma and f16c. let d = unsafe { $crate::AvxDescriptor::new_unchecked() }; return $name(d, $($val),*); } else if let Some(d) = $crate::AvxDescriptor::new() { - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] fn avx(d: $crate::AvxDescriptor, $($arg: $ty),*) $(-> $ret)? { $name(d, $($val),*) } - // SAFETY: we just checked for avx2 and fma. + // SAFETY: we just checked for avx2, fma and f16c. return unsafe { avx(d, $($arg),*) }; } }; @@ -170,11 +170,11 @@ macro_rules! test_avx { fn [<$name _avx>]() { use $crate::SimdDescriptor; let Some(d) = $crate::AvxDescriptor::new() else { return; }; - #[target_feature(enable = "avx2,fma")] + #[target_feature(enable = "avx2,fma,f16c")] fn inner(d: $crate::AvxDescriptor) { $name(d) } - // SAFETY: we just checked for avx2 and fma. + // SAFETY: we just checked for avx2, fma and f16c. return unsafe { inner(d) }; } } diff --git a/jxl_simd/src/x86_64/sse42.rs b/jxl_simd/src/x86_64/sse42.rs index 7fef8c92..b4021570 100644 --- a/jxl_simd/src/x86_64/sse42.rs +++ b/jxl_simd/src/x86_64/sse42.rs @@ -609,6 +609,28 @@ unsafe impl F32SimdVec for F32VecSse42 { impl_f32_array_interface!(); + #[inline(always)] + fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self { + assert!(mem.len() >= Self::LEN); + // SSE4.2 doesn't have F16C, use scalar conversion + let mut result = [0.0f32; 4]; + for i in 0..4 { + result[i] = crate::f16::from_bits(mem[i]).to_f32(); + } + Self::load(d, &result) + } + + #[inline(always)] + fn store_f16_bits(self, dest: &mut [u16]) { + assert!(dest.len() >= Self::LEN); + // SSE4.2 doesn't have F16C, use scalar conversion + let mut tmp = [0.0f32; 4]; + self.store(&mut tmp); + for i in 0..4 { + dest[i] = crate::f16::from_f32(tmp[i]).to_bits(); + } + } + #[inline(always)] fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) { #[target_feature(enable = "sse4.2")] @@ -789,6 +811,26 @@ impl I32SimdVec for I32VecSse42 { let p1 = _mm_unpackhi_epi32(l, h); I32VecSse42(_mm_unpackhi_epi64(p0, p1), this.1) }); + + #[inline(always)] + fn store_u16(self, dest: &mut [u16]) { + // Pack i32 to i16 with signed saturation, then store lower 64 bits + // _mm_packs_epi32 saturates i32 to i16, which preserves low 16 bits for values in range + #[target_feature(enable = "sse4.2")] + #[inline] + fn store_u16_impl(v: __m128i, dest: &mut [u16]) { + assert!(dest.len() >= I32VecSse42::LEN); + // Use scalar loop since _mm_packs_epi32 would saturate incorrectly for unsigned values + let mut tmp = [0i32; 4]; + // SAFETY: tmp has 4 elements, matching LEN + unsafe { _mm_storeu_si128(tmp.as_mut_ptr() as *mut __m128i, v) }; + for i in 0..4 { + dest[i] = tmp[i] as u16; + } + } + // SAFETY: sse4.2 is available from the safety invariant on the descriptor. + unsafe { store_u16_impl(self.0, dest) } + } } impl Add for I32VecSse42 {