Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions jxl/src/headers/bit_depth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ impl BitDepth {
exponent_bits_per_sample: 8,
}
}
#[cfg(test)]
pub fn f16() -> BitDepth {
BitDepth {
floating_point_sample: true,
bits_per_sample: 16,
exponent_bits_per_sample: 5,
}
}
pub fn bits_per_sample(&self) -> u32 {
self.bits_per_sample
}
Expand Down
3 changes: 1 addition & 2 deletions jxl/src/image/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ impl RawImageBuffer {
// invariant.
let start = unsafe { self.buf.add(start) };
// SAFETY: due to the struct safety invariant, we know the entire slice is in a range of
// memory valid for writes. Moreover, the caller promises not to write uninitialized data
// in the returned slice. Finally, the caller guarantees aliasing rules will not be violated.
// memory valid for reads. The caller guarantees aliasing rules will not be violated.
unsafe { std::slice::from_raw_parts(start, self.bytes_per_row) }
}

Expand Down
168 changes: 155 additions & 13 deletions jxl/src/render/stages/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
headers::bit_depth::BitDepth,
render::{Channels, ChannelsMut, RenderPipelineInOutStage},
};
use jxl_simd::{F32SimdVec, simd_function};
use jxl_simd::{F32SimdVec, I32SimdVec, simd_function};

pub struct ConvertU8F32Stage {
channel: usize,
Expand Down Expand Up @@ -135,20 +135,82 @@ impl std::fmt::Display for ConvertModularToF32Stage {
}
}

// SIMD 32-bit float passthrough (bitcast i32 to f32)
simd_function!(
int_to_float_32bit_simd_dispatch,
d: D,
fn int_to_float_32bit_simd(input: &[i32], output: &mut [f32], xsize: usize) {
let simd_width = D::I32Vec::LEN;

// Process complete SIMD vectors
for (in_chunk, out_chunk) in input
.chunks_exact(simd_width)
.zip(output.chunks_exact_mut(simd_width))
.take(xsize.div_ceil(simd_width))
{
let val = D::I32Vec::load(d, in_chunk);
val.bitcast_to_f32().store(out_chunk);
}
}
);

// SIMD 16-bit float (half-precision) to 32-bit float conversion
// Uses hardware F16C/NEON instructions when available via F32Vec::load_f16_bits()
simd_function!(
int_to_float_16bit_simd_dispatch,
d: D,
fn int_to_float_16bit_simd(input: &[i32], output: &mut [f32], xsize: usize) {
let simd_width = D::F32Vec::LEN;

// Temporary buffer for i32->u16 conversion via SIMD
// Note: Using constant 16 (max AVX-512 width) because D::F32Vec::LEN
// cannot be used as array size in Rust (const generics limitation)
const { assert!(D::F32Vec::LEN <= 16) }
let mut u16_buf = [0u16; 16];

// Process complete SIMD vectors
for (in_chunk, out_chunk) in input
.chunks_exact(simd_width)
.zip(output.chunks_exact_mut(simd_width))
.take(xsize.div_ceil(simd_width))
{
// Use SIMD to extract lower 16 bits from each i32 lane
let i32_vec = D::I32Vec::load(d, in_chunk);
i32_vec.store_u16(&mut u16_buf[..simd_width]);
// Use hardware f16->f32 conversion
let result = D::F32Vec::load_f16_bits(d, &u16_buf[..simd_width]);
result.store(out_chunk);
}
}
);

// Converts custom [bits]-bit float (with [exp_bits] exponent bits) stored as
// int back to binary32 float.
// TODO(sboukortt): SIMD
fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth) {
fn int_to_float(input: &[i32], output: &mut [f32], bit_depth: &BitDepth, xsize: usize) {
assert_eq!(input.len(), output.len());
let bits = bit_depth.bits_per_sample();
let exp_bits = bit_depth.exponent_bits_per_sample();
if bits == 32 {
assert_eq!(exp_bits, 8);
for (&in_val, out_val) in input.iter().zip(output) {
*out_val = f32::from_bits(in_val as u32);
}

// Use SIMD fast paths for common formats
if bits == 32 && exp_bits == 8 {
// 32-bit float passthrough
int_to_float_32bit_simd_dispatch(input, output, xsize);
return;
}

if bits == 16 && exp_bits == 5 {
// IEEE 754 half-precision (f16) - common HDR format
int_to_float_16bit_simd_dispatch(input, output, xsize);
return;
}

// Generic scalar path for other custom float formats
int_to_float_generic(input, output, bits, exp_bits);
}

// Generic scalar conversion for arbitrary bit-depth floats
// TODO: SIMD optimization for custom float formats
fn int_to_float_generic(input: &[i32], output: &mut [f32], bits: u32, exp_bits: u32) {
let exp_bias = (1 << (exp_bits - 1)) - 1;
let sign_shift = bits - 1;
let mant_bits = bits - exp_bits - 1;
Expand Down Expand Up @@ -215,12 +277,9 @@ impl RenderPipelineInOutStage for ConvertModularToF32Stage {
) {
let input = &input_rows[0];
if self.bit_depth.floating_point_sample() {
int_to_float(
&input[0][..xsize],
&mut output_rows[0][0][..xsize],
&self.bit_depth,
);
int_to_float(input[0], output_rows[0][0], &self.bit_depth, xsize);
} else {
// TODO(veluca): SIMDfy this code.
let scale = 1.0 / ((1u64 << self.bit_depth.bits_per_sample()) - 1) as f32;
for i in 0..xsize {
output_rows[0][0][i] = input[0][i] as f32 * scale;
Expand Down Expand Up @@ -419,6 +478,7 @@ impl RenderPipelineInOutStage for ConvertF32ToF16Stage {
mod test {
use super::*;
use crate::error::Result;
use crate::headers::bit_depth::BitDepth;
use test_log::test;

#[test]
Expand Down Expand Up @@ -467,4 +527,86 @@ mod test {
1,
)
}

#[test]
fn test_int_to_float_32bit() {
// Test 32-bit float passthrough
let bit_depth = BitDepth::f32();
let test_values: Vec<f32> = vec![
0.0,
1.0,
-1.0,
0.5,
-0.5,
f32::INFINITY,
f32::NEG_INFINITY,
1e-30,
1e30,
];
let input: Vec<i32> = test_values
.iter()
.map(|&f| f.to_bits() as i32)
.chain(std::iter::repeat(0))
.take(16)
.collect();
let mut output = vec![0.0f32; 16];

int_to_float(&input, &mut output, &bit_depth, test_values.len());

for (i, (&expected, &actual)) in test_values.iter().zip(output.iter()).enumerate() {
if expected.is_nan() {
assert!(actual.is_nan(), "index {}: expected NaN, got {}", i, actual);
} else {
assert_eq!(expected, actual, "index {}: mismatch", i);
}
}
}

#[test]
fn test_int_to_float_16bit() {
// Test 16-bit float (f16) conversion for normal values
let bit_depth = BitDepth::f16();

// f16 format: 1 sign, 5 exp, 10 mantissa
// Test cases: (f16_bits, expected_f32)
let test_cases: Vec<(u16, f32)> = vec![
(0x0000, 0.0), // +0
(0x8000, -0.0), // -0
(0x3C00, 1.0), // 1.0
(0xBC00, -1.0), // -1.0
(0x3800, 0.5), // 0.5
(0x4000, 2.0), // 2.0
(0x4400, 4.0), // 4.0
(0x7BFF, 65504.0), // max normal f16
(0x7C00, f32::INFINITY), // +inf
(0xFC00, f32::NEG_INFINITY), // -inf
(0x0001, 5.960_464_5e-8), // smallest positive subnormal
(0x03FF, 6.097_555e-5), // largest positive subnormal
(0x8001, -5.960_464_5e-8), // smallest negative subnormal
];

let input: Vec<i32> = test_cases
.iter()
.map(|(bits, _)| *bits as i32)
.chain(std::iter::repeat(0))
.take(16)
.collect();
let mut output = vec![0.0f32; 16];

int_to_float(&input, &mut output, &bit_depth, test_cases.len());

for (i, (&(_, expected), &actual)) in test_cases.iter().zip(output.iter()).enumerate() {
assert!(
(expected - actual).abs() < 1e-6
|| expected == actual
|| (expected.is_sign_negative() == actual.is_sign_negative()
&& expected == 0.0
&& actual == 0.0),
"index {}: expected {}, got {}",
i,
expected,
actual
);
}
}
}
51 changes: 50 additions & 1 deletion jxl_simd/src/aarch64/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,42 @@ unsafe impl F32SimdVec for F32VecNeon {
vst1_u16(dest.as_mut_ptr(), u16s);
}
}

fn store_f16_bits(this: F32VecNeon, dest: &mut [u16]) {
assert!(dest.len() >= F32VecNeon::LEN);
// Use inline asm because Rust stdarch incorrectly requires fp16 target feature
// for vcvt_f16_f32 (fixed in https://github.com/rust-lang/stdarch/pull/1978)
let f16_bits: uint16x4_t;
// SAFETY: NEON is available (guaranteed by descriptor), dest has enough space
unsafe {
std::arch::asm!(
"fcvtn {out:v}.4h, {inp:v}.4s",
inp = in(vreg) this.0,
out = out(vreg) f16_bits,
options(pure, nomem, nostack),
);
vst1_u16(dest.as_mut_ptr(), f16_bits);
}
}
}

#[inline(always)]
fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self {
assert!(mem.len() >= Self::LEN);
// Use inline asm because Rust stdarch incorrectly requires fp16 target feature
// for vcvt_f32_f16 (fixed in https://github.com/rust-lang/stdarch/pull/1978)
let result: float32x4_t;
// SAFETY: NEON is available (guaranteed by descriptor), mem has enough space
unsafe {
let f16_bits = vld1_u16(mem.as_ptr());
std::arch::asm!(
"fcvtl {out:v}.4s, {inp:v}.4h",
inp = in(vreg) f16_bits,
out = out(vreg) result,
options(pure, nomem, nostack),
);
}
F32VecNeon(result, d)
}

#[inline(always)]
Expand All @@ -450,7 +486,8 @@ unsafe impl F32SimdVec for F32VecNeon {
fn prepare_impl(table: &[f32; 8]) -> uint8x16_t {
// Convert f32 table to BF16 packed in 128 bits (16 bytes for 8 entries)
// BF16 is the high 16 bits of f32
// SAFETY: neon is available from target_feature
// SAFETY: neon is available from target_feature, and `table` is large
// enough for the loads.
let (table_lo, table_hi) =
unsafe { (vld1q_f32(table.as_ptr()), vld1q_f32(table.as_ptr().add(4))) };

Expand Down Expand Up @@ -653,6 +690,18 @@ impl I32SimdVec for I32VecNeon {
// SAFETY: We know neon is available from the safety invariant on `self.1`.
unsafe { Self(vshrq_n_s32::<AMOUNT_I>(self.0), self.1) }
}

#[inline(always)]
fn store_u16(self, dest: &mut [u16]) {
assert!(dest.len() >= Self::LEN);
// SAFETY: We know neon is available from the safety invariant on `self.1`,
// and we just checked that `dest` has enough space.
unsafe {
// vmovn narrows i32 to i16 by taking the lower 16 bits
let narrowed = vmovn_s32(self.0);
vst1_u16(dest.as_mut_ptr(), vreinterpret_u16_s16(narrowed));
}
}
}

impl Add<I32VecNeon> for I32VecNeon {
Expand Down
Loading
Loading